38 lines
1.0 KiB
Python
38 lines
1.0 KiB
Python
from matplotlib import pyplot as plt
|
|
import torch
|
|
import os
|
|
import time
|
|
from PIL.Image import Image
|
|
from torchvision import transforms
|
|
from torchvision.datasets import MNIST
|
|
from torch.utils.data import DataLoader
|
|
|
|
from main import Net
|
|
|
|
|
|
def get_test_loader() -> DataLoader[Image]:
|
|
to_tensor = transforms.Compose([transforms.ToTensor()])
|
|
data_set = MNIST("", False, transform=to_tensor, download=True)
|
|
return DataLoader(data_set, batch_size=15, shuffle=True)
|
|
|
|
def main():
|
|
path = '.result'
|
|
try:
|
|
os.mkdir(path)
|
|
except:
|
|
pass
|
|
net = Net()
|
|
model = torch.load('model.pth', weights_only=True)
|
|
net.load_state_dict(model)
|
|
|
|
test_data = get_test_loader()
|
|
for (n, (x, _)) in enumerate(test_data):
|
|
if(n > 15):
|
|
break
|
|
predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
|
|
plt.figure(n)
|
|
plt.title('prediction: ' + str(int(predict)))
|
|
plt.imsave(path + '/' + str(time.time_ns()) + '_predict-' + str(predict) + '.png', x[0].view(28, 28))
|
|
|
|
if __name__ == '__main__':
|
|
main() |