##################################
# Python-Programm "Spatial (2x3)-QuantumGames"
##################################

import networkx as nx
import matplotlib.pyplot as plt
from random import randint, uniform
from math import isclose
import numpy as np
from matplotlib import rcParams
import matplotlib.gridspec as gridspec
from scipy.integrate import solve_ivp
import os
from sympy import symbols, Matrix, Eq, transpose, solve, N, exp, sin, cos, I, re
import matplotlib.colors as colors

#Berechnung der Erwartungswerte der observablen Strategien s_1 und s_2 (Gleichungen von QuantumGame.ipynb)
def calc_P(s_A, s_B, gamma):
    def P_11(theta_A,phi_A,theta_B,phi_B,gamma):
        f = (exp(2*I*(phi_A + phi_B))*cos(gamma/2)**2 + sin(gamma/2)**2)*(sin(gamma/2)**2 + exp(-2*I*(phi_A + phi_B))*cos(gamma/2)**2)*cos(theta_A/2)**2*cos(theta_B/2)**2
        return f
    def P_12(theta_A,phi_A,theta_B,phi_B,gamma):
        f = (-exp(I*phi_A)*sin(gamma/2)**2*sin(theta_B/2)*cos(theta_A/2) + I*exp(I*phi_B)*sin(gamma/2)*sin(theta_A/2)*cos(gamma/2)*cos(theta_B/2) - I*exp(-I*phi_B)*sin(gamma/2)*sin(theta_A/2)*cos(gamma/2)*cos(theta_B/2) - exp(-I*phi_A)*sin(theta_B/2)*cos(gamma/2)**2*cos(theta_A/2))*(-exp(I*phi_A)*sin(theta_B/2)*cos(gamma/2)**2*cos(theta_A/2) + I*exp(I*phi_B)*sin(gamma/2)*sin(theta_A/2)*cos(gamma/2)*cos(theta_B/2) - I*exp(-I*phi_B)*sin(gamma/2)*sin(theta_A/2)*cos(gamma/2)*cos(theta_B/2) - exp(-I*phi_A)*sin(gamma/2)**2*sin(theta_B/2)*cos(theta_A/2))
        return f
    def P_21(theta_A,phi_A,theta_B,phi_B,gamma):
        f = (I*exp(I*phi_A)*sin(gamma/2)*sin(theta_B/2)*cos(gamma/2)*cos(theta_A/2) - exp(I*phi_B)*sin(gamma/2)**2*sin(theta_A/2)*cos(theta_B/2) - exp(-I*phi_B)*sin(theta_A/2)*cos(gamma/2)**2*cos(theta_B/2) - I*exp(-I*phi_A)*sin(gamma/2)*sin(theta_B/2)*cos(gamma/2)*cos(theta_A/2))*(I*exp(I*phi_A)*sin(gamma/2)*sin(theta_B/2)*cos(gamma/2)*cos(theta_A/2) - exp(I*phi_B)*sin(theta_A/2)*cos(gamma/2)**2*cos(theta_B/2) - exp(-I*phi_B)*sin(gamma/2)**2*sin(theta_A/2)*cos(theta_B/2) - I*exp(-I*phi_A)*sin(gamma/2)*sin(theta_B/2)*cos(gamma/2)*cos(theta_A/2))
        return f
    def P_22(theta_A,phi_A,theta_B,phi_B,gamma):
        f = (-I*exp(-I*phi_A - I*phi_B)*sin(gamma/2)*cos(gamma/2)*cos(theta_A/2)*cos(theta_B/2) + I*exp(I*phi_A + I*phi_B)*sin(gamma/2)*cos(gamma/2)*cos(theta_A/2)*cos(theta_B/2) + sin(theta_A/2)*sin(theta_B/2))**2
        return f

    # Strategien: (theta, phi) für s=0,1,2 (s_1, s_2, Q)
    S = [(0, 0), (np.pi, 0), (0, np.pi/2)]
    theta_A, phi_A = S[s_A]
    theta_B, phi_B = S[s_B]
    p_11 = float(re(N(P_11(theta_A, phi_A,theta_B, phi_B,gamma))))
    p_12 = float(re(N(P_12(theta_A, phi_A,theta_B, phi_B,gamma))))
    p_21 = float(re(N(P_21(theta_A, phi_A,theta_B, phi_B,gamma))))
    p_22 = float(re(N(P_22(theta_A, phi_A,theta_B, phi_B,gamma))))

    return np.array([[p_11,p_12],[p_21,p_22]])

