diff --git a/main.py b/main.py index 434fcf2..5825c88 100644 --- a/main.py +++ b/main.py @@ -13,9 +13,9 @@ class Net(torch.nn.Module): self.fc3: torch.nn.Linear = torch.nn.Linear(64, 64) self.fc4: torch.nn.Linear = torch.nn.Linear(64, 10) def forward(self, x: torch.Tensor): - x = torch.nn.functional.relu(self.fc1(x)) - x = torch.nn.functional.relu(self.fc2(x)) - x = torch.nn.functional.relu(self.fc3(x)) + x = torch.nn.functional.rrelu(self.fc1(x)) + x = torch.nn.functional.rrelu(self.fc2(x)) + x = torch.nn.functional.rrelu(self.fc3(x)) x = torch.nn.functional.log_softmax(self.fc4(x), dim=1) return x @@ -39,13 +39,16 @@ def evaluate(test_data: DataLoader[Image], net: Net) -> float: def main(): print("=== START ===") torch.set_default_device(torch.device(type='cpu', index=0)) + torch.set_default_dtype(torch.float16) train_data = get_data_loader(is_train=True) + print('train data size:',len(train_data)) test_data = get_data_loader(is_train=False) + print('test data size: ', len(test_data)) net = Net() print('initial accuracy:', evaluate(test_data, net)) optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # type: ignore - for epoch in range(1, 7, 1): + for epoch in range(1, 3, 1): for (x, y) in train_data: net.zero_grad() output = net.forward(x.view(-1, 28 * 28)) @@ -54,14 +57,9 @@ def main(): optimizer.step() print('epoch', epoch, 'accuracy', evaluate(test_data, net)) - # test - for (n, (x, _)) in enumerate(test_data): - if(n > 15): - break - predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28))) - plt.figure(n) - plt.title('prediction: ' + str(int(predict))) - plt.imsave(str(n) + '_predict-' + str(predict) + '.png', x[0].view(28, 28)) + model_name = 'model.pth' + print('saving model to: ', model_name) + torch.save(net.state_dict(), model_name) print('=== DONE ===') if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/model.pth b/model.pth new file mode 100644 index 0000000..b71e826 Binary files /dev/null and b/model.pth differ diff --git a/test.py b/test.py index 9b071b4..e2ae983 100644 --- a/test.py +++ b/test.py @@ -1,2 +1,38 @@ -def __main__(): - print('hello') \ No newline at end of file +from matplotlib import pyplot as plt +import torch +import os +import time +from PIL.Image import Image +from torchvision import transforms +from torchvision.datasets import MNIST +from torch.utils.data import DataLoader + +from main import Net + + +def get_test_loader() -> DataLoader[Image]: + to_tensor = transforms.Compose([transforms.ToTensor()]) + data_set = MNIST("", False, transform=to_tensor, download=True) + return DataLoader(data_set, batch_size=15, shuffle=True) + +def main(): + path = '.result' + try: + os.mkdir(path) + except: + pass + net = Net() + model = torch.load('model.pth', weights_only=True) + net.load_state_dict(model) + + test_data = get_test_loader() + for (n, (x, _)) in enumerate(test_data): + if(n > 15): + break + predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28))) + plt.figure(n) + plt.title('prediction: ' + str(int(predict))) + plt.imsave(path + '/' + str(time.time_ns()) + '_predict-' + str(predict) + '.png', x[0].view(28, 28)) + +if __name__ == '__main__': + main() \ No newline at end of file