import torch

def reset():
    state0 = torch.zeros(4, 4, 4)
    row0 = torch.randint(4, ()).item()
    column0 = torch.randint(4, ()).item()
    state0[3, row0, column0] = 1. #pit
    while True:
        row0 = torch.randint(4, ()).item()
        column0 = torch.randint(4, ()).item()
        if state0[:, row0, column0].sum() == 0.:
            break
    state0[2, row0, column0] = 1. #goal
    while True:
        row0 = torch.randint(4, ()).item()
        column0 = torch.randint(4, ()).item()
        if state0[:, row0, column0].sum() == 0.:
            break
    state0[1, row0, column0] = 1. #wall
    while True:
        row0 = torch.randint(4, ()).item()
        column0 = torch.randint(4, ()).item()
        if state0[:, row0, column0].sum() == 0.:
            break
    state0[0, row0, column0] = 1. #player
    return row0, column0, state0

def step(row0, column0, state0, action0):
    row1, column1, state1 = row0, column0, state0.clone()
    state1[0, row1, column1] = 0.
    if action0 == 0 and row1 + 1 < 4 and state1[1, row1 + 1, column1] == 0.:
        row1 += 1
    if action0 == 1 and 0 < row1 and state1[1, row1 - 1, column1] == 0.:
        row1 -= 1
    if action0 == 2 and column1 + 1 < 4 and state1[1, row1, column1 + 1] == 0.:
        column1 += 1
    if action0 == 3 and 0 < column1 and state1[1, row1, column1 - 1] == 0.:
        column1 -= 1
    state1[0, row1, column1] = 1.
    if state1[3, row1, column1] == 1.:
        reward1 = -10.
        done1 = True
    elif state1[2, row1, column1] == 1.:
        reward1 = +10.
        done1 = True
    else:
        reward1 = -1.
        done1 = False
    return row1, column1, state1, reward1, done1

model = torch.nn.Sequential(
    torch.nn.Flatten(-3, -1),
    torch.nn.Linear(64, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 4))

epochs = 1000
epsilon = 1.
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
    row0, column0, state0 = reset()
    step1 = 0
    gain1 = 0.
    done1 = False
    while not done1:
        quality0 = model(state0)
        action0 = torch.randint(4, ()) if torch.rand(()) < epsilon else quality0.argmax()
        row1, column1, state1, reward1, done1 = step(row0, column0, state0, action0)
        step1 += 1
        gain1 += reward1
        quality1 = model(state1)
        target0 = reward1 + 0.9 * (not done1) * quality1.max()
        loss = torch.nn.functional.mse_loss(quality0[action0], target0.detach())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        row0, column0, state0 = row1, column1, state1
    print("%4d %4d %12.3f" % (epoch, step1, gain1), flush = True)
    epsilon = epsilon - 1. / epochs if 0.1 < epsilon else epsilon

epochs = 1000
wins = 0
for epoch in range(epochs):
    row0, column0, state0 = reset()
    step1 = 0
    done1 = False
    while not done1 and step1 < 16:
        quality0 = model(state0)
        action0 = quality0.argmax()
        row1, column1, state1, reward1, done1 = step(row0, column0, state0, action0)
        step1 += 1
        row0, column0, state0 = row1, column1, state1
    if state0[2, row0, column0] == 1.:
        wins += 1
print(wins / epochs)

#success rate: 442 498 510 546 611