# Berechnung der Wahrscheinlichkeitsmatrizen (3x3x2x2-Array) der Erwartungswerte
def PQ(gamma):
    P = np.zeros((3, 3, 2, 2))
    for i in range(3):
        for j in range(3):
            P[i, j] = calc_P(i, j, gamma)
    return P

# Berechnung der Auszahlungen im 2x2-Quantenspiel (3x3-Auszahlungsmatrix: s_1-Quantenstrategie als 3.Strategie)
def set_Dollar(gamma, a, b, c, d):
    P = PQ(gamma)                    # Einmal alle P berechnen
    D = np.zeros((3, 3))
    for i in range(3):
        for j in range(3):
            p00 = P[i, j, 0, 0]       # CC
            p01 = P[i, j, 0, 1]       # CD
            p10 = P[i, j, 1, 0]       # DC
            p11 = P[i, j, 1, 1]       # DD
            D[i, j] = a * p00 + b * p01 + c * p10 + d * p11
    return D

# baryzentrisches Dreiecks-Koordinatensystem
def xy(vx):
    return [vx[1]+vx[2]/2,vx[2]]

# Funktion zur Berechnung der Nashgleichgewichte
def find_nash_equilibria(D):
    equilibria = []
    # Berechnung der reinen Nash-Gleichgewichte
    for i in range(D.shape[0]):
        for j in range(D.shape[0]):
            if np.isclose(D[i, j], np.max(D[:, j])) and np.isclose(D[j, i], np.max(D[:, i])):
                equilibria.append({'type': 'pure', 's': (i+1,j+1)})
    #Berechnung der gemischten Nash-Gleichgewichte
    Loes_GN = []
    x1,x2,x3,y1,y2,y3 = symbols('x_1,x_2,x_3,y_1,y_2,y_3')
    xs = Matrix([x1,x2,x3])
    ys = Matrix([y1,y2,y3])
    Dollar_A = transpose(xs)*D*ys
    Dollar_As = Dollar_A.subs(x3,1-x1-x2).subs(y3,1-y1-y2)[0]
    Dollar_As_1 = Dollar_A.subs(x1,0).subs(x3,1-x2).subs(y3,1-y1-y2)[0]
    Dollar_As_2 = Dollar_A.subs(x2,0).subs(x3,1-x1).subs(y3,1-y1-y2)[0]
    Dollar_As_3 = Dollar_A.subs(x3,0).subs(x2,1-x1).subs(y3,1-y1-y2)[0]
    GemNash_Eq1 = Eq(Dollar_As.diff(x1),0)
    GemNash_Eq2 = Eq(Dollar_As.diff(x2),0)
    GemNash_Eq_1 = Eq(Dollar_As_1.diff(x2),0)
    GemNash_Eq_2 = Eq(Dollar_As_2.diff(x1),0)
    GemNash_Eq_3 = Eq(Dollar_As_3.diff(x1),0)
    Bed=Eq(1,y1+y2+y3)
    Loes_GN.append(solve([GemNash_Eq1,GemNash_Eq2,Bed]))
    Bed_a=Eq(0,y1)
    Bed_b=Eq(1,y2+y3)
    Loes_GN.append(solve([GemNash_Eq_1,Bed_a,Bed_b]))
    Bed_a=Eq(0,y2)
    Bed_b=Eq(1,y1+y3)
    Loes_GN.append(solve([GemNash_Eq_2,Bed_a,Bed_b]))
    Bed_a=Eq(0,y3)
    Bed_b=Eq(1,y1+y2)
    Loes_GN.append(solve([GemNash_Eq_3,Bed_a,Bed_b]))
    for l in Loes_GN:
        try:
            y1_n = float(N(l[y1]))
            y2_n = float(N(l[y2]))
            y3_n = float(N(l[y3]))
        except (TypeError, ValueError):
            continue
        if l and 0 <= y1_n < 1 and 0 <= y2_n < 1 and 0 <= y3_n < 1 :
            equilibria.append({'type': 'mixed', 's*': (l[y1],l[y2],l[y3])})
    return equilibria

