matrices:   NumPy tensors:   PyTorch 
#!/usr/bin/env python3
import torch                     # PyTorch instead of NumPy
import math
import matplotlib.pyplot as plt
myType   = torch.float
myDevice = torch.device("cpu")   # "cuda:0" for GPU; not activated
# global parameters
nData = 2000                     # number of training pairs
nIter = 2000                     # number training iterations
nPar  =    4                     # number of fit parameters
learning_rate = 0.5e-2/nData     # relative learning rate
fitPar = []                      # empty list; fit parameters
for i in range(nPar):            # randn() : normal distribution
  fitPar.append(torch.randn((), device=myDevice, dtype=myType))
print(fitPar)
def fitFunction(x):              # polynomial fitting function 
  sum = 0.0
  for i in range(nPar):
    sum += fitPar[i]*(x**i)
  return sum
# linspace returns a list
x = torch.linspace(-math.pi, math.pi, nData, device=myDevice, dtype=myType)
y = torch.sin(x)                 # target function y = sin(x)
# training iteration
for iIter in range(nIter):
  y_pred = fitFunction(x)                  # list; element-wise
  loss   = torch.square(y_pred - y).sum()  # sum of squared elements
  if iIter % 100 == 99:                    # test printout
    print(f'{iIter:5d}  {loss:10.6f}')
  grad_y_pred = 2.0 * (y_pred - y)         # error signal
  for i in range(nPar):                    # least-square fit
    gradient = ( grad_y_pred*(x**i) ).sum()
    fitPar[i] -= learning_rate * gradient
# showing result
plt.plot(x, torch.sin(x)             , 'b', label="sin(x)")
plt.plot(x, fitFunction(x)           , 'r', label="polynomial fit")
plt.plot(x, 0.0*x                    , '--k')
plt.legend()
plt.show()
np.set_printoptions() 
     formatted printing via a lambda function
     
#!/usr/bin/env python3
import torch
import numpy as np
# on putput, substitute  x  ->  f'{x:4.2f}'
np.set_printoptions(formatter=\
   {'float_kind':lambda x: f"{x:4.2f}"})
x = torch.rand(4, 5)
print("# non formatted ")
print(x)
print("\n# formatted with np ")
print(x.numpy())
# print(x.detach().numpy())     # for parameters
tensor.shape  is a tuple (dim1, dim2, ...)  of dimensions 
     aa[i][j]  accessing a tensor, or aa[tt]  via a tuple $\ tt=(i,j)$ aa.view(tt)  tensor reshaping
     torch.arange(A)  linearly
#!/usr/bin/env python3
#
# tuples can be used to access tensors
# with previously unknown dimensions
#
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
def generateTensor(D=2, N=3):
  "of shape (N,N,..); with D dimensions"
  dimTuple = ()                            # emmpty tuple
  nElements = 1                            # of tensor
  for _ in range(D):
    dimTuple = dimTuple + (N,)             # add to tuple
    nElements *= N
#
  print("genTen: dimTuple :", dimTuple)
  print("genTen: nElements:", nElements)
  return torch.arange(nElements, dtype=torch.int).view(dimTuple)
def doSomething(inTensor):
  """with a tensor of arbitrary shape;
     changing a random elements
     assuming shape (N_1, N_1, ..) with N_i==N
  """
  inShape = myTensor.shape
  D = len(inShape)
  N = inShape[0]                  # assuming (N, N, ..)
  accessTuple = ()       
  for _ in range(D):
    rrInt = np.random.randint(N)
    accessTuple += (rrInt,)
  inTensor[accessTuple] = -1      # access via tuple
#
  print("doSome:  inShape :", inShape)
  print("doSome:     D, N :", D, N)
  print("doSome:  accTuple:", accessTuple)
  return inTensor
#
# main
#
myTensor = generateTensor()
print("  main:    shape :", myTensor.shape)
print(myTensor)
outTensor = doSomething(myTensor)
print(outTensor)
requires_grad = True detach   (temporarily) tensors for 
     'by hand' operations, requires_grad = False
     with torch.no_grad(): 
