import torch

torch.set_printoptions(precision = 2, linewidth = 120)

states = 10
actions = 5
rewards = 2

system = torch.zeros(states, rewards, states, actions)
for action in range(actions):
    for state in range(states):
        system[(action + state) % states, 1, state, action] = 1.
system[:, :, 0, :] = 0.
system[0, 0, 0, :] = 1.

quality = torch.zeros(states, actions)
for epoch in range(100):
    action = quality.argmax(1)
    policy = torch.eq(action, torch.arange(actions)[:, None]).float()
    value = torch.tensordot(policy, quality, ([0], [1])).diagonal()
    quality = torch.tensordot(system, value - torch.arange(rewards)[:, None], ([0, 1], [1, 0]))
    print(value, flush = True)
