import torch from PIL.Image import Image from torchvision import transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader import matplotlib.pyplot as plt class Net(torch.nn.Module): def __init__(self): super().__init__() self.fc1: torch.nn.Linear = torch.nn.Linear(28 * 28, 64) self.fc2: torch.nn.Linear = torch.nn.Linear(64, 64) self.fc3: torch.nn.Linear = torch.nn.Linear(64, 64) self.fc4: torch.nn.Linear = torch.nn.Linear(64, 10) def forward(self, x: torch.Tensor): x = torch.nn.functional.relu(self.fc1(x)) x = torch.nn.functional.relu(self.fc2(x)) x = torch.nn.functional.relu(self.fc3(x)) x = torch.nn.functional.log_softmax(self.fc4(x), dim=1) return x def get_data_loader(is_train: bool) -> DataLoader[Image]: to_tensor = transforms.Compose([transforms.ToTensor()]) data_set = MNIST("", is_train, transform=to_tensor, download=True) return DataLoader(data_set, batch_size=15, shuffle=True) def evaluate(test_data: DataLoader[Image], net: Net) -> float: n_correct: int = 0 n_total: int = 0 with torch.no_grad(): for (x, y) in test_data: outputs = net.forward(x.view(-1, 28 * 28)) for i, output in enumerate(outputs): if torch.argmax(output) == y[i]: n_correct += 1 n_total += 1 return n_correct / n_total def main(): print("=== START ===") torch.set_default_device(torch.device(type='cpu', index=0)) train_data = get_data_loader(is_train=True) test_data = get_data_loader(is_train=False) net = Net() print('initial accuracy:', evaluate(test_data, net)) optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # type: ignore for epoch in range(1, 7, 1): for (x, y) in train_data: net.zero_grad() output = net.forward(x.view(-1, 28 * 28)) loss = torch.nn.functional.nll_loss(output, y) loss.backward() optimizer.step() print('epoch', epoch, 'accuracy', evaluate(test_data, net)) # test for (n, (x, _)) in enumerate(test_data): if(n > 15): break predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28))) plt.figure(n) plt.title('prediction: ' + str(int(predict))) plt.imsave(str(n) + '_predict-' + str(predict) + '.png', x[0].view(28, 28)) print('=== DONE ===') if __name__ == '__main__': main()