import torch

torch.manual_seed(0)
torch.set_printoptions(precision = 2, linewidth = 120)

states = 10
actions = 5
rewards = 7

system = torch.rand(states, rewards, states, actions)
system /= system.sum((0, 1))
system[:, :, 0, :] = 0.
system[0, 0, 0, :] = 1.

policy = torch.rand(actions, states)
policy /= policy.sum(0)

quality = torch.zeros(states, actions)
for counter in range(100):
    value = torch.tensordot(policy, quality, ([0], [1])).diagonal()
    quality = torch.tensordot(system, value - torch.arange(rewards)[:, None], ([0, 1], [1, 0]))
action = quality.argmax(1)
policy = torch.eq(action, torch.arange(actions)[:, None]).float()
