import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import matplotlib.colors as colors
from scipy.integrate import solve_ivp
from sympy import *

# baryzentrisches Dreiecks-Koordinatensystem
def xy(vx):
    return np.array([vx[1]+vx[2]/2,vx[2]])

# Definition der Funktionen g_x und g_y
def g(xy, D):
    m = 3
    x = np.array([1-xy[0]-xy[1]/2,xy[0]-xy[1]/2,xy[1]])
    dx_dt = []
    for i in range(m):
        dx_dt.append(sum(D[i,j]*x[i]*x[j] for j in range(m)) - sum(sum(D[k,j]*x[k]*x[j] for j in range(m)) for k in range(m))*x[i])
    return [dx_dt[1]+dx_dt[2]/2,dx_dt[2]]

# Vektorisiertes DGL-System mit numpy
def DGLsys(t, x, D):
    x = np.asarray(x)
    Dx = D @ x
    u = np.dot(x, Dx)
    return x * (Dx - u)

# Alternative Definition des DGL-Systems
#def DGLsys(t,x,D):
#    m=3
#    dx_dt = []
#    for i in range(m):
#        dx_dt.append(sum(D[i,j]*x[i]*x[j] for j in range(m)) - sum(sum(D[k,j]*x[k]*x[j] for j in range(m)) for k in range(m))*x[i])
#    return dx_dt

# Funktion zur Berechnung der Nashgleichgewichte
def find_nash_equilibria(D):
    equilibria = []
    # Berechnung der reinen Nash-Gleichgewichte
    for i in range(D.shape[0]):
        for j in range(D.shape[0]):
            if D[i, j] == max(D[:, j]) and D[j, i] == max(D[:, i]):
                equilibria.append({'type': 'pure', 's': (i+1,j+1)})
    #Berechnung der gemischten Nash-Gleichgewichte (Interior-gemischte und Boundary-gemischte)
    Loes_GN = []
    x1,x2,x3,y1,y2,y3 = symbols('x_1,x_2,x_3,y_1,y_2,y_3')
    xs = Matrix([x1,x2,x3])
    ys = Matrix([y1,y2,y3])
    Dollar_A = transpose(xs)*D*ys
    Dollar_As = Dollar_A.subs(x3,1-x1-x2).subs(y3,1-y1-y2)[0]
    Dollar_As_1 = Dollar_A.subs(x1,0).subs(x3,1-x2).subs(y3,1-y1-y2)[0]
    Dollar_As_2 = Dollar_A.subs(x2,0).subs(x3,1-x1).subs(y3,1-y1-y2)[0]
    Dollar_As_3 = Dollar_A.subs(x3,0).subs(x2,1-x1).subs(y3,1-y1-y2)[0]
    GemNash_Eq1 = Eq(Dollar_As.diff(x1),0)
    GemNash_Eq2 = Eq(Dollar_As.diff(x2),0)
    GemNash_Eq_1 = Eq(Dollar_As_1.diff(x2),0)
    GemNash_Eq_2 = Eq(Dollar_As_2.diff(x1),0)
    GemNash_Eq_3 = Eq(Dollar_As_3.diff(x1),0)
    Bed=Eq(1,y1+y2+y3)
    Loes_GN.append(solve([GemNash_Eq1,GemNash_Eq2,Bed]))
    Bed_a=Eq(0,y1)
    Bed_b=Eq(1,y2+y3)
    Loes_GN.append(solve([GemNash_Eq_1,Bed_a,Bed_b]))
    Bed_a=Eq(0,y2)
    Bed_b=Eq(1,y1+y3)
    Loes_GN.append(solve([GemNash_Eq_2,Bed_a,Bed_b]))
    Bed_a=Eq(0,y3)
    Bed_b=Eq(1,y1+y2)
    Loes_GN.append(solve([GemNash_Eq_3,Bed_a,Bed_b]))
    for l in Loes_GN:
        if l and 0 <= l[y1] < 1 and 0 <= l[y1] < 1 and 0 <= l[y1] < 1 :
            equilibria.append({'type': 'mixed', 's*': (l[y1],l[y2],l[y3])})
    return equilibria


