From b6c8a5a3d4938264e307061d520729ec028d5f5c Mon Sep 17 00:00:00 2001 From: 05412 <2738076308@qq.com> Date: Tue, 3 Sep 2024 11:29:23 +0800 Subject: [PATCH] try save model --- main.py | 24 +++++++++++------------- model.pth | Bin 0 -> 121064 bytes test.py | 40 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 15 deletions(-) create mode 100644 model.pth diff --git a/main.py b/main.py index 434fcf2..5825c88 100644 --- a/main.py +++ b/main.py @@ -13,9 +13,9 @@ class Net(torch.nn.Module): self.fc3: torch.nn.Linear = torch.nn.Linear(64, 64) self.fc4: torch.nn.Linear = torch.nn.Linear(64, 10) def forward(self, x: torch.Tensor): - x = torch.nn.functional.relu(self.fc1(x)) - x = torch.nn.functional.relu(self.fc2(x)) - x = torch.nn.functional.relu(self.fc3(x)) + x = torch.nn.functional.rrelu(self.fc1(x)) + x = torch.nn.functional.rrelu(self.fc2(x)) + x = torch.nn.functional.rrelu(self.fc3(x)) x = torch.nn.functional.log_softmax(self.fc4(x), dim=1) return x @@ -39,13 +39,16 @@ def evaluate(test_data: DataLoader[Image], net: Net) -> float: def main(): print("=== START ===") torch.set_default_device(torch.device(type='cpu', index=0)) + torch.set_default_dtype(torch.float16) train_data = get_data_loader(is_train=True) + print('train data size:',len(train_data)) test_data = get_data_loader(is_train=False) + print('test data size: ', len(test_data)) net = Net() print('initial accuracy:', evaluate(test_data, net)) optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # type: ignore - for epoch in range(1, 7, 1): + for epoch in range(1, 3, 1): for (x, y) in train_data: net.zero_grad() output = net.forward(x.view(-1, 28 * 28)) @@ -54,14 +57,9 @@ def main(): optimizer.step() print('epoch', epoch, 'accuracy', evaluate(test_data, net)) - # test - for (n, (x, _)) in enumerate(test_data): - if(n > 15): - break - predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28))) - plt.figure(n) - plt.title('prediction: ' + str(int(predict))) - plt.imsave(str(n) + '_predict-' + str(predict) + '.png', x[0].view(28, 28)) + model_name = 'model.pth' + print('saving model to: ', model_name) + torch.save(net.state_dict(), model_name) print('=== DONE ===') if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/model.pth b/model.pth new file mode 100644 index 0000000000000000000000000000000000000000..b71e826bc86701f071567bdfd62b28c84f6708bb GIT binary patch literal 121064 zcmeI)&vO%H9LMn|O=-hYpkn=jBI56YG)=PkK@qe@jOdDlFpimCI%%?yNFaQZZ5`Ck zaB0V1+`Qt+tDHS~QLi}8DC3Mr{sqsD&y(HOY)TpEr5$*0=Sim9e810UU)iUpXSV1> z2bGG&)S6?L8dDcrHMcogt9X_8d%VK?5H(7RgGE$PAWplA6&|IU^kzI8?TYa$Zw@b?I`mi(Y$!NFBl04!>mSm4# z|EMQpC3!5kpw!9*Js#?ik|*jq$fzezc3U{p%d*#*HZh*^+dS>bzLGo>#7Oxu_J^9- z@~j`@KoH|O6XSWm{sm87EXhkj40|QU%b^Zwc_oN3z8d48ALEeU=2cG)mt-P{Vf!)S zp(Yub^kXD~7)cW&<=5Muq)Uj+Nwi z5F_o!csj~i-O=}|I8LO~y>;}hAnz~Asd~?ulzgx#r+d%XQe2cXy=T%=T9mWh zGxMFfc+nY-9bK?3OW%%W6xUr@yVX>%vgWRMZp$p0LO!!*Y4xqNJOA2jcwW;rb1xmg zy?01osTu3fJ))+2R`q0}kUzGuiGw)mUw#lk009ILKmY**5I_I{1Q0*~0R#|0009IL zKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~ z0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY** z5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0 z009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{ z1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009IL zKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~ z0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY** z5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|0 z009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5I_I{1Q0*~0R#|00D*^7py&+8e*X2J zxcVZZ%qTk1L8W3bwdNR6L+WCy<~Ap56|XXxEaV5&%*0R#|0009ILKmY**5I_I{1Q0*~0R#|0009ILKmY**5O@#-%>IGePv2fw z2lVCvJ??q`fW0;L4{Uh`%+}aHu;Kj=*{zX(v%i1lM*x9`SHS%L%-{O^E45AUf79ch z_rK-V|Nm)v9vVBK#&kP?Y%?#^JeT0)y?Q`e(l^9&uz79Zo9v@_u)??q2gPs#a%6Z3E{}Sk6j8^7Qcyv z9u?JJ*}LH~Tse9zRGI3pyvv+FgevX+%Dc>YBUG90ue{5gH$#<~{>t8k)pwo0g(|cC zmA%U+Tv_}hRGI6q>|F{wL*M83tChiG>o4uSKlg=7_x!ms>77?j(HR;#98~O&POR3h oL|=Ko*}dnL`cd~Yhprno@5_Jrznj;hv(4PdYDE8-&- DataLoader[Image]: + to_tensor = transforms.Compose([transforms.ToTensor()]) + data_set = MNIST("", False, transform=to_tensor, download=True) + return DataLoader(data_set, batch_size=15, shuffle=True) + +def main(): + path = '.result' + try: + os.mkdir(path) + except: + pass + net = Net() + model = torch.load('model.pth', weights_only=True) + net.load_state_dict(model) + + test_data = get_test_loader() + for (n, (x, _)) in enumerate(test_data): + if(n > 15): + break + predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28))) + plt.figure(n) + plt.title('prediction: ' + str(int(predict))) + plt.imsave(path + '/' + str(time.time_ns()) + '_predict-' + str(predict) + '.png', x[0].view(28, 28)) + +if __name__ == '__main__': + main() \ No newline at end of file