import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import matplotlib.gridspec as gridspec
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

rcParams.update({
    'axes.titlesize' : 10,
    'axes.labelsize' : 10,
    'xtick.labelsize' : 10 ,
    'ytick.labelsize' : 10 ,
    'legend.fontsize' : 10
})

set_m = [1,1,1] # Bitte abändern, falls andere Massenwerte
set_frames_N = 100 # Anzahl der Bilder im Film

data_C = np.genfromtxt("./N-BodyProblem.dat")
N = len(data_C)

# Koordinaten-Transformation: Schwerpunkt, r, v_r
def Transf_S(data, m):
    # Schwerpunkt
    x_S = (m[0]*data[:,2+0*6] + m[1]*data[:,2+1*6] + m[2]*data[:,2+2*6]) / sum(m)
    y_S = (m[0]*data[:,2+0*6+2] + m[1]*data[:,2+1*6+2] + m[2]*data[:,2+2*6+2]) / sum(m)
    z_S = (m[0]*data[:,2+0*6+4] + m[1]*data[:,2+1*6+4] + m[2]*data[:,2+2*6+4]) / sum(m)
    # Geschwindigkeit Schwerpunkt
    v_x_S = (m[0]*data[:,2+0*6+1] + m[1]*data[:,2+1*6+1] + m[2]*data[:,2+2*6+1]) / sum(m)
    v_y_S = (m[0]*data[:,2+0*6+3] + m[1]*data[:,2+1*6+3] + m[2]*data[:,2+2*6+3]) / sum(m)
    v_z_S = (m[0]*data[:,2+0*6+5] + m[1]*data[:,2+1*6+5] + m[2]*data[:,2+2*6+5]) / sum(m)
    # Relativer Radius
    r_1 = np.sqrt((data[:,2+0*6]-x_S)**2 + (data[:,2+0*6+2]-y_S)**2 + (data[:,2+0*6+4]-z_S)**2)
    r_2 = np.sqrt((data[:,2+1*6]-x_S)**2 + (data[:,2+1*6+2]-y_S)**2 + (data[:,2+1*6+4]-z_S)**2)
    r_3 = np.sqrt((data[:,2+2*6]-x_S)**2 + (data[:,2+2*6+2]-y_S)**2 + (data[:,2+2*6+4]-z_S)**2)
    # Radiale Geschwindigkeit
    vr_1 = ((data[:,2+0*6]-x_S)*(data[:,2+0*6+1]-v_x_S) + (data[:,2+0*6+2]-y_S)*(data[:,2+0*6+3]-v_y_S) + (data[:,2+0*6+4]-z_S)*(data[:,2+0*6+5]-v_z_S)) / r_1
    vr_2 = ((data[:,2+1*6]-x_S)*(data[:,2+1*6+1]-v_x_S) + (data[:,2+1*6+2]-y_S)*(data[:,2+1*6+3]-v_y_S) + (data[:,2+1*6+4]-z_S)*(data[:,2+1*6+5]-v_z_S)) / r_2
    vr_3 = ((data[:,2+2*6]-x_S)*(data[:,2+2*6+1]-v_x_S) + (data[:,2+2*6+2]-y_S)*(data[:,2+2*6+3]-v_y_S) + (data[:,2+2*6+4]-z_S)*(data[:,2+2*6+5]-v_z_S)) / r_3
    return [x_S,y_S,z_S],[r_1,r_2,r_3],[vr_1,vr_2,vr_3]

def create_animation_short(data, R, Vr, r_S, N, frames_N=100):
    # Berechnung der zu Animation verwendeten Indices
    indices = [int(f * (N - 1) / (frames_N - 1)) for f in range(frames_N)]
    fig = plt.figure(figsize=(11, 5))
    gs = gridspec.GridSpec(2, 3, width_ratios=[2, 1, 1], hspace=0.3, wspace=0.4)
    # Achsen-Setup
    ax_main = fig.add_subplot(gs[:, 0])
    axes_sub = [fig.add_subplot(gs[0, 1]), fig.add_subplot(gs[0, 2]), fig.add_subplot(gs[1, 1])]
    ax_cm = fig.add_subplot(gs[1, 2])
    colors = ['red', 'blue', 'green']
    labels = [r"$\rm \vec{r}_1(t)$", r"$\rm \vec{r}_2(t)$", r"$\rm \vec{r}_3(t)$"]
    dots_main, dots_sub, lines_cm = [], [], []
    # Hauptplot (Trajektorien im im Konfigurationsraum (x,y))
    ax_main.axis('equal')
    for j in range(3):
        idx_x, idx_y = 2 + j * 6, 2 + j * 6 + 2
        ax_main.plot(data[:,idx_x],data[:,idx_y], color=colors[j], label=labels[j], alpha=0.3)
        dots_main.append(ax_main.scatter([], [], color=colors[j], s=30))
        lines_cm.append(ax_main.plot([], [], color='grey', alpha=0.2)[0])
    ax_main.legend(loc='upper right')
    # Phasenraum-Plots (R vs Vr)
    for j, ax in enumerate(axes_sub):
        ax.plot(R[j], Vr[j], color=colors[j], alpha=0.3)
        dots_sub.append(ax.scatter([], [], color=colors[j], s=30))
        ax.set_xlabel(rf"$r_{j+1}$")
        ax.set_ylabel(rf"$v_{{r_{j+1}}}$")
    # Schwerpunktplot
    ax_cm.plot(r_S[0], r_S[1], color="black", label="Schwerpunkt")
    ax_cm.legend(loc='upper left')

    def update(frame):
        i = indices[frame]
        changed_objects = []
        for j in range(3):
            data[:,2+0*6+2]
            # Update Punkte der Körper im Hauptplot
            x, y = data[i,2+j*6], data[i,2+j*6+2]
            dots_main[j].set_offsets([[x, y]])
            # Update Verbindungslinien zum Schwerpunkt
            lines_cm[j].set_data([r_S[0][i], x], [r_S[1][i], y])
            # Update Punkte in den Phasenraum-Plots
            dots_sub[j].set_offsets([[R[j][i], Vr[j][i]]])
            changed_objects.extend([dots_main[j], lines_cm[j], dots_sub[j]])
        return changed_objects
    ani = FuncAnimation(fig, update, frames=frames_N, interval=80, blit=True)
    plt.close(fig)
    return ani

r_S, R, Vr = Transf_S(data_C,set_m)
ani = create_animation_short(data_C, R, Vr, r_S, N, set_frames_N)
ani.save('N-BodyProblem.mp4', writer='ffmpeg')
