import torch
import torchvision as tv

samples0, samples1 = 60000, 10000

source0 = tv.datasets.MNIST("../MNIST", train = True, download = True)
source1 = tv.datasets.MNIST("../MNIST", train = False, download = True)
DATA0 = source0.data.unsqueeze(1).float() / 255.
DATA1 = source1.data.unsqueeze(1).float() / 255.
TARGET0 = source0.targets
TARGET1 = source1.targets

model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, 5), #24
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2), #12
    torch.nn.Flatten(),
    torch.nn.Linear(8 * 12 * 12, 10))
variables = model.parameters()

optimizer = torch.optim.Adam(variables)
for epoch in range(100):
    optimizer.zero_grad()
    ACTIVATION0 = model(DATA0)
    LOSS0 = torch.nn.functional.cross_entropy(ACTIVATION0, TARGET0)
    VALUE0 = ACTIVATION0.argmax(1)
    ACCURACY0 = torch.eq(VALUE0, TARGET0).float().mean()
    LOSS0.backward()
    optimizer.step()
    with torch.no_grad():
        ACTIVATION1 = model(DATA1)
        LOSS1 = torch.nn.functional.cross_entropy(ACTIVATION1, TARGET1)
        VALUE1 = ACTIVATION1.argmax(1)
        ACCURACY1 = torch.eq(VALUE1, TARGET1).float().mean()
    print("%4d %12.3f %4.3f %12.3f %4.3f" % \
          (epoch, LOSS0, ACCURACY0, LOSS1, ACCURACY1), flush = True)
