Machine Learning Primer -- Python Tutorial




Claudius Gros, WS 2025/26

Institut für theoretische Physik
Goethe-University Frankfurt a.M.

PyTorch Basics

PyTorch example

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch                     # PyTorch instead of NumPy
import math
import matplotlib.pyplot as plt


myType   = torch.float
myDevice = torch.device("cpu")   # "cuda:0" for GPU; not activated

# global parameters
nData = 2000                     # number of training pairs
nIter = 2000                     # number training iterations
nPar  =    4                     # number of fit parameters

learning_rate = 0.5e-2/nData     # relative learning rate
fitPar = []                      # empty list; fit parameters
for i in range(nPar):            # randn() : normal distribution
  fitPar.append(torch.randn((), device=myDevice, dtype=myType))
print(fitPar)

def fitFunction(x):              # polynomial fitting function 
  sum = 0.0
  for i in range(nPar):
    sum += fitPar[i]*(x**i)
  return sum

# linspace returns a list
x = torch.linspace(-math.pi, math.pi, nData, device=myDevice, dtype=myType)
y = torch.sin(x)                 # target function y = sin(x)

# training iteration
for iIter in range(nIter):
  y_pred = fitFunction(x)                  # list; element-wise
  loss   = torch.square(y_pred - y).sum()  # sum of squared elements

  if iIter % 100 == 99:                    # test printout
    print(f'{iIter:5d}  {loss:10.6f}')

  grad_y_pred = 2.0 * (y_pred - y)         # error signal
  for i in range(nPar):                    # least-square fit
    gradient = ( grad_y_pred*(x**i) ).sum()
    fitPar[i] -= learning_rate * gradient

# showing result
plt.plot(x, torch.sin(x)             , 'b', label="sin(x)")
plt.plot(x, fitFunction(x)           , 'r', label="polynomial fit")
plt.plot(x, 0.0*x                    , '--k')
plt.legend()
plt.show()

printing PyTorch tensors

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch
import numpy as np

# on putput, substitute  x  ->  f'{x:4.2f}'
np.set_printoptions(formatter=\
   {'float_kind':lambda x: f"{x:4.2f}"})

x = torch.rand(4, 5)

print("# non formatted ")
print(x)
print("\n# formatted with np ")
print(x.numpy())

# print(x.detach().numpy())     # for parameters

tensor shapes

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

#
# tuples can be used to access tensors
# with previously unknown dimensions
#

import torch
import math
import numpy as np
import matplotlib.pyplot as plt


def generateTensor(D=2, N=3):
  "of shape (N,N,..); with D dimensions"
  dimTuple = ()                            # emmpty tuple
  nElements = 1                            # of tensor
  for _ in range(D):
    dimTuple = dimTuple + (N,)             # add to tuple
    nElements *= N
#
  print("genTen: dimTuple :", dimTuple)
  print("genTen: nElements:", nElements)
  return torch.arange(nElements, dtype=torch.int).view(dimTuple)

def doSomething(inTensor):
  """with a tensor of arbitrary shape;
     changing a random elements
     assuming shape (N_1, N_1, ..) with N_i==N
  """
  inShape = myTensor.shape
  D = len(inShape)
  N = inShape[0]                  # assuming (N, N, ..)
  accessTuple = ()       
  for _ in range(D):
    rrInt = np.random.randint(N)
    accessTuple += (rrInt,)
  inTensor[accessTuple] = -1      # access via tuple
#
  print("doSome:  inShape :", inShape)
  print("doSome:     D, N :", D, N)
  print("doSome:  accTuple:", accessTuple)
  return inTensor

#
# main
#
myTensor = generateTensor()
print("  main:    shape :", myTensor.shape)
print(myTensor)
outTensor = doSomething(myTensor)
print(outTensor)

automatic gradient evaluation




[Data Hacker]
Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch                     # PyTorch needs to be installed

