'''
鸢尾花分类
'''
import torch
import pandas as pd
train_data = pd.read_csv('./csv/iris_training.csv').to_numpy()
test_data = pd.read_csv('./csv/iris_test.csv').to_numpy()
x_train = train_data[..., :-1]
y_train = train_data[..., -1]
x_test = test_data[..., :-1]
y_test = test_data[..., -1]
x = torch.FloatTensor(x_train)
y = torch.FloatTensor(y_train)
sz = 4
weights = torch.randn((1, sz), requires_grad=True)
bias = torch.randn((sz), requires_grad=True)
weights2 = torch.randn((sz, 1), requires_grad=True)
learning_rate = 0.001
losses = []
for i in range(1000):
mul = x * weights
hidden = mul + bias
hidden = torch.sigmoid(hidden)
predictions = hidden.mm(weights2)
loss = torch.mean((predictions - y) ** 2)
losses.append(loss.data.numpy())
if i % 100 == 0:
print('loss:', loss)
loss.backward()
weights.data.add_(-learning_rate * weights.grad.data)
bias.data.add_(-learning_rate * bias.grad.data)
weights2.data.add_(-learning_rate * weights2.grad.data)
weights.grad.zero_()
bias.grad.zero_()
weights2.grad.zero_()
shape = x_test.shape
print(x_test.shape)
print(y_test.shape)
x = torch.FloatTensor(x_test)
y = torch.FloatTensor(y_test)
hidden = x * weights + bias
hidden = torch.sigmoid(hidden)
predictions = hidden.mm(weights2)
loss = torch.mean((predictions - y) ** 2)
print('loss:', loss)
鸢尾花分类
创建日期:2024-06-21
更新日期:2025-02-01
简介
一个来自三线小城市的程序员开发经验总结。