from matplotlib import pyplot as plt import torch import os import time from PIL.Image import Image from torchvision import transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader from main import Net def get_test_loader() -> DataLoader[Image]: to_tensor = transforms.Compose([transforms.ToTensor()]) data_set = MNIST("", False, transform=to_tensor, download=True) return DataLoader(data_set, batch_size=15, shuffle=True) def main(): path = '.result' try: os.mkdir(path) except: pass net = Net() model = torch.load('model.pth', weights_only=True) net.load_state_dict(model) test_data = get_test_loader() for (n, (x, _)) in enumerate(test_data): if(n > 15): break predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28))) plt.figure(n) plt.title('prediction: ' + str(int(predict))) plt.imsave(path + '/' + str(time.time_ns()) + '_predict-' + str(predict) + '.png', x[0].view(28, 28)) if __name__ == '__main__': main()