try save model

This commit is contained in:
05412 2024-09-03 11:29:23 +08:00
parent 07b7ac0a3c
commit b6c8a5a3d4
3 changed files with 49 additions and 15 deletions

24
main.py
View File

@ -13,9 +13,9 @@ class Net(torch.nn.Module):
self.fc3: torch.nn.Linear = torch.nn.Linear(64, 64)
self.fc4: torch.nn.Linear = torch.nn.Linear(64, 10)
def forward(self, x: torch.Tensor):
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.relu(self.fc3(x))
x = torch.nn.functional.rrelu(self.fc1(x))
x = torch.nn.functional.rrelu(self.fc2(x))
x = torch.nn.functional.rrelu(self.fc3(x))
x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
return x
@ -39,13 +39,16 @@ def evaluate(test_data: DataLoader[Image], net: Net) -> float:
def main():
print("=== START ===")
torch.set_default_device(torch.device(type='cpu', index=0))
torch.set_default_dtype(torch.float16)
train_data = get_data_loader(is_train=True)
print('train data size:',len(train_data))
test_data = get_data_loader(is_train=False)
print('test data size: ', len(test_data))
net = Net()
print('initial accuracy:', evaluate(test_data, net))
optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # type: ignore
for epoch in range(1, 7, 1):
for epoch in range(1, 3, 1):
for (x, y) in train_data:
net.zero_grad()
output = net.forward(x.view(-1, 28 * 28))
@ -54,14 +57,9 @@ def main():
optimizer.step()
print('epoch', epoch, 'accuracy', evaluate(test_data, net))
# test
for (n, (x, _)) in enumerate(test_data):
if(n > 15):
break
predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
plt.figure(n)
plt.title('prediction: ' + str(int(predict)))
plt.imsave(str(n) + '_predict-' + str(predict) + '.png', x[0].view(28, 28))
model_name = 'model.pth'
print('saving model to: ', model_name)
torch.save(net.state_dict(), model_name)
print('=== DONE ===')
if __name__ == '__main__':
main()
main()

BIN
model.pth Normal file

Binary file not shown.

40
test.py
View File

@ -1,2 +1,38 @@
def __main__():
print('hello')
from matplotlib import pyplot as plt
import torch
import os
import time
from PIL.Image import Image
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from main import Net
def get_test_loader() -> DataLoader[Image]:
to_tensor = transforms.Compose([transforms.ToTensor()])
data_set = MNIST("", False, transform=to_tensor, download=True)
return DataLoader(data_set, batch_size=15, shuffle=True)
def main():
path = '.result'
try:
os.mkdir(path)
except:
pass
net = Net()
model = torch.load('model.pth', weights_only=True)
net.load_state_dict(model)
test_data = get_test_loader()
for (n, (x, _)) in enumerate(test_data):
if(n > 15):
break
predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
plt.figure(n)
plt.title('prediction: ' + str(int(predict)))
plt.imsave(path + '/' + str(time.time_ns()) + '_predict-' + str(predict) + '.png', x[0].view(28, 28))
if __name__ == '__main__':
main()