#!/usr/bin/env python3

import torch                       # pytorch; ML
import math                        # math
import matplotlib.pyplot as plt    # plotting
import numpy as np
from numpy.linalg import inv       # inverse matrix
from numpy import linalg as LA     # linear algebra


# ***
# *** covariance matrix from data
# ***

def covarianceMatrix(data):
  '''normalized covariance matrix of input data'''
  nRow = len(data)                     # number of data points
  nCol = len(data[0])                  # dimension
  mean = [0.0 for iCol in range(nCol)]
  for thisRow in data:                 # loop over data points
    mean += thisRow*1.0/nRow           # vector-mean
#
  coVar = np.zeros((nCol,nCol))        # empty matrix
  for thisRow in data:
    dataPoint = thisRow - mean         # shifted data point
    coVar += np.outer(dataPoint,dataPoint)/nRow    # outer product
#
  if (1==2):
    print("\ninput data\n", data)
    print("nRow, nCol\n", nRow, nCol)
    print("mean\n", mean)
    print("\n covariance matrix: measured \n",coVar)
#
  return mean, coVar                   # returning both

# ***
# *** generate data for plotting an ellipse
# ***

def plotEllipseData(data):
  "generates ellipse from symmetric 2x2 slice of the input matrix"
#
  slice22 = [inner[:2] for inner in data[:2]]   # slice
  slice22[0][1] = slice22[1][0]                 # symmetrize
                                                # ^ should not be necessary
  eigenValues, eigenVectors = LA.eig(slice22)   # eigensystem
#
  if (eigenValues[0]<0.0) or (eigenValues[1]<0.0):
    print("# plotEllipseData: only positive eigenvalues (variance)")
    return
#
  if (1==2):
    print("\n# coVar     \n", coVar  )
    print("\n# slice22     ", slice22)
    print("\n# eigenValues ", eigenValues)
    print("\n# eigenVectors\n",eigenVectors)
#
  a = math.sqrt(eigenValues[0])
  b = math.sqrt(eigenValues[1])
  cTheta = eigenVectors[0][0]
  sTheta = eigenVectors[1][0]
  x = []
  y = []
  for i in range(nPoints:=101):          # walrus assignment
    tt = i*2.0*math.pi/(nPoints-1)       # full loop
    cc = math.cos(tt)
    ss = math.sin(tt)
    xx = a*cTheta*cc - b*sTheta*ss
    yy = a*sTheta*cc + b*cTheta*ss
    x.append(xx)
    y.append(yy)
#   print(xx,yy)
  return x, y

# ***
# *** generate 2D test data
# ***

def testData(angle, var1, var2, nData, startMean=[0.0,0.0]):
  """2D Gaussian for a given angle and main variances.
     A = \sum_i \lambda_i |lambda_i><lambda_i|"""
#
  eigen1 = [math.cos(angle),-math.sin(angle)]
  eigen2 = [math.sin(angle), math.cos(angle)]
  startCoVar  = var1*np.outer(eigen1,eigen1)
  startCoVar += var2*np.outer(eigen2,eigen2)
#  print("\n covariance matrix: data generation \n",startCoVar)
  return np.random.multivariate_normal(startMean, startCoVar, nData)

# ******************
# *** SVM with torch
# ******************

def SVM_torch(Data, Label, C):
  """SVM; direct minimization of loss function, 
   Loss = (1-C)*|w|/2 
        + C ∑ max[0, 1 - y ( x*w - b ) ]^2"""
#
  nIter        = 10000          # fixed number of update iterations
  epsilon      =  0.01          # update rate

# adaptable variables  W, B
# on most machines double == float64
  W = torch.randn(2,requires_grad=True,dtype=torch.double)
  B = torch.randn(1,requires_grad=True,dtype=torch.double)

  for iIter in range(nIter):

# forward pass, with element-wise
# addition, multiplication, power, ReLU(x) = max(0,x)
    loss = (1.0-C)*0.5*torch.dot(W,W) \
         + C*torch.relu(1.0-Label*(torch.matmul(Data,W)-B[0])).pow(2).sum()

    loss.backward()                          # backward pass

    with torch.no_grad():                    # off computational graph
      W -= epsilon*W.grad                    # steepest descent
      B -= epsilon*B.grad                      
      W.grad = None                          # remove gradients
      B.grad = None
#
    if (iIter%100==0):
      print(f'{iIter:5d} {W[0]:8.4f} {W[1]:8.4f} {B[0]:8.4f} {0.0:8.4f} ')
#     print(f'{iIter:4d} {W.grad[0]:8.4f} {W.grad[1]:8.4f} {B.grad[0]:8.4f} ')
#
  return W.tolist(), B.tolist()              # returning non-tensors

# ********
# *** main
# ********