dim = 2
eps = 0.1
x = torch.ones(dim, requires_grad=True)  # leaf of computational graph
print("x           : ",x)
print("x.data      : ",x.data)
print("x[0]        : ",x[0])
print("x[0].item() : ",x[0].item())
print()

y = x + 2
out = torch.dot(y,y)             # scalar product
print("y      : ",y)
print("out    : ",out)
print()

out.backward()                   # backward pass --> gradients
print("x.grad : ",x.grad)

with torch.no_grad():            # detach from computational graph
  x -= eps*x.grad                # updating parameter tensor
  x.grad = None                  # flush

print("x      : ",x.data)

print("\n#---")
print("#--- .backward() adds new gradient to old gradient")
print("#---             convenient for batch updating")
print("#---\n")

y = torch.zeros(dim, requires_grad=True)  
torch.dot(y+1,y+1).backward()
print("y.grad : ",y.grad)
torch.dot(y+1,y+1).backward()
print("y.grad : ",y.grad)
torch.dot(y+1,y+1).backward()
print("y.grad : ",y.grad)
torch.dot(y+1,y+1).backward()

least square fit

$$ y = \sin(x) \approx \sum_{k=0}^{{\rm nPar}-1} f_k x^k $$
Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch                   
import math
import matplotlib.pyplot as plt


myType   = torch.float
myDevice = (
    "cuda"                       # for GPUs
    if torch.cuda.is_available()
    else "mps"                   # 'MultiProcessor Specification'
    if torch.backends.mps.is_available()
    else "cpu"                   # plain old CPU
)

# global parameters
nData = 2000                     # number of training pairs
nIter = 2000                     # number training iterations
nPar  =    4                     # number of fit parameters
learning_rate = 0.5e-2/nData    

# gradients with respect to figPar[] to be evaluated
fitPar = []                      # list of 1x1 tensors
for i in range(nPar):      
  fitPar.append(torch.randn((), device=myDevice, dtype=myType,\
                                requires_grad=True))  
print(fitPar)

x = torch.linspace(-math.pi, math.pi, nData)
y = torch.sin(x)                 # element-wise

def fitFunction(x):              # polynomial fitting fuction
  sum = 0.0
  for i in range(nPar):
    sum += fitPar[i]*(x**i)      # element-wise, x is a tensor
  return sum                     # returns a tensor

# training iteration
for iIter in range(nIter):
  y_pred     = fitFunction(x)              # forward pass
  lossTensor = (y_pred - y).pow(2).sum()   # element-wise pow(2)

  if iIter % 100 == 99:                    # print scalar loss value
    print(f'{iIter:5d}  {lossTensor.item():10.6f}')

# backward pass 
# calculates gradients, viz 'tensor.grad',
# with respect to tensors with "requires_grad=True"
  lossTensor.backward()                    

# temporarily 'detaching' all tensors for by-hand updating
# the value of   fitPar[i].grad   is not affected
  with torch.no_grad():        
    for i in range(nPar):                  # gradients via backward pass
      fitPar[i] -= learning_rate * fitPar[i].grad
      fitPar[i].grad = None

# "detach" tensors requiring gradients from computation graph
plt.plot(x, torch.sin(x)                    , 'b', label="sin(x)")
plt.plot(x, fitFunction(x).detach().numpy() , 'r', label="fit")
plt.plot(x, 0.0*x                           , '--k')
plt.legend()
plt.show()

inplace operations

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3
#
# testing inplace opertions
#

import torch

print()
x = torch.tensor([1,2,3])
print(x," x original")          # tensor([1, 2, 3])
print(x.add(1), " x.add(1)")    # tensor([2, 3, 4])
print(x," x current")           # tensor([1, 2, 3])
print(x.add_(2), " x.add_(2)")  # tensor([3, 4, 5])  
                                # inplace operation, changes the value 
                                # of x in memory
print(x," x current")           # tensor([3, 4, 5])

print()
print("=========================")
print()

