#!/usr/bin/env python3
#!/usr/bin/env python3

#
# fitting the function
# f(t) =  sin(t*(1.0+1.5*sin(0.3*t)))
# via linear time prediction
# using Monte-Carlo sampling of the Bayes posterior
#

import torch, math
import matplotlib.pyplot as plt

torch.manual_seed(12)
nSamples = 100
nMC      = 5000        # number of MC steps

#
# synthetic data generation
#
y = torch.ones(nSamples+2)
for i in range(len(y)):
  x = 0.2*i
  y[i] = y[i]*math.sin(x*(1.0+1.5*math.sin(0.3*x))) # no noise needed
x1 = y.roll(shifts=-1)                              # cyclic shift
x2 = y.roll(shifts=-2)
y  =  y[:-2]                                        # cut last two
x1 = x1[:-2]
x2 = x2[:-2]

def log_prior(a, b, c):
    """Log-prior distribution:
       Assume normal priors for a, b, c and 
       HalfNormal for sigma."""
    prior_a = torch.distributions.Normal(0, 10).log_prob(a)
    prior_b = torch.distributions.Normal(0, 10).log_prob(b)
    prior_c = torch.distributions.Normal(0, 10).log_prob(c)
    return prior_a + prior_b + prior_c 

def log_likelihood(a, b, c, sigma, x1, x2, y):
    """Log-likelihood of the data."""
    mean = a + b * x1 + c * x2
    likelihood = torch.distributions.Normal(mean, sigma).log_prob(y)
    return likelihood.sum()

def log_posterior(a, b, c, sigma, x1, x2, y):
    """Log-posterior = Log-prior + Log-likelihood."""
    return log_prior(a, b, c) +\
           log_likelihood(a, b, c, sigma, x1, x2, y)

def metropolis_hastings(x1, x2, y, num_samples, step_size):
    """Metropolis-Hastings sampler."""
    samples = []
    
    # Initialize parameters
    a = torch.tensor(0.0)
    b = torch.tensor(0.0)
    c = torch.tensor(0.0)
    sigma = torch.tensor(1.0)
    
    current_log_posterior = log_posterior(a, b, c, sigma, x1, x2, y)
    
    for iIter in range(num_samples):
        if (iIter%1000==0):
           print(f'# {iIter:6d}')
        # Propose new values; ensure sigma > 0
        a_new = a + torch.randn(1) * step_size
        b_new = b + torch.randn(1) * step_size
        c_new = c + torch.randn(1) * step_size
        sigma_new = sigma * (1.0 + (torch.rand(1)-0.5)*step_size)
        
        # Compute the log-posterior for the proposed values
        proposed_log_posterior =\
           log_posterior(a_new, b_new, c_new, sigma_new, x1, x2, y)
        
        # Accept/reject step
        acceptance_ratio = torch.exp(proposed_log_posterior -\
                           current_log_posterior)
        if torch.rand(1) < acceptance_ratio:
            a, b, c, sigma = a_new, b_new, c_new, sigma_new
            current_log_posterior = proposed_log_posterior
        
        # Store samples
        samples.append((a.item(), b.item(), c.item(), sigma.item()))
    
    return samples

#
# running Metropolis - MC
#
print("Running MCMC...")
samples = metropolis_hastings(x1, x2, y, num_samples=nMC, step_size=0.05)

#
# means and variances
#
samples_tensor = torch.tensor(samples)
a_samples = samples_tensor[:, 0]
b_samples = samples_tensor[:, 1]
c_samples = samples_tensor[:, 2]
s_samples = samples_tensor[:, 3]

mean_a = a_samples.mean().item()
mean_b = b_samples.mean().item()
mean_c = c_samples.mean().item()
mean_s = s_samples.mean().item()

var_a = a_samples.var().item()
var_b = b_samples.var().item()
var_c = c_samples.var().item()
var_s = s_samples.var().item()

print(f'a     mean/variance: {mean_a:8.4f} {var_a:8.3f}')
print(f'b     mean/variance: {mean_b:8.4f} {var_b:8.3f}')
print(f'c     mean/variance: {mean_c:8.4f} {var_c:8.3f}')
print(f'sigma mean/variance: {mean_s:8.4f} {var_s:8.3f}')

#
# plotting posterior distributions
#
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.hist(a_samples.numpy(), bins=30, density=True, alpha=0.7, color='blue')
plt.title("Posterior of a")
plt.xlabel("a")
plt.ylabel("Density")

plt.subplot(1, 3, 2)
plt.hist(b_samples.numpy(), bins=30, density=True, alpha=0.7, color='green')
plt.title("Posterior of b")
plt.xlabel("b")
plt.ylabel("Density")

plt.subplot(1, 3, 3)
plt.hist(c_samples.numpy(), bins=30, density=True, alpha=0.7, color='red')
plt.title("Posterior of c")
plt.xlabel("c")
plt.ylabel("Density")

plt.tight_layout()
plt.show()

#
# visualizing the fit
#
predicted_y = mean_a + mean_b*x1 + mean_c*x2

plt.plot(y, label='True Data')
plt.plot(predicted_y.detach(), label='Model Prediction', 
         linestyle='--')
plt.title('Model Fit vs. True Data')
plt.xlabel('time')
plt.ylabel('y(t)')
plt.legend()
plt.show()
