import numpy as np
from sklearn import datasets
import torch

samples0, samples1 = 1500, 297
features = 64
classes = 10

source = datasets.load_digits()
data0, data1 = source.data[: samples0], source.data[samples0 :]
target0, target1 = source.target[: samples0], source.target[samples0 :]

DATA0 = torch.tensor(data0, dtype = torch.float32)
DATA1 = torch.tensor(data1, dtype = torch.float32)
TARGET0 = torch.tensor(target0, dtype = torch.int64)
TARGET1 = torch.tensor(target1, dtype = torch.int64)

size1 = 128
BIAS1 = torch.zeros(1, size1, requires_grad = True)
WEIGHT1 = torch.zeros(features, size1, requires_grad = True)
BIAS = torch.zeros(1, classes, requires_grad = True)
WEIGHT = torch.zeros(size1, classes, requires_grad = True)
variables = [BIAS, WEIGHT, BIAS1, WEIGHT1]

def model(DATA):
    ACTIVATION1 = BIAS1 + DATA @ WEIGHT1
    ACTIVITY1 = torch.nn.functional.softmax(ACTIVATION1, 1)
    ACTIVATION = BIAS + ACTIVITY1 @ WEIGHT
    return ACTIVATION

batch = 100
optimizer = torch.optim.SGD(variables, lr = 0.1)
for epoch in range(100):
    LOSS0 = torch.zeros(())
    ACCURACY0 = torch.zeros(())
    count0 = 0
    for index in range(0, samples0, batch):
        optimizer.zero_grad()
        DATA = DATA0[index : index + batch]
        TARGET = TARGET0[index : index + batch]
        count = TARGET.size(0)
        ACTIVATION = model(DATA)
        LOSS = torch.nn.functional.cross_entropy(ACTIVATION, TARGET)
        LOSS0 += LOSS * count
        VALUE = torch.argmax(ACTIVATION, 1)
        ACCURACY0 += torch.sum(VALUE == TARGET)
        count0 += count
        LOSS.backward()
        optimizer.step()
    LOSS0 /= count0
    ACCURACY0 /= count0
    with torch.no_grad():
        LOSS1 = torch.zeros(())
        ACCURACY1 = torch.zeros(())
        count1 = 0
        for index in range(0, samples1, batch):
            DATA = DATA1[index : index + batch]
            TARGET = TARGET1[index : index + batch]
            ACTIVATION = model(DATA)
            LOSS1 += torch.nn.functional.cross_entropy(ACTIVATION, TARGET, reduction = "sum")
            VALUE = torch.argmax(ACTIVATION, 1)
            ACCURACY1 += torch.sum(VALUE == TARGET)
            count1 += TARGET.size(0)
        LOSS1 /= count1
        ACCURACY1 /= count1
    print("%4d %12.3f %4.3f %12.3f %4.3f" % \
          (epoch, LOSS0, ACCURACY0, LOSS1, ACCURACY1), flush = True)
