[Bahdanau, Cho, Bengio 2014]
#!/usr/bin/env python3
# ALiBi positional embedding via broadcasting
import torch
import math
nC = 6 # context length
nHead = 2 # number of attention heads
rel_dist = torch.arange(0, nC).view(1, 1, nC) -\
torch.arange(0, nC).view(1, nC, 1)
slopes = torch.tensor([1.0/(2.0**(h*1.0/nHead)) for h in range(nHead)])
biases = -slopes.view(nHead, 1, 1) * rel_dist.abs()
ALiBi_tensor = biases.exp()
print(rel_dist)
print(slopes)
print(biases)
print(ALiBi_tensor)
print()
print("# === testing ===")
print()
test = torch.ones(nHead,nC,nC)
print(test)
print(ALiBi_tensor*test)
print()
__init__()
outside computation graph
#!/usr/bin/env python3
# basic attention layer
# no batch processing, no layer norm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
nLayer = 24 # number of layers
nContext = 24 # context length
dim = 2 # embedding dimension
nIter = 4000 # training iterations
learning_rate = 1.0e-3
# ==========================
# transformer bilayer module
# attention plus FFN
# ==========================
class transformerBilayer(torch.nn.Module):
def __init__(self, dim, nContext):
"""Q/K/V matrices will be broadcasted along the
context dimension
"""
super().__init__()
self.alpha = torch.zeros(nContext,nContext)
self.nContext = nContext # context length
self.dim = dim # embedding dimension
self.hidden = 4 # hidden layer expanded size
#
# Q/K/V matrices: requires_grad after scaline
#
mySigma = 0.1/(dim*dim)
self.Q_mat = mySigma*torch.randn(1,dim,dim)
self.K_mat = mySigma*torch.randn(1,dim,dim)
self.V_mat = mySigma*torch.randn(1,dim,dim)
self.Q_mat.requires_grad_(True)
self.K_mat.requires_grad_(True)
self.V_mat.requires_grad_(True)
#
# FFN: feed forward network
# FFN_1 --> hidden --> FFN_2
#
self.FFN_1 = torch.randn(1,self.hidden*dim,dim)*mySigma
self.FFN_2 = torch.randn(1,dim,self.hidden*dim)*mySigma
self.FFN_b = torch.randn(1,self.hidden*dim,requires_grad=True)
self.FFN_1.requires_grad_(True)
self.FFN_2.requires_grad_(True)
#
# padding mask for causal attention
#
self.paddingMask = torch.zeros(nContext,nContext) # for masking
for ii in range(nContext):
for jj in range(ii+1,nContext):
self.paddingMask[ii][jj] = -1.0e9 # exp -> 0
def FFN(self, x):
"""FFN sublayer
"""
# broadasting (1,.. ,.. ) --> (context,.. ,.. )
shape_1 = (self.nContext,self.hidden*self.dim,self.dim)
shape_2 = (self.nContext,self.dim,self.hidden*self.dim)
shape_b = (self.nContext,self.hidden*self.dim)
FFN_1_all = self.FFN_1.expand(shape_1)
FFN_2_all = self.FFN_2.expand(shape_2)
FFN_b_all = self.FFN_b.expand(shape_b)
#
xx = torch.einsum("chd,cd->ch",FFN_1_all,x)
hh = torch.tanh(xx+FFN_b_all)
yy = torch.einsum("cdh,ch->cd",FFN_2_all,hh)
return yy + x
def attention(self, x):
"""attention sublayer
"""
# broadasting (1,dim,dim) --> (context,dim,dim)
expanded_shape = (self.nContext,self.dim,self.dim)
Q_all = self.Q_mat.expand(expanded_shape)
K_all = self.K_mat.expand(expanded_shape)
V_all = self.V_mat.expand(expanded_shape)
# c and C: context (input length)
# d and D: dim (enbedding dimension)
QQ = torch.einsum("cdD,cD->cd",Q_all,x)
KK = torch.einsum("cdD,cD->cd",K_all,x)
VV = torch.einsum("cdD,cD->cd",V_all,x)
# normalized attention matrix
logits = torch.einsum("cd,Cd->cC",QQ,KK)
alpha = torch.exp(logits+self.paddingMask)
row_sum = alpha.sum(dim=-1, keepdim=True)
alpha = alpha / row_sum
# store detached attention matrix
self.alpha = alpha.detach()
# return with skip connections
return torch.matmul(alpha,VV) + x
def forward(self, x):
"""indexing always from the back
"""
x = self.attention(x)
x = self.FFN(x)
return x
def update(self, eps): # updating parameters
with torch.no_grad():
self.Q_mat -= eps*self.Q_mat.grad
self.K_mat -= eps*self.K_mat.grad
self.V_mat -= eps*self.V_mat.grad
self.Q_mat.grad = None
self.K_mat.grad = None
self.V_mat.grad = None
self.FFN_1 -= eps*self.FFN_1.grad
self.FFN_2 -= eps*self.FFN_2.grad
self.FFN_b -= eps*self.FFN_b.grad
self.FFN_1.grad = None
self.FFN_2.grad = None
self.FFN_b.grad = None
# ======================
# n-idential layer model
# ======================
allLayers = [transformerBilayer(dim,nContext) for iL in range(nLayer)]
def model(x):
for iLayer in range(nLayer):
x = allLayers[iLayer](x)
return x
# ====================================
# console printing of attention matrix
# ====================================
def printAttenionMatrix():
for iLayer in range(nLayer):
print()
print("# attention matrix for layer ", iLayer)
for ss in range(nContext):
for tt in range(nContext):
alpha = allLayers[iLayer].alpha[ss][tt]
print(f'{alpha:9.4f}', end="")
print()
# ======================
# standard loss function
# ======================
def lossFunction(outputActivity, targetActivity):
return torch.square(outputActivity - targetActivity).sum()
# ================================================
# training data, token at position i: x[i]
# x[i+1] = x[i] - x[i-1] (modulo normalization)
# settles into a limit cycle of period six
# initial warm-up is discarded
# ================================================
def trainingSequence(seqLength, myDim):
"""generates a random vector difference sequence"""
data = torch.zeros(2*seqLength,myDim)
data[0] = F.normalize(torch.randn(myDim), p=2, dim=0)
data[1] = F.normalize(torch.randn(myDim), p=2, dim=0)
for ss in range(2,2*seqLength):
vector_sum = data[ss-1]-data[ss-2]
data[ss] = F.normalize(vector_sum, p=2, dim=0)
return data[seqLength:] # discard warm-up phase
if (1==2):
testLength = 20
training_data = trainingSequence(testLength,dim)
print("# --- training_data ---")
for ll in range(testLength):
for dd in range(dim):
print(f'{training_data[ll][dd]:8.3f}',end="")
print()
# ===========================================
# training with random sequences
# token prediction == shifting prompt by (-1)
# ===========================================
print("# --- training ---")
for iIter in range(nIter):
training_all = trainingSequence(nContext+1,dim)
training_data = training_all[:-1]
training_value = training_all[1:]
loss = lossFunction(model(training_data),training_value)
if (loss<0.001):
break
loss.backward()
#
for iLayer in range(nLayer):
allLayers[iLayer].update(learning_rate)
if (iIter%200==0): # print progress
print(f'{iIter:4d} {loss.item():9.4f}')
#
# visual testing
#
test_all = trainingSequence(nContext+1,dim)
test_data = test_all[:-1]
test_value = test_all[1:]
yy = model(test_data)
print("# --- value vs. output ---")
for ll in range(nContext):
for dd in range(dim):
print(f'{test_value[ll][dd]:8.3f}',end="")
print(" |",end="")
for dd in range(dim):
print(f'{ yy[ll][dd]:8.3f}',end="")
print()
#
# print attention matrix
#
if (1==2):
print()
printAttenionMatrix()
nn.MultiheadAttention() nomen est omen
nn.Sequential() sequence of objects
nn.LayerNorm() torch.nn.utils.clip_grad_norm_ self.parameters() self.register_buffer()
#!/usr/bin/env python3
# basic attention layer
# no batch processing, no layer norm
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
nLayer = 24 # number of layers
nContext = 24 # context length
dim = 2 # embedding dimension
nIter = 4000 # training iterations
learning_rate = 1.0e-3
# ==========================
# transformer bilayer module
# using torch.nn modules
# ==========================
class transformerBlock(nn.Module):
def __init__(self, dim, nContext, hidden_mult=4):
super().__init__()
self.nContext = nContext
# LayerNorms for stability
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# single-head attention
self.attn = nn.MultiheadAttention(embed_dim=dim,
num_heads=1, batch_first=True)
# feed-forward network
self.ffn = nn.Sequential(
nn.Linear(dim, hidden_mult * dim),
nn.Tanh(),
nn.Linear(hidden_mult * dim, dim)
)
# causal mask
mask = torch.triu(torch.ones(nContext, nContext)\
* float('-inf'), diagonal=1)
self.register_buffer("causal_mask", mask)
# store last attention weights for inspection
self.alpha = torch.zeros(nContext, nContext)
def forward(self, x):
# (nContext, dim) → (1, nContext, dim)
x = x.unsqueeze(0)
# attention with layer norm and skip connections
x_norm = self.norm1(x)
attn_out, attn_weights = self.attn(
x_norm, x_norm, x_norm,
attn_mask=self.causal_mask
)
self.alpha = attn_weights[0].detach()
x = x + attn_out
# feed-forward with layer normalization
ff_out = self.ffn(self.norm2(x))
x = x + ff_out
return x.squeeze(0) # (nContext, dim)
def update(self, eps):
"""manual SGD update with gradient clipping."""
with torch.no_grad():
torch.nn.utils.clip_grad_norm_(self.parameters(),
max_norm=1.0)
for p in self.parameters():
if p.grad is not None:
p -= eps * p.grad
p.grad = None
# ======================
# n-idential layer model
# ======================
allLayers = [transformerBlock(dim,nContext) for iL in range(nLayer)]
def model(x):
for iLayer in range(nLayer):
x = allLayers[iLayer](x)
return x
# ====================================
# console printing of attention matrix
# ====================================
def printAttenionMatrix():
for iLayer in range(nLayer):
print()
print("# attention matrix for layer ", iLayer)
for ss in range(nContext):
for tt in range(nContext):
alpha = allLayers[iLayer].alpha[ss][tt]
print(f'{alpha:9.4f}', end="")
print()
# ======================
# standard loss function
# ======================
def lossFunction(outputActivity, targetActivity):
return torch.square(outputActivity - targetActivity).sum()
# ================================================
# training data, token at position i: x[i]
# x[i+1] = x[i] - x[i-1] (modulo normalization)
# settles into a limit cycle of period six
# initial warm-up is discarded
# ================================================
def trainingSequence(seqLength, myDim):
"""generates a random vector difference sequence"""
data = torch.zeros(2*seqLength,myDim)
data[0] = F.normalize(torch.randn(myDim), p=2, dim=0)
data[1] = F.normalize(torch.randn(myDim), p=2, dim=0)
for ss in range(2,2*seqLength):
vector_sum = data[ss-1]-data[ss-2]
data[ss] = F.normalize(vector_sum, p=2, dim=0)
return data[seqLength:] # discard warm-up phase
if (1==2):
testLength = 20
training_data = trainingSequence(testLength,dim)
print("# --- training_data ---")
for ll in range(testLength):
for dd in range(dim):
print(f'{training_data[ll][dd]:8.3f}',end="")
print()
# ===========================================
# training with random sequences
# token prediction == shifting prompt by (-1)
# ===========================================
print("# --- training ---")
for iIter in range(nIter):
training_all = trainingSequence(nContext+1,dim)
training_data = training_all[:-1]
training_value = training_all[1:]
loss = lossFunction(model(training_data),training_value)
if (loss<0.001):
break
loss.backward()
#
for iLayer in range(nLayer):
allLayers[iLayer].update(learning_rate)
if (iIter%200==0): # print progress
print(f'{iIter:4d} {loss.item():9.4f}')
#
# visual testing
#
test_all = trainingSequence(nContext+1,dim)
test_data = test_all[:-1]
test_value = test_all[1:]
yy = model(test_data)
print("# --- value vs. output ---")
for ll in range(nContext):
for dd in range(dim):
print(f'{test_value[ll][dd]:8.3f}',end="")
print(" |",end="")
for dd in range(dim):
print(f'{ yy[ll][dd]:8.3f}',end="")
print()
#
# print attention matrix
#
if (1==2):
print()
printAttenionMatrix()
nn.ModuleList() torch.optim.Adam() nomen est omen
#!/usr/bin/env python3
# basic transformer with batch processing and PyTorch optimizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
nLayer = 24 # number of layers
nContext = 24 # context length
dim = 2 # embedding dimension
batch_size = 32 # batch size for training
nIter = 4000 # training iterations
learning_rate = 1.0e-3
# ==========================
# transformer block module
# enhanced for batch processing
# ==========================
class transformerBlock(nn.Module):
def __init__(self, dim, nContext, hidden_mult=4):
super().__init__()
self.nContext = nContext
# layer norm
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# attention with num_heads heads
self.attn = nn.MultiheadAttention(embed_dim=dim,
num_heads=1, batch_first=True)
# FFN
self.ffn = nn.Sequential(
nn.Linear(dim, hidden_mult * dim),
nn.Tanh(),
nn.Linear(hidden_mult * dim, dim)
)
# causal mask
mask = torch.triu(torch.ones(nContext, nContext) * float('-inf'),
diagonal=1)
self.register_buffer("causal_mask", mask)
def forward(self, x):
# x shape: (batch_size, nContext, dim)
# batch_size = x.size(0)
x_norm = self.norm1(x)
attn_out, attn_weights = self.attn(
x_norm, x_norm, x_norm,
attn_mask=self.causal_mask
)
x = x + attn_out
ff_out = self.ffn(self.norm2(x))
x = x + ff_out
return x
# ======================
# complete model class
# ======================
class TransformerModel(nn.Module):
def __init__(self, dim, nContext, nLayer):
super().__init__()
self.layers = nn.ModuleList([
transformerBlock(dim, nContext) for _ in range(nLayer)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
# initialize model
model = TransformerModel(dim, nContext, nLayer)
# initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# ======================
# loss function
# ======================
def lossFunction(outputActivity, targetActivity):
return torch.square(outputActivity - targetActivity).sum()
# ================================================
# training data generation
# x[i+1] = x[i] - x[i-1] (modulo normalization)
# settles into a limit cycle of period six
# initial warm-up is discarded
# ================================================
def trainingSequence(seqLength, myDim):
"""generates a random vector difference sequence"""
data = torch.zeros(2*seqLength, myDim)
data[0] = F.normalize(torch.randn(myDim), p=2, dim=0)
data[1] = F.normalize(torch.randn(myDim), p=2, dim=0)
for ss in range(2, 2*seqLength):
vector_sum = data[ss-1] - data[ss-2]
data[ss] = F.normalize(vector_sum, p=2, dim=0)
return data[seqLength:] # discard warm-up phase
def generateBatch(batch_size, nContext, dim):
"""generate a batch of training sequences"""
batch_data = torch.zeros(batch_size, nContext, dim)
batch_targets = torch.zeros(batch_size, nContext, dim)
for b in range(batch_size):
training_all = trainingSequence(nContext + 1, dim)
batch_data[b] = training_all[:-1]
batch_targets[b] = training_all[1:]
return batch_data, batch_targets
# ===========================================
# training with batched random sequences
# ===========================================
print("# --- training with batches ---")
for iIter in range(nIter):
# generate batch / forward pass / loss
training_data, training_targets = generateBatch(batch_size, nContext, dim)
outputs = model(training_data)
loss = lossFunction(outputs, training_targets) / batch_size
if loss < 0.001:
break
# backward pass, gradient clipping and udate
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# progress printing
if iIter % 200 == 0:
print(f'{iIter:4d} {loss.item():9.4f}')
# ===========================================
# visual testing with single sequence
# ===========================================
print("# --- testing ---")
test_all = trainingSequence(nContext + 1, dim)
test_data = test_all[:-1].unsqueeze(0) # Add batch dimension
test_value = test_all[1:]
with torch.no_grad():
yy = model(test_data).squeeze(0) # Remove batch dimension
print("# --- value vs. output ---")
for ll in range(nContext):
for dd in range(dim):
print(f'{test_value[ll][dd]:8.3f}', end="")
print(" |", end="")
for dd in range(dim):
print(f'{yy[ll][dd]:8.3f}', end="")
print()