# Erstellen eines raumlichen 2D-Gittergraph (8 Nachbarn)
def create_grid(width, height):
    g = nx.Graph()
    nkn = width * height
    for i in range(width):
        for j in range(height):
            k = i * height + j
            g.add_node(k)
    for i in range(width):
        for j in range(height):
            center = i * height + j
            neighbors = [
                ((i-1) % width) * height + (j-1) % height,  # upperleft
                (i % width) * height + (j-1) % height,      # up
                ((i+1) % width) * height + (j-1) % height,  # upperright
                ((i+1) % width) * height + (j % height),    # right
                ((i+1) % width) * height + (j+1) % height,  # lowerright
                (i % width) * height + (j+1) % height,      # low
                ((i-1) % width) * height + (j+1) % height,  # lowerleft
                ((i-1) % width) * height + (j % height)     # left
            ]
            for neigh in neighbors:
                if neigh != center:  # Vermeide Selbst-Edges
                    g.add_edge(center, neigh)
    return g

# Initialisiert Spieler-Array P mit Positionen und Strategien
def initialize_players(width, height, x_init):
    nkn = width * height
    p = np.zeros((nkn, 6))  # Spalten-Inhalt: [id, x, y, current_strat, payoff, next_strat]
    k = 0
    for i in range(width):
        for j in range(height):
            p[k, 0] = k
            p[k, 1] = i
            p[k, 2] = j
            zufall = uniform(0, 1)
            if zufall <= x_init[0]:
                p[k, 3] = 0 #->Strategie s_1
            elif zufall > (x_init[0]+x_init[1]):
                p[k, 3] = 2 #->Strategie s_3 (s1-quantum strategy)
            else:
                p[k, 3] = 1 #->Strategie s_2
            p[k, 5] = p[k, 3]
            k += 1
    return p

# Berechnet Auszahlungen für alle Spieler und die mittleren Erwartungswerte der observablen Strategien
def compute_payoffs(g, p, D, gamma):
    P = PQ(gamma)
    av_s = 0
    for u, v in g.edges():
        p[u, 4] += D[int(p[u,3]),int(p[v,3])]
        p[v, 4] += D[int(p[v,3]),int(p[u,3])]
        av_s += 2*P[int(p[u,3]),int(p[v,3]),0,0] + P[int(p[u,3]),int(p[v,3]),1,0] + P[int(p[u,3]),int(p[v,3]),0,1]
    return av_s/8

# Neue Strategien-Wahl
def update_strategy(p, g, rule):
    nkn = len(p)
    for k in range(nkn):
        neighbors = list(g.neighbors(k))
        if not neighbors:
            continue
        # Rule 0: Imitiere besten Nachbarn
        if rule == 0:
            neigh_payoffs = p[neighbors, 4]
            max_payoff = np.max(neigh_payoffs)
            if max_payoff > p[k, 4] and not isclose(max_payoff, p[k, 4]):
                best_idx = np.argmax(neigh_payoffs)
                p[k, 5] = p[neighbors[best_idx], 3]
        # Rule 1: Imitiere zufälligen besseren Nachbarn
        elif rule == 1:
            j = randint(0, len(neighbors) - 1)
            neigh_payoff = p[neighbors[j], 4]
            if neigh_payoff > p[k, 4] and not isclose(neigh_payoff, p[k, 4]):
                p[k, 5] = p[neighbors[j], 3]
        else:
            raise ValueError(f"Ungültige Regel: {rule}. Verwende 0 oder 1.")

