#!/usr/bin/env python3

import torch

nMatrix = 50
nRow    = 40
nCol    = 30
nS      = 20

AA = torch.randn(nMatrix,nRow,nCol)             # combine many matrices
BB = torch.randn(nMatrix,nCol,nS  )             # to a single tensor

YY = torch.zeros(nMatrix,nRow,nS)   
if (1==1): 
  YY = torch.matmul(AA,BB)                      # tensor operation
else:
  for ii in range(nMatrix):                     # looping over all matrices
    for nn in range(nRow):          
      for mm in range(nS):          
        for ll in range(nCol):                  # explicit multiplication
           YY[ii][nn][mm] += AA[ii][nn][ll]*BB[ii][ll][mm]
#
print()
print("AA", AA.shape)
print("BB", BB.shape)
print("YY", YY.shape)
print()
print(f'we did multiply {nMatrix:d} matrices')
print(f'of type ({nRow:d}x{nCol:d}) and ({nCol:d}x{nS:d})')
