QuyenAnhDE/Diseases_Symptoms
收藏SmallMedLM: Fine-Tuning GPT-2 for Medical Data
概述
该项目涉及在一个包含疾病和症状的数据集上微调GPT-2模型(distilgpt2)。目标是训练一个能够生成与医疗条件及其症状相关文本的语言模型。
该项目在Google Colab笔记本中实现,涵盖数据加载、预处理、模型训练和评估。最终模型被保存,并可用于根据输入查询生成医疗相关文本。
内容
设置
该项目所需的Python包包括:
torchtorchtexttransformerssentencepiecepandastqdmdatasets
可以使用以下命令安装这些包:
python !pip install torch torchtext transformers sentencepiece pandas tqdm datasets
数据准备
- 加载数据:使用的数据集是
QuyenAnhDE/Diseases_Symptoms,其中包含各种疾病及其症状的信息。 - 预处理数据:症状被格式化为逗号分隔的字符串,以便于处理。
- 创建数据集类:定义了一个自定义的
LanguageDataset类,用于以适合GPT-2训练的格式处理数据。
模型训练
- 模型选择:使用
distilgpt2模型,这是GPT-2的一个较小且更快的版本,用于微调。 - 训练循环:训练过程涉及使用CrossEntropyLoss函数和Adam优化器更新模型的权重。每个epoch都会记录训练和验证损失。
- 参数:
- 批量大小:8
- 学习率:5e-4
- 训练轮数:10
- 设备配置:模型训练可以在GPU、MPS或CPU上运行,具体取决于可用硬件。
生成预测
模型训练完成后,可以根据输入查询生成文本。例如,给定输入字符串“Kidney Failure”,模型会生成相关文本。
python input_str = "Kidney Failure" input_ids = tokenizer.encode(input_str, return_tensors=pt).to(device)
output = model.generate( input_ids, max_length=20, num_return_sequences=1, do_sample=True, top_k=8, top_p=0.95, temperature=0.5, repetition_penalty=1.2 )
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True) print(decoded_output)
使用
- 克隆仓库:
bash git clone https://github.com/mshaadk/Fine-tuning-GPT2-Medical-Data.git
-
打开Colab笔记本:
- 将笔记本上传到Google Colab。
- 运行每个单元格以执行代码。
-
加载和使用模型:
- 使用保存的模型文件(
SmallMedLM.pt)进行预测或进一步训练。
- 使用保存的模型文件(
许可证
该项目根据MIT许可证授权。
联系
如有任何问题或建议,请随时联系Mohamed Shaad。




