SurgicAI
收藏A Fine-grained Platform for Data Collection and Benchmarking in Surgical Policy Learning
系统要求
- Ubuntu 20.04
- Gymnasium 0.29.1
- Stable Baselines3 2.2.1
- ROS Noetic
- Python 3.8
- Torch 2.1.0
- ambf 2.0
安装步骤
-
安装Gymnasium: bash pip install gymnasium
-
配置Pytorch和CUDA(如有NVIDIA显卡)。
-
安装Stable Baselines3: bash pip install stable-baselines3
安装验证
运行以下脚本以验证安装: python import gymnasium import stable_baselines3
env = gymnasium.make(CartPole-v1) model = stable_baselines3.PPO(MlpPolicy, env, verbose=1) model.learn(total_timesteps=10000)
RL训练
运行SRC环境
确保ROS和SRC运行后,执行以下命令: bash roscore
bash ~/ambf/bin/lin-x86_64/ambf_simulator --launch_file ~/ambf/surgical_robotics_challenge/launch.yaml -l 0,1,3,4,13,14 -p 200 -t 1 --override_max_comm_freq 120
注册Gymnasium环境
python import gymnasium as gym from stable_baselines3.common.evaluation import evaluate_policy from Approach_env import SRC_approach import numpy as np from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.env_checker import check_env from RL_algo.PPO import PPO from stable_baselines3.common.utils import set_random_seed gym.envs.register(id="TD3_HER_BC", entry_point=SRC_approach) env = gym.make("TD3_HER_BC", render_mode="human",reward_type = "sparse")
初始化和训练模型
使用Proximal Policy Optimization (PPO)算法: python model = PPO("MlpPolicy", env, verbose=1,tensorboard_log="./First_version/",) checkpoint_callback = CheckpointCallback(save_freq=10000, save_path=./First_version/Model_temp, name_prefix=SRC) model.learn(total_timesteps=int(1000000), progress_bar=True,callback=checkpoint_callback,) model.save("SRC")
加载模型
python model = PPO("MlpPolicy", env, verbose=1,tensorboard_log="./First_version/",) model_path = "./Model/SRC_10000_steps.zip" model = PPO.load(model_path) model.set_env(env=env)
测试模型预测
python obs,info = env.reset() print(obs) for i in range(10000): action, _state = model.predict(obs, deterministic=True) print(action) obs, reward, terminated,truncated, info = env.step(action) print(info) env.render() if terminated or truncated: obs, info = env.reset()




