##################################
# Python-Programm "Network (2x3)-Games"
##################################

import networkx as nx
import matplotlib.pyplot as plt
from random import randint, uniform
import numpy as np
from matplotlib import rcParams
import matplotlib.gridspec as gridspec
from scipy.integrate import solve_ivp
from math import isclose
import os
from sympy import symbols, Matrix, Eq, transpose, solve
import matplotlib.colors as colors

# baryzentrisches Dreiecks-Koordinatensystem
def xy(vx):
    return [vx[1]+vx[2]/2,vx[2]]

# Funktion zur Berechnung der Nashgleichgewichte
def find_nash_equilibria(D):
    equilibria = []
    # Berechnung der reinen Nash-Gleichgewichte
    for i in range(D.shape[0]):
        for j in range(D.shape[0]):
            if D[i, j] == max(D[:, j]):
                if D[j, i] == max(D[:, i]):
                    equilibria.append({'type': 'pure', 's': (i+1,j+1)})
    #Berechnung der gemischten Nash-Gleichgewichte
    Loes_GN = []
    x1,x2,x3,y1,y2,y3 = symbols('x_1,x_2,x_3,y_1,y_2,y_3')
    xs = Matrix([x1,x2,x3])
    ys = Matrix([y1,y2,y3])
    Dollar_A = transpose(xs)*D*ys
    Dollar_As = Dollar_A.subs(x3,1-x1-x2).subs(y3,1-y1-y2)[0]
    Dollar_As_1 = Dollar_A.subs(x1,0).subs(x3,1-x2).subs(y3,1-y1-y2)[0]
    Dollar_As_2 = Dollar_A.subs(x2,0).subs(x3,1-x1).subs(y3,1-y1-y2)[0]
    Dollar_As_3 = Dollar_A.subs(x3,0).subs(x2,1-x1).subs(y3,1-y1-y2)[0]
    GemNash_Eq1 = Eq(Dollar_As.diff(x1),0)
    GemNash_Eq2 = Eq(Dollar_As.diff(x2),0)
    GemNash_Eq_1 = Eq(Dollar_As_1.diff(x2),0)
    GemNash_Eq_2 = Eq(Dollar_As_2.diff(x1),0)
    GemNash_Eq_3 = Eq(Dollar_As_3.diff(x1),0)
    Bed=Eq(1,y1+y2+y3)
    Loes_GN.append(solve([GemNash_Eq1,GemNash_Eq2,Bed]))
    Bed_a=Eq(0,y1)
    Bed_b=Eq(1,y2+y3)
    Loes_GN.append(solve([GemNash_Eq_1,Bed_a,Bed_b]))
    Bed_a=Eq(0,y2)
    Bed_b=Eq(1,y1+y3)
    Loes_GN.append(solve([GemNash_Eq_2,Bed_a,Bed_b]))
    Bed_a=Eq(0,y3)
    Bed_b=Eq(1,y1+y2)
    Loes_GN.append(solve([GemNash_Eq_3,Bed_a,Bed_b]))
    for l in Loes_GN:
        if l and 0 <= l[y1] < 1 and 0 <= l[y2] < 1 and 0 <= l[y3] < 1 :
            equilibria.append({'type': 'mixed', 's*': (l[y1],l[y2],l[y3])})
    return equilibria

# Erstellen eines zufälligen Graphen
def create_random_graph(nkn, prob):
    g = nx.erdos_renyi_graph(nkn, prob)
    i = 0
    while i < nkn:
        if g.degree(i) == 0:
            Kn = randint(0, nkn-1)
            if Kn != i:
                g.add_edge(i,Kn)
                i += 1
        else:
            i += 1
    return g

# Erstellen des Netzwerks
def create_network(net, nkn):
    if net == "random":
        g = create_random_graph(nkn, 2*nkn/nkn**2)
    elif net == "scale_free":
        g = nx.barabasi_albert_graph(nkn, 1)
    elif net == "small_world":
        g = nx.watts_strogatz_graph(nkn, 4, 0.01)
    else:
        raise ValueError(f"Ungültiges Netzwerk: {rule}. Verwende random, scale_free oder small_world")

    if net == "scale_free" and nkn <= 1000:
        pos = nx.kamada_kawai_layout(g)
    elif net == "small_world" and nkn <= 1000:
        pos = nx.circular_layout(g)
    else:
        pos_random = nx.random_layout(g)
        pos = nx.spring_layout(g, pos=pos_random, k=0.04, iterations=5)

    return g, pos

