import gym
import torch

env = gym.make("CartPole-v1")

model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 2))
model.load_state_dict(torch.load("target.pt"))

state0 = torch.tensor(env.reset(), dtype = torch.float)
done1 = False
while not done1:
    env.render()
    quality0 = model(state0)
    action0 = quality0.argmax()
    state1, reward1, done1, info = env.step(action0.item())
    state1 = torch.tensor(state1, dtype = torch.float)
    state0 = state1

env.close()
