ML_Sample/main.py

66 lines
2.4 KiB
Python
Raw Permalink Normal View History

2024-09-02 13:42:45 +08:00
import torch
from PIL.Image import Image
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1: torch.nn.Linear = torch.nn.Linear(28 * 28, 64)
self.fc2: torch.nn.Linear = torch.nn.Linear(64, 64)
self.fc3: torch.nn.Linear = torch.nn.Linear(64, 64)
self.fc4: torch.nn.Linear = torch.nn.Linear(64, 10)
def forward(self, x: torch.Tensor):
2024-09-03 11:29:23 +08:00
x = torch.nn.functional.rrelu(self.fc1(x))
x = torch.nn.functional.rrelu(self.fc2(x))
x = torch.nn.functional.rrelu(self.fc3(x))
2024-09-02 13:42:45 +08:00
x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
return x
def get_data_loader(is_train: bool) -> DataLoader[Image]:
to_tensor = transforms.Compose([transforms.ToTensor()])
data_set = MNIST("", is_train, transform=to_tensor, download=True)
return DataLoader(data_set, batch_size=15, shuffle=True)
def evaluate(test_data: DataLoader[Image], net: Net) -> float:
n_correct: int = 0
n_total: int = 0
with torch.no_grad():
for (x, y) in test_data:
outputs = net.forward(x.view(-1, 28 * 28))
for i, output in enumerate(outputs):
if torch.argmax(output) == y[i]:
n_correct += 1
n_total += 1
return n_correct / n_total
def main():
print("=== START ===")
torch.set_default_device(torch.device(type='cpu', index=0))
2024-09-03 11:29:23 +08:00
torch.set_default_dtype(torch.float16)
2024-09-02 13:42:45 +08:00
train_data = get_data_loader(is_train=True)
2024-09-03 11:29:23 +08:00
print('train data size:',len(train_data))
2024-09-02 13:42:45 +08:00
test_data = get_data_loader(is_train=False)
2024-09-03 11:29:23 +08:00
print('test data size: ', len(test_data))
2024-09-02 13:42:45 +08:00
net = Net()
print('initial accuracy:', evaluate(test_data, net))
optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # type: ignore
2024-09-03 11:29:23 +08:00
for epoch in range(1, 3, 1):
2024-09-02 13:42:45 +08:00
for (x, y) in train_data:
net.zero_grad()
output = net.forward(x.view(-1, 28 * 28))
loss = torch.nn.functional.nll_loss(output, y)
loss.backward()
optimizer.step()
print('epoch', epoch, 'accuracy', evaluate(test_data, net))
2024-09-03 11:29:23 +08:00
model_name = 'model.pth'
print('saving model to: ', model_name)
torch.save(net.state_dict(), model_name)
2024-09-02 13:42:45 +08:00
print('=== DONE ===')
if __name__ == '__main__':
2024-09-03 11:29:23 +08:00
main()