# Initialisiert Spieler-Array P mit Positionen und Strategien
def initialize_players(nkn, x_init):
    p = np.zeros((nkn, 6))  # Spalten-Inhalt: [id, x, y, current_strat, payoff, next_strat]
    k = 0
    for k in range(nkn):
        p[k, 0] = k
        zufall = uniform(0, 1)
        if zufall <= x_init[0]:
            p[k, 3] = 0 #->Strategie s_1
        elif zufall > (x_init[0]+x_init[1]):
            p[k, 3] = 2 #->Strategie s_3
        else:
            p[k, 3] = 1 #->Strategie s_2
        p[k, 5] = p[k, 3]
    return p

# Berechnet Auszahlungen für alle Spieler
def compute_payoffs(g, p, D):
    for u, v in g.edges():
        p[u, 4] += D[int(p[u,3]),int(p[v,3])]
        p[v, 4] += D[int(p[v,3]),int(p[u,3])]

# Neue Strategien-Wahl
def update_strategy(p, g, rule):
    nkn = len(p)
    for k in range(nkn):
        neighbors = list(g.neighbors(k))
        if not neighbors:
            continue
        # Rule 0: Imitiere besten Nachbarn
        if rule == 0:
            neigh_degree = [deg for n, deg in g.degree(neighbors)]
            neigh_payoffs = p[neighbors, 4]/neigh_degree
            max_payoff = np.max(neigh_payoffs)
            if max_payoff > p[k, 4] and not isclose(max_payoff, p[k, 4]):
                best_idx = np.argmax(neigh_payoffs)
                p[k, 5] = p[neighbors[best_idx], 3]
        # Rule 1: Imitiere zufälligen besseren Nachbarn
        elif rule == 1:
            j = randint(0, len(neighbors) - 1)
            neigh_payoff = p[neighbors[j], 4]/g.degree()[j]
            if neigh_payoff > p[k, 4]/g.degree()[k] and not isclose(neigh_payoff, p[k, 4]/g.degree()[k]):
                p[k, 5] = p[neighbors[j], 3]
        else:
            raise ValueError(f"Ungültige Regel: {rule}. Verwende 0 oder 1.")

# Wendet zukünftige Strategien an, berechnet Mittelwert des Populationsvektors und mittleren Payoffs und setzt Payoffs zurück
def apply_updates(p, g):
    nkn = len(p)
    mean_payoff = np.mean(p[:, 4])
    mean_strat_3 = [np.count_nonzero(p[:,3] == 0)/nkn, np.count_nonzero(p[:,3] == 1)/nkn, np.count_nonzero(p[:,3] == 2)/nkn]
    p[:, 4] = 0
    p[:, 3] = p[:, 5]
    return xy(mean_strat_3), mean_payoff

# Definition der Funktionen g_x und g_y
def g_xy(xy,D):
    m = 3
    x = np.array([1-xy[0]-xy[1]/2,xy[0]-xy[1]/2,xy[1]])
    dx_dt = []
    for i in range(m):
        dx_dt.append(sum(D[i,j]*x[i]*x[j] for j in range(m)) - sum(sum(D[k,j]*x[k]*x[j] for j in range(m)) for k in range(m))*x[i])
    return [dx_dt[1]+dx_dt[2]/2,dx_dt[2]]

# Analytische Lösung der Replikator-DGL
def analytical_solution(D, t_span, x_init, num_points=500):
    # Definition des DGL-Systems
    def DGLsys(t, x):
        x = np.asarray(x)
        Dx = D @ x
        u = np.dot(x, Dx)
        return x * (Dx - u)

    t_eval = np.linspace(t_span[0], t_span[1], num_points)
    sol = solve_ivp(DGLsys, t_span, x_init, t_eval=t_eval, rtol=10**(-13), atol=10**(-13))
    return sol.t, sol.y

