#!/usr/bin/env python3

#
# recurrent net performing a prediction task
#

import torch
import math
import random 
import numpy as np
import matplotlib.pyplot as plt

#
# global variables
#
dimOutput      = 1            # only 1 implemented
dimHidden      = 40
nData          = 20           # number function values
nPlot          = 20           # needs to be identical!
nIter          = 1000
learningRate   = 4.0e-2    
xMax           = 3.0          # for data / plotting
Delta_T        = 3            # number of time steps to predict

#
# general layer
#
class MyLayer(torch.nn.Module):    # inheritance
  def __init__(self, dim1, dim2):  # constructor
    super().__init__()            
    self.w = torch.zeros(dim1,dim2,requires_grad=True)  # feed forward
    self.v = torch.zeros(dim1,dim1,requires_grad=True)  # recurrent
    self.bias    = torch.zeros(dim1,requires_grad=True)

    self.hidden_activity = torch.zeros(dim1)   # hidden activity

    sigma_w = 1.0/math.sqrt(dim2)       
    sigma_v = 1.0/math.sqrt(dim1)  
    torch.nn.init.normal_(self.w, mean=0.0, std=sigma_w)
    torch.nn.init.normal_(self.v, mean=0.0, std=sigma_v)

  def forward(self, x):            # default forward pass
    yy = torch.tanh(torch.matmul(self.w,x) +
                    torch.matmul(self.v,self.hidden_activity)*1.0 -
                    self.bias)
    self.hidden_activity = yy.detach()         # store hidden activity
    return yy

  def forward_linear(self, x):     # linear unit
    return torch.matmul(self.w,x) - self.bias

  def update_hidden(self, eps):    # updating 
    with torch.no_grad():
      self.w    -= eps*self.w.grad 
      self.v    -= eps*self.v.grad 
      self.bias -= eps*self.bias.grad   
      self.w.grad    = None
      self.v.grad    = None
      self.bias.grad = None

  def update_linear(self, eps):    # no recurrent connections
    with torch.no_grad():
      self.w    -= eps*self.w.grad 
      self.bias -= eps*self.bias.grad   
      self.w.grad    = None
      self.bias.grad = None

#
# target: Bell curve and beyond
#
def target_curve(x):
  return torch.exp(-0.5*x.pow(2)) / math.sqrt(2.0*math.pi)
# return torch.sin(x.pow(2)) + torch.cos(x)

#
# new training data, using random starting point
#
def trainingData(nPoints):
 startX = -xMax + xMax*0.1*random.random()    
 endX   = startX + 2.0*xMax 
 deltaX = 2.0*xMax/(nPoints-1.0)
 startY = startX + Delta_T*deltaX
 endY   =   endX + Delta_T*deltaX
#
 inputPoints    = torch.linspace(startX, endX, nPoints)
 inputFunction  = target_curve( inputPoints )
 outputPoints   = torch.linspace(startY ,endY, nPoints)
 outputFunction = target_curve( outputPoints )
 return inputPoints, inputFunction, outputPoints, outputFunction

#
# instantiate model, define forward pass
#
layerHidden = MyLayer(dimHidden,1) 
layerOutput = MyLayer(dimOutput,dimHidden)  

def modelForward(myInput):
  hidden = layerHidden(myInput)              # calling defaulf forward pass
  return layerOutput.forward_linear(hidden)  # linear output units

#
# training loop
#
for iIter in range(nIter):                      # trainning loop
 
  inPoints, inFunction, outPoints, outFunction = trainingData(nData)
#
  if iIter==-1:
    for iData in range(nData):
      print(inPoints[iData].item(), inFunction[iData].item())

  trainingLoss = 0.0                            # loss is added
  for iData in range(nData):                    # data points == batch

# function approximation
#    trainInput = inPoints[iData].unsqueeze(0)  # add dimension
#    trainValue = inFunction[iData]
 
# function prediction
     trainInput = inFunction[iData].unsqueeze(0)  
     trainValue = outFunction[iData]
 
     output = modelForward(trainInput)          # forward pass
     trainingLoss += (output-trainValue).pow(2).sum()  
#
  trainingLoss.backward()                       # backward pass
  layerHidden.update_hidden(learningRate/nData)   
  layerOutput.update_linear(learningRate/nData)  
#
  tenPercent = int(nIter/10) 
  if (iIter%tenPercent==0):
    print(f'{iIter:7d}', trainingLoss.tolist())

# 
# preparing plots
#
inPoints, inFunction, outPoints, outFunction = trainingData(nPlot)
in__points_Plot = inPoints.tolist()
out_points_Plot = outPoints.tolist()
inference_Plot = [0.0 for _ in range(nPlot)]
in__F_Plot = inFunction.tolist()
out_F_Plot = outFunction.tolist()
for iPlot in range(nPlot):
#  testInput = inPoints[iPlot].unsqueeze(0)
   testInput = inFunction[iPlot].unsqueeze(0)
   inference_Plot[iPlot] = modelForward(testInput).item()
 
#
# plotting
#
plt.plot(in__points_Plot,   in__F_Plot,   'k', label="original curve")
plt.plot(in__points_Plot,   out_F_Plot,   'g', label="shifted curve")
plt.plot(in__points_Plot,inference_Plot, '.r', label="inference", markersize=10)
plt.legend()
plt.xlabel('input activity')
plt.ylabel('output activity')
plt.savefig('foo.svg') 
plt.show()
