import gym
import torch

env = gym.make("CartPole-v1")

model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 2))

optimizer = torch.optim.Adam(model.parameters())
for epoch in range(16):
    batch = []
    for index in range(128):
        episode = []
        state0 = torch.tensor(env.reset(), dtype = torch.float)
        done1 = False
        while not done1:
            preference0 = model(state0)
            policy0 = torch.nn.functional.softmax(preference0, -1)
            action0 = torch.multinomial(policy0, 1).item()
            state1, reward1, done1, info = env.step(action0)
            state1 = torch.tensor(state1, dtype = torch.float)
            episode.append((state0, action0, reward1))
            state0 = state1
        batch.append(episode)
        batch = sorted(batch, reverse = True, 
                       key = lambda episode: sum([reward1 for (state0, action0, reward1) in episode]))[:32]
        episode = sum(batch, [])
        STATE0 = torch.stack([state0 for (state0, action0, reward1) in episode])
        ACTION0 = torch.tensor([action0 for (state0, action0, reward1) in episode])
        REWARD1 = torch.tensor([reward1 for (state0, action0, reward1) in episode])
        PREFERENCE0 = model(STATE0)
        loss = torch.nn.functional.cross_entropy(PREFERENCE0, ACTION0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("%5d %12.3f %12.3f" % (epoch, len(episode) / 32., REWARD1.sum() / 32.), flush = True)

torch.save(model.state_dict(), "entropy18.pt")

env.close()