# Wendet zukünftige Strategien an, berechnet Mittelwert des Populationsvektors und mittleren Payoffs und setzt Payoffs zurück
def apply_updates(p):
    nkn = len(p)
    mean_payoff = np.mean(p[:, 4]) / 8
    p[:, 4] = 0
    p[:, 3] = p[:, 5]
    mean_strat_3 = [np.count_nonzero(p[:,3] == 0)/nkn, np.count_nonzero(p[:,3] == 1)/nkn, np.count_nonzero(p[:,3] == 2)/nkn]
    return xy(mean_strat_3),  mean_payoff

# Definition der Funktionen g_x und g_y
def g_xy(xy,D):
    m = 3
    x = np.array([1-xy[0]-xy[1]/2,xy[0]-xy[1]/2,xy[1]])
    dx_dt = []
    for i in range(m):
        dx_dt.append(sum(D[i,j]*x[i]*x[j] for j in range(m)) - sum(sum(D[k,j]*x[k]*x[j] for j in range(m)) for k in range(m))*x[i])
    return [dx_dt[1]+dx_dt[2]/2,dx_dt[2]]

# Analytische Lösung der Replikator-DGL
def analytical_solution(D, t_span, x_init, num_points=500):
    # Definition des DGL-Systems
    def DGLsys(t, x):
        x = np.asarray(x)
        Dx = D @ x
        u = np.dot(x, Dx)
        return x * (Dx - u)

    t_eval = np.linspace(t_span[0], t_span[1], num_points)
    sol = solve_ivp(DGLsys, t_span, x_init, t_eval=t_eval, rtol=10**(-13), atol=10**(-13))
    return sol.t, sol.y

# Analytische Lösung der Replikator-DGL dx/dt = g(x)
def analytical_solution_2x2(a, b, c, d, t_span, initial_x, num_points=500):
    def g(x, a, b, c, d):
        return ((a - c) * (x - x**2) + (b - d) * (1 - 2*x + x**2)) * x

    def dgl(t, x, a, b, c, d):
        return g(x, a, b, c, d)

    t_eval = np.linspace(t_span[0], t_span[1], num_points)
    sol = solve_ivp(dgl, t_span, [initial_x], args=(a, b, c, d), t_eval=t_eval)
    return sol.t, sol.y[0]

# Plotted den aktuellen Zustand der Simulation
def plot_simulation(ax1, ax2, ax3, av_strat, p, width, height, av_x):
    nkn = len(p)
    sgross = np.sqrt( 2822400 / nkn ) # Groesse der Spieler Kaestchen
    col = ['r' if s == 1 else 'b' if s == 0 else 'g' for s in p[:, 3]]
    col_new = ['r' if s == 1 else 'b' if s == 0 else 'g' for s in p[:, 5]]
    alpha = [1 if p[k, 3] == p[k, 5] else 0.5 for k in range(nkn)]

    #Mittelwert der Populationsvektoren
    ax1.plot(av_strat[0], av_strat[1], c="black", linewidth=3)
    ax3.plot(range(len(av_x)), av_x, c="black", linewidth=2)

    # Grid-Plot
    x_pos, y_pos = p[:, 1], p[:, 2]
    #ax2.scatter(x_pos, y_pos, s=sgross, c=col, marker="s", alpha=alpha, edgecolor='none')
    ax2.scatter(x_pos, y_pos, s=sgross, c=col, marker="s", edgecolor='none')

    ax2.set_xlim(-0.5, width - 0.5)
    ax2.set_ylim(-0.5, height - 0.5)
    ax2.set_aspect('equal', adjustable='box')