#!/usr/bin/env python3
import torch                     # PyTorch needs to be installed
dim = 2
eps = 0.1
x = torch.ones(dim, requires_grad=True)  # leaf of computational graph
print("x           : ",x)
print("x.data      : ",x.data)
print("x[0]        : ",x[0])
print("x[0].item() : ",x[0].item())
print()
y = x + 2
out = torch.dot(y,y)             # scalar product
print("y      : ",y)
print("out    : ",out)
print()
out.backward()                   # backward pass --> gradients
print("x.grad : ",x.grad)
with torch.no_grad():            # detach from computational graph
  x -= eps*x.grad                # updating parameter tensor
  x.grad = None                  # flush
print("x      : ",x.data)
print("\n#---")
print("#--- .backward() adds new gradient to old gradient")
print("#---             convenient for batch updating")
print("#---\n")
y = torch.zeros(dim, requires_grad=True)  
torch.dot(y+1,y+1).backward()
print("y.grad : ",y.grad)
torch.dot(y+1,y+1).backward()
print("y.grad : ",y.grad)
torch.dot(y+1,y+1).backward()
print("y.grad : ",y.grad)
torch.dot(y+1,y+1).backward()
#!/usr/bin/env python3
import torch                   
import math
import matplotlib.pyplot as plt
myType   = torch.float
myDevice = (
    "cuda"                       # for GPUs
    if torch.cuda.is_available()
    else "mps"                   # 'MultiProcessor Specification'
    if torch.backends.mps.is_available()
    else "cpu"                   # plain old CPU
)
# global parameters
nData = 2000                     # number of training pairs
nIter = 2000                     # number training iterations
nPar  =    4                     # number of fit parameters
learning_rate = 0.5e-2/nData    
# gradients with respect to figPar[] to be evaluated
fitPar = []                      # list of 1x1 tensors
for i in range(nPar):      
  fitPar.append(torch.randn((), device=myDevice, dtype=myType,\
                                requires_grad=True))  
print(fitPar)
x = torch.linspace(-math.pi, math.pi, nData)
y = torch.sin(x)                 # element-wise
def fitFunction(x):              # polynomial fitting fuction
  sum = 0.0
  for i in range(nPar):
    sum += fitPar[i]*(x**i)      # element-wise, x is a tensor
  return sum                     # returns a tensor
# training iteration
for iIter in range(nIter):
  y_pred     = fitFunction(x)              # forward pass
  lossTensor = (y_pred - y).pow(2).sum()   # element-wise pow(2)
  if iIter % 100 == 99:                    # print scalar loss value
    print(f'{iIter:5d}  {lossTensor.item():10.6f}')
# backward pass 
# calculates gradients, viz 'tensor.grad',
# with respect to tensors with "requires_grad=True"
  lossTensor.backward()                    
# temporarily 'detaching' all tensors for by-hand updating
# the value of   fitPar[i].grad   is not affected
  with torch.no_grad():        
    for i in range(nPar):                  # gradients via backward pass
      fitPar[i] -= learning_rate * fitPar[i].grad
      fitPar[i].grad = None
# "detach" tensors requiring gradients from computation graph
plt.plot(x, torch.sin(x)                    , 'b', label="sin(x)")
plt.plot(x, fitFunction(x).detach().numpy() , 'r', label="fit")
plt.plot(x, 0.0*x                           , '--k')
plt.legend()
plt.show()
tensor._version   version counter
     tensor.function()  default function tensor.function_()  inplace version 
#!/usr/bin/env python3
#
# testing inplace opertions
#
import torch
print()
x = torch.tensor([1,2,3])
print(x," x original")          # tensor([1, 2, 3])
print(x.add(1), " x.add(1)")    # tensor([2, 3, 4])
print(x," x current")           # tensor([1, 2, 3])
print(x.add_(2), " x.add_(2)")  # tensor([3, 4, 5])  
                                # inplace operation, changes the value 
                                # of x in memory
