#!/usr/bin/env python3

# generate and store delayed XOR time series 
# x(t) = ( x(t-tau) + x(t-tau-1) ) % N

import random
from datasets import Dataset, DatasetDict, load_from_disk

class GP():
  "global parameters"
  tau      = 1           # delay
  N        = 10          # number symbols
  nTrain   = 20          # number training data
  nTest    = 10          # number test data
  lenTrain =  7          # length of sequences
  lenTest  =  6          # 


def oneStep_NT(x, N, delay):
  """returns next token (NT time series)
     on input: current sequence x"""
  length = len(x)
  pair = x[length-delay-1:length-delay+1]
  return sum(pair)%N  


def oneSequence(length):
  """returns a time series of length 'length',
     starting with a random initial state"""
  sequence = random.sample(range(0, GP.N), GP.tau+1)
  for ll in range(length-GP.tau-1):
    sequence.append(oneStep_NT(sequence, GP.N, GP.tau))
  return sequence


def makeDataset():
  """returns the entire dataset object"""
  myType = "N" + str(GP.N) + "T" + str(GP.tau)

# generate training datasets
  train_data = []
  for _ in range(GP.nTrain):
    newSeries = oneSequence(GP.lenTrain)
    newRow = {"type":myType, "series":newSeries}
    train_data.append(newRow)

# generate test datasets
  test_data = []
  for _ in range(GP.nTest):
    newSeries = oneSequence(GP.lenTest)
    newRow = {"type":myType, "series":newSeries}
    test_data.append(newRow)

# make Dataset 'DatasetDict objects
  train_dataset = Dataset.from_list(train_data)
  test_dataset  = Dataset.from_list(test_data)
  dataset_dict = DatasetDict(
    {'train': train_dataset, 'test': test_dataset}
                            )
  return dataset_dict

def testPrintingDataset(DS):
  print()
  print("# =============")
  print("# test printing")
  print("# =============")
  for split_name, split_dataset in DS.items():
    print(f"Split: {split_name}, Size: {len(split_dataset)}")
    for example in split_dataset.select(range(2)): 
       print(example["series"])
    print()

#
# main
#
print(dataset_dict:=makeDataset())

# saving entire DatasetDict
dataset_dict.save_to_disk('./my_dataset')

# loading entire dataset
loaded_dataset = load_from_disk('./my_dataset')
# print(loaded_dataset)

# test printing
testPrintingDataset(loaded_dataset)
