67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
import torch
|
|
from PIL.Image import Image
|
|
from torchvision import transforms
|
|
from torchvision.datasets import MNIST
|
|
from torch.utils.data import DataLoader
|
|
import matplotlib.pyplot as plt
|
|
|
|
class Net(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1: torch.nn.Linear = torch.nn.Linear(28 * 28, 64)
|
|
self.fc2: torch.nn.Linear = torch.nn.Linear(64, 64)
|
|
self.fc3: torch.nn.Linear = torch.nn.Linear(64, 64)
|
|
self.fc4: torch.nn.Linear = torch.nn.Linear(64, 10)
|
|
def forward(self, x: torch.Tensor):
|
|
x = torch.nn.functional.relu(self.fc1(x))
|
|
x = torch.nn.functional.relu(self.fc2(x))
|
|
x = torch.nn.functional.relu(self.fc3(x))
|
|
x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
|
|
return x
|
|
|
|
def get_data_loader(is_train: bool) -> DataLoader[Image]:
|
|
to_tensor = transforms.Compose([transforms.ToTensor()])
|
|
data_set = MNIST("", is_train, transform=to_tensor, download=True)
|
|
return DataLoader(data_set, batch_size=15, shuffle=True)
|
|
|
|
def evaluate(test_data: DataLoader[Image], net: Net) -> float:
|
|
n_correct: int = 0
|
|
n_total: int = 0
|
|
with torch.no_grad():
|
|
for (x, y) in test_data:
|
|
outputs = net.forward(x.view(-1, 28 * 28))
|
|
for i, output in enumerate(outputs):
|
|
if torch.argmax(output) == y[i]:
|
|
n_correct += 1
|
|
n_total += 1
|
|
return n_correct / n_total
|
|
|
|
def main():
|
|
print("=== START ===")
|
|
torch.set_default_device(torch.device(type='cpu', index=0))
|
|
train_data = get_data_loader(is_train=True)
|
|
test_data = get_data_loader(is_train=False)
|
|
net = Net()
|
|
|
|
print('initial accuracy:', evaluate(test_data, net))
|
|
optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # type: ignore
|
|
for epoch in range(1, 7, 1):
|
|
for (x, y) in train_data:
|
|
net.zero_grad()
|
|
output = net.forward(x.view(-1, 28 * 28))
|
|
loss = torch.nn.functional.nll_loss(output, y)
|
|
loss.backward()
|
|
optimizer.step()
|
|
print('epoch', epoch, 'accuracy', evaluate(test_data, net))
|
|
|
|
# test
|
|
for (n, (x, _)) in enumerate(test_data):
|
|
if(n > 15):
|
|
break
|
|
predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
|
|
plt.figure(n)
|
|
plt.title('prediction: ' + str(int(predict)))
|
|
plt.imsave(str(n) + '_predict-' + str(predict) + '.png', x[0].view(28, 28))
|
|
print('=== DONE ===')
|
|
if __name__ == '__main__':
|
|
main() |