import torch from PIL.Image import Image from torchvision import transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader import matplotlib.pyplot as plt class Net(torch.nn.Module): def __init__(self): super().__init__() self.fc1: torch.nn.Linear = torch.nn.Linear(28 * 28, 64) self.fc2: torch.nn.Linear = torch.nn.Linear(64, 64) self.fc3: torch.nn.Linear = torch.nn.Linear(64, 64) self.fc4: torch.nn.Linear = torch.nn.Linear(64, 10) def forward(self, x: torch.Tensor): x = torch.nn.functional.rrelu(self.fc1(x)) x = torch.nn.functional.rrelu(self.fc2(x)) x = torch.nn.functional.rrelu(self.fc3(x)) x = torch.nn.functional.log_softmax(self.fc4(x), dim=1) return x def get_data_loader(is_train: bool) -> DataLoader[Image]: to_tensor = transforms.Compose([transforms.ToTensor()]) data_set = MNIST("", is_train, transform=to_tensor, download=True) return DataLoader(data_set, batch_size=15, shuffle=True) def evaluate(test_data: DataLoader[Image], net: Net) -> float: n_correct: int = 0 n_total: int = 0 with torch.no_grad(): for (x, y) in test_data: outputs = net.forward(x.view(-1, 28 * 28)) for i, output in enumerate(outputs): if torch.argmax(output) == y[i]: n_correct += 1 n_total += 1 return n_correct / n_total def main(): print("=== START ===") torch.set_default_device(torch.device(type='cpu', index=0)) torch.set_default_dtype(torch.float16) train_data = get_data_loader(is_train=True) print('train data size:',len(train_data)) test_data = get_data_loader(is_train=False) print('test data size: ', len(test_data)) net = Net() print('initial accuracy:', evaluate(test_data, net)) optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # type: ignore for epoch in range(1, 3, 1): for (x, y) in train_data: net.zero_grad() output = net.forward(x.view(-1, 28 * 28)) loss = torch.nn.functional.nll_loss(output, y) loss.backward() optimizer.step() print('epoch', epoch, 'accuracy', evaluate(test_data, net)) model_name = 'model.pth' print('saving model to: ', model_name) torch.save(net.state_dict(), model_name) print('=== DONE ===') if __name__ == '__main__': main()