-
Notifications
You must be signed in to change notification settings - Fork 787
Description
WANDB_API_KEY=f3cea53b10ea690ec762a06376b1c8e8c9fa4cc0
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
NPROC_PER_NODE=6
swift rlhf
--rlhf_type grpo
--model Qwen2.5-VL-7B-Instruct
--external_plugins examples/train/grpo/plugin/plugin.py
--reward_funcs format
--use_vllm true
--vllm_mode server
--vllm_server_host 127.0.0.1
--vllm_server_port 8000
--train_type lora
--torch_dtype bfloat16
--dataset train_data.json
--max_completion_length 1024
--num_train_epochs 1
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--learning_rate 1e-6
--gradient_accumulation_steps 2
--save_strategy 'steps'
--eval_strategy 'steps'
--eval_steps 1000
--save_steps 1000
--save_total_limit 10
--logging_steps 1
--output_dir output/GRPO
--warmup_ratio 0.01
--dataloader_num_workers 4
--num_generations 24
--temperature 1.0
--system 'examples/train/grpo/doctor_prompt.txt'
--log_completions true
--report_to wandb
--num_iterations 1
--async_generate false
--deepspeed zero2
--beta 0.001
vllm 0.8.3 ms_swift 3.7.0.dev0

对数据进行了shuffle操作 有的数据能跑通 有的数据就卡住
尝试降低vllm版本又会报其他的错