b  = torch.rand(4)
print(b._version, " b._version, id ", id(b))
b += 2              # an inplace operation
# b = b + 2         # new object generated
print(b._version, " b._version, id ", id(b))

print()
print("=========================")
print()

a = torch.rand(4, requires_grad=True)
c = a.pow(1)                   # compare
#c = a.pow_(1)                  # compare
c.sum().backward()
print("a       ", a)
print("a.grad  ", a.grad)

Straight Through Estimator (STE)

# 'detach' from computation graph
# forward  :  y = a - b + b = a
# backward :  y = b
y = (a-b).detach() + b
Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch                   
import math
import matplotlib.pyplot as plt

# global parameters
nData =   200   
nIter =    30 
eps   =   0.01

fitPar = 1.5*torch.randn(1, requires_grad=True)  

x = torch.linspace(-5.0, 5.0, nData)

def lossFunction(x, fitPar): 
  """fitting input 'y' with a inverse step function,
     T = Theta(fitpar-x) 
     using boolean -> float
     S : surrogate gradient
  """
  y = 1.0/(1.0+torch.exp(x))              # target
  T = (x < fitPar).float()
  S = 1.0/(1.0+torch.exp(4*(x-fitPar)))   # surrogate
  out = S + (T - S).detach()   # STE (by hand)
  return (y-out).pow(2).sum()  # MSE: mean square error

# training iteration
for iIter in range(nIter):
  loss = lossFunction(x, fitPar)
  print(f" {fitPar.item():10.4f} {loss.item():10.4f}")
  loss.backward()                    
  with torch.no_grad():        
    if fitPar.grad is not None:
      fitPar     -= eps*fitPar.grad
      fitPar.grad = None
  fitPar = fitPar.detach().requires_grad_()
# ^ re-insert into computation graph 

# plotting
y = 1.0/(1.0+torch.exp(x))              # target
T = (x < fitPar).float()
S = 1.0/(1.0+torch.exp(4*(x-fitPar)))   # surrogate
plt.plot(x, y                  , 'b'  , label="target")
plt.plot(x, T.detach().numpy() , 'r'  , label="fit")
plt.plot(x, S.detach().numpy() , '--k', label="surrogate")
plt.legend()
plt.show()

tensor reshaping

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch
import math
import matplotlib.pyplot as plt

uu = torch.ones(3)
xx = torch.ones(3,5)
print("\n# uu")
print(uu)
print("\n# xx")
print(xx)
print("\n# uu unsqueezed")
print(torch.unsqueeze(uu, 1))
print("\n# xx plus uu unsqueezed along 1")
print(xx+torch.unsqueeze(uu, 1))
#
aa = torch.arange(4.0)
print("\n# arranged\n", aa)
aa = torch.reshape(aa, (2, 2))
print("\n# reshaped\n", aa)
aa = torch.reshape(aa, (-1,))
print("\n#   ..back\n", aa)

implicit broadcasting

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch

# input data of shape [4, 3]
# batch of 4 samples, each with 3 features
x = torch.arange(1.0, 13.0).reshape(4, 3)
# x = torch.tensor([ [ 1.0,  2.0,  3.0],
#                    [ 4.0,  5.0,  6.0],
#                    [ 7.0,  8.0,  9.0],
#                    [10.0, 11.0, 12.0] ])  
x_mean = x.mean(dim=0)
x_std = x.std(dim=0, unbiased=False)

# each feature (column) should be normalized, with 
# respect to a given mean and standard deviation
feature_means = torch.tensor([2.0, 4.0, 6.0])
feature_stds  = torch.tensor([1.0, 2.0, 4.0])

# automatic broadcasting: 
# [4, 3] - [3] -> [4, 3]
x_normalized = feature_means\
             + (x - x_mean)*feature_stds/x_std

