#!/usr/bin/env python3

import math                        # math
import matplotlib.pyplot as plt    # plotting
import numpy as np 
from numpy.linalg import inv       # inverse matrix
import ML_covarianceMatrix as CM   # loading user-defined module


#
# --- data generation
#

dataMatrix_A  = CM.testData( 0.4*math.pi, 1.0, 9.0, 30, [ 6.0,0.0])
dataMatrix_B  = CM.testData(-0.4*math.pi, 1.0, 9.0, 30, [-6.0,0.0])
mean_A, coVar_A = CM.covarianceMatrix(dataMatrix_A)  
mean_B, coVar_B = CM.covarianceMatrix(dataMatrix_B)  

#
# --- LDA
#

S_inverse    = inv(coVar_A+coVar_B)                     # build in
w_vector     = np.matmul(S_inverse,mean_B-mean_A)       # LDA

w_normalized = w_vector / np.sqrt(np.sum(w_vector**2))  # normalization
w_orthogonal = [w_normalized[1], -w_normalized[0]]      # \perp vector

LL = 5.0
midPoint = 0.5*(mean_B+mean_A)
x_lower  = midPoint[0] - LL*w_orthogonal[0]
y_lower  = midPoint[1] - LL*w_orthogonal[1]
x_upper  = midPoint[0] + LL*w_orthogonal[0]
y_upper  = midPoint[1] + LL*w_orthogonal[1]
x_LDA_plane = [x_lower, x_upper]
y_LDA_plane = [y_lower, y_upper]

if (1==2):
  print("\n#main w_orthogonal\n", w_orthogonal)
  print("\n#main midPoint    \n", midPoint    )

#
# --- printing
#

xEllipse_A, yEllipse_A = CM.plotEllipseData(coVar_A)  # coVar-ellipse
xEllipse_B, yEllipse_B = CM.plotEllipseData(coVar_B)  # coVar-ellipse

xData_A = [thisRow[0] for thisRow in dataMatrix_A]
yData_A = [thisRow[1] for thisRow in dataMatrix_A]
xData_B = [thisRow[0] for thisRow in dataMatrix_B]
yData_B = [thisRow[1] for thisRow in dataMatrix_B]

Z_95 = math.sqrt(5.991)                      # 95% confidence  
 
xE_95_A = [Z_95*xx + mean_A[0] for xx in xEllipse_A]
yE_95_A = [Z_95*yy + mean_A[1] for yy in yEllipse_A]
xE_95_B = [Z_95*xx + mean_B[0] for xx in xEllipse_B]
yE_95_B = [Z_95*yy + mean_B[1] for yy in yEllipse_B]

x_connectMean = [ mean_A[0], mean_B[0] ]
y_connectMean = [ mean_A[1], mean_B[1] ]
 
plt.plot(xE_95_A, yE_95_A, "k", linewidth=2.0, label="95%")
plt.plot(xE_95_B, yE_95_B, "b", linewidth=2.0, label="95%")
plt.plot(xData_A, yData_A, "ok", markersize=5)
plt.plot(xData_B, yData_B, "ob", markersize=5)

plt.plot(x_connectMean, y_connectMean, "r", linewidth=3.0)
plt.plot(x_connectMean, y_connectMean, "or", markersize=9.0)

plt.plot(x_LDA_plane, y_LDA_plane, "g", linewidth=4.0, label="LDA plane")
 
plt.legend(loc="upper left")
#plt.axis('square')                           # square plot
plt.savefig('foo.svg')                       # export figure
plt.show()
