#!/usr/bin/env python3

# ALiBi positional embedding via broadcasting

import torch
import math

nC = 6       # context length
nHead = 2    # number of attention heads

rel_dist = torch.arange(0, nC).view(1, 1, nC) -\
           torch.arange(0, nC).view(1, nC, 1)
slopes = torch.tensor([1.0/(2.0**(h*1.0/nHead)) for h in range(nHead)])
biases = -slopes.view(nHead, 1, 1) * rel_dist.abs() 
ALiBi_tensor = biases.exp()

print(rel_dist)
print(slopes)
print(biases)
print(ALiBi_tensor)
print()
print("# === testing ===")
print()
test = torch.ones(nHead,nC,nC)
print(test)
print(ALiBi_tensor*test)
print()
