You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)
使用自定义的数据集对minicpm-Llama3-V-2_5进行dpo时,报错了
File "/workspace/miniconda3/envs/env_name/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3427, in apply_function_on_filtered_inputs
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
File "/workspace/miniconda3/envs/env_name/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 85, in _tokenize
prompt = features["prompt"]
File "/workspace/miniconda3/envs/env_name/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 277, in getitem
value = self.data[key]
KeyError: 'prompt'
数据集严格按照给定的自定义数据集要求格式,例如{"query": "11111", "response": "22222", "rejected_response": "33333", "images": ["image_path"]}。
Your hardware and system info
Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等)
训练环境是两台rtx6000ada
Additional context
Add any other context about the problem here(在这里补充其他信息)
The text was updated successfully, but these errors were encountered:
Describe the bug
What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)
使用自定义的数据集对minicpm-Llama3-V-2_5进行dpo时,报错了
File "/workspace/miniconda3/envs/env_name/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3427, in apply_function_on_filtered_inputs
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
File "/workspace/miniconda3/envs/env_name/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 85, in _tokenize
prompt = features["prompt"]
File "/workspace/miniconda3/envs/env_name/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 277, in getitem
value = self.data[key]
KeyError: 'prompt'
数据集严格按照给定的自定义数据集要求格式,例如{"query": "11111", "response": "22222", "rejected_response": "33333", "images": ["image_path"]}。
训练脚本为:
nproc_per_node=2
CUDA_VISIBLE_DEVICES=0,1
NPROC_PER_NODE=$nproc_per_node
MASTER_PORT=29500
swift rlhf
--rlhf_type dpo
--model_type minicpm-v-v2_5-chat
--model_id_or_path /MiniCPM-V/merged_MiniCPM-Llama3-V-2_5
--ref_model_type minicpm-v-v2_5-chat
--ref_model_id_or_path /MiniCPM-V/merged_MiniCPM-Llama3-V-2_5
--model_revision master
--sft_type lora
--tuner_backend swift
--dtype AUTO
--output_dir output/minicpm_dpo
--dataset /DPO/data/dpo_data.jsonl
--num_train_epochs 4
--max_length 1024
--max_prompt_length 512
--check_dataset_strategy none
--lora_rank 8
--lora_alpha 32
--lora_dropout 0.05
--lora_target_modules ALL
--gradient_checkpointing true
--batch_size 1
--weight_decay 0.1
--learning_rate 5e-5
--gradient_accumulation_steps $(expr 16 / $nproc_per_node)
--max_grad_norm 1.0
--warmup_ratio 0.03
--eval_steps 2000
--save_steps 100
--save_total_limit 2
--logging_steps 10 \
尝试排查时,发现这里的features的键是query、response和rejected_response(自定义数据集的键),而trl的dpo_trainer这里要求的键是prompt、chosen和rejected。然而修改了两个中的任意一个都会带来新的报错。
Your hardware and system info
Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等)
训练环境是两台rtx6000ada
Additional context
Add any other context about the problem here(在这里补充其他信息)
The text was updated successfully, but these errors were encountered: