from matplotlib import pyplot as plt
import torch

PARAM = torch.tensor([1., 0.5, -0.2, 0.1])

DATA1 = torch.linspace(-10., 10., 101)
DESIGN1 = torch.stack([DATA1**power for power in range(len(PARAM))]).T
TARGET1 = DESIGN1 @ PARAM
plt.plot(DATA1, TARGET1)

DATA0 = 20. * torch.rand(10) - 10.
DESIGN0 = torch.stack([DATA0**power for power in range(len(PARAM))]).T
TARGET0 = DESIGN0 @ PARAM + 20. * torch.rand(10) - 10.
plt.plot(DATA0, TARGET0, 'o')

PARAM = torch.zeros(4, requires_grad = True)

optimizer = torch.optim.SGD([PARAM], lr = 0.0000001)
for epoch in range(1000):
    optimizer.zero_grad()
    VALUE0 = DESIGN0 @ PARAM
    LOSS = torch.sum((VALUE0 - TARGET0)**2)
    LOSS.backward()
    optimizer.step()
    print(epoch, LOSS.item(), flush = True)

VALUE1 = DESIGN1 @ PARAM
plt.plot(DATA1, VALUE1.detach())

plt.show()
