#!/usr/bin/env python3

#
# Monte Carlo Bayes when two coins are flipped 
#

import numpy as np
import matplotlib.pyplot as plt

def simulate_data(num_trials, b1, b2):
    """generating simulated data"""
    results = []
    chosen_coins = []
    for _ in range(num_trials):
        coin = np.random.choice([0, 1])  # 0 for coin1, 1 for coin2
        chosen_coins.append(coin)
        if coin == 0:
            results.append(np.random.rand() < b1)
        else:
            results.append(np.random.rand() < b2)
    return np.array(results), np.array(chosen_coins)

def log_likelihood(data, b1, b2):
    """log likelihood for tossing coins"""
    log_lik = 0
    for result in data:
        coin = np.random.choice([0, 1])  # Randomly select coin
        p = b1 if coin == 0 else b2
        log_lik += np.log(p if result else 1 - p)
    return log_lik

class VariationalMetropolis:
    """all in one class"""
    def __init__(self, data, initial_b1=0.5, initial_b2=0.5, proposal_scale=0.05):
        self.data = data
        self.b1 = initial_b1
        self.b2 = initial_b2
        self.proposal_scale = proposal_scale
        self.samples = []

    def propose(self):
        """selecting potential new b1 and b2 values"""
        eps = 0.001
        b1_large = self.b1 + np.random.normal(0, self.proposal_scale)
        b2_large = self.b2 + np.random.normal(0, self.proposal_scale)
        new_b1 = np.clip(b1_large, eps, 1.0-eps)
        new_b2 = np.clip(b2_large, eps, 1.0-eps)
        return new_b1, new_b2

    def acceptance_ratio(self, new_b1, new_b2):
        """calculating the acceptance ratio"""
        old_ll = log_likelihood(self.data, self.b1, self.b2)
        new_ll = log_likelihood(self.data, new_b1, new_b2)
        return np.exp(new_ll - old_ll)

    def run(self, iterations):
        tenPercent = int(iterations/10)
        for ii in range(iterations):
            if (ii%tenPercent==0):
              print(f'# walking {ii:6d}')
            new_b1, new_b2 = self.propose()
            alpha = self.acceptance_ratio(new_b1, new_b2)
            if np.random.rand() < alpha:
                self.b1, self.b2 = new_b1, new_b2

            self.samples.append((self.b1, self.b2))
        return self.samples

#
# main
#
true_b1 = 0.2
true_b2 = 0.2
num_trials = 100
n_MC       = 10000

# generating data
print(f'# tossing coins {num_trials:d} times')
observed_data, chosen_coins = simulate_data(num_trials, true_b1, true_b2)

# running variational Metropolis
vm = VariationalMetropolis(observed_data)
print(f'# Monte Carlo walk in paramter space')
samples = vm.run(n_MC)

# extracting samples
samples = np.array(samples)
b1_samples = samples[:, 0]
b2_samples = samples[:, 1]

# ploting results
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(b1_samples, bins=50, density=True, color='blue',
         alpha=0.7, label='Posterior of b1')
plt.axvline(true_b1, color='red', linestyle='--', label='True b1')
plt.xlabel('b1')
plt.ylabel('Density')
plt.legend()

plt.subplot(1, 2, 2)
plt.hist(b2_samples, bins=50, density=True, color='green',
         alpha=0.7, label='Posterior of b2')
plt.axvline(true_b2, color='red', linestyle='--', label='True b2')
plt.xlabel('b2')
plt.ylabel('Density')
plt.legend()

plt.tight_layout()
plt.show()