listData_A  = testData( 0.2*math.pi, 1.0, 9.0, 10, [ 4.0,0.0])
listData_B  = testData(-0.3*math.pi, 1.0,12.0, 10, [-4.0,0.0])
mean_A, coVar_A = covarianceMatrix(listData_A)
mean_B, coVar_B = covarianceMatrix(listData_B)

x_Ellipse_A, y_Ellipse_A = plotEllipseData(coVar_A)
x_Ellipse_B, y_Ellipse_B = plotEllipseData(coVar_B)  

#
# --- combine lists; lists to torch tensors
#
listLabel_A = [ 1.0 for _ in range(len(listData_A))]
listLabel_B = [-1.0 for _ in range(len(listData_B))]
list_allData  = np.vstack((listData_A,listData_B)) 
list_allLabel = np.hstack((listLabel_A,listLabel_B))  

tensor_allData  = torch.tensor(list_allData)     # fixed tensors
tensor_allLabel = torch.tensor(list_allLabel)    # non adaptable

torch_C = 0.9                                    # in [0,1]
data_ww, data_BB = SVM_torch(tensor_allData, tensor_allLabel, torch_C)


if (1==1):
  print("\n# list_allLabel  \n",list_allLabel)
  print("\n# data_ww        \n",data_ww      )
  print("\n# data_BB        \n",data_BB      )

#
# --- SVM plane and margins
#

svmPlane_x = [0.0, 0.0]
svmPlane_y = [0.0, 0.0]
svmMar_A_x = [0.0, 0.0]
svmMar_A_y = [0.0, 0.0]
svmMar_B_x = [0.0, 0.0]
svmMar_B_y = [0.0, 0.0]
r  = math.sqrt(data_ww[0]*data_ww[0]+data_ww[1]*data_ww[1])
rr = r*r
Len = 6.0
svmPlane_x[0] = (data_BB[0]-0.0)*data_ww[0]/rr + Len*data_ww[1]/r
svmPlane_y[0] = (data_BB[0]-0.0)*data_ww[1]/rr - Len*data_ww[0]/r
svmPlane_x[1] = (data_BB[0]-0.0)*data_ww[0]/rr - Len*data_ww[1]/r
svmPlane_y[1] = (data_BB[0]-0.0)*data_ww[1]/rr + Len*data_ww[0]/r

svmMar_A_x[0] = (data_BB[0]-1.0)*data_ww[0]/rr + Len*data_ww[1]/r
svmMar_A_y[0] = (data_BB[0]-1.0)*data_ww[1]/rr - Len*data_ww[0]/r
svmMar_A_x[1] = (data_BB[0]-1.0)*data_ww[0]/rr - Len*data_ww[1]/r
svmMar_A_y[1] = (data_BB[0]-1.0)*data_ww[1]/rr + Len*data_ww[0]/r

svmMar_B_x[0] = (data_BB[0]+1.0)*data_ww[0]/rr + Len*data_ww[1]/r
svmMar_B_y[0] = (data_BB[0]+1.0)*data_ww[1]/rr - Len*data_ww[0]/r
svmMar_B_x[1] = (data_BB[0]+1.0)*data_ww[0]/rr - Len*data_ww[1]/r
svmMar_B_y[1] = (data_BB[0]+1.0)*data_ww[1]/rr + Len*data_ww[0]/r

#
# --- printing, including covariance ellipse
#

Z_95 = math.sqrt(5.991)                  
xE_95_A = [Z_95*xx + mean_A[0] for xx in x_Ellipse_A]
yE_95_A = [Z_95*yy + mean_A[1] for yy in y_Ellipse_A]
xE_95_B = [Z_95*xx + mean_B[0] for xx in x_Ellipse_B]
yE_95_B = [Z_95*yy + mean_B[1] for yy in y_Ellipse_B]

x_connectMean = [ mean_A[0], mean_B[0] ]
y_connectMean = [ mean_A[1], mean_B[1] ]

plt.plot(xE_95_A, yE_95_A, "k", linewidth=2.0, label="95%")
plt.plot(xE_95_B, yE_95_B, "b", linewidth=2.0, label="95%")

plt.plot(listData_A[:,0], listData_A[:,1], "ok", markersize=8)
plt.plot(listData_B[:,0], listData_B[:,1], "ob", markersize=8)

plt.plot(x_connectMean, y_connectMean, "r", linewidth=3.0)
plt.plot(x_connectMean, y_connectMean, "or", markersize=9.0)

plt.plot(svmPlane_x, svmPlane_y, color=[1.0,0.84,0.0], 
         linewidth=4.0, label="SVM plane")
plt.plot(svmMar_A_x, svmMar_A_y, color=[1.0,0.84,0.0], 
         linewidth=4.0, linestyle="dashed", label="SVM margin")
plt.plot(svmMar_B_x, svmMar_B_y, color=[1.0,0.84,0.0],
         linewidth=4.0, linestyle="dashed", label="")

plt.legend(loc="upper left")
#plt.axis('square')                           # square plot
plt.savefig('foo.svg')                       # export figure
plt.show()
