import gym
import torch

env = gym.make("CartPole-v1")

model0 = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 2))

epochs = 1000
epsilon = 0.3
optimizer = torch.optim.Adam(model0.parameters())
for epoch in range(epochs):
    state0 = torch.tensor(env.reset(), dtype = torch.float)
    step1 = 0
    gain1 = 0.
    done1 = False
    while not done1:
        quality0 = model0(state0)
        action0 = torch.randint(2, ()) if torch.rand(()) < epsilon else quality0.argmax()
        state1, reward1, done1, info = env.step(action0.item())
        state1 = torch.tensor(state1, dtype = torch.float)
        step1 += 1
        gain1 += reward1
        terminal1 = done1 and step1 < 500
        quality1 = model0(state1)
        target0 = reward1 + 0.99 * (not terminal1) * quality1.max()
        loss = torch.nn.functional.mse_loss(quality0[action0], target0.detach())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        state0 = state1
    print("%4d %4d %12.3f" % (epoch, step1, gain1), flush = True)
    epsilon = epsilon - 1. / epochs if 0.1 < epsilon else epsilon

torch.save(model0.state_dict(), "target00.pt")

epochs = 1000
stats = torch.empty(epochs)
for epoch in range(epochs):
    state0 = torch.tensor(env.reset(), dtype = torch.float)
    step1 = 0
    done1 = False
    while not done1:
        quality0 = model0(state0)
        action0 = quality0.argmax()
        state1, reward1, done1, info = env.step(action0.item())
        state1 = torch.tensor(state1, dtype = torch.float)
        step1 += 1
        state0 = state1
    stats[epoch] = step1
print(stats.mean().item(), stats.std().item())

env.close()

#a: 017+-19