print("# input")
print(x)
print("\n# input mean along columns")
print(x_mean)
print("\n# input mean along rows")
print(x.mean(dim=1))
print("\n# feature means, the target")
print(feature_means)
print("\n# feature standard deviations:")
print(feature_stds)
print("\n# normalized output")
print(x_normalized)
print("\n# output mean")
print(x_normalized.mean(dim=0))
print("\n# output standard deviation")
print(x_normalized.std(dim=0, unbiased=False))

explicit broadcasting

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3
#
# calculating the pairwise differences
# of several vectors
#

import torch

# batch of 3 vectors, each of dim 2; 
x32 = torch.tensor([ [1, 2],
                     [3, 4],
                     [5, 6] ])  
print("# original tensor")
print(x32)
print()

# shape: [3, 2]
N, D = x32.shape

# unsqueeze to get shapes [3, 1, 2] and [1, 3, 2]
x312 = x32.unsqueeze(1)  
y132 = x32.unsqueeze(0) 

# expand both tensors to [3, 3, 2] 
x332 = x312.expand(N, N, D)
y332 = y132.expand(N, N, D)

# tensors of identical shapes can be substracted directly
diffs = x332 - y332

print("# pairwise differences")
print(diffs)
#print(x332)
#print(y332)
print()

print("# shapes")
print("  x32   ",  x32.shape)
print("  x312  ", x312.shape)
print("  y132  ", y132.shape)
print()
print("  x332  ",  x332.shape)
print("  y332  ",  y332.shape)
print("  diffs ", diffs.shape)

dividing tensors / lists

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch
import random

aa = torch.arange(10.0).reshape(2,5)
bb = torch.arange(5) + 1.0
cc = torch.divide(aa,bb)
print("aa\n", aa)
print("bb\n", bb)
print("cc\n", cc)
print()

print("# ======================")
print("# row-wise normalization")
print("# ======================")
aaSum   = aa.sum(1) 
# aaSum = aa.pow(2).sum(1) 
aaTrans = torch.divide(aa.transpose(0,1),aaSum)
aaNorm  = torch.transpose(aaTrans,0,1)
print("asSum\n"  , aaSum)
print("asTrans\n", aaTrans)
print("aaNorm\n" , aaNorm)
print()

print("# =====================================")
print("# dividing lists, checking denominator ")
print("# =====================================")
aa_list = [random.randint(0, 10) for _ in range(5)]  
bb_list = [random.choice(range(10)) for _ in range(5)]
cc_list = [a / b if b != 0 else -11 for a, b in zip(aa_list,bb_list)]
print("aa_list\n"  , aa_list)
print("bb_list\n"  , bb_list)
print("cc_list\n"  , cc_list)

tensor multiplication

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch

# two batches of matrices, of shapes
# 4x5 and 5x6
A = torch.randn(3, 4, 5)  
B = torch.randn(3, 5, 6)  

# batch matrix multiplication using einsum
# with 'bij,bjk->bik' 
# for each batch b do
# \sum_j A[b,i,j] * B[b,j,k] -> result[b,i,k]
result_einsum = torch.einsum('bij,bjk->bik', A, B)

# equivalent to using torch.bmm (batch matrix multiply)
result_bmm = torch.bmm(A, B)

print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")
print(f"Result shape: {result_einsum.shape}")
print(f"Results are equal: {torch.allclose(result_einsum, result_bmm)}")

fancy indexing

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch

source_tensor = 10*torch.arange(9)
index_tensor  = torch.arange(4,6)
result_tensor = source_tensor[index_tensor]
print()
print("# --- 1D example --- ")
print("# source, index, result ")
print(source_tensor)
print(index_tensor)
print(result_tensor)
#
# selecting (0, 2), (1, 1), (2, 0)
# from a 2D tensor
D2_source = source_tensor.view(3,3)
row_indices = torch.tensor([0, 1, 2])
col_indices = torch.tensor([2, 1, 0])
D2_result = D2_source[row_indices, col_indices]

