from PIL import Image, ImageDraw
import torchvision as tv
import torch

model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 4, 7), #2
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2), #1
    torch.nn.Flatten(),
    torch.nn.Linear(4 * 1 * 1, 2))

batch = 100
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(100):
    optimizer.zero_grad()
    DATA = torch.empty(batch, 1, 8, 8)
    TARGET = torch.empty(batch, dtype = torch.int64)
    for sample in range(batch):
        count = torch.randint(2, ()).item()
        image = Image.new('L', (8, 8))
        draw = ImageDraw.Draw(image)
        for index in range(count):
            draw.rectangle([(1, 1), (7, 7)], fill = 255)
        DATA[sample] = tv.transforms.functional.to_tensor(image)
        TARGET[sample] = count
    ACTIVATION = model(DATA)
    LOSS = torch.nn.functional.cross_entropy(ACTIVATION, TARGET)
    VALUE = ACTIVATION.argmax(1)
    ACCURACY = torch.eq(VALUE, TARGET).float().mean()
    LOSS.backward()
    optimizer.step()
    print("%4d %12.3f %4.3f" % (epoch, LOSS, ACCURACY), flush = True)
