from matplotlib import pyplot as plt
from PIL import Image, ImageDraw
import torch
import torchvision as tv

model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 16, 3, padding = 1),
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, 3, padding = 1),
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 2, 3, padding = 1))

batch = 1024
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(1000):
    optimizer.zero_grad()
    DATA0 = torch.empty(batch, 1, 64, 64)
    TARGET0 = torch.empty(batch, 1, 64, 64, dtype = torch.int64)
    ONEHOT0 = torch.empty(batch, 2, 64, 64, dtype = torch.int64)
    for sample in range(batch):
        count = torch.randint(10, ()).item()
        POSITIONS = torch.randint(64, (count, 2))
        DATA = torch.rand(64, 64)
        data = tv.transforms.functional.to_pil_image(DATA)
        target = Image.new('L', (64, 64))
        data_draw = ImageDraw.Draw(data)
        target_draw = ImageDraw.Draw(target)
        for index in range(count):
            data_draw.ellipse([tuple(POSITIONS[index] - 8), tuple(POSITIONS[index] + 8)],
                              fill = torch.randint(256, ()).item())
            target_draw.ellipse([tuple(POSITIONS[index] - 8), tuple(POSITIONS[index] + 8)],
                                fill = 255)
        DATA = tv.transforms.functional.to_tensor(data)
        TARGET = tv.transforms.functional.to_tensor(target).type(torch.int64)
        ONEHOT = (TARGET == torch.arange(2)[:, None, None])
        DATA0[sample] = DATA
        TARGET0[sample] = TARGET
        ONEHOT0[sample] = ONEHOT
    ACTIVATION0 = model(DATA0)
    LOGARITHM0 = torch.nn.functional.log_softmax(ACTIVATION0, 1)
    ENTROPY0 = -LOGARITHM0 * ONEHOT0
    LOSS0 = ENTROPY0.sum()
    VALUE0 = ACTIVATION0.argmax(1)
    ACCURACY0 = torch.eq(VALUE0, TARGET0[:, 0, :, :]).float().mean()
    LOSS0.backward()
    optimizer.step()
    print("%4d %12.3f %4.3f" % (epoch, LOSS0, ACCURACY0), flush = True)

while True:
    count = torch.randint(10, ()).item()
    POSITIONS = torch.randint(64, (count, 2))
    DATA = torch.rand(64, 64)
    data = tv.transforms.functional.to_pil_image(DATA)
    draw = ImageDraw.Draw(data)
    for index in range(count):
        draw.ellipse([tuple(POSITIONS[index] - 8), tuple(POSITIONS[index] + 8)],
                     fill = torch.randint(256, ()).item())
    DATA = tv.transforms.functional.to_tensor(data)
    ACTIVATION = model(DATA[None, :])
    VALUE = ACTIVATION.argmax(1)
    plt.figure()
    plt.imshow(DATA[0], cmap = 'gray')
    plt.figure()
    plt.imshow(VALUE[0].detach(), cmap = 'gray')
    plt.show()
    input()

#accuracy: 995 after 5000 epochs
