#!/usr/bin/env python3

# IMDB movie rating dataset
# 25 000 train/test,  31/32MByte
# https://huggingface.co/datasets/stanfordnlp/imdb

from datasets import load_dataset, load_from_disk
import os

# Load the IMDB dataset

def download_data(dataSetName = "imdb"):
  """download from Huggingface if not available locally"""
  localDatasetPath = dataSetName + "_all"
  yesLocal = True if os.path.exists(localDatasetPath)\
                  else False
#
  MyDataset = None
  if (yesLocal):
    print(f'# loading {dataSetName} from file')
    MyDataset = load_from_disk(localDatasetPath)
  else:
    print(f'# downloading {dataSetName} from Huggingface')
    MyDataset = load_dataset(dataSetName)
    MyDataset.save_to_disk(localDatasetPath)
#
  print(MyDataset)
  for split in MyDataset.keys():          # slits and features
    print("# split   ", split)
    print("# features ", MyDataset[split].features)
    print()
#
  if not yesLocal:                        # specific to IMDB
    for split in MyDataset.keys():
     fileName = dataSetName + "." + split
     with open(fileName, 'w', encoding='utf-8') as f:
       for example in MyDataset[split]:
         f.write(f"{example['label']:3d} {example['text']}\n")
#
  return MyDataset

#
# main
#

main_dataset = download_data()

# available datasets
if (1==2):
  from huggingface_hub import list_datasets
  with open("huggingface.sets", 'w') as f:
    [f.write(f"{dataset.id}\n") for dataset in list_datasets()]
