from matplotlib import pyplot as plt
from sklearn import datasets
import torch

samples0, samples1 = 200, 200
data0, target0 = datasets.make_circles(samples0, noise = 0.25, factor = 0.1)
data1, target1 = datasets.make_circles(samples1, noise = 0.25, factor = 0.1)
DATA0 = torch.tensor(data0, dtype = torch.float32)
DATA1 = torch.tensor(data1, dtype = torch.float32)
TARGET0 = torch.tensor(target0, dtype = torch.int64)
TARGET1 = torch.tensor(target1, dtype = torch.int64)

model = torch.nn.Sequential(
    torch.nn.Linear(2, 2))
variables = model.parameters()

batch = 200
optimizer = torch.optim.SGD(variables, lr = 0.1)
for epoch in range(10000):
    LOSS0 = torch.zeros(())
    ACCURACY0 = torch.zeros(())
    count0 = 0
    for index in range(0, samples0, batch):
        optimizer.zero_grad()
        DATA = DATA0[index : index + batch]
        TARGET = TARGET0[index : index + batch]
        count = TARGET.size(0)
        ACTIVATION = model(DATA)
        LOSS = torch.nn.functional.cross_entropy(ACTIVATION, TARGET)
        LOSS0 += LOSS * count
        VALUE = ACTIVATION.argmax(1)
        ACCURACY0 += torch.eq(VALUE, TARGET).sum()
        count0 += count
        LOSS.backward()
        optimizer.step()
    LOSS0 /= count0
    ACCURACY0 /= count0
    with torch.no_grad():
        LOSS1 = torch.zeros(())
        ACCURACY1 = torch.zeros(())
        count1 = 0
        for index in range(0, samples1, batch):
            DATA = DATA1[index : index + batch]
            TARGET = TARGET1[index : index + batch]
            ACTIVATION = model(DATA)
            LOSS1 += torch.nn.functional.cross_entropy(ACTIVATION, TARGET, reduction = "sum")
            VALUE = ACTIVATION.argmax(1)
            ACCURACY1 += torch.eq(VALUE, TARGET).sum()
            count1 += TARGET.size(0)
        LOSS1 /= count1
        ACCURACY1 /= count1
    print("%4d %12.3f %4.3f %12.3f %4.3f" % \
          (epoch, LOSS0, ACCURACY0, LOSS1, ACCURACY1), flush = True)

plt.figure(figsize = (5, 5))
plt.xlim(-2., 2.)
plt.ylim(-2., 2.)
plt.scatter(data0[:, 0], data0[:, 1], marker = '.', c = target0)

plt.figure(figsize = (5, 5))
plt.xlim(-2., 2.)
plt.ylim(-2., 2.)
plt.scatter(data1[:, 0], data1[:, 1], marker = '.', c = target1)

plt.show()
