import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from scipy.integrate import solve_ivp
from sympy import *

# DGL bestimmende Funktion g_A und g_B
def gA(x,y,D):
    return ((D[0,0]+D[1,1]-D[0,1]-D[1,0])*y + (D[0,1]-D[1,1]))*(x-x*x)

def gB(x,y,D):
    return ((D[0,0]+D[1,1]-D[0,1]-D[1,0])*x + (D[1,0]-D[1,1]))*(y-y*y)

# Definition des DGL Systems
def DGLsys(t,vx,DA,DB):
    x, y = vx
    dxdt = gA(x,y,DA)
    dydt = gB(x,y,DB)
    return [dxdt,dydt]

# Berechnung der Nashgleichgewichte
def find_nash_equilibria(DA,DB):
    equilibria = []
    # Berechnung der reinen Nash-Gleichgewichte
    for i in range(2):
        for j in range(2):
            if DA[i, j] == max(DA[:, j]) and DB[i, j] == max(DB[i, :]):
                equilibria.append({'type': 'pure', 's': (i+1,j+1)})
    #Berechnung der gemischten Nash-Gleichgewichte
    x, y = symbols('x, y')
    Dollar_A = DA[0,0]*x*y + DA[0,1]*x*(1-y) + DA[1,0]*(1-x)*y + DA[1,1]*(1-x)*(1-y)
    Dollar_B = DB[0,0]*x*y + DB[0,1]*x*(1-y) + DB[1,0]*(1-x)*y + DB[1,1]*(1-x)*(1-y)
    GlGemNashA = Eq(Dollar_A.diff(x), 0)
    GlGemNashB = Eq(Dollar_B.diff(y), 0)
    if GlGemNashA != False and GlGemNashB != False:
        s_A = solve(GlGemNashB,x)[0]
        s_B = solve(GlGemNashA,y)[0]
        if 0 < s_A < 1 and 0 < s_B < 1:
            equilibria.append({'type': 'mixed', 's*': (s_A,s_B)})
    return equilibria

# Loesen der DGL und plotten der Populationsentwicklung
def solve_and_plot(DA, DB, xy0, t_end=7, n_points=1000):
    # Groessenfestlegung der Labels usw. im Bild
    rcParams.update({
        'text.usetex': True,
        'axes.titlesize': 22,
        'axes.labelsize': 20,
        'xtick.labelsize': 20,
        'ytick.labelsize': 20
    })

    # Loesung der DGL fuer eine Anfangspopulation xy0=(x0,y0)
    t_eval = np.linspace(0, t_end, n_points)
    Loes = solve_ivp(DGLsys, [0, t_end], xy0, args=(DA,DB, ), t_eval=t_eval)

    # Fuer die Darstellung des Feldliniendiagramms streamplot
    SY,SX = np.mgrid[0:1:100j,0:1:100j]
    SgA = gA(SX,SY,DA)
    SgB = gB(SX,SY,DB)
    # Die Farbe wird die Geschwindigkeit der Aenderung des Populationsvektors anzeigen
    speed = np.sqrt(SgA*SgA + SgB*SgB)
    colorspeed = speed/speed.max()

    # Plotten des Bildes
    fig, ax = plt.subplots(figsize=(10, 8))
    strm = ax.streamplot(SX,SY,SgA,SgB,density=[2, 2], linewidth=1,color=colorspeed, cmap=plt.cm.cool)
    ax.plot(Loes.y[0],Loes.y[1],c="black", linewidth=1.5, linestyle='-')
    ax.plot(xy0[0],xy0[1], marker='o', color='grey', markersize=8)
    for i in np.linspace(n_points/30,n_points-2,10):
        ax.arrow(Loes.y[0][int(i)],Loes.y[1][int(i)], Loes.y[0][int(i)+1] - Loes.y[0][int(i)], Loes.y[1][int(i)+1]
                 - Loes.y[1][int(i)], head_width=0.02, head_length=0.03, fc='black', ec='black')
    # Erzeugung der nebenstehenden Farblegende colorbar
    cbar = fig.colorbar(strm.lines,ax=ax,pad=0.02)
    cbar.set_label(r'$\sqrt{{g_A}^2 + {g_B}^2}$',size=20)

    # Achsenbeschriftungen usw.
    plt.xlim(0,1)
    plt.ylim(0,1)
    plt.xticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],["$0$","","$0.2$","","$0.4$","","$0.6$","","$0.8$","","$1$"])
    plt.yticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],["$0$","","$0.2$","","$0.4$","","$0.6$","","$0.8$","","$1$"])
    plt.ylabel(r"$\rm y$")
    plt.xlabel(r"$\rm x$")

    # Berechnung und Terminalausgabe der Nashgleichgewichte
    equilibria = find_nash_equilibria(D_A,D_B)
    print("Nash-Gleichgewichte:")
    for eq in equilibria:
        if eq['type'] == 'pure':
            strat = list(eq["s"])
            print(f"Reines: Spieler 1 Strategie {strat[0]}, Spieler 2 Strategie {strat[1]}")
            if strat[0] == 2:
                strat[0] = 0
            if strat[1] == 2:
                strat[1] = 0
            # Kennzeichnung der reinen Nashgleichgewichte im Bild
            ax.scatter(strat[0], strat[1], s=150, marker='h', c="red")
        else:
            print(f"Gemischtes: Spieler 1 s*={eq["s*"][0]}, Spieler 2 s*={eq["s*"][1]}")
            # Kennzeichnung des gemischten Nashgleichgewichts im Bild
            ax.scatter(eq["s*"][0], eq["s*"][1], s=100, marker='^', c="red")

    #Speichern der Bilder als (.png , benötigt dvipng unter Linux)- und .pdf-Datei
    saveFig="./bimatrix.png"
    plt.savefig(saveFig, dpi=100,bbox_inches="tight",pad_inches=0.05,format="png")
    saveFig="./bimatrix.pdf"
    plt.savefig(saveFig,bbox_inches="tight",pad_inches=0.05,format="pdf")
    plt.show()

#Festlegung der Auszahlungsmatrix des unsymmetrischen (2x2)-Spiels, Anfangspopulation, Endzeit
# Klasse der Eckenspiele (Corner Class Games)
#D_A = np.array([[10,4],[12,5]])
#D_B = np.array([[10,12],[7,5]])
#te = 7
#xy_0 = [0.9, 0.2]
# Klasse der Sattelspiele (Saddle Class Games)
#D_A = np.array([[10,4],[9,5]])
#D_B = np.array([[10,7],[4,5]])
#te = 12
#xy_0 = [0.6, 0.1]
# Klasse der Zentrumsspiele (Center Class Games)
D_A = np.array([[10,4],[7,5]])
D_B = np.array([[10,12],[9,5]])
te = 8.5
xy_0 = [0.6, 0.01]

solve_and_plot(D_A, D_B, xy_0, te)
