#!/usr/bin/env python3

import numpy as np
import math
import matplotlib.pyplot as plt

# global parameters
nData = 2000                     # number of training pairs
nIter = 2000                     # number training iterations
nPar  =    4                     # number of fit parameters

learning_rate = 0.5e-2/nData     # relative learning rate
fitPar = []                      # empty list; fit parameters
for i in range(nPar):
  fitPar.append(np.random.randn())  
print(fitPar)

# fitting fuction
def fitFunction(x):
  sum = 0.0
  for i in range(nPar):
    sum += fitPar[i]*(x**i)
  return sum

# linespace returns a list
# training data: y= sin(x)
x = np.linspace(-math.pi, math.pi, nData)
y = np.sin(x)

# training iteration 
for iIter in range(nIter):
  y_pred = fitFunction(x)                  # list; element-wise
  loss = np.square(y_pred - y).sum()       # sum of squared elements

  if iIter % 100 == 99:                    # test printou
    print(f'{iIter:5d}  {loss:10.6f}')

  grad_y_pred = 2.0 * (y_pred - y)         # error signal
  for i in range(nPar):
    gradient = ( grad_y_pred*(x**i) ).sum()
    fitPar[i] -= learning_rate * gradient

# showing result
plt.plot(x, np.sin(x)                , 'b', label="sin(x)")
plt.plot(x, fitFunction(x)           , 'r', label="fit")
plt.plot(x, 0.0*x                    , '--k')
plt.legend()
plt.show()
