#!/usr/bin/env python3

#
# basic attention layer, token-wise neural net not included
#
import torch
import math
import matplotlib.pyplot as plt

nLayer   =  3                      # number of layers
tokenPerLayer = 5                  # context length
nToken   =  tokenPerLayer*nLayer   # total number of token
dim      =    2                    # embedding dimension
yesMask  =  1.0                    # 1/0 masked attention on/off

b        =  0.9                    # (+b) / (-b)  : logical True/False
nIter   =   4400                   # training iterations
learning_rate = 1.5e-2

#
# attention layer module
#
class attentionLayer(torch.nn.Module):
  def __init__(self, dim, nContext, yesMask=1, yesNorm=True, myID=0):
    super().__init__()
    self.Q_mat  = torch.randn(nContext,dim,dim,requires_grad=True)
    self.K_mat  = torch.randn(nContext,dim,dim,requires_grad=True)
    self.V_mat  = torch.randn(nContext,dim,dim,requires_grad=True)
#
    mySigma = 1.0/(dim*dim)    # initialization outside computation graph
    torch.nn.init.normal_(self.Q_mat, mean=0.0, std=mySigma)
    torch.nn.init.normal_(self.K_mat, mean=0.0, std=mySigma)
    torch.nn.init.normal_(self.V_mat, mean=0.0, std=mySigma)
#
    self.alpha  = torch.zeros(nContext,nContext)
    self.yesMask  = yesMask    # masked self attention
    self.yesNorm  = yesNorm    # layer normalization
    self.nContext = nContext
    self.dim      = dim        # embedding
    self.ID       = myID
#
    self.paddingMask = torch.zeros(nContext,nContext)   # for masking
    for ii in range(nContext):
      for jj in range(ii+1,nContext):
        self.paddingMask[ii][jj] = -1.0e9               # exp -> 0

  def layerNorm(self, x):
    mean  = torch.zeros(self.dim)             # vector mean
    sigma = torch.tensor(0.0)                 # zero-dimensional tensor
#   for ii in range(self.nContext):
#      mean += x[ii] / self.nContext
    mean = torch.sum(x, 0) / self.nContext    # sum over rows
#
#   for ii in range(self.nContext):
#      sigma += torch.square(x[ii]-mean).sum()
#   sigma = torch.sqrt(sigma/self.nContext)
    sigma = torch.sqrt(torch.square(x-mean).sum() / self.nContext)
#
#   for ii in range(self.nContext):           # layer normalization
#     x[ii] -= mean
#     x[ii] /= sigma
    x = (x-mean)/sigma                        # for all rows
    return x

  def forward(self, x, storeAttention=False):
    if (self.yesNorm):
      self.layerNorm(x)
#  Q/K/V vectors
    Q  = torch.zeros(self.nContext,self.dim)
    K  = torch.zeros(self.nContext,self.dim)
    V  = torch.zeros(self.nContext,self.dim)
    for ii in range(self.nContext):
      Q[ii] = torch.matmul(self.Q_mat[ii],x[ii])
      K[ii] = torch.matmul(self.K_mat[ii],x[ii])
      V[ii] = torch.matmul(self.V_mat[ii],x[ii])
#  local attention matrix
    alpha = torch.zeros(self.nContext,self.nContext)
    for ii in range(self.nContext):
      for jj in range(self.nContext):
        alpha[ii][jj] = torch.exp( torch.dot(Q[ii],K[jj])\
                                 + yesMask*self.paddingMask[ii][jj] )
      alpha[ii] /= alpha[ii].sum()      # normalization
#  store attention matrix
    if storeAttention:
      self.alpha = alpha
    return torch.matmul(alpha,V)
#   return torch.matmul(alpha,V) + x    # with skip connections

  def update(self, eps):                # updating internal parameters
    with torch.no_grad():
      self.Q_mat -= eps*self.Q_mat.grad
      self.K_mat -= eps*self.K_mat.grad
      self.V_mat -= eps*self.V_mat.grad
      self.Q_mat.grad = None
      self.K_mat.grad = None
      self.V_mat.grad = None

#
# n-idential layer model
#
allLayers = [attentionLayer(dim,tokenPerLayer,myID=iL) for iL in range(nLayer)]
def model(x, storeAttention=False):
  for iLayer in range(nLayer):
    x = allLayers[iLayer](x, storeAttention)
  return x

#
# console printing of attention matrix
#
def printAttenionMatrix():
  for iLayer in range(nLayer):
    print()
    print("# attention matrix for layer ", iLayer)
    for ss in range(tokenPerLayer):
      for tt in range(tokenPerLayer):
         alpha = allLayers[iLayer].alpha[ss][tt]
         print(f'{alpha:9.4f}', end="")
      print()

#
# test output of token activities
#
def printTokenActivities(x, myString):
  print()
  print("# activity for", myString)
  for ii in range(dim):
    for token in range(tokenPerLayer):
        print(f'{x[token][ii]:8.4f}', end="")
    print()

#
# standard loss function
#
def lossFunction(outputActivity, targetActivity):
  return torch.square(outputActivity - targetActivity).sum()

#
# random boolean (\pm b) mapping
#
training_data  =\
    b*(2.0*(torch.FloatTensor(tokenPerLayer,dim).uniform_()>0.5)-1.0)
training_value =\
    b*(2.0*(torch.FloatTensor(tokenPerLayer,dim).uniform_()>0.5)-1.0)

#
# testing model
#
if (1==2):
  print("# training_data")
  print(training_data,"\n")
  print("# training_value")
  print(training_value,"\n")
#
for iIter in range(nIter):
  loss = lossFunction(model(training_data),training_value)
  if (loss<0.001):
    break
  loss.backward()
#
  for iLayer in range(nLayer):
    allLayers[iLayer].update(learning_rate)
  if (iIter%200==0):
    print(f'{iIter:4d} {loss.item():9.4f}')

#
# compare output with target
#
print()
yy = model(training_data, storeAttention=True)
printTokenActivities(training_value, "training_value")
printTokenActivities(yy            , "output activities")
#
if (1==2):
  print()
  printAttenionMatrix()
