import torch

torch.set_printoptions(precision = 2, linewidth = 120, sci_mode = False)

states = 10
actions = 5

def step(state0, action0):
    state1 = (action0 + state0) % states
    reward1 = -1.
    done1 = (state1 == 0)
    return state1, reward1, done1

model = torch.nn.Sequential(
    torch.nn.Linear(states, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, actions))

epochs = 1000
epsilon = 1.
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
for epoch in range(epochs):
    state0 = torch.randint(1, states, ())
    done1 = False
    while not done1:
        data0 = torch.eq(state0, torch.arange(states)).float()
        quality0 = model(data0)
        action0 = torch.randint(actions, ()) if torch.rand(()) < epsilon else quality0.argmax()
        state1, reward1, done1 = step(state0, action0)
        data1 = torch.eq(state1, torch.arange(states)).float()
        quality1 = model(data1)
        target0 = reward1 + (not done1) * quality1.max()
        loss = torch.nn.functional.mse_loss(quality0[action0], target0.detach())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        state0 = state1
        print("%4d %12.3f" % (epoch, loss), flush = True)
    epsilon = epsilon - 1. / epochs if 0.1 < epsilon else epsilon

data = torch.eye(states)
quality = model(data)
action = quality.argmax(1)
print(quality)
print(action)
