from matplotlib import pyplot as plt
from scipy import misc
import torch

data = misc.ascent()
DATA = torch.tensor(data).float()[None, None, :, :]

print(DATA.dtype, DATA.shape, flush = True)
plt.imshow(DATA[0, 0], cmap = 'gray')

FILTER = torch.tensor([[[[-1., -2., -1.],
                         [ 0.,  0.,  0.],
                         [ 1.,  2.,  1.]]],

                       [[[ 1.,  0., -1.],
                         [ 2.,  0., -2.],
                         [ 1.,  0., -1.]]]])

ACTIVATION = torch.nn.functional.conv2d(DATA, FILTER)

ACTIVITY = ACTIVATION.square().sum(1, keepdim = True).sqrt()

print(ACTIVITY.shape, flush = True)
plt.imshow(ACTIVITY[0, 0], cmap = 'gray')

plt.show()
