import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from random import randint, uniform, random, choice
from math import isclose
import os
from matplotlib import rcParams
import matplotlib.gridspec as gridspec

# Klasse der relevanten Eigenschaften der Spieler
class Players:
    # Konstruktor zur Initialisierung der Spielereigenschaften
    def __init__(self, n_nodes, initial_coop_prob=0.3):
        # Instanzvariablen (Daten-Member) der Klasse:  [id, x, y, current_strat, payoff, next_strat, last_opponent_strat]
        # data:= [id, x, y, current_strat, payoff, next_strat, last_opponent_strat]
        self.data = np.zeros((n_nodes, 7))
        self.n_nodes = n_nodes

        # Initialisierung
        k = 0
        for i in range(int(np.sqrt(n_nodes))):
            for j in range(int(np.sqrt(n_nodes))):
                self.data[k, 0] = k
                self.data[k, 1] = i
                self.data[k, 2] = j
                self.data[k, 3] = 1 if uniform(0, 1) < initial_coop_prob else 0
                self.data[k, 5] = self.data[k, 3]
                self.data[k, 6] = 1 - self.data[k, 3]  # dummy
                k += 1
    # Öffentlichen Member-Funktionen der Klasse
    def reset_payoffs(self):
        self.data[:, 4] = 0.0

    def set_next_strategy(self, node_idx, strategy):
        self.data[node_idx, 5] = strategy

    def apply_next_strategies(self):
        self.data[:, 3] = self.data[:, 5]

    def get_current_strategy(self, node_idx):
        return self.data[node_idx, 3]

    def get_last_opponent_strategy(self, node_idx):
        return self.data[node_idx, 6]

    def update_last_opponent(self, node_idx, opponent_strategy):
        self.data[node_idx, 6] = opponent_strategy

    def mean_strategy(self):
        return np.mean(self.data[:, 3])

    # Öffentliche Member-Funktion: Strategie-Update nach verschiedenen Regeln (rule)
    def update_strategies(self, graph, rule):
        for node in range(self.n_nodes):
            neighbors = list(graph.neighbors(node))
            if not neighbors:
                continue

            my_payoff = self.data[node, 4]
            my_current = self.get_current_strategy(node)
            # Standardmäßig: next = current (falls keine Änderung)
            self.set_next_strategy(node, my_current)

             # Imitate the Best
            if rule == 0:
                neigh_payoffs = self.data[neighbors, 4]
                max_p = np.max(neigh_payoffs)
                if max_p > my_payoff and not isclose(max_p, my_payoff):
                    best_idx = neighbors[np.argmax(neigh_payoffs)]
                    best_strategy = self.get_current_strategy(best_idx)
                    self.set_next_strategy(node, best_strategy)

            # Zufälliger besserer Nachbar
            elif rule == 1:
                neigh = neighbors[randint(0, len(neighbors)-1)]
                if self.data[neigh, 4] > my_payoff and not isclose(self.data[neigh, 4], my_payoff):
                    self.set_next_strategy(node, self.get_current_strategy(neigh))

            # Tit-for-Tat (Mehrheitsentscheidung)
            elif rule == 2:
                last_moves = [self.get_last_opponent_strategy(n) for n in neighbors]
                coop_count = sum(last_moves)
                new_strat = 1 if coop_count >= len(neighbors)/2 else 0
                self.set_next_strategy(node, new_strat)

            # Fermi-Funktion
            elif rule == 3:
                fermi_K = 0.5 # Temperatur für Fermi-Regel
                neigh = neighbors[randint(0, len(neighbors)-1)]
                neigh_payoff = self.data[neigh, 4]
                neigh_strat = self.get_current_strategy(neigh)
                if my_current != neigh_strat:  # Nur wenn unterschiedlich
                    delta = my_payoff - neigh_payoff
                    prob = 1 / (1 + np.exp(delta / fermi_K))
                    if random() < prob:
                        self.set_next_strategy(node, neigh_strat)

            # # Imitate the Best mit Fehlerrate μ
            elif rule == 4:
                mu = 0.01 # Fehlerrate μ
                neigh_payoffs = self.data[neighbors, 4]
                max_p = np.max(neigh_payoffs)
                if max_p > my_payoff and not isclose(max_p, my_payoff):
                    best_idx = neighbors[np.argmax(neigh_payoffs)]
                    best_strategy = self.get_current_strategy(best_idx)
                    self.set_next_strategy(node, best_strategy)
                if random() < mu:
                    self.set_next_strategy(node, choice([0, 1]))

            else:
                raise ValueError(f"Unbekannte Update-Regel: {rule}")

