pd12m-full
收藏PD12M 数据集
基本信息
- 语言: 英语 (en)
- 名称: PD12M
- 许可证: CDLA-Permissive-2.0
- 标签: 图像 (image)
描述
- 该数据集是 Spawning/PD12M 的下载变体,特别兼容
webdataset。 - 数据集在获得原始作者的许可后公开发布。
使用示例
python import webdataset as wds
dataset_path = "pipe:curl -s -f -L https://huggingface.co/datasets/sayakpaul/pd12m-full/resolve/main/{00155..02480}.tar"
dataset = ( wds.WebDataset(dataset_path, handler=wds.warn_and_continue) .shuffle(690, handler=wds.warn_and_continue) .decode("pil", handler=wds.warn_and_continue) )
for sample in dataset: print(sample.keys()) print(sample["jpg"].size) print(sample["json"]) print(sample["txt"]) break
数据加载
- 提供了参考数据加载器实现,详见 dataloader.py。
数据下载
-
使用
img2dataset工具进行下载。 -
下载命令如下: bash img2dataset --url_list pd12m_full.parquet --input_format "parquet" --url_col "url" --caption_col "caption" --output_format webdataset --number_sample_per_shard=5000 --skip_reencode=True --output_folder s3://diffusion-datasets/pd12m --processes_count 16 --thread_count 64 --resize_mode no --enable_wandb True
-
下载的
webdataset分片被序列化到 S3 存储桶。 -
pd12m_full.parquet是通过合并 metadata 中的所有 parquet 文件到一个 pandas 数据框中生成的,文件位于 original_parquet/pd12m_full.parquet。
文件复制
-
使用以下脚本将文件从 S3 存储桶复制到当前仓库: python from huggingface_hub import create_repo, upload_file, dataset_info import ray import os
Change
_temp_dirpath accordingly.ray.init(num_cpus=16, _temp_dir="/scratch")
def main(): s3_fs = s3fs.S3FileSystem()
bucket_path = "s3://diffusion-datasets/pd12m" files = s3_fs.ls(bucket_path, detail=True) files = sorted([f["name"] for f in files if f["name"].endswith(".tar") and f["size"] > 0.0]) @ray.remote def fn(tar_file): # Change the paths accordingly. full_s3_tar_file = f"s3://{tar_file}" local_path = f"/scratch/{tar_file}" s3_fs.download(full_s3_tar_file, local_path) # Adjust according to what your local storage allows for. batch_size = 20 for i in range(0, len(files), batch_size): batch = files[i : i + batch_size] futures = [fn.remote(tar_file) for tar_file in batch] ray.get(futures) os.system( "huggingface-cli upload-large-folder sayakpaul/pd12m-full --repo-type=dataset /scratch/diffusion-datasets/pd12m --num-workers=16" ) os.system(f"rm -rf /scratch/diffusion-datasets/pd12m/*.tar") print("All shards have been downloaded successfully.")if name == "main": create_repo(repo_id="sayakpaul/pd12m-full", repo_type="dataset", private=True, exist_ok=True) main()