# Plotted den aktuellen Zustand der Simulation
def plot_simulation(ax1, ax2, av_strat, p, g, pos):
    nkn = len(p)
    sgross = 15
    col = ['r' if s == 0 else 'b' if s == 1 else 'g' for s in p[:, 3]]
    col_new = ['r' if s == 0 else 'b' if s == 1 else 'g' for s in p[:, 5]]
    alpha = [1 if p[k, 3] == p[k, 5] else 0.5 for k in range(nkn)]

    #Mittelwert des Populationsvektors
    ax1.plot(av_strat[0], av_strat[1],c="black",linewidth=2)

    # Netzwerk-Plot
    #nx.draw_networkx_nodes(g, pos, node_size=sgross, node_color=col, alpha=alpha, edgecolors="none")
    nx.draw_networkx_nodes(g, pos, node_size=sgross, node_color=col, edgecolors="none")
    nx.draw_networkx_edges(g, pos, alpha=0.3, width=0.4, edge_color="grey")
    ax2.axis("off")

# Führt die Simulation auf dem 2D-Gitter aus
def run_simulation(net="small_world", nkn=500, nit=80, rule=1, D = np.array([[0,2,-1],[-1,0,2],[2,-1,0]]), x_init=[0.15,0.6,0.25], output_dir="./pics"):
    os.makedirs(output_dir, exist_ok=True)

    # Netzwerk und Spieler initialisieren
    g, pos = create_network(net, nkn)
    p = initialize_players(nkn, x_init)
    av_strat_x = [xy(x_init)[0]]
    av_strat_y = [xy(x_init)[1]]
    av_dollar = []

    # Analytische Lösung
    t_span = [0, nit]
    t_ana, x_ana = analytical_solution(D, t_span, x_init)

    # Plot-Setup
    rcParams.update({
        'figure.figsize': [8,14],
        'text.usetex': True,
        'legend.fontsize': 12
    })
    plt.figure(0)
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1.5], hspace=0.08)
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])

    # Fuer die Darstellung des Feldliniendiagramms streamplot
    Y, X = np.mgrid[0:1:100j, 0:1:100j]
    gXY = g_xy([X,Y],D)
    # Die Farbe wird die Geschwindigkeit der Aenderung des Populationsvektors anzeigen
    colorspeed = np.sqrt(gXY[0]**2 + gXY[1]**2)

    ax1.set_xlabel(r"$\rm x$")
    ax1.set_ylabel(r"$\rm y$")
    figure = ax1.streamplot(X, Y, gXY[0], gXY[1], linewidth=1,density=[2, 2],norm=colors.Normalize(vmin=0.,vmax=0.4), color=colorspeed, cmap=plt.cm.jet)
    ax1.fill([0,0.5,0], [0,1,1], facecolor='black')
    ax1.fill([1,0.5,1], [0,1,1], facecolor='black')
    ax1.fill([0,1,1,0], [0,0,-0.02,-0.02], facecolor='black')
    ax1.fill([0,1,1,0], [1,1,1.02,1.02], facecolor='black')
    ax1.scatter(xy(x_init)[0],xy(x_init)[1], s=50, marker='o', c="black")
    ax1.plot(xy(x_ana[0:3])[0], xy(x_ana[0:3])[1],c="grey",linewidth=1)
    # Erzeugung der nebenstehenden Farblegende colorbar
    cbar=plt.colorbar(figure.lines, aspect=20, ax=ax1)
    cbar.set_label(r'$\left| \vec{v} \right| =\sqrt{{g_x}^2 + {g_y}^2}$',size=20)
    ax1.set_xlim(0, 1)
    ax1.set_ylim(-0.02, 1.02)

    # Berechnung der Nashgleichgewichte
    equilibria = find_nash_equilibria(D)
    print("Nash-Gleichgewichte:")
    for eq in equilibria:
        if eq['type'] == 'pure':
            print(f"Reines: Spieler 1 Strategie {eq["s"][0]}, Spieler 2 Strategie {eq["s"][1]}")
            # Kennzeichnung der reinen Nashgleichgewichte im Bild
            if eq["s"][0] == 1 and eq["s"][1] == 1:
                ax1.scatter(0,0, s=150, marker='h', c="red")
            if eq["s"][0] == 2 and eq["s"][1] == 2:
                ax1.scatter(1,0, s=150, marker='h', c="red")
            if eq["s"][0] == 3 and eq["s"][1] == 3:
                ax1.scatter(0.5,1, s=150, marker='h', c="red")
        else:
            print(f"Gemischtes: s*={eq["s*"]}")
            ax1.scatter(xy(eq["s*"])[0],xy(eq["s*"])[1], s=100, marker='^', c="red")

    # Simulation-Schleife
    for it in range(0, nit):
        print(f"Iteration {it} ---------------------------------------------------")
        compute_payoffs(g, p, D)
        update_strategy(p, g, rule)
        plot_simulation(ax1, ax2, [av_strat_x,av_strat_y], p, g, pos)
        mean_strat, mean_dollar = apply_updates(p,g)
        av_strat_x.append(mean_strat[0])
        av_strat_y.append(mean_strat[1])
        av_dollar.append(mean_dollar)

        # Speichern (PNG und PDF)
        filename = f"img-{'{:0>3d}'.format(it)}"
        plt.savefig(os.path.join(output_dir, f"{filename}.png"), dpi=100, bbox_inches="tight", pad_inches=0.05)
