#!/usr/bin/env python3

import torch                     # PyTorch instead of NumPy
import math
import matplotlib.pyplot as plt


myType   = torch.float
myDevice = torch.device("cpu")   # "cuda:0" for GPU; not activated

# global parameters
nData = 2000                     # number of training pairs
nIter = 2000                     # number training iterations
nPar  =    4                     # number of fit parameters

learning_rate = 0.5e-2/nData     # relative learning rate
fitPar = []                      # empty list; fit parameters
for i in range(nPar):            # randn() : normal distribution
  fitPar.append(torch.randn((), device=myDevice, dtype=myType))
print(fitPar)

def fitFunction(x):              # polynomial fitting function 
  sum = 0.0
  for i in range(nPar):
    sum += fitPar[i]*(x**i)
  return sum

# linspace returns a list
x = torch.linspace(-math.pi, math.pi, nData, device=myDevice, dtype=myType)
y = torch.sin(x)                 # target function y = sin(x)

# training iteration
for iIter in range(nIter):
  y_pred = fitFunction(x)                  # list; element-wise
  loss   = torch.square(y_pred - y).sum()  # sum of squared elements

  if iIter % 100 == 99:                    # test printout
    print(f'{iIter:5d}  {loss:10.6f}')

  grad_y_pred = 2.0 * (y_pred - y)         # error signal
  for i in range(nPar):                    # least-square fit
    gradient = ( grad_y_pred*(x**i) ).sum()
    fitPar[i] -= learning_rate * gradient

# showing result
plt.plot(x, torch.sin(x)             , 'b', label="sin(x)")
plt.plot(x, fitFunction(x)           , 'r', label="polynomial fit")
plt.plot(x, 0.0*x                    , '--k')
plt.legend()
plt.show()
