#!/usr/bin/env python3

import math                        # math
import matplotlib.pyplot as plt    # plotting
import numpy as np
from numpy.linalg import inv       # inverse matrix
from numpy import linalg as LA     # linear algebra


# ***
# *** covariance matrix from data
# ***

def covarianceMatrix(data):
  '''normalized covariance matrix of input data'''
  nRow = len(data)                     # number of data points
  nCol = len(data[0])                  # dimension
  mean = [0.0 for iCol in range(nCol)]
  for thisRow in data:                 # loop over data points
    mean += thisRow*1.0/nRow           # vector-mean
#
  coVar = np.zeros((nCol,nCol))        # empty matrix
  for thisRow in data:
    dataPoint = thisRow - mean         # shifted data point
    coVar += np.outer(dataPoint,dataPoint)/nRow    # outer product
#
  if (1==2):
    print("\ninput data\n", data)
    print("nRow, nCol\n", nRow, nCol)
    print("mean\n", mean)
    print("\n covariance matrix: measured \n",coVar)
#
  return mean, coVar                   # returning both

# ***
# *** generate data for plotting an ellipse
# ***

def plotEllipseData(data):
  "generates ellipse from symmetric 2x2 slice of the input matrix"
#
  slice22 = [inner[:2] for inner in data[:2]]   # slice
  slice22[0][1] = slice22[1][0]                 # symmetrize
                                                # ^ should not be necessary
  eigenValues, eigenVectors = LA.eig(slice22)   # eigensystem
#
  if (eigenValues[0]<0.0) or (eigenValues[1]<0.0):
    print("# plotEllipseData: only positive eigenvalues (variance)")
    return
#
  if (1==2):
    print("\n# coVar     \n", coVar  )
    print("\n# slice22     ", slice22)
    print("\n# eigenValues ", eigenValues)
    print("\n# eigenVectors\n",eigenVectors)
#
  a = math.sqrt(eigenValues[0])
  b = math.sqrt(eigenValues[1])
  cTheta = eigenVectors[0][0]
  sTheta = eigenVectors[1][0]
  x = []
  y = []
  for i in range(nPoints:=101):          # walrus assignment
    tt = i*2.0*math.pi/(nPoints-1)       # full loop
    cc = math.cos(tt)
    ss = math.sin(tt)
    xx = a*cTheta*cc - b*sTheta*ss
    yy = a*sTheta*cc + b*cTheta*ss
    x.append(xx)
    y.append(yy)
#   print(xx,yy)
  return x, y

# ***
# *** generate 2D test data
# ***

def testData(angle, var1, var2, nData, startMean=[0.0,0.0]):
  """2D Gaussian for a given angle and main variances.
     A = \sum_i \lambda_i |lambda_i><lambda_i|"""
#
  eigen1 = [math.cos(angle),-math.sin(angle)]
  eigen2 = [math.sin(angle), math.cos(angle)]
  startCoVar  = var1*np.outer(eigen1,eigen1)
  startCoVar += var2*np.outer(eigen2,eigen2)
#  print("\n covariance matrix: data generation \n",startCoVar)
  return np.random.multivariate_normal(startMean, startCoVar, nData)

# *******
# *** SVM 
# *******

def SVM_ascent(svmData, svmL, svmSoft):
  "Simple SVM code"
  nIter        =  10000         # fixed number of update iterations
  epsilon      =  0.001         # update rate
  svmThreshold =  0.001         # for support vectors
  nData = len(svmL)             # total number of labeled data points
  svmA  = np.random.rand(nData) # random [0,1] Lagrange parameters
  svmW  = np.zeros(2)           # 2D w-vector

  for iIter in range(nIter):                 # loop
    svmW = [0.0, 0.0]
    for ii in range(nData):   
      svmW += svmL[ii]*svmA[ii]*svmData[ii]
#   print(f'# {iIter:5d} {svmW[0]:6.2f} {svmW[1]:6.2f}')

    for ii in range(nData):                  # updating L-parameters
      svmA[ii] += epsilon*(1.0-svmL[ii]*np.dot(svmData[ii],svmW))

    factor = np.dot(svmA,svmL)*1.0/nData
    for ii in range(nData):                  # orthogonalization
      svmA[ii] = svmA[ii] - factor*svmL[ii]

    for ii in range(nData): 
      svmA[ii] = max(0.0,svmA[ii])           # positiveness
      svmA[ii] = min(svmSoft,svmA[ii])       # soft bound
