#!/usr/bin/env python3

import torch
import math
import matplotlib.pyplot as plt

#
# wrapping (x+p)^2 inside a module
#
class MyLayer(torch.nn.Module):   # inherintance
  def __init__(self, p):          # constructor
    super().__init__()            
    self.p = p

  def forward(self, x):           # define forward pass
    return torch.dot(x+self.p,x+self.p)

#
# main start
#
myLayerObject = MyLayer(2.0)      # instanstiation
input = torch.ones(3,requires_grad=True) 
output = myLayerObject(input)     # forward pass (implicit)
output.backward()                 
print("\n# input")
print(input)
print("\n# output")
print(output)
print("\n# input.grad")
print(input.grad)
