鸢尾花分类

创建日期:2024-06-21
更新日期:2025-02-01
'''
鸢尾花分类
'''
import torch
import pandas as pd

train_data = pd.read_csv('./csv/iris_training.csv').to_numpy()
test_data = pd.read_csv('./csv/iris_test.csv').to_numpy()
x_train = train_data[..., :-1]
y_train = train_data[..., -1]
x_test = test_data[..., :-1]
y_test = test_data[..., -1]

x = torch.FloatTensor(x_train)
y = torch.FloatTensor(y_train)

sz = 4
weights = torch.randn((1, sz), requires_grad=True)
bias = torch.randn((sz), requires_grad=True)
weights2 = torch.randn((sz, 1), requires_grad=True)

learning_rate = 0.001
losses = []

for i in range(1000):
    mul = x * weights
    hidden = mul + bias
    hidden = torch.sigmoid(hidden)
    predictions = hidden.mm(weights2)
    loss = torch.mean((predictions - y) ** 2)
    losses.append(loss.data.numpy())
    if i % 100 == 0:
        print('loss:', loss)

    loss.backward()
    weights.data.add_(-learning_rate * weights.grad.data)
    bias.data.add_(-learning_rate * bias.grad.data)
    weights2.data.add_(-learning_rate * weights2.grad.data)
    weights.grad.zero_()
    bias.grad.zero_()
    weights2.grad.zero_()

shape = x_test.shape
print(x_test.shape)
print(y_test.shape)

x = torch.FloatTensor(x_test)
y = torch.FloatTensor(y_test)
hidden = x * weights + bias
hidden = torch.sigmoid(hidden)
predictions = hidden.mm(weights2)
loss = torch.mean((predictions - y) ** 2)
print('loss:', loss)