five

hayden-donnelly/mnist-webdataset-png

收藏
Hugging Face2024-03-06 更新2024-06-22 收录
下载链接:
https://hf-mirror.com/datasets/hayden-donnelly/mnist-webdataset-png
下载链接
链接失效反馈
官方服务:
资源简介:
--- task_categories: - image-classification - unconditional-image-generation size_categories: - 10K<n<100K --- # MNIST WebDataset PNG The MNIST dataset with samples stored as PNG images and compiled into the WebDataset format. ## DALI/JAX Example The following code shows how this dataset can be loaded into JAX arrays by DALI. ```python from nvidia.dali import pipeline_def import nvidia.dali.fn as fn import nvidia.dali.types as types from nvidia.dali.plugin.jax import DALIGenericIterator from nvidia.dali.plugin.base_iterator import LastBatchPolicy def get_data_iterator(batch_size, dataset_path): @pipeline_def(batch_size=batch_size, num_threads=4, device_id=0) def wds_pipeline(): raw_image, ascii_label = fn.readers.webdataset( paths=dataset_path, ext=['png', 'cls'], missing_component_behavior='error', ) image = fn.decoders.image(raw_image) ascii_shift = types.Constant(48).uint8() label = ascii_label - ascii_shift return image, label data_pipeline = wds_pipeline() data_iterator = DALIGenericIterator( pipelines=[data_pipeline], output_map=['x', 'y'], last_batch_policy=LastBatchPolicy.DROP ) return data_iterator data_iterator = get_data_iterator( batch_size=32, dataset_path='data/mnist_webdataset_numpy_flat_9/data.tar' ) batch = next(data_iterator) x = batch['x'] y = batch['y'] print('x shape:', x.shape) print('y shape:', y.shape) print('y:', y[:, 0]) ``` Output: ``` x shape: (32, 28, 28, 3) y shape: (32, 1) y: [5 0 4 1 9 2 1 3 1 4 3 5 3 6 1 7 2 8 6 9 4 0 9 1 1 2 4 3 2 7 3 8] ``` ## Acknowledgements - Yann LeCun, Courant Institute, NYU - Corinna Cortes, Google Labs, New York - Christopher J.C. Burges, Microsoft Research, Redmond
提供机构:
hayden-donnelly
原始信息汇总

MNIST WebDataset PNG

数据集概述

  • 任务类别:
    • 图像分类
    • 无条件图像生成
  • 数据集大小:
    • 10K<n<100K

数据格式

  • 样本存储为PNG图像
  • 编译为WebDataset格式

示例代码

以下代码展示了如何通过DALI将该数据集加载到JAX数组中:

python from nvidia.dali import pipeline_def import nvidia.dali.fn as fn import nvidia.dali.types as types from nvidia.dali.plugin.jax import DALIGenericIterator from nvidia.dali.plugin.base_iterator import LastBatchPolicy

def get_data_iterator(batch_size, dataset_path): @pipeline_def(batch_size=batch_size, num_threads=4, device_id=0) def wds_pipeline(): raw_image, ascii_label = fn.readers.webdataset( paths=dataset_path, ext=[png, cls], missing_component_behavior=error, ) image = fn.decoders.image(raw_image) ascii_shift = types.Constant(48).uint8() label = ascii_label - ascii_shift return image, label

data_pipeline = wds_pipeline()
data_iterator = DALIGenericIterator(
    pipelines=[data_pipeline], 
    output_map=[x, y], 
    last_batch_policy=LastBatchPolicy.DROP
)
return data_iterator

data_iterator = get_data_iterator( batch_size=32, dataset_path=data/mnist_webdataset_numpy_flat_9/data.tar ) batch = next(data_iterator) x = batch[x] y = batch[y] print(x shape:, x.shape) print(y shape:, y.shape) print(y:, y[:, 0])

输出结果:

x shape: (32, 28, 28, 3) y shape: (32, 1) y: [5 0 4 1 9 2 1 3 1 4 3 5 3 6 1 7 2 8 6 9 4 0 9 1 1 2 4 3 2 7 3 8]

5,000+
优质数据集
54 个
任务类型
进入经典数据集
二维码
社区交流群

面向社区/商业的数据集话题

二维码
科研交流群

面向高校/科研机构的开源数据集话题

数据驱动未来

携手共赢发展

商业合作