#!./.mnist-keras/bin/python
# import os

# import fire
# import numpy as np
# import tensorflow as tf


# def get_data(out_dir='data'):
#     # Make dir if necessary
#     if not os.path.exists(out_dir):
#         os.mkdir(out_dir)

#     # Download data
#     (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
#     np.savez(f'{out_dir}/mnist.npz', x_train=x_train,
#              y_train=y_train, x_test=x_test, y_test=y_test)


# if __name__ == '__main__':
#     fire.Fire(get_data)


# Updated get_data.py to load CIFAR-10 data

#!./.mnist-keras/bin/python
import os
import fire
import numpy as np
import tensorflow as tf

def get_data(out_dir='data'):
    # Make dir if necessary
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    # Download CIFAR-10 data
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    np.savez(f'{out_dir}/cifar10.npz', x_train=x_train,
             y_train=y_train, x_test=x_test, y_test=y_test)

if __name__ == '__main__':
    fire.Fire(get_data)
