#!./.mnist-pytorch/bin/python
import os

import fire
import torchvision


def get_data(out_dir='data'):
    # Make dir if necessary
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    # Download data
    torchvision.datasets.MNIST(
        root=f'{out_dir}/train', transform=torchvision.transforms.ToTensor, train=True, download=True)
    torchvision.datasets.MNIST(
        root=f'{out_dir}/test', transform=torchvision.transforms.ToTensor, train=False, download=True)


if __name__ == '__main__':
    fire.Fire(get_data)
