#!/usr/bin/env python3

# fitting the function
# f(t) =  sin(t*(1.0+1.5*sin(0.3*t)))
# via linear time prediction
#
# straightfoward least-square fit
# minimized using PyTorch

import torch, math
import matplotlib.pyplot as plt

nSamples   = 100
nIter      = 1000
nPar       = 3
paraMeters = torch.randn(nPar, requires_grad=True)
eps        = 0.05

def allResiduals(f, parMet):
  nF   = len(f)
  nPar = len(parMet)
  nRes = nF+1-nPar
  allRes = torch.zeros(nRes)
  for i in range(nPar-1, len(f)):
    predicted = parMet[0]
    for pp in range(1,len(parMet)):
      predicted = predicted + parMet[pp]*f[i-pp]
    allRes[i+1-nPar] = f[i] - predicted
  return allRes

#
# main: synthetic data; generation and plotting
#
f = torch.ones(nSamples)
for i in range(nSamples):
  x = 0.2*i
  f[i] = f[i]*math.sin(x*(1.0+1.5*math.sin(0.3*x)))

if (1==2):
  plt.plot(f)
  plt.title('f(t) = sin(t*(1.0+1.5*sin(0.3*t)))')
  plt.xlabel('time')
  plt.ylabel('f(t)')
  plt.legend()
  plt.show()

#
# main: minimizing least-square loss
#
for iIter in range(nIter):
  errors = allResiduals(f, paraMeters)
  loss = errors.pow(2).sum()/len(errors)
  loss.backward()
  with torch.no_grad():
    paraMeters -= eps*paraMeters.grad   
    paraMeters.grad = None   
  if (iIter%100==0):
    print(f'{iIter:5d} {loss.item():8.4f}')

#
# main: output
#
a_est = paraMeters[0]
b_est = paraMeters[1]
c_est = paraMeters[2]
print("Estimated paraMeters:")
print(f'# a = {a_est:7.3f}')
print(f'# b = {b_est:7.3f}')
print(f'# c = {c_est:7.3f}')

#
# visualizing the fit
#
errors = allResiduals(f, paraMeters)
small_f = f[nPar-1:]
predicted_f = small_f - errors

plt.plot(small_f, label='True Data')
plt.plot(predicted_f.detach(), label='Model Prediction', 
         linestyle='--')
plt.title('Model Fit vs. True Data')
plt.xlabel('time')
plt.ylabel('f(t)')
plt.legend()
plt.show()