print(x," x current")           # tensor([3, 4, 5])
print()
print("=========================")
print()
b  = torch.rand(4)
print(b._version, " b._version, id ", id(b))
b += 2              # an inplace operation
# b = b + 2         # new object generated
print(b._version, " b._version, id ", id(b))
print()
print("=========================")
print()
a = torch.rand(4, requires_grad=True)
c = a.pow(1)                   # compare
#c = a.pow_(1)                  # compare
c.sum().backward()
print("a       ", a)
print("a.grad  ", a.grad)
# 'detach' from computation graph
# forward  :  y = a - b + b = a
# backward :  y = b
y = (a-b).detach() + b
#!/usr/bin/env python3
import torch                   
import math
import matplotlib.pyplot as plt
# global parameters
nData =   200   
nIter =    30 
eps   =   0.01
fitPar = 1.5*torch.randn(1, requires_grad=True)  
x = torch.linspace(-5.0, 5.0, nData)
def lossFunction(x, fitPar): 
  """fitting input 'y' with a inverse step function,
     T = Theta(fitpar-x) 
     using boolean -> float
     S : surrogate gradient
  """
  y = 1.0/(1.0+torch.exp(x))              # target
  T = (x < fitPar).float()
  S = 1.0/(1.0+torch.exp(4*(x-fitPar)))   # surrogate
  out = S + (T - S).detach()   # STE (by hand)
  return (y-out).pow(2).sum()  # MSE: mean square error
# training iteration
for iIter in range(nIter):
  loss = lossFunction(x, fitPar)
  print(f" {fitPar.item():10.4f} {loss.item():10.4f}")
  loss.backward()                    
  with torch.no_grad():        
    if fitPar.grad is not None:
      fitPar     -= eps*fitPar.grad
      fitPar.grad = None
  fitPar = fitPar.detach().requires_grad_()
# ^ re-insert into computation graph 
# plotting
y = 1.0/(1.0+torch.exp(x))              # target
T = (x < fitPar).float()
S = 1.0/(1.0+torch.exp(4*(x-fitPar)))   # surrogate
plt.plot(x, y                  , 'b'  , label="target")
plt.plot(x, T.detach().numpy() , 'r'  , label="fit")
plt.plot(x, S.detach().numpy() , '--k', label="surrogate")
plt.legend()
plt.show()
reshape, flatten, 
squeeze, unsqueeze, view
     squeeze   allows to add a vector to a matrix row-wise
     
#!/usr/bin/env python3
import torch
import math
import matplotlib.pyplot as plt
uu = torch.ones(3)
xx = torch.ones(3,5)
print("\n# uu")
print(uu)
print("\n# xx")
print(xx)
print("\n# uu unsqueezed")
print(torch.unsqueeze(uu, 1))
print("\n# xx plus uu unsqueezed along 1")
print(xx+torch.unsqueeze(uu, 1))
#
aa = torch.arange(4.0)
print("\n# arranged\n", aa)
aa = torch.reshape(aa, (2, 2))
print("\n# reshaped\n", aa)
aa = torch.reshape(aa, (-1,))
print("\n#   ..back\n", aa)
tensor.mean() tensor.std(dim=0, 
           unbiased=False)
#!/usr/bin/env python3
import torch
# input data of shape [4, 3]
# batch of 4 samples, each with 3 features
x = torch.arange(1.0, 13.0).reshape(4, 3)
# x = torch.tensor([ [ 1.0,  2.0,  3.0],
#                    [ 4.0,  5.0,  6.0],
#                    [ 7.0,  8.0,  9.0],
#                    [10.0, 11.0, 12.0] ])  
x_mean = x.mean(dim=0)
x_std = x.std(dim=0, unbiased=False)
# each feature (column) should be normalized, with 
# respect to a given mean and standard deviation
feature_means = torch.tensor([2.0, 4.0, 6.0])
feature_stds  = torch.tensor([1.0, 2.0, 4.0])
# automatic broadcasting: 
# [4, 3] - [3] -> [4, 3]
x_normalized = feature_means\
             + (x - x_mean)*feature_stds/x_std
