##################################
# Python-Programm "Spatial (2x3)-Games"
##################################

import networkx as nx
import matplotlib.pyplot as plt
from random import randint, uniform
from math import isclose
import numpy as np
from matplotlib import rcParams
import matplotlib.gridspec as gridspec
from scipy.integrate import solve_ivp
import os
from sympy import symbols, Matrix, Eq, transpose, solve
import matplotlib.colors as colors

# baryzentrisches Dreiecks-Koordinatensystem
def xy(vx):
    return [vx[1]+vx[2]/2,vx[2]]

# Funktion zur Berechnung der Nashgleichgewichte
def find_nash_equilibria(D):
    equilibria = []
    # Berechnung der reinen Nash-Gleichgewichte
    for i in range(D.shape[0]):
        for j in range(D.shape[0]):
            if D[i, j] == max(D[:, j]):
                if D[j, i] == max(D[:, i]):
                    equilibria.append({'type': 'pure', 's': (i+1,j+1)})
    #Berechnung der gemischten Nash-Gleichgewichte
    Loes_GN = []
    x1,x2,x3,y1,y2,y3 = symbols('x_1,x_2,x_3,y_1,y_2,y_3')
    xs = Matrix([x1,x2,x3])
    ys = Matrix([y1,y2,y3])
    Dollar_A = transpose(xs)*D*ys
    Dollar_As = Dollar_A.subs(x3,1-x1-x2).subs(y3,1-y1-y2)[0]
    Dollar_As_1 = Dollar_A.subs(x1,0).subs(x3,1-x2).subs(y3,1-y1-y2)[0]
    Dollar_As_2 = Dollar_A.subs(x2,0).subs(x3,1-x1).subs(y3,1-y1-y2)[0]
    Dollar_As_3 = Dollar_A.subs(x3,0).subs(x2,1-x1).subs(y3,1-y1-y2)[0]
    GemNash_Eq1 = Eq(Dollar_As.diff(x1),0)
    GemNash_Eq2 = Eq(Dollar_As.diff(x2),0)
    GemNash_Eq_1 = Eq(Dollar_As_1.diff(x2),0)
    GemNash_Eq_2 = Eq(Dollar_As_2.diff(x1),0)
    GemNash_Eq_3 = Eq(Dollar_As_3.diff(x1),0)
    Bed=Eq(1,y1+y2+y3)
    Loes_GN.append(solve([GemNash_Eq1,GemNash_Eq2,Bed]))
    Bed_a=Eq(0,y1)
    Bed_b=Eq(1,y2+y3)
    Loes_GN.append(solve([GemNash_Eq_1,Bed_a,Bed_b]))
    Bed_a=Eq(0,y2)
    Bed_b=Eq(1,y1+y3)
    Loes_GN.append(solve([GemNash_Eq_2,Bed_a,Bed_b]))
    Bed_a=Eq(0,y3)
    Bed_b=Eq(1,y1+y2)
    Loes_GN.append(solve([GemNash_Eq_3,Bed_a,Bed_b]))
    for l in Loes_GN:
        if l and 0 <= l[y1] < 1 and 0 <= l[y2] < 1 and 0 <= l[y3] < 1 :
            equilibria.append({'type': 'mixed', 's*': (l[y1],l[y2],l[y3])})
    return equilibria

# Erstellen eines raumlichen 2D-Gittergraph (8 Nachbarn)
def create_grid(width, height):
    g = nx.Graph()
    nkn = width * height
    for i in range(width):
        for j in range(height):
            k = i * height + j
            g.add_node(k)
    for i in range(width):
        for j in range(height):
            center = i * height + j
            neighbors = [
                ((i-1) % width) * height + (j-1) % height,  # upperleft
                (i % width) * height + (j-1) % height,      # up
                ((i+1) % width) * height + (j-1) % height,  # upperright
                ((i+1) % width) * height + (j % height),    # right
                ((i+1) % width) * height + (j+1) % height,  # lowerright
                (i % width) * height + (j+1) % height,      # low
                ((i-1) % width) * height + (j+1) % height,  # lowerleft
                ((i-1) % width) * height + (j % height)     # left
            ]
            for neigh in neighbors:
                if neigh != center:  # Vermeide Selbst-Edges
                    g.add_edge(center, neigh)
    return g

