hayden-donnelly/mnist-webdataset-png
收藏MNIST WebDataset PNG
数据集概述
- 任务类别:
- 图像分类
- 无条件图像生成
- 数据集大小:
- 10K<n<100K
数据格式
- 样本存储为PNG图像
- 编译为WebDataset格式
示例代码
以下代码展示了如何通过DALI将该数据集加载到JAX数组中:
python from nvidia.dali import pipeline_def import nvidia.dali.fn as fn import nvidia.dali.types as types from nvidia.dali.plugin.jax import DALIGenericIterator from nvidia.dali.plugin.base_iterator import LastBatchPolicy
def get_data_iterator(batch_size, dataset_path): @pipeline_def(batch_size=batch_size, num_threads=4, device_id=0) def wds_pipeline(): raw_image, ascii_label = fn.readers.webdataset( paths=dataset_path, ext=[png, cls], missing_component_behavior=error, ) image = fn.decoders.image(raw_image) ascii_shift = types.Constant(48).uint8() label = ascii_label - ascii_shift return image, label
data_pipeline = wds_pipeline()
data_iterator = DALIGenericIterator(
pipelines=[data_pipeline],
output_map=[x, y],
last_batch_policy=LastBatchPolicy.DROP
)
return data_iterator
data_iterator = get_data_iterator( batch_size=32, dataset_path=data/mnist_webdataset_numpy_flat_9/data.tar ) batch = next(data_iterator) x = batch[x] y = batch[y] print(x shape:, x.shape) print(y shape:, y.shape) print(y:, y[:, 0])
输出结果:
x shape: (32, 28, 28, 3) y shape: (32, 1) y: [5 0 4 1 9 2 1 3 1 4 3 5 3 6 1 7 2 8 6 9 4 0 9 1 1 2 4 3 2 7 3 8]