print("# input")
print(x)
print("\n# input mean along columns")
print(x_mean)
print("\n# input mean along rows")
print(x.mean(dim=1))
print("\n# feature means, the target")
print(feature_means)
print("\n# feature standard deviations:")
print(feature_stds)
print("\n# normalized output")
print(x_normalized)
print("\n# output mean")
print(x_normalized.mean(dim=0))
print("\n# output standard deviation")
print(x_normalized.std(dim=0, unbiased=False))
tensor.unsqueeze()tensor.expand()
#!/usr/bin/env python3
#
# calculating the pairwise differences
# of several vectors
#
import torch
# batch of 3 vectors, each of dim 2; 
x32 = torch.tensor([ [1, 2],
                     [3, 4],
                     [5, 6] ])  
print("# original tensor")
print(x32)
print()
# shape: [3, 2]
N, D = x32.shape
# unsqueeze to get shapes [3, 1, 2] and [1, 3, 2]
x312 = x32.unsqueeze(1)  
y132 = x32.unsqueeze(0) 
# expand both tensors to [3, 3, 2] 
x332 = x312.expand(N, N, D)
y332 = y132.expand(N, N, D)
# tensors of identical shapes can be substracted directly
diffs = x332 - y332
print("# pairwise differences")
print(diffs)
#print(x332)
#print(y332)
print()
print("# shapes")
print("  x32   ",  x32.shape)
print("  x312  ", x312.shape)
print("  y132  ", y132.shape)
print()
print("  x332  ",  x332.shape)
print("  y332  ",  y332.shape)
print("  diffs ", diffs.shape)
zip()   combine
      python lists for list comprehension
     
#!/usr/bin/env python3
import torch
import random
aa = torch.arange(10.0).reshape(2,5)
bb = torch.arange(5) + 1.0
cc = torch.divide(aa,bb)
print("aa\n", aa)
print("bb\n", bb)
print("cc\n", cc)
print()
print("# ======================")
print("# row-wise normalization")
print("# ======================")
aaSum   = aa.sum(1) 
# aaSum = aa.pow(2).sum(1) 
aaTrans = torch.divide(aa.transpose(0,1),aaSum)
aaNorm  = torch.transpose(aaTrans,0,1)
print("asSum\n"  , aaSum)
print("asTrans\n", aaTrans)
print("aaNorm\n" , aaNorm)
print()
print("# =====================================")
print("# dividing lists, checking denominator ")
print("# =====================================")
aa_list = [random.randint(0, 10) for _ in range(5)]  
bb_list = [random.choice(range(10)) for _ in range(5)]
cc_list = [a / b if b != 0 else -11 for a, b in zip(aa_list,bb_list)]
print("aa_list\n"  , aa_list)
print("bb_list\n"  , bb_list)
print("cc_list\n"  , cc_list)
torch.matmul(A,B)torch.bmm(A,B)torch.einsum("abcd,aced->abed",A,B)torch.allclose(A,B)
#!/usr/bin/env python3
import torch
# two batches of matrices, of shapes
# 4x5 and 5x6
A = torch.randn(3, 4, 5)  
B = torch.randn(3, 5, 6)  
# batch matrix multiplication using einsum
# with 'bij,bjk->bik' 
# for each batch b do
# \sum_j A[b,i,j] * B[b,j,k] -> result[b,i,k]
result_einsum = torch.einsum('bij,bjk->bik', A, B)
# equivalent to using torch.bmm (batch matrix multiply)
result_bmm = torch.bmm(A, B)
print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")
print(f"Result shape: {result_einsum.shape}")
print(f"Results are equal: {torch.allclose(result_einsum, result_bmm)}")
#!/usr/bin/env python3
import torch
source_tensor = 10*torch.arange(9)
index_tensor  = torch.arange(4,6)
result_tensor = source_tensor[index_tensor]
print()
print("# --- 1D example --- ")
print("# source, index, result ")
print(source_tensor)
print(index_tensor)
print(result_tensor)
#
# selecting (0, 2), (1, 1), (2, 0)
# from a 2D tensor
D2_source = source_tensor.view(3,3)
row_indices = torch.tensor([0, 1, 2])
col_indices = torch.tensor([2, 1, 0])
D2_result = D2_source[row_indices, col_indices]
print()
print("# --- 2D example --- ")
print("# source, indices, result ")
print(D2_source)
print(row_indices, col_indices)
print(D2_result)
# masking (boolean indexing)
D2_bool = D2_source[ (D2_source>5)&(D2_source<35) ]
print()
print("# --- boolean indexing --- ")
print(D2_bool)
torch.gather(input, dim, index)torch.scatter(input, dim, index, src)torch.topk(input, K, dim)torch.transpose()   nomen est omen
#!/usr/bin/env python3
import torch
print()
print("# --- --------- ---")
print("# --- gathering ---")
print("# --- --------- ---")
source_tensor = torch.tensor([[10, 20, 30],
                              [40, 50, 60],
                              [70, 80, 90]])