# Initialisiert Spieler-Array P mit Positionen und Strategien
def initialize_players(width, height, x_init):
    nkn = width * height
    p = np.zeros((nkn, 6))  # Spalten-Inhalt: [id, x, y, current_strat, payoff, next_strat]
    k = 0
    for i in range(width):
        for j in range(height):
            p[k, 0] = k
            p[k, 1] = i
            p[k, 2] = j
            zufall = uniform(0, 1)
            if zufall <= x_init[0]:
                p[k, 3] = 0 #->Strategie s_1
            elif zufall > (x_init[0]+x_init[1]):
                p[k, 3] = 2 #->Strategie s_3
            else:
                p[k, 3] = 1 #->Strategie s_2
            p[k, 5] = p[k, 3]
            k += 1
    return p

# Berechnet Auszahlungen für alle Spieler
def compute_payoffs(g, p, D):
    for u, v in g.edges():
        p[u, 4] += D[int(p[u,3]),int(p[v,3])]
        p[v, 4] += D[int(p[v,3]),int(p[u,3])]

# Neue Strategien-Wahl
def update_strategy(p, g, rule):
    nkn = len(p)
    for k in range(nkn):
        neighbors = list(g.neighbors(k))
        if not neighbors:
            continue
        # Rule 0: Imitiere besten Nachbarn
        if rule == 0:
            neigh_payoffs = p[neighbors, 4]
            max_payoff = np.max(neigh_payoffs)
            if max_payoff > p[k, 4] and not isclose(max_payoff, p[k, 4]):
                best_idx = np.argmax(neigh_payoffs)
                p[k, 5] = p[neighbors[best_idx], 3]
        # Rule 1: Imitiere zufälligen besseren Nachbarn
        elif rule == 1:
            j = randint(0, len(neighbors) - 1)
            neigh_payoff = p[neighbors[j], 4]
            if neigh_payoff > p[k, 4] and not isclose(neigh_payoff, p[k, 4]):
                p[k, 5] = p[neighbors[j], 3]
        else:
            raise ValueError(f"Ungültige Regel: {rule}. Verwende 0 oder 1.")

# Wendet zukünftige Strategien an, berechnet Mittelwert des Populationsvektors und mittleren Payoffs und setzt Payoffs zurück
def apply_updates(p):
    nkn = len(p)
    mean_payoff = np.mean(p[:, 4]) / 8
    mean_strat_3 = [np.count_nonzero(p[:,3] == 0)/nkn, np.count_nonzero(p[:,3] == 1)/nkn, np.count_nonzero(p[:,3] == 2)/nkn]
    p[:, 4] = 0
    p[:, 3] = p[:, 5]
    return xy(mean_strat_3),  mean_payoff

# Definition der Funktionen g_x und g_y
def g_xy(xy,D):
    m = 3
    x = np.array([1-xy[0]-xy[1]/2,xy[0]-xy[1]/2,xy[1]])
    dx_dt = []
    for i in range(m):
        dx_dt.append(sum(D[i,j]*x[i]*x[j] for j in range(m)) - sum(sum(D[k,j]*x[k]*x[j] for j in range(m)) for k in range(m))*x[i])
    return [dx_dt[1]+dx_dt[2]/2,dx_dt[2]]

# Analytische Lösung der Replikator-DGL
def analytical_solution(D, t_span, x_init, num_points=500):
    # Definition des DGL-Systems
    def DGLsys(t, x):
        x = np.asarray(x)
        Dx = D @ x
        u = np.dot(x, Dx)
        return x * (Dx - u)

    t_eval = np.linspace(t_span[0], t_span[1], num_points)
    sol = solve_ivp(DGLsys, t_span, x_init, t_eval=t_eval, rtol=10**(-13), atol=10**(-13))
    return sol.t, sol.y

