#!/usr/bin/env python3

#
# (complete) small transfomer 
# :: no positional embedding
# :: single attention head
# :: no beam search, greedy prediction
# :: slow, contains loops
#
import torch
import math
import random
import pickle                      # serialization
import matplotlib.pyplot as plt

nLayer   =  3                      # number of layers
nContext =  10                     # context length
dim      =    2                    # embedding dimension --> 71

nBatch        =    100             # batch size
nEpoch        =     60             # number of epochs
learning_rate = 5.0e-2
load_model    = False              # load dumped model

#
# transfomer bilayer: attention plus token-wise feed-forward net
#
class transformerBilayer(torch.nn.Module):
  """ The second part of the transformer bilayer,
      the token-wise feedforward net, expands and contracts:
      W_expand -> W_contract, standard by a factor 4 """
  def __init__(self, dim, nContext, expand=4):
    super().__init__()
    self.Q_mat  = torch.randn(nContext,dim,dim,requires_grad=True)
    self.K_mat  = torch.randn(nContext,dim,dim,requires_grad=True)
    self.V_mat  = torch.randn(nContext,dim,dim,requires_grad=True)
#
    mySigma = 1.0/(dim*dim)
    torch.nn.init.normal_(self.Q_mat, mean=0.0, std=mySigma)
    torch.nn.init.normal_(self.K_mat, mean=0.0, std=mySigma)
    torch.nn.init.normal_(self.V_mat, mean=0.0, std=mySigma)
#
    self.W_expand   = torch.randn(nContext,dim*expand,dim,requires_grad=True)
    self.W_contract = torch.randn(nContext,dim,dim*expand,requires_grad=True)
    mySigma = 1.0/(dim*dim*expand)
    torch.nn.init.normal_(self.W_expand  , mean=0.0, std=mySigma)
    torch.nn.init.normal_(self.W_contract, mean=0.0, std=mySigma)
#
    self.W_bias = torch.zeros(nContext,dim*expand,requires_grad=True)
#
    self.A_matrix = torch.zeros(nContext,nContext)   # storing
    self.nContext = nContext
    self.dim      = dim                              # embedding
    self.expand   = expand                           # FFN expansion factor
#
    self.paddingMask = torch.ones(nContext,nContext) # for masking
    for ii in range(nContext):
      for jj in range(ii+1,nContext):
        self.paddingMask[ii][jj] = 0.0

  def layerNorm(self, x):
    mean = torch.sum(x, 0) / self.nContext    # sum over rows
    sigma = torch.sqrt(torch.square(x-mean).sum() / self.nContext)
    return (x-mean)/sigma                     # for all rows

  def feedForward_subLayer(self, x):
    y = torch.zeros(self.nContext,self.dim)
    for ii in range(self.nContext):           # explicit sum
      hidden = torch.matmul(self.W_expand[ii],x[ii])\
             - self.W_bias[ii]
      hidden = torch.tanh(hidden)
      y[ii]  = torch.matmul(self.W_contract[ii],hidden)
    return y + x                              # with skip connections

  def attention_subLayer(self, x):
    Q  = torch.zeros(self.nContext,self.dim)
    K  = torch.zeros(self.nContext,self.dim)
    V  = torch.zeros(self.nContext,self.dim)
    for ii in range(self.nContext):
      Q[ii] = torch.matmul(self.Q_mat[ii],x[ii])
      K[ii] = torch.matmul(self.K_mat[ii],x[ii])
      V[ii] = torch.matmul(self.V_mat[ii],x[ii])
# 
    alpha = torch.zeros(self.nContext,self.nContext)
    for ii in range(self.nContext):
      for jj in range(self.nContext):
        alpha[ii][jj] = self.paddingMask[ii][jj]\
                      * torch.exp( torch.dot(Q[ii],K[jj]) )
      alpha[ii] /= alpha[ii].sum()   
    if (self.storeAttention==True):
      self.A_matrix = alpha.data
    return torch.matmul(alpha,V) + x    

  def forward(self, x, storeAttention=False):
    self.storeAttention = storeAttention
    self.layerNorm(x)                   # normlization on entry
    x = self.attention_subLayer(x)
    self.layerNorm(x)
    x = self.feedForward_subLayer(x)
    return x

  def update(self, eps):                # updating internal parameters
    with torch.no_grad():
      self.Q_mat -= eps*self.Q_mat.grad
      self.K_mat -= eps*self.K_mat.grad
      self.V_mat -= eps*self.V_mat.grad
      self.Q_mat.grad = None
      self.K_mat.grad = None
      self.V_mat.grad = None
      self.W_expand   -= eps*self.W_expand.grad
      self.W_contract -= eps*self.W_contract.grad
      self.W_bias     -= eps*self.W_bias.grad
      self.W_expand.grad   = None
      self.W_contract.grad = None
      self.W_bias.grad     = None

  def printAttention(self):             # printing attention matrix
    print(f'\n#     |', end="")
    for col in range(self.nContext):
      print(f'{col:3d}', end="")
    print(f'\n#      ', end="")
    for col in range(self.nContext):
      print('---', end="")
    for row in range(self.nContext):
      print(f'\n# {row:3d} |', end="")
      for col in range(self.nContext):
        intAtt = int(self.A_matrix[row][col]*100.0)
        print(f'{intAtt:3d}', end="")
    print()

