#!/usr/bin/env python3

import torch
import math
import numpy as np
import matplotlib.pyplot as plt

#
# global variables
#
dimOutput      = 1            # only 1 implemented
dimHidden      = 40
dimInput       = 1            # only 1 implemented
nHidden        = 2            # number of hidden layers
nData          = 20           # number training pairs
nBatch         = 20
nEpoch         = 2000
learningRate   = 4.0e-2       # eta
momemtum_mu    = 0.8          # for momentum updating
xMax           = 3.0          # for data / plotting

#
# general layer
#
class MyLayer(torch.nn.Module):
  def __init__(self, dim1, dim2, mu=0.0):
    super().__init__()
    self.weights = torch.zeros(dim1,dim2,requires_grad=True)
    self.bias    = torch.zeros(dim1,requires_grad=True)

    mySigma = 1.0/math.sqrt(dim2)  # scaling of weights
    torch.nn.init.normal_(self.weights, mean=0.0, std=mySigma)

    self.weights_v = torch.zeros(dim1,dim2)  # associated
    self.bias_v    = torch.zeros(dim1)       # velocities
    self.mu = mu     # mometum update parameter [0,1]

  def forward(self, x):            # tanh unit
    return torch.tanh(torch.matmul(self.weights,x)-self.bias)

  def forward_linear(self, x):     # linear unit
    return torch.matmul(self.weights,x) - self.bias

  def update(self, eps): 
    with torch.no_grad():
      self.weights_v = self.mu*self.weights_v  \
                     - eps*self.weights.grad             # update
      self.bias_v    = self.mu*self.bias_v     \
                     - eps*self.bias.grad                # velocities

      self.weights += self.weights_v
      self.bias    += self.bias_v

      self.weights.grad = None
      self.bias.grad    = None

#
# target: Bell curve and beyond
#
def target_curve(x):
# return torch.exp(-0.5*x.pow(2)) / math.sqrt(2.0*math.pi)
  return torch.sin(x.pow(2)) + torch.cos(x)

#
# fixed training data
#
dataInput = torch.zeros((nData,dimInput))
dataInput[:,0] = torch.linspace(-xMax,xMax,nData)
dataValue = target_curve( dataInput[:,0] )
# print("\n# dataInput\n", dataInput)
# print("\n# dataValue\n", dataValue)

#
# instantiate model, define forward pass
#

allHidden = [None for iH in range(nHidden)]
allHidden[0] = MyLayer(dimHidden,dimInput,momemtum_mu)
for iH in range(1,nHidden):
  allHidden[iH] = MyLayer(dimHidden,dimHidden,momemtum_mu)
layerOutput = MyLayer(dimOutput,dimHidden,momemtum_mu)

def modelForward(myInput):
  hidden = allHidden[0](myInput)             # input -> first hidden
  for iH in range(1,nHidden):
    hidden = allHidden[iH](hidden)
  return layerOutput.forward_linear(hidden)  # linear output units

#
# training loop
#
for iEpoch in range(nEpoch):                 # trainning loop
  randIntArray = np.random.randint(nData, size=nBatch)
# print("\n# randIntArray\n", randIntArray)
  for iBatch in range(nBatch):      
    batchInput = dataInput[randIntArray[iBatch],:]
    batchValue = dataValue[randIntArray[iBatch]]
    output = modelForward(batchInput)        # forward pass
    trainingLoss = (output-batchValue).pow(2).sum()
    trainingLoss.backward()                  # backward pass

  for iH in range(nHidden):
    allHidden[iH].update(learningRate/nBatch)
  layerOutput.update(learningRate/nBatch)

  if (iEpoch%int(nEpoch/20)==0):
    print(f'{iEpoch:7d} {trainingLoss:9.5f}')

#
# testing
#
nPlot = 100
xPlot = [-xMax + iPlot*2.0*xMax/nPlot for iPlot in range(nPlot)]
yPlot = [0.0 for _ in range(nPlot)]
zPlot = [0.0 for _ in range(nPlot)]

testInput = torch.zeros(dimInput)
for iPlot in range(nPlot):
  testInput[0] = xPlot[iPlot]
  testOutput = modelForward(testInput)  # forward pass with test data

  yPlot[iPlot] = target_curve( testInput[0] ).item()
  zPlot[iPlot] = testOutput[0].item()

if (1==2):
  for iPlot in range(nPlot):
    print(xPlot[iPlot],yPlot[iPlot],zPlot[iPlot])

xPoints = [ dataInput[ii,0] for ii in range(nData)]
yPoints = [ dataValue[ii]   for ii in range(nData)]

#
# plotting
#
plt.plot(xPlot,   yPlot,   'k',   label="data curve")
plt.plot(xPoints, yPoints, '.r',  label="data points", markersize=8)
plt.plot(xPlot,   zPlot,   '--b', label="inference", linewidth=3.0)
plt.legend()
plt.xlabel('input activity')
plt.ylabel('output activity')
plt.savefig('foo.svg')
plt.show()