index_tensor = torch.tensor([[2, 0],
                             [1, 1],
                             [0, 1]])
# result[i,j] = source_tensor[index_tensor[i,j], j]
gather_dim_0 = torch.gather(source_tensor, dim=0, index=index_tensor)
# result[i,j] = source_tensor[i, index_tensor[i,j]]
gather_dim_1 = torch.gather(source_tensor, dim=1, index=index_tensor)
print("\n# source \n", source_tensor)
print("\n# index  \n", index_tensor)
print("\n# gather along dim=0\n", gather_dim_0)
print("\n# gather along dim=1\n", gather_dim_1)
print()
print("# --- ------------ ---")
print("# --- top K values ---")
print("# --- ------------ ---")
print()
aa = torch.randint(0,10,(3,4))
# for each column, find the top 2 values across all rows
val_dim_0, idx_dim_0 = torch.topk(aa, 2, dim=0)
# for each row   , find the top 2 values across all columns
val_dim_1, idx_dim_1 = torch.topk(aa, 2, dim=1)
print("# rand int tensor\n",aa)
print()
print("# (index,value) pairs for dim=0")
for val, idx in zip(val_dim_0, idx_dim_0):
  for ii in range(val.shape[0]):
    print(idx[ii].item(), val[ii].item())
  print()
print("# (index,value) pairs for dim=1")
for val, idx in zip(val_dim_1, idx_dim_1):
  for ii in range(val.shape[0]):
    print(idx[ii].item(), val[ii].item())
  print()
print("# --- ---------- ---")
print("# --- scattering ---")
print("# --- ---------- ---")
print()
# destination tensors, shapes
# (3,4) for dim=0 and
# (4,3) for dim=1
destination_dim_0 = torch.zeros(3, 4, dtype=torch.float)
destination_dim_1 = destination_dim_0.transpose(0,1)
# source tensor, the values to scatter.
source_dim_0 = torch.tensor([[1., 2., 3., 4.],
                             [5., 6., 7., 8.]])
source_dim_1 = torch.tensor([[1., 2.],
                             [3., 4.],
                             [5., 6.]])
index_dim_0 = torch.tensor([[0, 2, 1, 2],
                            [1, 0, 2, 1]])
index_dim_1 = torch.tensor([[0, 2],
                            [1, 0],
                            [2, 1]])
# scattering along dim=0
# destination_dim_0[index_dim_0[i, j], j] = source_dim_0[i, j].
result_dim_0 = destination_dim_0.scatter(dim=0,
               index=index_dim_0, src=source_dim_0)
# scattering along dim=1
# destination_dim_0[i, index_dim_1[i, j]] = source_dim_1[i, j]
result_dim_1 = destination_dim_1.scatter(dim=1,
               index=index_dim_1, src=source_dim_1)
print("# destination_dim_0\n", destination_dim_0)
print("# index_dim_0\n", index_dim_0)
print("# source_dim_0 \n", source_dim_0)
print("# result_dim_0\n", result_dim_0)
print()
print("# --- ---------------------- ---")
print("# --- scattering along dim=1 ---")
print("# --- ---------------------- ---")
print()
print("# destination_dim_1\n", destination_dim_1)
print("# index_dim_1\n", index_dim_1)
print("# source_dim_1 \n", source_dim_1)
print("# result_dim_0\n", result_dim_0)