#!/usr/bin/env python3

#
# Baysian Monte Carlo sampling for 2D linear model
# y = true_a + true_b*x1 + true_c*x2 + noise
#

import torch
import matplotlib.pyplot as plt

# Generate synthetic data
torch.manual_seed(42)

# Parameters
true_a = 1.0
true_b = 2.0
true_c = -1.5
n_samples = 1000

# Generate predictors and response
x1 = torch.rand(n_samples) * 10  # Random values between 0 and 10
x2 = torch.rand(n_samples) * 5   # Random values between 0 and 5
noise = torch.randn(n_samples) * 0.5
y = true_a + true_b * x1 + true_c * x2 + noise

def log_prior(a, b, c):
    """Log-prior distribution:
       Assume normal priors for a, b, c and 
       HalfNormal for sigma."""
    prior_a = torch.distributions.Normal(0, 10).log_prob(a)
    prior_b = torch.distributions.Normal(0, 10).log_prob(b)
    prior_c = torch.distributions.Normal(0, 10).log_prob(c)
    return prior_a + prior_b + prior_c 

def log_likelihood(a, b, c, sigma, x1, x2, y):
    """Log-likelihood of the data."""
    mean = a + b * x1 + c * x2
    likelihood = torch.distributions.Normal(mean, sigma).log_prob(y)
    return likelihood.sum()

def log_posterior(a, b, c, sigma, x1, x2, y):
    """Log-posterior = Log-prior + Log-likelihood."""
    return log_prior(a, b, c) +\
           log_likelihood(a, b, c, sigma, x1, x2, y)

def metropolis_hastings(x1, x2, y, num_samples=5000, step_size=0.1):
    """Metropolis-Hastings sampler."""
    samples = []
    
    # Initialize parameters
    a = torch.tensor(0.0)
    b = torch.tensor(0.0)
    c = torch.tensor(0.0)
    sigma = torch.tensor(1.0)
    
    current_log_posterior = log_posterior(a, b, c, sigma, x1, x2, y)
    
    for iIter in range(num_samples):
        if (iIter%1000==0):
           print(f'# {iIter:6d}')
        # Propose new values; ensure sigma > 0
        a_new = a + torch.randn(1) * step_size
        b_new = b + torch.randn(1) * step_size
        c_new = c + torch.randn(1) * step_size
        sigma_new = sigma * (1.0 + (torch.rand(1)-0.5)*step_size)
        
        # Compute the log-posterior for the proposed values
        proposed_log_posterior =\
           log_posterior(a_new, b_new, c_new, sigma_new, x1, x2, y)
        
        # Accept/reject step
        acceptance_ratio = torch.exp(proposed_log_posterior -\
                           current_log_posterior)
        if torch.rand(1) < acceptance_ratio:
            a, b, c, sigma = a_new, b_new, c_new, sigma_new
            current_log_posterior = proposed_log_posterior
        
        # Store samples
        samples.append((a.item(), b.item(), c.item(), sigma.item()))
    
    return samples

# Run MCMC
print("Running MCMC...")
samples = metropolis_hastings(x1, x2, y, num_samples=5000, step_size=0.05)

# Extract samples
samples_tensor = torch.tensor(samples)
a_samples = samples_tensor[:, 0]
b_samples = samples_tensor[:, 1]
c_samples = samples_tensor[:, 2]
s_samples = samples_tensor[:, 3]

# Compute Means and Variances
mean_a = a_samples.mean().item()
mean_b = b_samples.mean().item()
mean_c = c_samples.mean().item()
mean_s = s_samples.mean().item()

var_a = a_samples.var().item()
var_b = b_samples.var().item()
var_c = c_samples.var().item()
var_s = s_samples.var().item()

# Print Results
print(f'a     mean/variance: {mean_a:8.4f} {var_a:8.3f} | {true_a:5.2f}')
print(f'b     mean/variance: {mean_b:8.4f} {var_b:8.3f} | {true_b:5.2f}')
print(f'c     mean/variance: {mean_c:8.4f} {var_c:8.3f} | {true_c:5.2f}')
print(f'sigma mean/variance: {mean_s:8.4f} {var_s:8.3f}')

# Plot Posterior Distributions
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.hist(a_samples.numpy(), bins=30, density=True, alpha=0.7, color='blue')
plt.title("Posterior of a")
plt.xlabel("a")
plt.ylabel("Density")

plt.subplot(1, 3, 2)
plt.hist(b_samples.numpy(), bins=30, density=True, alpha=0.7, color='green')
plt.title("Posterior of b")
plt.xlabel("b")
plt.ylabel("Density")

plt.subplot(1, 3, 3)
plt.hist(c_samples.numpy(), bins=30, density=True, alpha=0.7, color='red')
plt.title("Posterior of c")
plt.xlabel("c")
plt.ylabel("Density")

plt.tight_layout()
plt.show()
