import gym
import torch

env = gym.make("CartPole-v1")

model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 2))

length = 0.
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(2048):
    episode = []
    state0 = torch.tensor(env.reset(), dtype = torch.float)
    done1 = False
    while not done1:
        preference0 = model(state0)
        policy0 = torch.nn.functional.softmax(preference0, -1)
        action0 = torch.multinomial(policy0, 1).item()
        state1, reward1, done1, info = env.step(action0)
        state1 = torch.tensor(state1, dtype = torch.float)
        episode.append((state0, action0, reward1))
        state0 = state1
    STATE0 = torch.stack([state0 for (state0, action0, reward1) in episode])
    ACTION0 = torch.tensor([action0 for (state0, action0, reward1) in episode])
    REWARD1 = torch.tensor([reward1 for (state0, action0, reward1) in episode])
    if length < len(episode):
        PREFERENCE0 = model(STATE0)
        loss = torch.nn.functional.cross_entropy(PREFERENCE0, ACTION0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        length += (len(episode) - length) / 16.
    print("%4d %4d %12.3f" % (epoch, length, REWARD1.sum()), flush = True)

torch.save(model.state_dict(), "cross21.pt")

epochs = 1000
stats = torch.empty(epochs)
for epoch in range(epochs):
    state0 = torch.tensor(env.reset(), dtype = torch.float)
    step1 = 0
    done1 = False
    while not done1:
        preference0 = model(state0)
        policy0 = torch.nn.functional.softmax(preference0, -1)
        action0 = torch.multinomial(policy0, 1).item()
        state1, reward1, done1, info = env.step(action0)
        state1 = torch.tensor(state1, dtype = torch.float)
        step1 += 1
        state0 = state1
    stats[epoch] = step1
print(stats.mean().item(), stats.std().item())

env.close()

#a: 487+-52
