##################################
# Python-Programm "Spatial Games"
##################################

import networkx as nx
import matplotlib.pyplot as plt
from random import randint, uniform
import numpy as np
from matplotlib import rcParams
import matplotlib.gridspec as gridspec
from scipy.integrate import solve_ivp
import os

# Erstellen eines raumlichen 2D-Gittergraph (8 Nachbarn)
def create_grid(width, height):
    g = nx.Graph()
    nkn = width * height
    for i in range(width):
        for j in range(height):
            k = i * height + j
            g.add_node(k)
    for i in range(width):
        for j in range(height):
            center = i * height + j
            neighbors = [
                ((i-1) % width) * height + (j-1) % height,  # upperleft
                (i % width) * height + (j-1) % height,      # up
                ((i+1) % width) * height + (j-1) % height,  # upperright
                ((i+1) % width) * height + (j % height),    # right
                ((i+1) % width) * height + (j+1) % height,  # lowerright
                (i % width) * height + (j+1) % height,      # low
                ((i-1) % width) * height + (j+1) % height,  # lowerleft
                ((i-1) % width) * height + (j % height)     # left
            ]
            for neigh in neighbors:
                if neigh != center:  # Vermeide Selbst-Edges
                    g.add_edge(center, neigh)
    return g

# Initialisiert Spieler-Array P mit Positionen und Strategien
def initialize_players(g, width, height, x_0=0.3):
    nkn = width * height
    p = np.zeros((nkn, 6))  # Spalten-Inhalt: [id, x, y, current_strat, payoff, next_strat]
    k = 0
    for i in range(width):
        for j in range(height):
            p[k, 0] = k
            p[k, 1] = i
            p[k, 2] = j
            if uniform(0, 1) < x_0:
                p[k, 3] = 1
            else:
                p[k, 3] = 0
            p[k, 5] = p[k, 3]
            k += 1
    return p

# Definition der symmetrischen Spielmatrix (in reinen (sA,sB=0,1) und gemischten Strategien)
def payoff(s_a, s_b, a, b, c, d):
    return a * s_a * s_b + b * s_a * (1 - s_b) + c * (1 - s_a) * s_b + d * (1 - s_a) * (1 - s_b)

# Berechnet Auszahlungen für alle Spieler
def compute_payoffs(g, p, a, b, c, d):
    for u, v in g.edges():
        p[u, 4] += payoff(p[u, 3], p[v, 3], a, b, c, d)
        p[v, 4] += payoff(p[v, 3], p[u, 3], a, b, c, d)

# Neue Strategien-Wahl
def update_strategy(p, g, rule):
    nkn = len(p)
    for k in range(nkn):
        neighbors = list(g.neighbors(k))
        if not neighbors:
            continue
        # Rule 0: Imitiere besten Nachbarn
        if rule == 0:
            neigh_payoffs = p[neighbors, 4]
            max_payoff = np.max(neigh_payoffs)
            if max_payoff > p[k, 4]:
                best_idx = np.argmax(neigh_payoffs)
                p[k, 5] = p[neighbors[best_idx], 3]
        # Rule 1: Imitiere zufälligen besseren Nachbarn
        elif rule == 1:
            j = randint(0, len(neighbors) - 1)
            neigh_payoff = p[neighbors[j], 4]
            if neigh_payoff > p[k, 4]:
                p[k, 5] = p[neighbors[j], 3]
        else:
            raise ValueError(f"Ungültige Regel: {rule}. Verwende 0 oder 1.")

# Wendet zukünftige Strategien an, berechnet Mittelwert des Populationsvektors und mittleren Payoffs und setzt Payoffs zurück
def apply_updates(p):
    mean_payoff = np.mean(p[:, 4]) / 8
    p[:, 4] = 0
    p[:, 3] = p[:, 5]
    return np.mean(p[:, 3]), mean_payoff

# Analytische Lösung der Replikator-DGL dx/dt = g(x)
def analytical_solution(a, b, c, d, t_span, initial_x, num_points=500):
    def g(x, a, b, c, d):
        return ((a - c) * (x - x**2) + (b - d) * (1 - 2*x + x**2)) * x

    def dgl(t, x, a, b, c, d):
        return g(x, a, b, c, d)

    t_eval = np.linspace(t_span[0], t_span[1], num_points)
    sol = solve_ivp(dgl, t_span, [initial_x], args=(a, b, c, d), t_eval=t_eval)
    return sol.t, sol.y[0]

