#!/usr/bin/env python3
#!/usr/bin/env python3

import torch
import math
import matplotlib.pyplot as plt

#
# tanh layer
#
class MyLayer(torch.nn.Module):    # inheritance
  def __init__(self, dim1, dim2):  # constructor
    super().__init__()
    self.weights = torch.randn(dim1,dim2,requires_grad=True)
    self.bias    = torch.randn(dim1,requires_grad=True)

  def forward(self, x):            # define forward pass
    return torch.tanh(torch.matmul(self.weights,x)-self.bias)

  def update(self, eps, nBatch):   # updating weights / bias
    with torch.no_grad():
      self.weights -= eps*self.weights.grad / nBatch
      self.bias    -= eps*self.bias.grad    / nBatch
      self.weights.grad = None
      self.bias.grad    = None

#
# main
#
dimOutput      = 1            # only 1 implemented
dimHidden      = 2
dimInput       = 2            # only 2 implemented
nBatch         = 4            # only 4 implemented
nEpoch         = 1000 
learningRate   = 4.0e-2
myLayerObject  = MyLayer(dimHidden,dimInput)   # instanstiation
myOutputObject = MyLayer(1,dimHidden)


# XOR for 2 inputs
booleanInput = torch.tensor([ [ 1.0, 1.0],
                              [ 1.0,-1.0],
                              [-1.0, 1.0],
                              [-1.0,-1.0] ])

booleanValue = torch.tensor([ [-1.0],
                              [ 1.0],
                              [ 1.0],
                              [-1.0] ])

print(booleanInput)
print(booleanValue)

#
# training loop
#
for iEpoch in range(nEpoch):                   # trainning loop
  for iBatch in range(nBatch):                 # a batch for every epoch
#
    thisInput  = booleanInput[iBatch]
    thisTarget = booleanValue[iBatch]
#
    hidden = myLayerObject(thisInput)          # forward pass (implicit)
    output = myOutputObject(hidden)
    loss   = (output-thisTarget).pow(2).sum()  # generic loss function

#--- alternative loss function
#--- ** just the sign has to be correct, may work only
#---    for some initial conditions (qualitatively)
#   loss   = torch.relu(0.75-output*thisTarget).pow(2).sum()

    loss.backward()                            # adding gradients

    if iEpoch>(nEpoch-9):
      print(f'{thisInput.tolist()[0]:7.3f}'    ,end="")
      print(f'{thisInput.tolist()[1]:7.3f}'    ,end="")
      print(f'{thisTarget.tolist()[0]:7.3f} ||',end="")
      print(f'{output.tolist()[0]:7.3f}')
      if iBatch==(nBatch-1):
        print()
#
  myLayerObject.update(learningRate,nBatch)    # gradients have
  myOutputObject.update(learningRate,nBatch)   # been summed up

# end of training