print()
print("# --- 2D example --- ")
print("# source, indices, result ")
print(D2_source)
print(row_indices, col_indices)
print(D2_result)

# masking (boolean indexing)
D2_bool = D2_source[ (D2_source>5)&(D2_source<35) ]
print()
print("# --- boolean indexing --- ")
print(D2_bool)

advanced indexing

Copy Copy to clipboad
Downlaod Download
#!/usr/bin/env python3

import torch

print()
print("# --- --------- ---")
print("# --- gathering ---")
print("# --- --------- ---")

source_tensor = torch.tensor([[10, 20, 30],
                              [40, 50, 60],
                              [70, 80, 90]])
index_tensor = torch.tensor([[2, 0],
                             [1, 1],
                             [0, 1]])

# result[i,j] = source_tensor[index_tensor[i,j], j]
gather_dim_0 = torch.gather(source_tensor, dim=0, index=index_tensor)

# result[i,j] = source_tensor[i, index_tensor[i,j]]
gather_dim_1 = torch.gather(source_tensor, dim=1, index=index_tensor)

print("\n# source \n", source_tensor)
print("\n# index  \n", index_tensor)
print("\n# gather along dim=0\n", gather_dim_0)
print("\n# gather along dim=1\n", gather_dim_1)

print()
print("# --- ------------ ---")
print("# --- top K values ---")
print("# --- ------------ ---")
print()

aa = torch.randint(0,10,(3,4))
# for each column, find the top 2 values across all rows
val_dim_0, idx_dim_0 = torch.topk(aa, 2, dim=0)

# for each row   , find the top 2 values across all columns
val_dim_1, idx_dim_1 = torch.topk(aa, 2, dim=1)

print("# rand int tensor\n",aa)
print()
print("# (index,value) pairs for dim=0")
for val, idx in zip(val_dim_0, idx_dim_0):
  for ii in range(val.shape[0]):
    print(idx[ii].item(), val[ii].item())
  print()

print("# (index,value) pairs for dim=1")
for val, idx in zip(val_dim_1, idx_dim_1):
  for ii in range(val.shape[0]):
    print(idx[ii].item(), val[ii].item())
  print()

print("# --- ---------- ---")
print("# --- scattering ---")
print("# --- ---------- ---")
print()

# destination tensors, shapes
# (3,4) for dim=0 and
# (4,3) for dim=1
destination_dim_0 = torch.zeros(3, 4, dtype=torch.float)
destination_dim_1 = destination_dim_0.transpose(0,1)

# source tensor, the values to scatter.
source_dim_0 = torch.tensor([[1., 2., 3., 4.],
                             [5., 6., 7., 8.]])

source_dim_1 = torch.tensor([[1., 2.],
                             [3., 4.],
                             [5., 6.]])

index_dim_0 = torch.tensor([[0, 2, 1, 2],
                            [1, 0, 2, 1]])

index_dim_1 = torch.tensor([[0, 2],
                            [1, 0],
                            [2, 1]])

# scattering along dim=0
# destination_dim_0[index_dim_0[i, j], j] = source_dim_0[i, j].

result_dim_0 = destination_dim_0.scatter(dim=0,
               index=index_dim_0, src=source_dim_0)

# scattering along dim=1
# destination_dim_0[i, index_dim_1[i, j]] = source_dim_1[i, j]

result_dim_1 = destination_dim_1.scatter(dim=1,
               index=index_dim_1, src=source_dim_1)

print("# destination_dim_0\n", destination_dim_0)
print("# index_dim_0\n", index_dim_0)
print("# source_dim_0 \n", source_dim_0)
print("# result_dim_0\n", result_dim_0)

print()
print("# --- ---------------------- ---")
print("# --- scattering along dim=1 ---")
print("# --- ---------------------- ---")
print()

print("# destination_dim_1\n", destination_dim_1)
print("# index_dim_1\n", index_dim_1)
print("# source_dim_1 \n", source_dim_1)
print("# result_dim_0\n", result_dim_0)