#!/usr/bin/env python3

import torch                   
import math
import matplotlib.pyplot as plt


myType   = torch.float
myDevice = (
    "cuda"                       # for GPUs
    if torch.cuda.is_available()
    else "mps"                   # 'MultiProcessor Specification'
    if torch.backends.mps.is_available()
    else "cpu"                   # plain old CPU
)

# global parameters
nData = 2000                     # number of training pairs
nIter = 2000                     # number training iterations
nPar  =    4                     # number of fit parameters
learning_rate = 0.5e-2/nData    

# gradients with respect to figPar[] to be evaluated
fitPar = []                      # list of 1x1 tensors
for i in range(nPar):      
  fitPar.append(torch.randn((), device=myDevice, dtype=myType,\
                                requires_grad=True))  
print(fitPar)

x = torch.linspace(-math.pi, math.pi, nData)
y = torch.sin(x)                 # element-wise

def fitFunction(x):              # polynomial fitting fuction
  sum = 0.0
  for i in range(nPar):
    sum += fitPar[i]*(x**i)      # element-wise, x is a tensor
  return sum                     # returns a tensor

# training iteration
for iIter in range(nIter):
  y_pred     = fitFunction(x)              # forward pass
  lossTensor = (y_pred - y).pow(2).sum()   # element-wise pow(2)

  if iIter % 100 == 99:                    # print scalar loss value
    print(f'{iIter:5d}  {lossTensor.item():10.6f}')

# backward pass 
# calculates gradients, viz 'tensor.grad',
# with respect to tensors with "requires_grad=True"
  lossTensor.backward()                    

# temporarily 'detaching' all tensors for by-hand updating
# the value of   fitPar[i].grad   is not affected
  with torch.no_grad():        
    for i in range(nPar):                  # gradients via backward pass
      fitPar[i] -= learning_rate * fitPar[i].grad
      fitPar[i].grad = None

# "detach" tensors requiring gradients from computation graph
plt.plot(x, torch.sin(x)                    , 'b', label="sin(x)")
plt.plot(x, fitFunction(x).detach().numpy() , 'r', label="fit")
plt.plot(x, 0.0*x                           , '--k')
plt.legend()
plt.show()