#
# load and process training data
#
f_in = open("data.txt", encoding='utf-8-sig')  
trainingString = f_in.read()
f_in.close()

trainingString = trainingString.replace("\n", " ")   # cleaning
trainingString = trainingString.replace("  ", " ")
trainingText = list(trainingString)
nToken_Data  = len(trainingText)
if (1==2):
  print("---",trainingString,"---")
trainingString = trainingString.lower()     # reducing vocabulary size

vocabulary = list(set(trainingText))        # set contains unique elements
vocabulary.sort()                           # for reloading model 
dim = len(vocabulary)                       # equal to embedding dimension
print("# vocabulary dimension    ", dim)
print("# number of token in data ", nToken_Data)
if (1==1):
  print(vocabulary)

#
# embedding dictionary:    token (letter) --> tensor
#
letterEmbedding = {letter: torch.zeros(dim) for letter in vocabulary}  

#
# orthogonal embedding tensors  (0, ..., 1, 0, ...)
#
count = 0
for letter in vocabulary:
  letterEmbedding[letter][count] = 1 
  count += 1
if (1==2):
  print("# ", letterEmbedding)

#
# standard/modified loss function
#
def lossFunction(outputActivity, targetActivity):
  lastPos = len(outputActivity)-1
  return torch.square(outputActivity[lastPos] - targetActivity[lastPos]).sum()
#  return torch.square(outputActivity - targetActivity).sum()

#
# n-idential layer model
#
allLayers = [transformerBilayer(dim,nContext) for iL in range(nLayer)]
def model(x, keepAttention=False):
  for iLayer in range(nLayer):
    x = allLayers[iLayer](x, storeAttention=keepAttention)
  return x

#
# dump/load  model  to/from file
#
def dumpLoad(fileName, whatToDo="load", myObject=None):
  if (whatToDo=="dump"):
    with open(fileName, 'wb') as file:
      pickle.dump(myObject, file)
      return None
#
  if (whatToDo=="load"):
    with open(fileName, 'rb') as file:
      return pickle.load(file)

#
# load stored trained model
# meta parameters must be identical
#
if (load_model==True):
  allLayers = dumpLoad("allLayers.dump")

#
# train on random slices
# shifted for a given batch by 'sride'
#
stride = 1                        
upperLimit = nToken_Data - nContext - stride*nBatch - 1
training_data  = torch.zeros(nContext,dim)
training_value = torch.zeros(nContext,dim)
#
for iEpoch in range(1,nEpoch+1):
  iStart = random.randrange(upperLimit)
  batchLoss = 0.0
  for iBatch in range(nBatch):
    inputLetters  = trainingText[iStart  :iStart+nContext]
    targetLetters = trainingText[iStart+1:iStart+nContext+1]
    iStart += stride
    if (1==2):
      print()
      print(inputLetters)
      print(targetLetters)
    for iC in range(nContext):
      training_data[iC]  = letterEmbedding[ inputLetters[iC]]
      training_value[iC] = letterEmbedding[targetLetters[iC]]
    loss = lossFunction(model(training_data),training_value)
    loss.backward()
    batchLoss += loss.data.item()
#
  for iLayer in range(nLayer):
    allLayers[iLayer].update(learning_rate/nBatch)
#
  if (iEpoch%(nEpoch//10)==0):     # dump model occasionaly
    dumpLoad("allLayers.dump", whatToDo="dump", myObject=allLayers)
    print(f'{iEpoch:5d}  {batchLoss/nBatch:8.4f}')

#
# generate text, token per token, greedy
#
inLetters  = trainingText[0:nContext]
inActivity = torch.zeros(nContext,dim)
for iC in range(nContext):
  inActivity[iC] = letterEmbedding[inLetters[iC]]
print('.'.join(inLetters), end="")
#
nGen = 0
while (nGen<80):
  nGen += 1
  outActivity = model(inActivity, keepAttention=True)
  lastOutToken = outActivity[nContext-1]
  if (1==2):
    print(inActivity)
    print(outActivity)
    print(lastOutToken.data)
#
  bestChar  = '@'
  bestMatch = 1.0e12
  for letter in vocabulary:
    match = torch.square(lastOutToken-letterEmbedding[letter]).sum().data.item()
    if (match<bestMatch):
      bestMatch = match
      bestChar  = letter
  print(f'_{bestChar:1s}', end="")
  inLetters.append(bestChar)
  inActivity = inActivity.roll(-1,0)                 # cyclic left shift
  inActivity[nContext-1] = letterEmbedding[bestChar]
#
print()
print(''.join(inLetters), end="")
print()
allLayers[0].printAttention()
print()
if (1==2):
  print(vocabulary)
  print(letterEmbedding)
