#!/usr/bin/env python3

import torch
import math
import numpy as np
import matplotlib.pyplot as plt

#
# global variables
#
dimOutput      = 1            # only 1 implemented
dimHidden      = 40
dimInput       = 1            # only 1 implemented
nHidden        = 2            # at least one
nData          = 20           # number training pairs
nBatch         = 20       
nEpoch         = 1000
learningRate   = 1.0e-3       # eta
xMax           = 3.0          # for data / plotting

#
# general layer
#
class MyLayer(torch.nn.Module):   
  def __init__(self, dim1, dim2):
    super().__init__()            
    self.weights = torch.nn.Parameter(torch.zeros(dim1,dim2))
    self.bias    = torch.nn.Parameter(torch.zeros(dim1))   # to be adapted

    mySigma = 1.0/math.sqrt(dim2)    # scaling of weights
    torch.nn.init.normal_(self.weights, mean=0.0, std=mySigma)

  def forward(self, x):              # tanh unit
    return torch.tanh(torch.matmul(self.weights,x)-self.bias)

  def forward_linear(self, x):       # linear unit
    return torch.matmul(self.weights,x) - self.bias

#
# target: Bell curve and beyond
#
def target_curve(x):
# return torch.exp(-0.5*x.pow(2)) / math.sqrt(2.0*math.pi)
  return torch.sin(x.pow(2)) + torch.cos(x)

#
# fixed training data
#
dataInput = torch.zeros((nData,dimInput))
dataInput[:,0] = torch.linspace(-xMax,xMax,nData)
dataValue = target_curve( dataInput[:,0] )

#
# instantiate model
#
allHidden = [None for iH in range(nHidden)]
allHidden[0] = MyLayer(dimHidden,dimInput)
for iH in range(1,nHidden):
  allHidden[iH] = MyLayer(dimHidden,dimHidden)

layerOut = MyLayer(dimOutput,dimHidden)

#
# instantiate optimizer
# SGD: stochastic gradient descent
#
allOptim = [None for iH in range(nHidden)]
for iH in range(nHidden):
  allOptim[iH] = torch.optim.SGD(allHidden[iH].parameters(),
                 lr=learningRate,momentum=0.7)

optimOut = torch.optim.Adam(layerOut.parameters(),lr=learningRate)

#
# define forward pass
#
def modelForward(myInput):
  hidden = allHidden[0](myInput)            
  for iH in range(1,nHidden):
    hidden = allHidden[iH](hidden)
  return layerOut.forward_linear(hidden)  

#
# training loop
#
for iEpoch in range(nEpoch):                    # trainning loop
  randIntArray = np.random.randint(nData, size=nBatch)
  for iBatch in range(nBatch):    
    batchInput = dataInput[randIntArray[iBatch],:]
    batchValue = dataValue[randIntArray[iBatch]]
    output = modelForward(batchInput)           # forward pass
    trainingLoss = (output-batchValue).pow(2).sum()  
    trainingLoss.backward()                     # backward pass
 
  for iH in range(nHidden):
    allOptim[iH].step()                          # adapting parameters
    allOptim[iH].zero_grad()                     # zero gradients 
  optimOut.step()
  optimOut.zero_grad()

  if (iEpoch%int(nEpoch/20)==0):
    print(f'{iEpoch:7d} {trainingLoss:9.5f}')

# 
# testing
#
nPlot = 100
xPlot = [-xMax + iPlot*2.0*xMax/nPlot for iPlot in range(nPlot)]
yPlot = [0.0 for _ in range(nPlot)]
zPlot = [0.0 for _ in range(nPlot)]
 
testInput = torch.zeros(dimInput)
for iPlot in range(nPlot):
  testInput[0] = xPlot[iPlot]
  testOutput = modelForward(testInput)  # forward pass with test data
 
  yPlot[iPlot] = target_curve( testInput[0] ).item()
  zPlot[iPlot] = testOutput[0].item()

if (1==2):
  for iPlot in range(nPlot):
    print(xPlot[iPlot],yPlot[iPlot],zPlot[iPlot])

xPoints = [ dataInput[ii,0] for ii in range(nData)]
yPoints = [ dataValue[ii]   for ii in range(nData)]

#
# plotting
#
plt.plot(xPlot,   yPlot,   'k',   label="data curve")
plt.plot(xPoints, yPoints, '.r',  label="data points", markersize=8)
plt.plot(xPlot,   zPlot,   '--b', label="inference", linewidth=3.0)
plt.legend()
plt.xlabel('input activity')
plt.ylabel('output activity')
plt.savefig('foo.svg') 
plt.show()
