#!/usr/bin/env python3

#
# basic transfomer bilayer
#
import torch
import math
import matplotlib.pyplot as plt

nLayer        =  3                    # number of layers
tokenPerLayer =  5                    # context length
dim           =  2                    # embedding dimension

b             =  0.9                  # (+b) / (-b)  : logical True/False
nIter         =  4400                 # training iterations
learning_rate =  1.5e-2

#
# transfomer bilayer: attention plus token-wise feed-forward net
#
class transformerBilayer(torch.nn.Module):
  """ The second part of the transformer bilayer,
      the token-wise feedforward net, expands and contracts:
      W_expand -> W_contract, standard by a factor 4 """
  def __init__(self, dim, nContext, expand=4):
    super().__init__()
    self.Q_mat  = torch.randn(nContext,dim,dim,requires_grad=True)
    self.K_mat  = torch.randn(nContext,dim,dim,requires_grad=True)
    self.V_mat  = torch.randn(nContext,dim,dim,requires_grad=True)
#
    mySigma = 1.0/(dim*dim)
    torch.nn.init.normal_(self.Q_mat, mean=0.0, std=mySigma)
    torch.nn.init.normal_(self.K_mat, mean=0.0, std=mySigma)
    torch.nn.init.normal_(self.V_mat, mean=0.0, std=mySigma)
#
    self.W_expand   = torch.randn(nContext,dim*expand,dim,requires_grad=True)
    self.W_contract = torch.randn(nContext,dim,dim*expand,requires_grad=True)
    mySigma = 1.0/(dim*dim*expand)
    torch.nn.init.normal_(self.W_expand  , mean=0.0, std=mySigma)
    torch.nn.init.normal_(self.W_contract, mean=0.0, std=mySigma)
#
    self.W_bias = torch.zeros(nContext,dim*expand,requires_grad=True)
#
    self.nContext = nContext
    self.dim      = dim        # embedding
    self.expand   = expand     # FFN expansion factor
#
    self.paddingMask = torch.ones(nContext,nContext)   # for masking
    for ii in range(nContext):
      for jj in range(ii+1,nContext):
        self.paddingMask[ii][jj] = 0.0

  def layerNorm(self, x):
    mean = torch.sum(x, 0) / self.nContext    # sum over rows
    sigma = torch.sqrt(torch.square(x-mean).sum() / self.nContext)
    return (x-mean)/sigma                     # for all rows

  def feedForward_subLayer(self, x):
    y = torch.zeros(self.nContext,self.dim)
    for ii in range(self.nContext):           # explicit sum
      hidden = torch.matmul(self.W_expand[ii],x[ii])\
             - self.W_bias[ii]
      hidden = torch.tanh(hidden)
      y[ii]  = torch.matmul(self.W_contract[ii],hidden)
    return y + x                              # with skip connections

  def attention_subLayer(self, x):
    Q  = torch.zeros(self.nContext,self.dim)
    K  = torch.zeros(self.nContext,self.dim)
    V  = torch.zeros(self.nContext,self.dim)
    for ii in range(self.nContext):
      Q[ii] = torch.matmul(self.Q_mat[ii],x[ii])
      K[ii] = torch.matmul(self.K_mat[ii],x[ii])
      V[ii] = torch.matmul(self.V_mat[ii],x[ii])
# 
    alpha = torch.zeros(self.nContext,self.nContext)
    for ii in range(self.nContext):
      for jj in range(self.nContext):
        alpha[ii][jj] = self.paddingMask[ii][jj]\
                      * torch.exp( torch.dot(Q[ii],K[jj]) )
      alpha[ii] /= alpha[ii].sum()      # normalization
    return torch.matmul(alpha,V) + x    # with skip connections

  def forward(self, x):
    self.layerNorm(x)                   # normlization on entry
    x = self.attention_subLayer(x)
    self.layerNorm(x)
    x = self.feedForward_subLayer(x)
    return x

  def update(self, eps):                # updating internal parameters
    with torch.no_grad():
      self.Q_mat -= eps*self.Q_mat.grad
      self.K_mat -= eps*self.K_mat.grad
      self.V_mat -= eps*self.V_mat.grad
      self.Q_mat.grad = None
      self.K_mat.grad = None
      self.V_mat.grad = None
      self.W_expand   -= eps*self.W_expand.grad
      self.W_contract -= eps*self.W_contract.grad
      self.W_bias     -= eps*self.W_bias.grad
      self.W_expand.grad   = None
      self.W_contract.grad = None
      self.W_bias.grad     = None

#
# n-idential layer model
#
allLayers = [transformerBilayer(dim,tokenPerLayer) for iL in range(nLayer)]
def model(x):
  for iLayer in range(nLayer):
    x = allLayers[iLayer](x)
  return x

#
# test output of token activities
#
def printTokenActivities(x, myString):
  print()
  print("# activity for", myString)
  for ii in range(dim):
    for token in range(tokenPerLayer):
        print(f'{x[token][ii]:8.4f}', end="")
    print()

#
# standard loss function
#
def lossFunction(outputActivity, targetActivity):
  return torch.square(outputActivity - targetActivity).sum()

#
# random boolean (\pm b) mapping
#
training_data  =\
    b*(2.0*(torch.FloatTensor(tokenPerLayer,dim).uniform_()>0.5)-1.0)
training_value =\
    b*(2.0*(torch.FloatTensor(tokenPerLayer,dim).uniform_()>0.5)-1.0)

#
# training
#
if (1==2):
  print("# training_data")
  print(training_data,"\n")
  print("# training_value")
  print(training_value,"\n")
#
for iIter in range(nIter):
  loss = lossFunction(model(training_data),training_value)
  if (loss<0.0001):
    break
  loss.backward()
#
  for iLayer in range(nLayer):
    allLayers[iLayer].update(learning_rate)
  if (iIter%200==0):
    print(f'{iIter:4d} {loss.item():9.4f}')

#
# compare output with target
#
print()
yy = model(training_data)
printTokenActivities(training_value, "training_value")
printTokenActivities(yy            , "output activities")
