from matplotlib import pyplot as plt
from PIL import Image, ImageDraw
import torch
import torchvision as tv

model = torch.nn.Sequential(
    torch.nn.BatchNorm2d(1),
    torch.nn.Conv2d(1, 16, 3, padding = 1),
    torch.nn.ReLU(),
    torch.nn.BatchNorm2d(16),
    torch.nn.Conv2d(16, 16, 3, padding = 1),
    torch.nn.ReLU(),
    torch.nn.BatchNorm2d(16),
    torch.nn.Conv2d(16, 2, 3, padding = 1))
model.load_state_dict(torch.load("segment.pt"))

model.eval()
while True:
    count = torch.randint(10, ()).item()
    POSITIONS = torch.randint(64, (count, 2))
    DATA = torch.rand(64, 64)
    data = tv.transforms.functional.to_pil_image(DATA)
    draw = ImageDraw.Draw(data)
    for index in range(count):
        draw.ellipse([tuple(POSITIONS[index] - 8), tuple(POSITIONS[index] + 8)],
                     fill = torch.randint(256, ()).item())
    DATA = tv.transforms.functional.to_tensor(data)
    ACTIVATION = model(DATA[None, :])
    VALUE = ACTIVATION.argmax(1)
    plt.figure()
    plt.imshow(DATA[0], cmap = 'gray')
    plt.figure()
    plt.imshow(VALUE[0].detach(), cmap = 'gray')
    plt.show()
    input()
