import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from scipy.integrate import solve_ivp

# DGL bestimmende Funktion g
def g(x, a, b, c, d):
    return ((a-c)*(x-x*x) + (b-d)*(1-2*x+x*x))*x

# Definition der DGL
def DGL(t, x, a, b, c, d):
    return g(x, a, b, c, d)

# Klassifizierung des Spiels, Festlegung der Farbe und des Titels
def classify_game(a, b, c, d):
    if (a > c and b > d) or (a < c and b < d):
        return plt.cm.Greys, r'$\rm Dominantes\, Spiel $'
    elif a > c and b < d:
        return plt.cm.Blues, r'$\rm Koordinationsspiel $'
    elif a < c and b > d:
        return plt.cm.Oranges, r'$\rm Anti-Koordinationsspiel $'
    else:
        return plt.cm.Greens, r'$\rm ?\, Spielklasse $'

# Berechnung der Nashgleichgewichte
def find_nash_equilibria(a, b, c, d):
    equilibria = []
    # Auszahlungsmatrix (symmetrisches 2x2-Spiel)
    D_A = np.array([[a,b],[c,d]])
    # Berechnung der reinen Nash-Gleichgewichte
    for i in range(2):
        for j in range(2):
            if D_A[i, j] == np.max(D_A[:, j]) and D_A[j, i] == np.max(D_A[:, i]):
                equilibria.append({'type': 'pure', 's': (i+1, j+1)})
    # Berechnung der gemischten Nash-Gleichgewichte (mittels sympy)
    Nenner = a - b - c + d
    if abs(Nenner) > 1e-10:
        s = (d - b) / Nenner
        if 0 < s < 1:
            equilibria.append({'type': 'mixed', 's_star': s})
    return equilibria

# Loesen der DGL und plotten der Populationsentwicklung
def solve_and_plot(a, b, c, d, t_end=6, num_x0=30, n_points=500):
    # Groessenfestlegung der Labels usw.
    rcParams.update({
        'text.usetex': True,
        'axes.titlesize': 22,
        'axes.labelsize': 20,
        'xtick.labelsize': 20,
        'ytick.labelsize': 20
    })

    # Mehrere Anfangswerte der Population
    x0 = np.linspace(0.01, 0.99, num_x0)
    t_eval = np.linspace(0, t_end, n_points)

    # Loesung der DGL
    Loes = solve_ivp(DGL, [0, t_end], x0, args=(a,b,c,d, ), t_eval=t_eval)

    # Festlegung der Farbe und des Titels
    cmap, title_text = classify_game(a, b, c, d)
    line_colors = cmap(np.linspace(0, 1, num_x0))
    line_width = 1.5

    # Plotten der Losungen
    for j in range(num_x0):
        plt.plot(Loes.t,Loes.y[j], c=line_colors[j], linewidth=line_width, linestyle='-')

    # Berechnung der Nash-Gleichgewichte
    equilibria = find_nash_equilibria(a, b, c, d)

    # Terminalausgabe der Nash-Gleichgewichte und Kennzeichnung der Nash-Gleichgewichte im Bild
    print("Nash-Gleichgewichte:")
    for eq in equilibria:
        if eq['type'] == 'pure':
            print(f"Reines: Spieler A Strategie {eq["s"][0]}, Spieler B Strategie {eq["s"][1]}")
            # Kennzeichnung der reinen Nash-Gleichgewichte im Bild
            if eq["s"] == (1, 1):
                plt.scatter([0,t_end], [1,1], s=150, marker='h', c="red", alpha=0.5)
                plt.plot([0,t_end],[1,1], c="red", alpha=0.5, linewidth=line_width, linestyle=':')
            if eq["s"] == (2, 2):
                plt.scatter([0,t_end], [0,0], s=150, marker='h', c="red", alpha=0.5)
                plt.plot([0,t_end],[0,0], c="red", alpha=0.5, linewidth=line_width, linestyle=':')
        else:
            print(f"Gemischtes: s*={eq["s_star"]}")
            # Kennzeichnung des gemischten Nash-Gleichgewichts im Bild
            plt.scatter([0,t_end], [eq["s_star"],eq["s_star"]], s=100, marker='^', c="red", alpha=0.5)
            plt.plot([0,t_end],[eq["s_star"],eq["s_star"]], c="red", alpha=0.5, linewidth=line_width, linestyle=':')

    # Plotten der Spielmatrix in das Bild
    textstr1 = r'$\hat{\bf {\cal \$}} = \left( \begin{array}[c]{cc} a & b \\ \ c & d \end{array} \right)\\a='+str(a)+', b='+str(b)+'$'
    textstr2 = r'$\\c='+str(c)+', d='+str(d)+'$'
    props = dict(boxstyle='round', facecolor='white', alpha=0.92)
    plt.text(t_end-t_end/50.0, 0.98, textstr1+textstr2, fontsize=16, verticalalignment='top', horizontalalignment='right', bbox=props)

    # Achsenbeschriftungen usw.
    plt.ylim(0,1)
    plt.yticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],["$0$","","$0.2$","","$0.4$","","$0.6$","","$0.8$","","$1$"])
    plt.ylabel(r"$\rm x(t)$")
    plt.xlabel(r"$\rm t$")
    plt.title(title_text)

    # Speichern der Bilder als .png und .pdf-Datei
    plt.savefig('evol.png', dpi=100, bbox_inches='tight', pad_inches=0.05)
    plt.savefig('evol.pdf', bbox_inches='tight', pad_inches=0.05)
    plt.show()

# Festlegung der Auszahlungsmatrix des symmetrischen (2x2)-Spiels
# Dominantes Spiel
# a_, b_, c_, d_ = -7, -1, -9, -3
# solve_and_plot(a_, b_, c_, d_, t_end=4)

# Koordinationsspiel
a_, b_, c_, d_ = 2, 4, 0, 5
solve_and_plot(a_, b_, c_, d_, t_end=6)

# Anti-Koordinationsspiel
# a_, b_, c_, d_ = -10, 2, 0, 1
# solve_and_plot(a_, b_, c_, d_, t_end=3)
