#!/usr/bin/env python3

# conditional tensor operations

import torch

xIn  = torch.randn(6)
xOut = torch.where(xIn > 0, 1.0, 0.0)     # conditional mapping
print("xIn \n", xIn)
print("xOut\n", xOut)
print()

yIn    = torch.arange(10)
yOnes  = torch.ones(10).int()             # float to int
yOut = torch.where(yIn%2==0, yIn, yOnes)  # conditional substitution

print("yIn   \n", yIn)
print("yOnes \n", yOnes)
print("yOut  \n", yOut)
print()

# element-wise stochastic operations
# need random masks
zIn    = torch.arange(10)
randMask = torch.rand(10)                 # in [0,1]
randMask = torch.where(randMask<0.5,1,0)  # either 0 or 1
zOut = torch.where(zIn%2==0, zIn, randMask)

print("zIn      \n", zIn)
print("randMask \n", randMask)
print("zOut     \n", zOut)
