from matplotlib import pyplot as plt
from PIL import Image, ImageDraw
import torch
import torchvision as tv

torch.pi = 4. * torch.tensor(1.).atan()
torch.sqrt2 = torch.tensor(2.).sqrt()

radius = 2.

def position(t):
    return torch.stack([torch.sqrt2 * radius * t.cos() / (1. + t.sin().square()),
                        torch.sqrt2 * radius * t.cos() * t.sin() / (1. + t.sin().square())]).T

def angle(t):
    return 3. * t.sin().atan()

n = 100
t = torch.linspace(0., 2. * torch.pi, n)
r = position(t)
a = angle(t)

def car(r, r0, a0):
    return (r - r0) @ torch.tensor([[a0.cos(), -a0.sin()], [a0.sin(), a0.cos()]])

def ground(p, r0, a0):
    return r0 + p @ torch.tensor([[a0.cos(), a0.sin()], [- a0.sin(), a0.cos()]])

def scan(r0, a0):
    p = car(r, r0, a0)
    image = Image.new('L', (64, 64))
    draw = ImageDraw.Draw(image)
    draw.line(list(map(tuple, 32. + torch.tensor([32, -32]) * p)), fill = 255, width = 8, joint = 'curve')
    image.thumbnail((16, 16), Image.ANTIALIAS)
    return tv.transforms.functional.to_tensor(image)

def go(r0, a0, k):
    s = torch.tensor(0.5)
    dp = torch.tensor([s * s * k / 2, s])
    da = s * k
    return ground(dp, r0, a0), a0 - da

index = 0

def kappa(r0, a0):
    global index
    while torch.norm(r[index] - r0) < 1.:
        index = (index + 1) % n
    dp = car(r[index], r0, a0)
    return 2. * dp[0] / dp[1] / dp[1]

model = torch.nn.Sequential(
    torch.nn.Linear(256, 128),
    torch.nn.Tanh(),
    torch.nn.Linear(128, 64),
    torch.nn.Tanh(),
    torch.nn.Linear(64, 32),
    torch.nn.Tanh(),
    torch.nn.Linear(32, 1))

r0 = r[index]
a0 = a[index]

batch = 1024
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(128):
    optimizer.zero_grad()
    DATA = torch.empty(batch, 256)
    TARGET = torch.empty(batch)
    for counter in range(batch):
        dr = 0.2 * torch.rand(2) - 0.1
        da = 0.2 * torch.rand(()) - 0.1
        DATA[counter] = scan(r0 + dr, a0 + da).flatten()
        TARGET[counter] = kappa(r0 + dr, a0 + da)
        r0, a0 = go(r0, a0, kappa(r0, a0))
    VALUE = model(DATA)
    LOSS = torch.nn.functional.mse_loss(VALUE, TARGET.unsqueeze(1))
    LOSS.backward()
    optimizer.step()
    print('%4d %12.3f' % (epoch, LOSS), flush = True)

rr = torch.empty(28, 2)

for counter in range(28):
    rr[counter] = r0
    k = model(scan(r0, a0).flatten())
    r0, a0 = go(r0, a0, k)

plt.plot(r[:, 0], r[:, 1])
plt.plot(rr[:, 0], rr[:, 1], 'o')
plt.show()
