#!/usr/bin/env python3

#
# tuples can be used to access tensors
# with previously unknown dimensions
#

import torch
import math
import numpy as np
import matplotlib.pyplot as plt


def generateTensor(D=2, N=3):
  "of shape (N,N,..); with D dimensions"
  dimTuple = ()                            # emmpty tuple
  nElements = 1                            # of tensor
  for _ in range(D):
    dimTuple = dimTuple + (N,)             # add to tuple
    nElements *= N
#
  print("genTen: dimTuple :", dimTuple)
  print("genTen: nElements:", nElements)
  return torch.arange(nElements, dtype=torch.int).view(dimTuple)

def doSomething(inTensor):
  """with a tensor of arbitrary shape;
     changing a random elements
     assuming shape (N_1, N_1, ..) with N_i==N
  """
  inShape = myTensor.shape
  D = len(inShape)
  N = inShape[0]                  # assuming (N, N, ..)
  accessTuple = ()       
  for _ in range(D):
    rrInt = np.random.randint(N)
    accessTuple += (rrInt,)
  inTensor[accessTuple] = -1      # access via tuple
#
  print("doSome:  inShape :", inShape)
  print("doSome:     D, N :", D, N)
  print("doSome:  accTuple:", accessTuple)
  return inTensor

#
# main
#
myTensor = generateTensor()
print("  main:    shape :", myTensor.shape)
print(myTensor)
outTensor = doSomething(myTensor)
print(outTensor)
