from PIL import Image, ImageDraw
import torch
import torchvision as tv

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv0 = torch.nn.Conv2d(1, 32, 5)
        self.relu0 = torch.nn.ReLU()
        self.pool0 = torch.nn.MaxPool2d(2)
        self.norm1 = torch.nn.BatchNorm2d(32)
        self.flat1 = torch.nn.Flatten()
        self.line2 = torch.nn.Linear(32 * 30 * 30, 256)
        self.relu2 = torch.nn.ReLU()
        self.line3p = torch.nn.Linear(256, 2)
        self.line3c = torch.nn.Linear(256, 2)
    def forward(self, SIGNAL):
        SIGNAL = self.conv0(SIGNAL)
        SIGNAL = self.relu0(SIGNAL)
        SIGNAL = self.pool0(SIGNAL)
        SIGNAL = self.norm1(SIGNAL)
        SIGNAL = self.flat1(SIGNAL)
        SIGNAL = self.line2(SIGNAL)
        SIGNAL = self.relu2(SIGNAL)
        POSITION = self.line3p(SIGNAL)
        ACTIVATION = self.line3c(SIGNAL)
        return POSITION, ACTIVATION

model = Model()

batch = 1024
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(256):
    optimizer.zero_grad()
    DATA0 = torch.empty(batch, 1, 64, 64)
    POSITION0 = 64. * torch.rand(batch, 2)
    CLASS0 = torch.randint(2, (batch,))
    for sample in range(batch):
        image = Image.new('L', (64, 64))
        draw = ImageDraw.Draw(image)
        if CLASS0[sample] == 0:
            draw.rectangle([tuple(POSITION0[sample] - 4.), tuple(POSITION0[sample] + 4.)], fill = 255)
        else:
            draw.ellipse([tuple(POSITION0[sample] - 4.), tuple(POSITION0[sample] + 4.)], fill = 255)
        DATA0[sample] = tv.transforms.functional.to_tensor(image)
    POSITION1, ACTIVATION1 = model(DATA0)
    LOSSP = torch.nn.functional.mse_loss(POSITION1, POSITION0)
    LOSSC = torch.nn.functional.cross_entropy(ACTIVATION1, CLASS0)
    LOSS = LOSSP + LOSSC
    CLASS1 = ACTIVATION1.argmax(1)
    ACCURACY = torch.eq(CLASS1, CLASS0).float().mean()
    LOSS.backward()
    optimizer.step()
    print('%4d %12.3f %4.3f' % (epoch, LOSSP.sqrt(), ACCURACY), flush = True)