# Plotted den aktuellen Zustand der Simulation
def plot_simulation(ax1, ax2, av_strat, p, width, height):
    nkn = len(p)
    sgross = np.sqrt( 2822400 / nkn ) # Groesse der Spieler Kaestchen
    col = ['r' if s == 0 else 'b' if s == 1 else 'g' for s in p[:, 3]]
    col_new = ['r' if s == 0 else 'b' if s == 1 else 'g' for s in p[:, 5]]
    alpha = [1 if p[k, 3] == p[k, 5] else 0.5 for k in range(nkn)]

    #Mittelwert des Populationsvektors
    ax1.plot(av_strat[0], av_strat[1],c="black",linewidth=2)

    # Grid-Plot
    x_pos, y_pos = p[:, 1], p[:, 2]
    #ax2.scatter(x_pos, y_pos, s=sgross, c=col, marker="s", alpha=alpha, edgecolor='none')
    ax2.scatter(x_pos, y_pos, s=sgross, c=col, marker="s", edgecolor='none')

    ax2.set_xlim(-0.5, width - 0.5)
    ax2.set_ylim(-0.5, height - 0.5)
    ax2.set_aspect('equal', adjustable='box')

# Führt die Simulation auf dem 2D-Gitter aus
def run_simulation(width=105, height=105, nit=15, rule=1, D = np.array([[0,-1,-1],[1,0,-3],[-1,-3,0]]), x_init=[0.25,0.5,0.25], output_dir="./pics"):
    os.makedirs(output_dir, exist_ok=True)

    # Netzwerk und Spieler initialisieren
    g = create_grid(width, height)
    p = initialize_players(width, height, x_init)
    av_strat_x = [xy(x_init)[0]]
    av_strat_y = [xy(x_init)[1]]
    av_dollar = []

    # Analytische Lösung
    t_span = [0, nit]
    t_ana, x_ana = analytical_solution(D, t_span, x_init)

    # Plot-Setup
    rcParams.update({
        'figure.figsize': [8,14],
        'text.usetex': True,
        'legend.fontsize': 12
    })
    plt.figure(0)
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1.5], hspace=0.2)
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])

    # Fuer die Darstellung des Feldliniendiagramms streamplot
    Y, X = np.mgrid[0:1:100j, 0:1:100j]
    gXY = g_xy([X,Y],D)
    # Die Farbe wird die Geschwindigkeit der Aenderung des Populationsvektors anzeigen
    colorspeed = np.sqrt(gXY[0]**2 + gXY[1]**2)

    ax1.set_xlabel(r"$\rm x$")
    ax1.set_ylabel(r"$\rm y$")
    figure = ax1.streamplot(X, Y, gXY[0], gXY[1], linewidth=1,density=[2, 2],norm=colors.Normalize(vmin=0.,vmax=0.4), color=colorspeed, cmap=plt.cm.jet)
    ax1.fill([0,0.5,0], [0,1,1], facecolor='black')
    ax1.fill([1,0.5,1], [0,1,1], facecolor='black')
    ax1.fill([0,1,1,0], [0,0,-0.02,-0.02], facecolor='black')
    ax1.fill([0,1,1,0], [1,1,1.02,1.02], facecolor='black')
    ax1.scatter(xy(x_init)[0],xy(x_init)[1], s=50, marker='o', c="black")
    ax1.plot(xy(x_ana[0:3])[0], xy(x_ana[0:3])[1],c="grey",linewidth=1)
    # Erzeugung der nebenstehenden Farblegende colorbar
    cbar=plt.colorbar(figure.lines, aspect=20, ax=ax1)
    cbar.set_label(r'$\left| \vec{v} \right| =\sqrt{{g_x}^2 + {g_y}^2}$',size=20)
    ax1.set_xlim(0, 1)
    ax1.set_ylim(-0.02, 1.02)

    # Berechnung der Nashgleichgewichte
    equilibria = find_nash_equilibria(D)
    print("Nash-Gleichgewichte:")
    for eq in equilibria:
        if eq['type'] == 'pure':
            print(f"Reines: Spieler 1 Strategie {eq["s"][0]}, Spieler 2 Strategie {eq["s"][1]}")
            # Kennzeichnung der reinen Nashgleichgewichte im Bild
            if eq["s"][0] == 1 and eq["s"][1] == 1:
                ax1.scatter(0,0, s=150, marker='h', c="red")
            if eq["s"][0] == 2 and eq["s"][1] == 2:
                ax1.scatter(1,0, s=150, marker='h', c="red")
            if eq["s"][0] == 3 and eq["s"][1] == 3:
                ax1.scatter(0.5,1, s=150, marker='h', c="red")
        else:
            print(f"Gemischtes: s*={eq["s*"]}")
            ax1.scatter(xy(eq["s*"])[0],xy(eq["s*"])[1], s=100, marker='^', c="red")

    # Simulation-Schleife
    for it in range(0, nit):
        print(f"Iteration {it} ---------------------------------------------------")
        compute_payoffs(g, p, D)
        update_strategy(p, g, rule)
        plot_simulation(ax1, ax2, [av_strat_x,av_strat_y], p, width, height)
        mean_strat, mean_dollar = apply_updates(p)
        av_strat_x.append(mean_strat[0])
        av_strat_y.append(mean_strat[1])
        av_dollar.append(mean_dollar)

        # Speichern (PNG und PDF)
        filename = f"img-{'{:0>3d}'.format(it)}"
        plt.savefig(os.path.join(output_dir, f"{filename}.png"), dpi=100, bbox_inches="tight", pad_inches=0.05)
