five

SALSA-CLRS

收藏
arXiv2023-11-20 更新2024-07-30 收录
下载链接:
https://github.com/jkminder/SALSA-CLRS
下载链接
链接失效反馈
官方服务:
资源简介:
SALSA-CLRS是一个针对算法学习基准的扩展,特别关注可扩展性和稀疏表示的利用。它包括了原始CLRS基准的适应算法以及来自分布式和随机算法的新问题。
创建时间:
2023-09-22
原始信息汇总

SALSA-CLRS 数据集

数据集加载

SALSA-CLRS 数据集支持以下算法:bfs, dfs, dijkstra, mst_prim, fast_mis, eccentricity。可以使用以下代码自动下载数据集:

python from salsaclrs import load_dataset

train_dataset = load_dataset(algorithm="bfs", split="train", local_dir="path/to/local/data/store") val_dataset = load_dataset(algorithm="bfs", split="val", local_dir="path/to/local/data/store") test_datasets = load_dataset(algorithm="bfs", split="val", local_dir="path/to/local/data/store") er_16 = test_datasets["er_16"]

返回的对象类型为 SALSACLRSDataset,是 PyG 数据集。通过 ds.specs 可以获取单个数据点的类型和规格。

数据集生成

可以生成符合自定义要求的新数据集。例如:

  • BFS 训练数据集,包含 10000 个样本,图生成器为 "er",节点数在 [16, 32] 之间,概率 p 在 (0.1, 0.3) 之间: python from salsaclrs import SALSACLRSDataset ds = SALSACLRSDataset(root=DATA_DIR, split="train", algorithm="bfs", num_samples=10000, graph_generator="er", graph_generator_kwargs={"n": [16, 32], "p_range": (0.1, 0.3)}, hints=True)

  • BFS 训练数据集,包含 10000 个样本,图生成器为 "ws",节点数在 [16, 32] 之间,k 在 [2, 4, 6] 之间,概率 p 在 (0.1, 0.3) 之间: python from salsaclrs import SALSACLRSDataset ds = SALSACLRSDataset(root=DATA_DIR, split="train", algorithm="bfs", num_samples=10000, graph_generator="ws", graph_generator_kwargs={"n": [16, 32], "k": [2, 4, 6], "p_range": (0.1, 0.3)}, hints=True)

  • MST 训练数据集,包含 10000 个样本,图生成器为 "delaunay",节点数在 [16, 32] 之间: python from salsaclrs import SALSACLRSDataset ds = SALSACLRSDataset(root=DATA_DIR, split="train", algorithm="mst_prim", num_samples=10000, graph_generator="delaunay", graph_generator_kwargs={"n": [16, 32]}, hints=True)

通过设置 hints=False 可以生成但不加载提示信息,通过设置 ignore_all_hints=True 可以生成不包含任何提示信息的数据集。

数据加载器

需要使用提供的 SALSACLRSDataLoader 而不是默认的 PyG DataLoader,以确保批次正确合并。API 保持不变:

python from salsaclrs import SALSACLRSDataLoader dl = SALSACLRSDataLoader(ds, batch_size=32, num_workers=...)

Pytorch Lightning

库提供了与 SALSACLRSDataset 数据集配合使用的 Pytorch Lightning 数据模块,支持多个验证和测试数据集。例如:

python from salsaclrs import SALSACLRSDataset, SALSACLRSDataModule import lightning.pytorch as pl

ds_train = SALSACLRSDataset(root=DATA_DIR, split="train", algorithm="bfs", num_samples=10000, graph_generator="er", ignore_all_hints=False, hints=True, graph_generator_kwargs={"n": [16, 32], "p": [0.1, 0.2, 0.3]}) ds_val = SALSACLRSDataset(root=DATA_DIR, split="val", algorithm="bfs", num_samples=100, graph_generator="er", ignore_all_hints=False, hints=True, graph_generator_kwargs={"n": [32], "p": [0.1, 0.2, 0.3]}) ds_test_small = SALSACLRSDataset(root=DATA_DIR, split="val", algorithm="bfs", num_samples=100, graph_generator="er", ignore_all_hints=False, hints=True, graph_generator_kwargs={"n": [32], "p": [0.1, 0.2, 0.3]}) ds_test_large = SALSACLRSDataset(root=DATA_DIR, split="val", algorithm="bfs", num_samples=100, graph_generator="er", ignore_all_hints=False, hints=True, graph_generator_kwargs={"n": [128], "p": [0.1, 0.2, 0.3]})

data_module = SALSACLRSDataModule(train_dataset=ds_train, val_datasets=[ds_val], test_datasets=[ds_test_small, ds_test_large]) trainer = pl.Trainer(...) trainer.fit(model, data_module)

5,000+
优质数据集
54 个
任务类型
进入经典数据集
二维码
社区交流群

面向社区/商业的数据集话题

二维码
科研交流群

面向高校/科研机构的开源数据集话题

数据驱动未来

携手共赢发展

商业合作