# Plotted den aktuellen Zustand der Simulation
def plot_simulation(ax1, ax2, it, av_strat, p, width, height):
    nkn = len(p)
    sgross = np.sqrt( 2822400 / nkn ) # Groesse der Spieler Kaestchen
    col = ['r' if s == 0 else 'b' for s in p[:, 3]]
    col_new = ['r' if s == 0 else 'b' for s in p[:, 5]]
    alpha = [1 if p[k, 3] == p[k, 5] else 0.5 for k in range(nkn)]

    #Mittelwert des Populationsvektors
    ax1.plot(range(len(av_strat)), av_strat, c="black")

    # Grid-Plot
    x_pos, y_pos = p[:, 1], p[:, 2]
    ax2.scatter(x_pos, y_pos, s=sgross, c=col, marker="s", alpha=alpha, edgecolor='none')

    ax2.set_xlim(-0.5, width - 0.5)
    ax2.set_ylim(-0.5, height - 0.5)
    ax2.set_aspect('equal', adjustable='box')

# Führt die Simulation auf dem 2D-Gitter aus
def run_simulation(width=105, height=105, nit=30, rule=1, a=3, b=4, c=1, d=5, x_0=0.3, output_dir="./output_2"):
    os.makedirs(output_dir, exist_ok=True)

    # Netzwerk und Spieler initialisieren
    g = create_grid(width, height)
    p = initialize_players(g, width, height, x_0)
    av_strat = [np.mean(p[:, 3])]
    av_dollar = []

    # Analytische Lösung
    t_span = [0, nit]
    t_ana, x_ana = analytical_solution(a, b, c, d, t_span, av_strat[0])
    # Zusätzliche Startbedingungen
    initials = [0.9, 0.7, 0.5, 0.3, 0.1]
    ana_seq = []
    for init_x in initials:
        t_seq, x_seq = analytical_solution(a, b, c, d, t_span, init_x)
        ana_seq.append(x_seq)

    # Plot-Setup
    rcParams.update({
        'figure.figsize': [7.5, 10],
        'text.usetex': True,
        'legend.fontsize': 12
    })
    plt.figure(0)
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 3.0], hspace=0.1)
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])

    ax1.set_xlim(0, nit - 1)
    ax1.set_ylim(0, 1)
    ax1.set_ylabel(r'$\rm x(t)$')

    # Analytische Kurven
    ax1.plot(t_ana, x_ana, c="grey", linewidth=1.5, linestyle=':')
    ax1.plot(0, av_strat[0], marker='o', color='grey', markersize=4)
    for x_seq in ana_seq:
        ax1.plot(t_seq, x_seq, c="lightblue", linewidth=1, linestyle=':')

    # Simulationsplot

    # Simulation-Schleife
    for it in range(0, nit):
        print(f"Iteration {it} ---------------------------------------------------")
        compute_payoffs(g, p, a, b, c, d)
        update_strategy(p, g, rule)
        plot_simulation(ax1, ax2, it, av_strat, p, width, height)
        mean_strat, mean_dollar = apply_updates(p)
        av_strat.append(mean_strat)
        av_dollar.append(mean_dollar)

        # Speichern (PNG und PDF)
        filename = f"img-{'{:0>3d}'.format(it)}"
        plt.savefig(os.path.join(output_dir, f"{filename}.png"), dpi=100, bbox_inches="tight", pad_inches=0.05)
#        plt.savefig(os.path.join(output_dir, f"{filename}.pdf"), bbox_inches="tight", pad_inches=0.05)
        ax2.clear()

    plt.close()
    print("Mittelwerte der Auszahlungen: ", [round(x, 3) for x in av_dollar])
    print("Simulation abgeschlossen. Plots in", output_dir, "gespeichert.")

# Festlegung der Simulationsparameter
# Dominantes Spiel
# run_simulation(width=105, height=105, nit=30, rule=1, a=-7, b=-1, c=-9, d=-3, x_0=0.1, output_dir="./output")

# Koordinationsspiel
# run_simulation(width=105, height=105, nit=30, rule=1, a=3, b=4, c=1, d=5, x_0=0.4, output_dir="./output")

# Anti-Koordinationsspiel
run_simulation(width=105, height=105, nit=30, rule=1, a=-3, b=4, c=1, d=2, x_0=0.8, output_dir="./output")