# Klasse der räumlichen Simulation des Spiels
class SpatialGameSimulation:
    # Konstruktor zur Initialisierung des räumlichen Spiels
    def __init__(self, width=105, height=105, nit=30, rule=1,
                 payoff_params=(3,4,1,5), initial_coop=0.3, output_dir="./output"):
        self.width = width
        self.height = height
        self.nit = nit
        self.rule = rule
        self.a, self.b, self.c, self.d = payoff_params
        self.output_dir = output_dir

        os.makedirs(output_dir, exist_ok=True)

        self.graph = self._create_spatial_grid() # Erzeugung des Networkx Graphen des räumlichen Gitters
        self.players = Players(width * height, initial_coop) # Instanzbildung der Spieler

        self.av_strat_history = [self.players.mean_strategy()]
        self.av_payoff_history = []

        self._setup_plot()

    # Protected Member-Funktion: Erzeugung des räumlichen Gitters
    def _create_spatial_grid(self):
        g = nx.Graph()
        n = self.width * self.height

        for i in range(self.width):
            for j in range(self.height):
                k = i * self.height + j
                g.add_node(k)

                neighbors = [
                    ((i-1)%self.width, (j-1)%self.height),
                    ((i  )%self.width, (j-1)%self.height),
                    ((i+1)%self.width, (j-1)%self.height),
                    ((i+1)%self.width, (j  )%self.height),
                    ((i+1)%self.width, (j+1)%self.height),
                    ((i  )%self.width, (j+1)%self.height),
                    ((i-1)%self.width, (j+1)%self.height),
                    ((i-1)%self.width, (j  )%self.height),
                ]
                for ni, nj in neighbors:
                    neigh_id = ni * self.height + nj
                    if neigh_id != k:
                        g.add_edge(k, neigh_id)
        return g


    def payoff(self, s_a, s_b):
        return (self.a * s_a * s_b +
                self.b * s_a * (1-s_b) +
                self.c * (1-s_a) * s_b +
                self.d * (1-s_a) * (1-s_b))


    def compute_all_payoffs(self):
        self.players.reset_payoffs()

        for u, v in self.graph.edges():
            su = self.players.get_current_strategy(u)
            sv = self.players.get_current_strategy(v)

            pu = self.payoff(su, sv)
            pv = self.payoff(sv, su)

            self.players.data[u, 4] += pu
            self.players.data[v, 4] += pv

            if self.rule == 2:  # Nur für TfT notwendig
                self.players.update_last_opponent(u, sv)
                self.players.update_last_opponent(v, su)


    def update_strategies(self):
        # Weiterleitung an die Klasse Players mit den Parametern
        self.players.update_strategies(self.graph, self.rule)


    def _setup_plot(self):
        rcParams.update({'figure.figsize': [7.5, 10], 'text.usetex': True})
        self.fig = plt.figure()
        gs = gridspec.GridSpec(2, 1, height_ratios=[1, 3.0], hspace=0.1)

        self.ax1 = plt.subplot(gs[0])
        self.ax2 = plt.subplot(gs[1])

        self.ax1.set_xlim(0, self.nit-1)
        self.ax1.set_ylim(0, 1)
        self.ax1.set_ylabel(r'$x(t)$')


    def plot_current_state(self, iteration):
        p = self.players.data
        sgross = np.sqrt(2822400 / self.players.n_nodes)

        colors = ['red' if s == 0 else 'blue' for s in p[:, 3]]
        alpha = [1 if curr == next_s else 0.5
                 for curr, next_s in zip(p[:, 3], p[:, 5])]

        self.ax1.plot(range(len(self.av_strat_history)), self.av_strat_history, 'k-')
        self.ax2.scatter(p[:,1], p[:,2], s=sgross, c=colors, marker="s",
                        alpha=alpha, edgecolor='none')

        self.ax2.set_xlim(-0.5, self.width-0.5)
        self.ax2.set_ylim(-0.5, self.height-0.5)
        self.ax2.set_aspect('equal')

    # Öffentlichen Member-Funktion: Simulation durchführen
    def run(self):
        for it in range(self.nit):
            print(f"Iteration {it:3d}", end=" ... ")

            self.compute_all_payoffs()
            self.update_strategies()

            self.plot_current_state(it)

            mean_strat = self.players.mean_strategy()
            mean_payoff = np.mean(self.players.data[:, 4]) / 8

            self.av_strat_history.append(mean_strat)
            self.av_payoff_history.append(mean_payoff)

            self.players.apply_next_strategies()

            fname = f"img-{it:03d}"
            self.fig.savefig(os.path.join(self.output_dir, f"{fname}.png"),
                           dpi=120, bbox_inches="tight")
            self.ax2.clear()

            print(f"x = {mean_strat:.4f}")

        print("\nSimulation abgeschlossen.")
        print("Mittlere Auszahlungen:", [round(x,3) for x in self.av_payoff_history])


if __name__ == "__main__":
    # Intanzbildung der Simulationsklasse
    sim = SpatialGameSimulation(
        width=105,
        height=105,
        nit=40,
        rule=2,
        payoff_params=(-3, -9, -1, -7),
        initial_coop=0.4,
        output_dir="./output"
    )

    # Durchführen der Simulation
    sim.run()
