import torch

tensor = torch.tensor([1, 2, 3], device = "cuda")
print(tensor)
torch.save(tensor, "tensor.pt")

tensor = torch.load("tensor.pt")
print(tensor)

model = torch.nn.Sequential(
    torch.nn.Linear(2, 4),
    torch.nn.Tanh(),
    torch.nn.Linear(4, 2)).cuda()
print(model)
print(next(model.parameters()).device)
torch.save(model, "model.pt")

model = torch.load("model.pt")
print(model)
print(next(model.parameters()).device)

model = torch.nn.Sequential(
    torch.nn.Linear(2, 4),
    torch.nn.Tanh(),
    torch.nn.Linear(4, 2))
dictionary = model.state_dict()
print(dictionary)
torch.save(dictionary, "dictionary.pt")

dictionary = torch.load("dictionary.pt")
print(dictionary)
model = torch.nn.Sequential(
    torch.nn.Linear(2, 4),
    torch.nn.Tanh(),
    torch.nn.Linear(4, 2)).cuda()
model.load_state_dict(dictionary)
dictionary = model.state_dict()
print(dictionary)
