import numpy as np
import matplotlib.pyplot as plt 
from matplotlib import rcParams
from scipy.integrate import solve_ivp

# Definition der Funktionen g_A und g_B
def gA(x,y,a,b,c,d):
    g = ((a+d-c-b)*y + (b-d))*(x-x*x)
    return g

def gB(x,y,a,b,c,d):
    g = ((a+d-c-b)*x + (c-d))*(y-y*y)
    return g

# Definition der DGL Systems
def DGLsys(t,vx):
    x, y = vx
    dxdt = gA(x,y,Aa,Ab,Ac,Ad)
    dydt = gB(x,y,Ba,Bb,Bc,Bd)
    return [dxdt,dydt]

# Groessenfestlegung der Labels usw. im Bild
rcParams.update({
    'text.usetex'       : True,
    'axes.titlesize' : 22,
    'axes.labelsize' : 20,  
    'xtick.labelsize' : 20 ,
    'ytick.labelsize' : 20 
})

#Festlegung der Auszahlungsmatrix des unsymmetrischen (2x2)-Spiels, Anfangspopulation, Endzeit
# Klasse der Eckenspiele (Corner Class Games)
#Aa=10
#Ab=4
#Ac=12
#Ad=5
#Ba=10
#Bb=12
#Bc=7
#Bd=5
#tend=7
#x0=0.9
#y0=0.2
# Klasse der Sattelspiele (Saddle Class Games)
Aa=10
Ab=4
Ac=9
Ad=5
Ba=10
Bb=7
Bc=4
Bd=5
tend=12
x0=0.6
y0=0.1
# Klasse der Zentrumsspiele (Center Class Games)
#Aa=10
#Ab=4
#Ac=7
#Ad=5
#Ba=10
#Bb=12
#Bc=9
#Bd=5
#tend=8.5
#x0=0.6
#y0=0.01

# Weitere Festlegungen
numpoints = 1000
t_val = np.linspace(0, tend, numpoints)

# Loesung der DGL fuer eine Anfangspopulation (x0,y0)
Loes = solve_ivp(DGLsys, [0, tend], [x0, y0], t_eval=t_val)

# Fuer die Darstellung des Feldliniendiagramms streamplot
SY,SX = np.mgrid[0:1:100j,0:1:100j]
SgA = gA(SX,SY,Aa,Ab,Ac,Ad)
SgB = gB(SX,SY,Ba,Bb,Bc,Bd)
# Die Farbe wird die Geschwindigkeit der Aenderung des Populationsvektors anzeigen
speed = np.sqrt(SgA*SgA + SgB*SgB)
colorspeed = speed/speed.max()

# Plotten des Bildes
fig, ax = plt.subplots()
strm = ax.streamplot(SX,SY,SgA,SgB,density=[2, 2], linewidth=1,color=colorspeed, cmap=plt.cm.cool)
ax.plot(Loes.y[0],Loes.y[1],c="black", linewidth=1.5, linestyle='-')
ax.plot(x0,y0, marker='o', color='grey', markersize=8)
for i in np.linspace(numpoints/30,numpoints-2,10):
    ax.arrow(Loes.y[0][int(i)],Loes.y[1][int(i)], Loes.y[0][int(i)+1] - Loes.y[0][int(i)], Loes.y[1][int(i)+1] - Loes.y[1][int(i)], head_width=0.02, head_length=0.03, fc='black', ec='black')
# Erzeugung der nebenstehenden Farblegende colorbar   
cbar = fig.colorbar(strm.lines,ax=ax,pad=0.02)
cbar.set_label(r'$\sqrt{{g_A}^2 + {g_B}^2}$',size=20)

# Achsenbeschriftungen usw.
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],["$0$","","$0.2$","","$0.4$","","$0.6$","","$0.8$","","$1$"])
plt.yticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],["$0$","","$0.2$","","$0.4$","","$0.6$","","$0.8$","","$1$"])
plt.ylabel(r"$\rm y$")
plt.xlabel(r"$\rm x$")

#Speichern der Bilder als (.png , benötigt dvipng unter Linux)- und .pdf-Datei
saveFig="./bimatrix.png"
plt.savefig(saveFig, dpi=100,bbox_inches="tight",pad_inches=0.05,format="png")
saveFig="./bimatrix.pdf"
plt.savefig(saveFig,bbox_inches="tight",pad_inches=0.05,format="pdf")
plt.show()
