#!/usr/bin/env python3

import torch
import random

def normalizeVector(x):
    "euclidan normalization of entry tensor"
    return x/torch.sqrt(x.pow(2).sum())

# main

print("original tensor")
x = torch.rand(4, 2)
#x = torch.arange(8).view(4,2)
print(x)
print("#====================\n")

# entire tensor normalized

print("fully normalized tensor")
allNormal = normalizeVector(x)
print("|allNormal|^2 : ", torch.sum(allNormal.pow(2)).item())
print(allNormal)
print("#====================\n")

# looping over rows

print("looping over rows manually")
for ii in range(len(allNormal)):
  rowVector = normalizeVector(allNormal[ii])
  print("|rowVector|^2 : ", ii, torch.sum(rowVector.pow(2)).item())
print("#====================\n")

# all rows are normalized, using vectorized function  
# function executed at every point of the map

print("row by row normalization with vmap")
g = torch.vmap(normalizeVector)
rowNormal = g(x)
for ii in range(len(allNormal)):
  print("|rowNormal|^2 : ", ii, torch.sum(rowNormal[ii].pow(2)).item())
print(rowNormal)
print("#====================\n")