#
# --- iteration loop finished
# --- threshold per support vector
#
  svmB  = np.zeros(nData) 
  BB_mean   = 0.0
  BB_number = 0
  for ii in range(nData):                    # positiveness
    if (svmA[ii]>svmThreshold):
      svmB[ii] = np.dot(svmData[ii],svmW) - svmL[ii]
      BB_mean   += svmB[ii]
      BB_number += 1
  BB_mean = BB_mean/BB_number

  print("# SVM data ")
# print(f'# {svmW[0]:6.2f} {svmW[1]:6.2f}')
  for ii in range(nData):                  
#   if (svmA[ii]>svmThreshold):
      print(f'{ii:5d} {svmA[ii]:10.6f} {svmL[ii]:5.1f} {svmB[ii]:10.4f} ',
            end="")
      print(f'{svmData[ii][0]:6.2f} {svmData[ii][1]:6.2f} ')
#
  return svmA, BB_mean, svmW     # Lagrange / threshold / w-vector
 

# ********
# *** main
# ********

dataMatrix_A  = testData( 0.2*math.pi, 1.0, 9.0, 20, [ 4.0,0.0])
dataMatrix_B  = testData(-0.3*math.pi, 1.0,12.0, 20, [-4.0,0.0])
mean_A, coVar_A = covarianceMatrix(dataMatrix_A)
mean_B, coVar_B = covarianceMatrix(dataMatrix_B)

#
# --- SVM preparation
#
dataLabel_A = [ 1.0 for _ in range(len(dataMatrix_A))]
dataLabel_B = [-1.0 for _ in range(len(dataMatrix_B))]
data_SVM   = np.vstack((dataMatrix_A,dataMatrix_B))    # vertical stacking
data_label = np.hstack((dataLabel_A,dataLabel_B))      # horizontal stacking
                                                       # do SVM
svmSoft = 15.0                                         # soft margin threshold
data_lagrange, data_BB, data_ww = SVM_ascent(data_SVM, data_label, svmSoft)

if (1==2):
  print()
  print(dataMatrix_A)
  print(dataLabel_A)
  print(dataMatrix_B)
  print(dataLabel_B)
  print(data_SVM)
  print(data_label)

#
# --- data points (normal, support)
#

xA_normal  = []
yA_normal  = []
xB_normal  = []
yB_normal  = []

xA_support = []
yA_support = []
xB_support = []
yB_support = []

xA_miss    = []
yA_miss    = []
xB_miss    = []
yB_miss    = []

svmThreshold = 0.001
for ii in range(len(data_label)):
  if (data_label[ii]>0) and (data_lagrange[ii]<svmThreshold):
   xA_normal.append(data_SVM[ii][0])
   yA_normal.append(data_SVM[ii][1])

  if (data_label[ii]<0) and (data_lagrange[ii]<svmThreshold):
   xB_normal.append(data_SVM[ii][0])
   yB_normal.append(data_SVM[ii][1])

  if (data_label[ii]>0) and (data_lagrange[ii]>svmThreshold):
   if (data_lagrange[ii]>0.95*svmSoft):
     xA_miss.append(data_SVM[ii][0])
     yA_miss.append(data_SVM[ii][1])
   else:
     xA_support.append(data_SVM[ii][0])
     yA_support.append(data_SVM[ii][1])

  if (data_label[ii]<0) and (data_lagrange[ii]>svmThreshold):
   if (data_lagrange[ii]>0.95*svmSoft):
     xB_miss.append(data_SVM[ii][0])
     yB_miss.append(data_SVM[ii][1])
   else:
     xB_support.append(data_SVM[ii][0])
     yB_support.append(data_SVM[ii][1])

print("# number of support vectors ",len(xA_support),len(xB_support))
print("# miss-classfied    vectors ",len(xA_miss),len(xB_miss))


#
# --- SVM plane and margins
#

