#!/usr/bin/env python3

import torch
import math
import matplotlib.pyplot as plt

# global parameters
nData    =  4                      # number of training pairs
nLayer   =  2                      # number of layers
unitsPerLayer = 2
b        =  0.9                    # (+b) / (-b)  : logical True/False

nIter   =    6000                  # training iterations
learning_rate = 1.5e-2

#
# tanh layer module, could be non-squared
#
class TanhLayer(torch.nn.Module):            # inherintance
  def __init__(self, dim_out, dim_in):       # constructor
    super().__init__()            
    self.weights  = torch.randn(dim_out,dim_in,requires_grad=True)
    self.theta    = torch.randn(dim_out,       requires_grad=True)

  def forward(self, x):            # define forward pass
    return torch.tanh(torch.matmul(self.weights,x)-self.theta)

  def update(self, eps):           # updating internal parameters
    with torch.no_grad():
      self.weights     -= eps*self.weights.grad
      self.theta       -= eps*self.theta.grad
      self.weights.grad = None
      self.theta.grad   = None

#
# n-idential layer model
#
allLayers = [TanhLayer(unitsPerLayer, unitsPerLayer) for _ in range(nLayer)]
def model(x):
  for iLayer in range(nLayer):     
    x = allLayers[iLayer](x)
  return x

#
# ( unitsPerLayer | nData )  tensor of training data
# element-wise mapping of uniform distribution [0,1] to binary
# automatic casting of boolean  (..>..)  to  0/1
#
allTraining_data   = torch.FloatTensor(nData,unitsPerLayer).uniform_()
allTraining_value  = torch.FloatTensor(nData,unitsPerLayer).uniform_()
for iData in range(nData):
  for unit in range(2):                   # boolean first two units
    allTraining_data[iData][unit] =\
      b*(2.0*(allTraining_data[iData][unit]>0.5)-1.0)

#
# (cros-idenity|XOR) for first two units
#

if (unitsPerLayer>0) and (nData>0) and (1==1): 
  allTraining_data[0][0]  =  b
  allTraining_data[0][1]  =  b
  allTraining_value[0][0] =  b 
  allTraining_value[0][1] = -b 

if (unitsPerLayer>0) and (nData>1) and (1==1): 
  allTraining_data[1][0]  =  b
  allTraining_data[1][1]  = -b
  allTraining_value[1][0] = -b
  allTraining_value[1][1] =  b

if (unitsPerLayer>0) and (nData>2) and (1==1): 
  allTraining_data[2][0]  = -b
  allTraining_data[2][1]  =  b
  allTraining_value[2][0] =  b
  allTraining_value[2][1] =  b

if (unitsPerLayer>0) and (nData>3) and (1==1): 
  allTraining_data[3][0]  = -b
  allTraining_data[3][1]  = -b
  allTraining_value[3][0] = -b
  allTraining_value[3][1] = -b
if (1==1):
  print("\n# traing data/value")
  print(allTraining_data)
  print(allTraining_value)

#
# explict sum allows for experiments
#
def lossFunction(outputActivity, targetActivity):
  loss = torch.zeros(1)
  for ii in range(list(outputActivity.size())[0]):     # casting to list
# for ii in range(2):                                  # for testing
    loss += ( outputActivity[ii] - targetActivity[ii] ).pow(2)
  return loss

#
# iterating over itentical batch of training data
#
batchLoss = 0.0                        # average loss 
for iIter in range(nIter):
  iData = iIter%nData                  # go through all training data
  training_data = torch.clone(allTraining_data[iData])
  loss = lossFunction(model(training_data), allTraining_value[iData])
  loss.backward()
#
  batchLoss += loss.item()
  if (iData==0):
    if iIter % 200 == 0:
      print(f'{iIter:5d}  {batchLoss:10.6f}')
    batchLoss = 0
#
  for iLayer in range(nLayer):         # parameter updating
    allLayers[iLayer].update(learning_rate)

#
# performance testing
#
print()
for iData in range(nData):
  training_data = torch.clone(allTraining_data[iData])
  output        = model(training_data)
#
  xIn  = training_data[0]
  yIn  = training_data[1]
  xVal = allTraining_value[iData][0].item()
  yVal = allTraining_value[iData][1].item()
  xOut =        output[0]
  yOut =        output[1]
  print(f'{xIn:6.3f} {yIn:6.3f} | ', end="")
  print(f'{xVal:6.3f} {yVal:6.3f} || ', end="")
  print(f'{xOut:6.3f} {yOut:6.3f}')
