#!/usr/bin/env p ython3

import torch
import math
import matplotlib.pyplot as plt

#
# tanh layer module
#
class MyLayer(torch.nn.Module):      
  def __init__(self, dim, yesTheta): 
    super().__init__()
    self.weights  = torch.randn(dim,dim,requires_grad=True)
    self.theta    = torch.randn(dim,    requires_grad=True)
    self.yesTheta = yesTheta          # 1/0 with/without thresholds

  def forward(self, x):               # unsqueezing threshold vector
    tt = torch.unsqueeze(self.yesTheta*self.theta, 1)
    return torch.tanh(torch.matmul(self.weights,x)-self.yesTheta*tt)

  def update(self, eps):              # updating internal parameters
    with torch.no_grad():
      self.weights     -= eps*self.weights.grad
      self.theta       -= eps*self.theta.grad*self.yesTheta
      self.weights.grad = None
      self.theta.grad   = None

#
# a single training pair  (myData,myValue)
#
dim           = 4
nData         = 3
nIter         = 1000
learningRate  = 5.0e-2
myLayerObject = MyLayer(dim, 1.0)            # instanstiation

myData  = torch.FloatTensor(dim,nData).uniform_()
myValue = torch.relu(torch.FloatTensor(dim,nData).uniform_())

print("\n# output")
for iIter in range(nIter):                   # trainning loop
  output = myLayerObject(myData)             # forward pass (implicit)
  loss   = (output-myValue).pow(2).sum()
  loss.backward()                            # backward pass
  myLayerObject.update(learningRate)         # weight updating
  print(loss.item())
print("\n# output")
print(output.data)
print("\n# myValue")
print(myValue)