svmPlane_x = [0.0, 0.0]
svmPlane_y = [0.0, 0.0]
svmMar_A_x = [0.0, 0.0]
svmMar_A_y = [0.0, 0.0]
svmMar_B_x = [0.0, 0.0]
svmMar_B_y = [0.0, 0.0]
r  = math.sqrt(data_ww[0]*data_ww[0]+data_ww[1]*data_ww[1])
rr = r*r
Len = 4.0
svmPlane_x[0] = (data_BB-0.0)*data_ww[0]/rr + Len*data_ww[1]/r
svmPlane_y[0] = (data_BB-0.0)*data_ww[1]/rr - Len*data_ww[0]/r
svmPlane_x[1] = (data_BB-0.0)*data_ww[0]/rr - Len*data_ww[1]/r
svmPlane_y[1] = (data_BB-0.0)*data_ww[1]/rr + Len*data_ww[0]/r

svmMar_A_x[0] = (data_BB-1.0)*data_ww[0]/rr + Len*data_ww[1]/r
svmMar_A_y[0] = (data_BB-1.0)*data_ww[1]/rr - Len*data_ww[0]/r
svmMar_A_x[1] = (data_BB-1.0)*data_ww[0]/rr - Len*data_ww[1]/r
svmMar_A_y[1] = (data_BB-1.0)*data_ww[1]/rr + Len*data_ww[0]/r

svmMar_B_x[0] = (data_BB+1.0)*data_ww[0]/rr + Len*data_ww[1]/r
svmMar_B_y[0] = (data_BB+1.0)*data_ww[1]/rr - Len*data_ww[0]/r
svmMar_B_x[1] = (data_BB+1.0)*data_ww[0]/rr - Len*data_ww[1]/r
svmMar_B_y[1] = (data_BB+1.0)*data_ww[1]/rr + Len*data_ww[0]/r

#
# --- printing, including covariance ellipse
#

xEllipse_A, yEllipse_A = plotEllipseData(coVar_A)
xEllipse_B, yEllipse_B = plotEllipseData(coVar_B)  

Z_95 = math.sqrt(5.991)                  
xE_95_A = [Z_95*xx + mean_A[0] for xx in xEllipse_A]
yE_95_A = [Z_95*yy + mean_A[1] for yy in yEllipse_A]
xE_95_B = [Z_95*xx + mean_B[0] for xx in xEllipse_B]
yE_95_B = [Z_95*yy + mean_B[1] for yy in yEllipse_B]

x_connectMean = [ mean_A[0], mean_B[0] ]
y_connectMean = [ mean_A[1], mean_B[1] ]

plt.plot(xE_95_A, yE_95_A, "k", linewidth=2.0, label="95%")
plt.plot(xE_95_B, yE_95_B, "b", linewidth=2.0, label="95%")

plt.plot(xA_normal, yA_normal, "ok", markersize=8)
plt.plot(xB_normal, yB_normal, "ob", markersize=8)
plt.plot(xA_support, yA_support, "ok", markersize=10, 
         markeredgewidth=2, mfc='none')
plt.plot(xB_support, yB_support, "ob", markersize=10,
         markeredgewidth=2, mfc='none')
plt.plot(xA_miss, yA_miss, "ok", markersize=8, markeredgewidth=2,
         mfc=[1.0,0.84,0.0], label="miss")
plt.plot(xB_miss, yB_miss, "ob", markersize=8, markeredgewidth=2,
         mfc=[1.0,0.84,0.0], label="miss")

plt.plot(x_connectMean, y_connectMean, "r", linewidth=3.0)
plt.plot(x_connectMean, y_connectMean, "or", markersize=9.0)

plt.plot(svmPlane_x, svmPlane_y, color=[1.0,0.84,0.0], 
         linewidth=4.0, label="SVM plane")
plt.plot(svmMar_A_x, svmMar_A_y, color=[1.0,0.84,0.0], 
         linewidth=4.0, linestyle="dashed", label="SVM margin")
plt.plot(svmMar_B_x, svmMar_B_y, color=[1.0,0.84,0.0],
         linewidth=4.0, linestyle="dashed", label="")


plt.legend(loc="upper left")
#plt.axis('square')                           # square plot
plt.savefig('foo.svg')                       # export figure
plt.show()
