#!/usr/bin/env python3

import torch
import math
import matplotlib.pyplot as plt

#
# relu = max(0,x)  layer (rectified linear)
#
class MyLayer(torch.nn.Module):   # inheritance
  def __init__(self, dim):        # constructor
    super().__init__()
    self.weights = torch.randn(dim,dim,requires_grad=True)

  def forward(self, x):           # default forward pass
    return torch.relu(torch.matmul(self.weights,x))

  def forward_tanh(self, x):      # alternative forward pass
    return torch.tanh(torch.matmul(self.weights,x))

  def update(self, eps):          # updating weights
    with torch.no_grad():
      self.weights -= eps*self.weights.grad
      self.weights.grad = None

#
# a single training pair  (myData,myValue)
#
dim           = 4
nIter         = 1000
learningRate  = 1.0e-2
myLayerObject = MyLayer(dim)                 # instanstiation

myData  = torch.FloatTensor(dim).uniform_()
myValue = torch.relu(torch.FloatTensor(dim).uniform_())

print("\n# output")
for iIter in range(nIter):                    # trainning loop
#
  output = myLayerObject(myData)              # forward pass (implicit)
# output = myLayerObject.forward(myData)      # forward pass (explicit)
# output = myLayerObject.forward_tanh(myData) # forward pass (specific)
#
  loss   = (output-myValue).pow(2).sum()
  loss.backward()                             # backward pass
  myLayerObject.update(learningRate)          # weight updating
  print(output.data)
print("\n# myValue")
print(myValue)