#        plt.savefig(os.path.join(output_dir, f"{filename}.pdf"), bbox_inches="tight", pad_inches=0.05)
        ax2.clear()

    plt.close()
    print("Mittelwerte der Auszahlungen: ", [round(x, 3) for x in av_dollar])
    print("Simulation abgeschlossen. Plots in", output_dir, "gespeichert.")

# Auszahlungsmatrix des symmetrischen (2x3)-Spiel (19 Zeeman Klassen)
# Zeeman_Klasse_1: 0,2,-1,-1,0,2,2,-1,0
# Zeeman_Klasse_2: 0,3,-1,1,0,1,3,-1,0
# Zeeman_Klasse_3: 0,1,1,-1,0,3,1,1,0
# Zeeman_Klasse_4: 0,6,-4,-3,0,5,-1,3,0
# Zeeman_Klasse_5: 0,1,1,1,0,1,1,1,0
# Zeeman_Klasse_6: 0,3,-1,3,0,-1,1,1,0
# Zeeman_Klasse_7: 0,1,3,-1,0,5,1,3,0
# Zeeman_Klasse_8: 0,1,-1,-1,0,1,-1,1,0
# Zeeman_Klasse_9: 0,-1,3,-1,0,3,1,1,0
# Zeeman_Klasse_10: 0,1,1,-1,0,1,-1,-1,0
# Zeeman_Klasse_11: 0,1,1,1,0,1,-1,-1,0
# Zeeman_Klasse_12: 0,1,-1,1,0,1,1,-1,0
# Zeeman_Klasse_13: 0,-1,-1,1,0,1,-1,1,0
# Zeeman_Klasse_14: 0,-1,1,-1,0,1,-1,-1,0
# Zeeman_Klasse_15: 0,-1,-1,1,0,-1,-1,-1,0
# Zeeman_Klasse_16: 0,-1,-1,1,0,-3,-1,-3,0
# Zeeman_Klasse_17: 0,1,-1,-3,0,1,-1,1,0
# Zeeman_Klasse_18: 0,1,-3,1,0,-1,-3,-1,0
# Zeeman_Klasse_19: 0,-3,-1,-3,0,-1,-1,-1,0

# Festlegung der Simulationsparameter
D11,D12,D13,D21,D22,D23,D31,D32,D33 = 0,2,-1,-1,0,2,2,-1,0
set_D = np.array([[D11,D12,D13],[D21,D22,D23],[D31,D32,D33]])
run_simulation(105, 105, 80, 1, set_D, [0.15,0.6,0.25],"./pics_1")
