SALSA-CLRS
收藏SALSA-CLRS 数据集
数据集加载
SALSA-CLRS 数据集支持以下算法:bfs, dfs, dijkstra, mst_prim, fast_mis, eccentricity。可以使用以下代码自动下载数据集:
python from salsaclrs import load_dataset
train_dataset = load_dataset(algorithm="bfs", split="train", local_dir="path/to/local/data/store") val_dataset = load_dataset(algorithm="bfs", split="val", local_dir="path/to/local/data/store") test_datasets = load_dataset(algorithm="bfs", split="val", local_dir="path/to/local/data/store") er_16 = test_datasets["er_16"]
返回的对象类型为 SALSACLRSDataset,是 PyG 数据集。通过 ds.specs 可以获取单个数据点的类型和规格。
数据集生成
可以生成符合自定义要求的新数据集。例如:
-
BFS 训练数据集,包含 10000 个样本,图生成器为 "er",节点数在 [16, 32] 之间,概率 p 在 (0.1, 0.3) 之间: python from salsaclrs import SALSACLRSDataset ds = SALSACLRSDataset(root=DATA_DIR, split="train", algorithm="bfs", num_samples=10000, graph_generator="er", graph_generator_kwargs={"n": [16, 32], "p_range": (0.1, 0.3)}, hints=True)
-
BFS 训练数据集,包含 10000 个样本,图生成器为 "ws",节点数在 [16, 32] 之间,k 在 [2, 4, 6] 之间,概率 p 在 (0.1, 0.3) 之间: python from salsaclrs import SALSACLRSDataset ds = SALSACLRSDataset(root=DATA_DIR, split="train", algorithm="bfs", num_samples=10000, graph_generator="ws", graph_generator_kwargs={"n": [16, 32], "k": [2, 4, 6], "p_range": (0.1, 0.3)}, hints=True)
-
MST 训练数据集,包含 10000 个样本,图生成器为 "delaunay",节点数在 [16, 32] 之间: python from salsaclrs import SALSACLRSDataset ds = SALSACLRSDataset(root=DATA_DIR, split="train", algorithm="mst_prim", num_samples=10000, graph_generator="delaunay", graph_generator_kwargs={"n": [16, 32]}, hints=True)
通过设置 hints=False 可以生成但不加载提示信息,通过设置 ignore_all_hints=True 可以生成不包含任何提示信息的数据集。
数据加载器
需要使用提供的 SALSACLRSDataLoader 而不是默认的 PyG DataLoader,以确保批次正确合并。API 保持不变:
python from salsaclrs import SALSACLRSDataLoader dl = SALSACLRSDataLoader(ds, batch_size=32, num_workers=...)
Pytorch Lightning
库提供了与 SALSACLRSDataset 数据集配合使用的 Pytorch Lightning 数据模块,支持多个验证和测试数据集。例如:
python from salsaclrs import SALSACLRSDataset, SALSACLRSDataModule import lightning.pytorch as pl
ds_train = SALSACLRSDataset(root=DATA_DIR, split="train", algorithm="bfs", num_samples=10000, graph_generator="er", ignore_all_hints=False, hints=True, graph_generator_kwargs={"n": [16, 32], "p": [0.1, 0.2, 0.3]}) ds_val = SALSACLRSDataset(root=DATA_DIR, split="val", algorithm="bfs", num_samples=100, graph_generator="er", ignore_all_hints=False, hints=True, graph_generator_kwargs={"n": [32], "p": [0.1, 0.2, 0.3]}) ds_test_small = SALSACLRSDataset(root=DATA_DIR, split="val", algorithm="bfs", num_samples=100, graph_generator="er", ignore_all_hints=False, hints=True, graph_generator_kwargs={"n": [32], "p": [0.1, 0.2, 0.3]}) ds_test_large = SALSACLRSDataset(root=DATA_DIR, split="val", algorithm="bfs", num_samples=100, graph_generator="er", ignore_all_hints=False, hints=True, graph_generator_kwargs={"n": [128], "p": [0.1, 0.2, 0.3]})
data_module = SALSACLRSDataModule(train_dataset=ds_train, val_datasets=[ds_val], test_datasets=[ds_test_small, ds_test_large]) trainer = pl.Trainer(...) trainer.fit(model, data_module)



