#!/usr/bin/env python3

import torch
import math
import numpy as np
import matplotlib.pyplot as plt

#
# global variables
#
dimOutput      = 1            # only 1 implemented
dimHidden      = 40
dimInput       = 1            # only 1 implemented
nData          = 20           # number training pairs
nBatch         = 20
nEpoch         = 1000
learningRate   = 4.0e-2
xMax           = 3.0          # for data / plotting

#
# general layer
#
class MyLayer(torch.nn.Module):    # inheritance
  def __init__(self, dim1, dim2):  # constructor
    super().__init__()
    self.weights = torch.zeros(dim1,dim2,requires_grad=True)
    self.bias    = torch.zeros(dim1,     requires_grad=True)

    mySigma = 1.0/math.sqrt(dim2)  # scaling of weights
    torch.nn.init.normal_(self.weights, mean=0.0, std=mySigma)

  def forward(self, x):            # tanh unit
    return torch.tanh(torch.matmul(self.weights,x)-self.bias)

  def forward_linear(self, x):     # linear unit
    return torch.matmul(self.weights,x) - self.bias

  def update(self, eps): 
    with torch.no_grad():
      self.weights -= eps*self.weights.grad
      self.bias    -= eps*self.bias.grad
      self.weights.grad = None
      self.bias.grad    = None

#
# target: Bell curve and beyond
#
def target_curve(x):
  return torch.exp(-0.5*x.pow(2)) / math.sqrt(2.0*math.pi)
# return torch.sin(x.pow(2)) + torch.cos(x)

#
# fixed training data
#
dataInput = torch.zeros((nData,dimInput))
dataInput[:,0] = torch.linspace(-xMax,xMax,nData)
dataValue = target_curve( dataInput[:,0] )
# print("\n# dataInput", dataInput.shape, "\n", dataInput)
# print("\n# dataValue", dataValue.shape, "\n", dataValue)

#
# instantiate model, define forward pass
#
layerHidden = MyLayer(dimHidden,dimInput)
layerOutput = MyLayer(dimOutput,dimHidden)

def modelForward(myInput):
  hidden = layerHidden(myInput)              # forward pass (implicit)
  return layerOutput.forward_linear(hidden)  # linear output units

#
# training loop
#
for iEpoch in range(nEpoch):                          # trainning loop
  randIntArray = np.random.randint(nData,size=nBatch) # random sampling
# print("\n# randIntArray\n", randIntArray)
#
  for iBatch in range(nBatch):              
     batchInput = dataInput[randIntArray[iBatch],:]
     batchValue = dataValue[randIntArray[iBatch]]
     output = modelForward(batchInput)                # forward pass
     trainingLoss = (output-batchValue).pow(2).sum()
     trainingLoss.backward()                          # backward pass

  layerHidden.update(learningRate/nBatch)     # updating
  layerOutput.update(learningRate/nBatch)     # gradients
#   print("# ", iIter, trainingLoss.tolist())
  tenPercent = int(nEpoch/10)
  if (iEpoch%tenPercent==0):
    print(f'{iEpoch:7d} {trainingLoss:9.5f}')

#
# testing
#
nPlot = 100
xPlot = [-xMax + iPlot*2.0*xMax/nPlot for iPlot in range(nPlot)]
yPlot = [0.0 for _ in range(nPlot)]
zPlot = [0.0 for _ in range(nPlot)]

testInput = torch.zeros(dimInput)
for iPlot in range(nPlot):
  testInput[0] = xPlot[iPlot]
  testOutput = modelForward(testInput)  # forward pass with test data

  yPlot[iPlot] = target_curve( testInput[0] ).item()
  zPlot[iPlot] = testOutput[0].item()

if (1==2):
  for iPlot in range(nPlot):
    print(xPlot[iPlot],yPlot[iPlot],zPlot[iPlot])

xPoints = [ dataInput[ii,0] for ii in range(nData)]
yPoints = [ dataValue[ii]   for ii in range(nData)]

#
# plotting
#
plt.plot(xPlot,   yPlot,   'k',   label="data curve")
plt.plot(xPoints, yPoints, '.r',  label="data points", markersize=8)
plt.plot(xPlot,   zPlot,   '--b', label="inference", linewidth=3.0)
plt.legend()
plt.xlabel('input activity')
plt.ylabel('output activity')
plt.savefig('foo.svg')
plt.show()
