Geometric Shapes Dataset
收藏Geometric Shapes Dataset
概述
该数据集包含几何形状的图像,适用于形状分类或图像识别等机器学习任务。数据集中的图像包括多边形,边数可变,并带有随机文本。
功能
- 生成多边形图像,边数可定制
- 为每张图像添加随机文本
- 创建包含训练、验证和测试集的数据集
- 可选地将数据集推送到Hugging Face Hub
- 使用生成的数据集训练形状分类模型
- 在验证集和测试集上评估模型
- 可选地将训练好的模型推送到Hugging Face Hub
安装
-
克隆仓库:
git clone https://github.com/0-ma/geometric-shape-detector.git cd geometric-shape-detector
-
安装依赖:
pip install -r requirements.txt
使用
生成数据集
运行以下命令生成数据集:
python generate_geometric_shapes_dataset.py [OPTIONS]
数据集生成选项
--output_dir: 数据集输出目录--nb_samples: 生成的样本总数(默认:21000)--output_hub_model_name: 推送到Hugging Face Hub的仓库名称(可选)--output_hub_token: 推送到Hugging Face Hub的令牌(可选)
数据集生成示例
-
本地生成数据集:
python generate_geometric_shapes_dataset.py --output_dir ./my_dataset --nb_samples 5000
-
生成数据集并推送到Hugging Face Hub:
python generate_geometric_shapes_dataset.py --output_dir ./my_dataset --nb_samples 5000 --push_to_hub --hub_name my-username/my-dataset
训练模型
生成数据集后,使用以下命令训练形状分类模型:
python train_shape_detector.py [OPTIONS]
模型训练选项
--dataset_name: 使用的数据集名称(必填)--base_checkpoint: 基础模型检查点(默认:"WinKawaks/vit-tiny-patch16-224")--output_hub_model_name: HuggingFace Hub的输出模型名称(可选)--output_hub_token: HuggingFace Hub的令牌(可选)--num_epochs: 训练轮数(默认:10)--learning_rate: 学习率(默认:5e-5)--batch_size: 训练和评估的批量大小(默认:16)
模型训练示例
-
使用本地数据集训练模型:
python train_shape_detector.py --dataset_name ./my_dataset
-
训练模型并推送到Hugging Face Hub:
python train_shape_detector.py --dataset_name ./my_dataset --output_hub_model_name my-username/my-model --output_hub_token your_token_here
模型训练过程
训练脚本执行以下步骤:
- 加载指定数据集
- 准备图像处理器和模型
- 设置训练参数
- 初始化训练器
- 训练模型
- 在验证集上评估模型
- 如果存在测试集,在测试集上评估模型
- 可选地将训练好的模型推送到Hugging Face Hub
脚本会自动处理数据集分割(如果未提供验证集),确保在训练期间和训练后进行适当的评估。




