import torch

torch.set_printoptions(precision = 2, linewidth = 120)

states = 10
actions = 5
rewards = 2

system = torch.zeros(states, rewards, states, actions)
for action in range(actions):
    for state in range(states):
        system[(action + state) % states, 1, state, action] = 1.
system[:, :, 0, :] = 0.
system[0, 0, 0, :] = 1.

quality = torch.zeros(states, actions)
for epoch in range(1000):
    state0 = torch.randint(states, ())
    action0 = torch.randint(actions, ())
    index1 = torch.multinomial(system[:, :, state0, action0].flatten(), 1)
    state1, reward1 = index1 // rewards, -(index1 % rewards)
    target0 = reward1 + quality[state1].max()
    quality[state0, action0] = quality[state0, action0] - (quality[state0, action0] - target0)

action = quality.argmax(1)
print(quality)
print(action)
