from matplotlib import pyplot as plt
from scipy import misc
import torch

data = misc.ascent()
DATA = torch.tensor(data).float()[None, None, :, :]

print(DATA.dtype, DATA.shape, flush = True)
plt.imshow(DATA[0, 0], cmap = 'gray')

FILTER0 = torch.tensor([[[[-1., -2., -1.],
                          [ 0.,  0.,  0.],
                          [ 1.,  2.,  1.]]]])
FILTER1 = torch.tensor([[[[ 1.,  0., -1.],
                          [ 2.,  0., -2.],
                          [ 1.,  0., -1.]]]])

ACTIVATION0 = torch.nn.functional.conv2d(DATA, FILTER0)
ACTIVATION1 = torch.nn.functional.conv2d(DATA, FILTER1)

ACTIVITY = torch.sqrt(ACTIVATION0**2 + ACTIVATION1**2)

print(ACTIVITY.shape, flush = True)
plt.imshow(ACTIVITY[0, 0], cmap = 'gray')

plt.show()