#        plt.savefig(os.path.join(output_dir, f"{filename}.pdf"), bbox_inches="tight", pad_inches=0.05)
        ax2.clear()

    plt.close()
    print("Mittelwerte der Auszahlungen: ", [round(float(x), 3) for x in av_dollar])
    print("Simulation abgeschlossen. Plots in", output_dir, "gespeichert.")

# Auszahlungsmatrix des symmetrischen (2x3)-Spiel (19 Zeeman Klassen)
# Zeeman_Klasse_1: 0,2,-1,-1,0,2,2,-1,0
# Zeeman_Klasse_2: 0,3,-1,1,0,1,3,-1,0
# Zeeman_Klasse_3: 0,1,1,-1,0,3,1,1,0
# Zeeman_Klasse_4: 0,6,-4,-3,0,5,-1,3,0
# Zeeman_Klasse_5: 0,1,1,1,0,1,1,1,0
# Zeeman_Klasse_6: 0,3,-1,3,0,-1,1,1,0
# Zeeman_Klasse_7: 0,1,3,-1,0,5,1,3,0
# Zeeman_Klasse_8: 0,1,-1,-1,0,1,-1,1,0
# Zeeman_Klasse_9: 0,-1,3,-1,0,3,1,1,0
# Zeeman_Klasse_10: 0,1,1,-1,0,1,-1,-1,0
# Zeeman_Klasse_11: 0,1,1,1,0,1,-1,-1,0
# Zeeman_Klasse_12: 0,1,-1,1,0,1,1,-1,0
# Zeeman_Klasse_13: 0,-1,-1,1,0,1,-1,1,0
# Zeeman_Klasse_14: 0,-1,1,-1,0,1,-1,-1,0
# Zeeman_Klasse_15: 0,-1,-1,1,0,-1,-1,-1,0
# Zeeman_Klasse_16: 0,-1,-1,1,0,-3,-1,-3,0
# Zeeman_Klasse_17: 0,1,-1,-3,0,1,-1,1,0
# Zeeman_Klasse_18: 0,1,-3,1,0,-1,-3,-1,0
# Zeeman_Klasse_19: 0,-3,-1,-3,0,-1,-1,-1,0

# Festlegung der Simulationsparameter
D11,D12,D13,D21,D22,D23,D31,D32,D33 = 0,2,-1,-1,0,2,2,-1,0
set_D = np.array([[D11,D12,D13],[D21,D22,D23],[D31,D32,D33]])
run_simulation("random", 10000, 120, 1, set_D, [0.15,0.6,0.25],"./pics_1_random")
#run_simulation("scale_free", 1000, 80, 1, set_D, [0.15,0.6,0.25],"./pics_1_scale_free")