# Führt die Simulation auf dem 2D-Gitter aus
def run_simulation(width=105, height=105, nit=15, rule=1, gamma=0.0001, a=1, b=0, c=1.62, d=0.01, x_init=[0.25,0.5,0.25], output_dir="./pics"):
    os.makedirs(output_dir, exist_ok=True)

    # Auszahlungsmatrix des (2x3)-Quanten-Spiels
    D = set_Dollar(gamma, a, b, c, d)

    # Netzwerk und Spieler initialisieren
    g = create_grid(width, height)
    p = initialize_players(width, height, x_init)
    av_strat_x = [xy(x_init)[0]]
    av_strat_y = [xy(x_init)[1]]
    #av_s = [x_init[0] + x_init[2]]
    av_s = [compute_payoffs(g, p, D, gamma)/(width*height)]
    av_dollar = []

    # Analytische Lösung (2x3)-Quanten-Spiel
    t_span = [0, nit]
    t_ana, x_ana = analytical_solution(D, t_span, x_init)

    # Analytische Lösung (2x2)-Spiel
    t_ana_1, x_ana_1 = analytical_solution_2x2(a, b, c, d, t_span, av_s[0])

    # Zusätzliche Startbedingungen
    initials = [0.9, 0.7, 0.5, 0.3, 0.1]
    ana_seq = []
    for init_x in initials:
        t_seq, x_seq = analytical_solution_2x2(a, b, c, d, t_span, init_x)
        ana_seq.append(x_seq)

    # Plot-Setup
    rcParams.update({
        'text.usetex': True,
        'axes.titlesize' : 14,
        'axes.labelsize' : 14,
        'xtick.labelsize' : 14 ,
        'ytick.labelsize' : 14
    })

    fig = plt.figure(figsize=(12, 14))
    gs = gridspec.GridSpec(2, 2, height_ratios=[1, 1.5], width_ratios=[1.5, 1],hspace=0.2)

    ax1 = fig.add_subplot(gs[0, 0]) # Oben links
    ax2 = fig.add_subplot(gs[1, :]) # Unten (erstreckt sich über beide Spalten)
    ax3 = fig.add_subplot(gs[0, 1]) # Oben rechts

    # Fuer die Darstellung des Feldliniendiagramms streamplot
    Y, X = np.mgrid[0:1:100j, 0:1:100j]
    gXY = g_xy([X,Y],D)
    # Die Farbe wird die Geschwindigkeit der Aenderung des Populationsvektors anzeigen
    colorspeed = np.sqrt(gXY[0]**2 + gXY[1]**2)

    ax1.set_xlabel(r"$\rm x_b$")
    ax1.set_ylabel(r"$\rm y_b$")
    figure = ax1.streamplot(X, Y, gXY[0], gXY[1], linewidth=1,density=[2, 2],norm=colors.Normalize(vmin=0.,vmax=0.4), color=colorspeed, cmap=plt.cm.jet)
    ax1.fill([0,0.5,0], [0,1,1], facecolor='black')
    ax1.fill([1,0.5,1], [0,1,1], facecolor='black')
    ax1.fill([0,1,1,0], [0,0,-0.02,-0.02], facecolor='black')
    ax1.fill([0,1,1,0], [1,1,1.02,1.02], facecolor='black')
    ax1.scatter(xy(x_init)[0],xy(x_init)[1], s=50, marker='o', c="black")
    ax1.plot(xy(x_ana[0:3])[0], xy(x_ana[0:3])[1],c="grey",linewidth=1)
    # Erzeugung der nebenstehenden Farblegende colorbar
    cbar=plt.colorbar(figure.lines, aspect=20, ax=ax1)
    cbar.set_label(r'$\left| \vec{v} \right| =\sqrt{{g_x}^2 + {g_y}^2}$',size=14)
    ax1.set_xlim(0, 1)
    ax1.set_ylim(-0.02, 1.02)

    ax3.set_xlim(0, nit - 1)
    ax3.set_ylim(0, 1)
    ax3.set_ylabel(r'$\rm x(t)$')
    ax3.plot(t_ana_1, x_ana_1, c="grey", linewidth=1.5, linestyle=':')
    ax3.scatter(0, av_s[0], s=50, marker='o', c="black")
    for x_seq in ana_seq:
        ax3.plot(t_seq, x_seq, c="lightblue", linewidth=1, linestyle=':')

    # Berechnung der Nashgleichgewichte
    equilibria = find_nash_equilibria(D)
    print("Nash-Gleichgewichte:")
    for eq in equilibria:
        if eq['type'] == 'pure':
            print(f"Reines: Spieler 1 Strategie {eq["s"][0]}, Spieler 2 Strategie {eq["s"][1]}")
            # Kennzeichnung der reinen Nashgleichgewichte im Bild
            if eq["s"][0] == 1 and eq["s"][1] == 1:
                ax1.scatter(0,0, s=150, marker='h', c="red")
            if eq["s"][0] == 2 and eq["s"][1] == 2:
                ax1.scatter(1,0, s=150, marker='h', c="red")
            if eq["s"][0] == 3 and eq["s"][1] == 3:
                ax1.scatter(0.5,1, s=150, marker='h', c="red")
        else:
            print(f"Gemischtes: s*={eq["s*"]}")
            ax1.scatter(xy(eq["s*"])[0],xy(eq["s*"])[1], s=100, marker='^', c="red")

    # Simulation-Schleife
    for it in range(0, nit):
        print(f"Iteration {it} ---------------------------------------------------")
        update_strategy(p, g, rule)
        plot_simulation(ax1, ax2, ax3, [av_strat_x,av_strat_y], p, width, height, av_s)
        mean_strat, mean_dollar = apply_updates(p)
        av_strat_x.append(mean_strat[0])
        av_strat_y.append(mean_strat[1])
        av_dollar.append(mean_dollar)
        av_s.append(compute_payoffs(g, p, D, gamma)/(width*height))

        # Speichern (PNG und PDF)
        filename = f"img-{'{:0>3d}'.format(it)}"
        plt.savefig(os.path.join(output_dir, f"{filename}.png"), dpi=100, bbox_inches="tight", pad_inches=0.05)
