ML_Sample/test.py

38 lines
1.0 KiB
Python
Raw Normal View History

2024-09-03 11:29:23 +08:00
from matplotlib import pyplot as plt
import torch
import os
import time
from PIL.Image import Image
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from main import Net
def get_test_loader() -> DataLoader[Image]:
to_tensor = transforms.Compose([transforms.ToTensor()])
data_set = MNIST("", False, transform=to_tensor, download=True)
return DataLoader(data_set, batch_size=15, shuffle=True)
def main():
path = '.result'
try:
os.mkdir(path)
except:
pass
net = Net()
model = torch.load('model.pth', weights_only=True)
net.load_state_dict(model)
test_data = get_test_loader()
for (n, (x, _)) in enumerate(test_data):
if(n > 15):
break
predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
plt.figure(n)
plt.title('prediction: ' + str(int(predict)))
plt.imsave(path + '/' + str(time.time_ns()) + '_predict-' + str(predict) + '.png', x[0].view(28, 28))
if __name__ == '__main__':
main()