# Loesen der DGL und plotten der Populationsentwicklung
def solve_and_plot(D, x_init, t_end=15, n_points=1000):
    # Groessenfestlegung der Labels usw. im Bild
    rcParams.update({
        'text.usetex'       : True,
        'figure.figsize'    : [8,6],
        'axes.titlesize' : 14,
        'axes.labelsize' : 16,
        'xtick.labelsize' : 14 ,
        'ytick.labelsize' : 14
    })

    # Weitere Festlegungen
    fehler = 10**(-13)
    t_eval = np.linspace(0, t_end, n_points)

    # Loesung der DGL fuer eine Anfangspopulation x_init
    Loes = solve_ivp(DGLsys, [0, t_end], x_init, args=(D, ), t_eval=t_eval, rtol=fehler, atol=fehler)

    # Fuer die Darstellung des Feldliniendiagramms streamplot
    Y, X = np.mgrid[0:1:100j, 0:1:100j]
    gXY = g([X,Y],D)
    # Die Farbe wird die Geschwindigkeit der Aenderung des Populationsvektors anzeigen
    colorspeed = np.sqrt(gXY[0]**2 + gXY[1]**2)

    # Plotten des Bildes
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.set_xlabel(r"$x$")
    ax.set_ylabel(r"$y$")
    strm = ax.streamplot(X, Y, gXY[0], gXY[1], linewidth=1,density=[2, 2],norm=colors.Normalize(vmin=0.,vmax=0.4), color=colorspeed, cmap=plt.cm.jet)
    ax.fill([0,0.5,0], [0,1,1], facecolor='black')
    ax.fill([1,0.5,1], [0,1,1], facecolor='black')
    ax.fill([0,1,1,0], [0,0,-0.02,-0.02], facecolor='black')
    ax.fill([0,1,1,0], [1,1,1.02,1.02], facecolor='black')
    ax.scatter(xy(x_init)[0],xy(x_init)[1], s=50, marker='o', c="black")
    ax.plot(xy(Loes.y[0:3])[0], xy(Loes.y[0:3])[1],c="black",linewidth=2)
    # Erzeugung der nebenstehenden Farblegende colorbar
    cbar=plt.colorbar(strm.lines, ax=ax, aspect=20)
    cbar.set_label(r'$\left| \vec{v} \right| =\sqrt{{g_x}^2 + {g_y}^2}$',size=20)
    ax.set_xlim(0, 1)
    ax.set_ylim(-0.02, 1.02)

    # Berechnung und Terminalausgabe der Nashgleichgewichte
    equilibria = find_nash_equilibria(D)
    print("Nash-Gleichgewichte:")
    for eq in equilibria:
        if eq['type'] == 'pure':
            print(f"Reines: Spieler 1 Strategie {eq["s"][0]}, Spieler 2 Strategie {eq["s"][1]}")
            # Kennzeichnung der reinen Nashgleichgewichte im Bild
            if eq["s"][0] == 1 and eq["s"][1] == 1:
                plt.scatter(0,0, s=150, marker='h', c="red")
            if eq["s"][0] == 2 and eq["s"][1] == 2:
                plt.scatter(1,0, s=150, marker='h', c="red")
            if eq["s"][0] == 3 and eq["s"][1] == 3:
                plt.scatter(0.5,1, s=150, marker='h', c="red")
        else:
            print(f"Gemischtes: s*={eq["s*"]}")
            plt.scatter(xy(eq["s*"])[0],xy(eq["s*"])[1], s=100, marker='^', c="red")

    #Speichern der Bilder als (.png , benötigt dvipng unter Linux)- und .pdf-Datei
    saveFig="./2x3-Spiel.png"
    plt.savefig(saveFig, dpi=100,bbox_inches="tight",pad_inches=0.05,format="png")
    saveFig="./2x3-Spiel.pdf"
    plt.savefig(saveFig,bbox_inches="tight",pad_inches=0.05,format="pdf")
    plt.show()


# Definition der Auszahlungsmatrizen der 19 Zeeman-Klassen der evolutionaeren symmetrischen (2x3)-Spiele
D_1 = np.array([[0,2,-1],[-1,0,2],[2,-1,0]])
D_2 = np.array([[0,3,-1],[1,0,1],[3,-1,0]])
D_3 = np.array([[0,1,1],[-1,0,3],[1,1,0]])
D_4 = np.array([[0,6,-4],[-3,0,5],[-1,3,0]])
D_5 = np.array([[0,1,1],[1,0,1],[1,1,0]])
D_6 = np.array([[0,3,-1],[3,0,-1],[1,1,0]])
D_7 = np.array([[0,1,3],[-1,0,5],[1,3,0]])
D_8 = np.array([[0,1,-1],[-1,0,1],[-1,1,0]])
D_9 = np.array([[0,-1,3],[-1,0,3],[1,1,0]])
D_10 = np.array([[0,1,1],[-1,0,1],[-1,-1,0]])
D_11 = np.array([[0,1,1],[1,0,1],[-1,-1,0]])
D_12 = np.array([[0,1,-1],[1,0,1],[1,-1,0]])
D_13 = np.array([[0,-1,-1],[1,0,1],[-1,1,0]])
D_14 = np.array([[0,-1,1],[-1,0,1],[-1,-1,0]])
D_15 = np.array([[0,-1,-1],[1,0,-1],[-1,-1,0]])
D_16 = np.array([[0,-1,-1],[1,0,-3],[-1,-3,0]])
D_17 = np.array([[0,1,-1],[-3,0,1],[-1,1,0]])
D_18 = np.array([[0,1,-3],[1,0,-1],[-3,-1,0]])
D_19 = np.array([[0,-3,-1],[-3,0,-1],[-1,-1,0]])

#Festlegung der Anfangspopulation, Endzeit
x_0 = [0.06, 0.9, 0.04]
t_e = 15

# Loesen der DGL und plotten der Populationsentwicklung
solve_and_plot(D_3, x_0, t_e)
