try save model
This commit is contained in:
parent
07b7ac0a3c
commit
b6c8a5a3d4
22
main.py
22
main.py
@ -13,9 +13,9 @@ class Net(torch.nn.Module):
|
||||
self.fc3: torch.nn.Linear = torch.nn.Linear(64, 64)
|
||||
self.fc4: torch.nn.Linear = torch.nn.Linear(64, 10)
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = torch.nn.functional.relu(self.fc1(x))
|
||||
x = torch.nn.functional.relu(self.fc2(x))
|
||||
x = torch.nn.functional.relu(self.fc3(x))
|
||||
x = torch.nn.functional.rrelu(self.fc1(x))
|
||||
x = torch.nn.functional.rrelu(self.fc2(x))
|
||||
x = torch.nn.functional.rrelu(self.fc3(x))
|
||||
x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
|
||||
return x
|
||||
|
||||
@ -39,13 +39,16 @@ def evaluate(test_data: DataLoader[Image], net: Net) -> float:
|
||||
def main():
|
||||
print("=== START ===")
|
||||
torch.set_default_device(torch.device(type='cpu', index=0))
|
||||
torch.set_default_dtype(torch.float16)
|
||||
train_data = get_data_loader(is_train=True)
|
||||
print('train data size:',len(train_data))
|
||||
test_data = get_data_loader(is_train=False)
|
||||
print('test data size: ', len(test_data))
|
||||
net = Net()
|
||||
|
||||
print('initial accuracy:', evaluate(test_data, net))
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # type: ignore
|
||||
for epoch in range(1, 7, 1):
|
||||
for epoch in range(1, 3, 1):
|
||||
for (x, y) in train_data:
|
||||
net.zero_grad()
|
||||
output = net.forward(x.view(-1, 28 * 28))
|
||||
@ -54,14 +57,9 @@ def main():
|
||||
optimizer.step()
|
||||
print('epoch', epoch, 'accuracy', evaluate(test_data, net))
|
||||
|
||||
# test
|
||||
for (n, (x, _)) in enumerate(test_data):
|
||||
if(n > 15):
|
||||
break
|
||||
predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
|
||||
plt.figure(n)
|
||||
plt.title('prediction: ' + str(int(predict)))
|
||||
plt.imsave(str(n) + '_predict-' + str(predict) + '.png', x[0].view(28, 28))
|
||||
model_name = 'model.pth'
|
||||
print('saving model to: ', model_name)
|
||||
torch.save(net.state_dict(), model_name)
|
||||
print('=== DONE ===')
|
||||
if __name__ == '__main__':
|
||||
main()
|
40
test.py
40
test.py
@ -1,2 +1,38 @@
|
||||
def __main__():
|
||||
print('hello')
|
||||
from matplotlib import pyplot as plt
|
||||
import torch
|
||||
import os
|
||||
import time
|
||||
from PIL.Image import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from main import Net
|
||||
|
||||
|
||||
def get_test_loader() -> DataLoader[Image]:
|
||||
to_tensor = transforms.Compose([transforms.ToTensor()])
|
||||
data_set = MNIST("", False, transform=to_tensor, download=True)
|
||||
return DataLoader(data_set, batch_size=15, shuffle=True)
|
||||
|
||||
def main():
|
||||
path = '.result'
|
||||
try:
|
||||
os.mkdir(path)
|
||||
except:
|
||||
pass
|
||||
net = Net()
|
||||
model = torch.load('model.pth', weights_only=True)
|
||||
net.load_state_dict(model)
|
||||
|
||||
test_data = get_test_loader()
|
||||
for (n, (x, _)) in enumerate(test_data):
|
||||
if(n > 15):
|
||||
break
|
||||
predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
|
||||
plt.figure(n)
|
||||
plt.title('prediction: ' + str(int(predict)))
|
||||
plt.imsave(path + '/' + str(time.time_ns()) + '_predict-' + str(predict) + '.png', x[0].view(28, 28))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user