#        plt.savefig(os.path.join(output_dir, f"{filename}.pdf"), bbox_inches="tight", pad_inches=0.05)
        ax2.clear()

    plt.close()
    print("Mittelwerte der Auszahlungen: ", [round(x, 3) for x in av_dollar])
    print("Simulation abgeschlossen. Plots in", output_dir, "gespeichert.")

# Quanten-Verschränkung (entanglement: in [0,pi/2])
set_gamma = np.pi/2
# Auszahlungsmatrix des symmetrischen (2x2)-Spiels
# Festlegung der Simulationsparameter und starten der Simulation
# Gefangenendilemma (cooperate: s_1-Srtategie), siehe QuantumGame.ipynb
run_simulation(105, 105, 25, 1, set_gamma, 10, 4, 12, 5, [0.65,0.3,0.05],"./V_max")

# Hanauske, Matthias, , et al. "Quantum game theory and open access publishing." Physica A: Statistical Mechanics and its Applications 382.2 (2007): 650-664.
# Open Access Game as a Prisoners' Dilemma
#run_simulation(105, 105, 13, 1, set_gamma, 4, 1, 5, 3, [0.2,0.65,0.15],"./OA_PD_")
# Open Access as a “Stag Hunt”
#run_simulation(105, 105, 15, 1, set_gamma, 4, 1, 3, 3, [0.2,0.65,0.15],"./OA_SH_")

# Hanauske, Matthias, et al. "Doves and hawks in economics revisited: An evolutionary quantum game theory based analysis of financial crises." Physica A:  389.21 (2010): 5084-5102.
# Falke Taube, high risk
#run_simulation(105, 105, 25, 1, set_gamma, 1.5, 0, 5, -7.5, [0.2,0.65,0.15],"./FT_")
