Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用自定数据集DPO mllm时报错KeyError: 'prompt' #1922

Closed
SparrowZheyuan18 opened this issue Sep 3, 2024 · 2 comments
Closed

使用自定数据集DPO mllm时报错KeyError: 'prompt' #1922

SparrowZheyuan18 opened this issue Sep 3, 2024 · 2 comments

Comments

@SparrowZheyuan18
Copy link

Describe the bug
What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)
使用自定义的数据集对minicpm-Llama3-V-2_5进行dpo时,报错了
File "/workspace/miniconda3/envs/env_name/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3427, in apply_function_on_filtered_inputs
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
File "/workspace/miniconda3/envs/env_name/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 85, in _tokenize
prompt = features["prompt"]
File "/workspace/miniconda3/envs/env_name/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 277, in getitem
value = self.data[key]
KeyError: 'prompt'
数据集严格按照给定的自定义数据集要求格式,例如{"query": "11111", "response": "22222", "rejected_response": "33333", "images": ["image_path"]}。

训练脚本为:
nproc_per_node=2

CUDA_VISIBLE_DEVICES=0,1
NPROC_PER_NODE=$nproc_per_node
MASTER_PORT=29500
swift rlhf
--rlhf_type dpo
--model_type minicpm-v-v2_5-chat
--model_id_or_path /MiniCPM-V/merged_MiniCPM-Llama3-V-2_5
--ref_model_type minicpm-v-v2_5-chat
--ref_model_id_or_path /MiniCPM-V/merged_MiniCPM-Llama3-V-2_5
--model_revision master
--sft_type lora
--tuner_backend swift
--dtype AUTO
--output_dir output/minicpm_dpo
--dataset /DPO/data/dpo_data.jsonl
--num_train_epochs 4
--max_length 1024
--max_prompt_length 512
--check_dataset_strategy none
--lora_rank 8
--lora_alpha 32
--lora_dropout 0.05
--lora_target_modules ALL
--gradient_checkpointing true
--batch_size 1
--weight_decay 0.1
--learning_rate 5e-5
--gradient_accumulation_steps $(expr 16 / $nproc_per_node)
--max_grad_norm 1.0
--warmup_ratio 0.03
--eval_steps 2000
--save_steps 100
--save_total_limit 2
--logging_steps 10 \

尝试排查时,发现这里的features的键是query、response和rejected_response(自定义数据集的键),而trl的dpo_trainer这里要求的键是prompt、chosen和rejected。然而修改了两个中的任意一个都会带来新的报错。

Your hardware and system info
Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等)
训练环境是两台rtx6000ada

Additional context
Add any other context about the problem here(在这里补充其他信息)

@hjh0119
Copy link
Collaborator

hjh0119 commented Sep 3, 2024

please pip install trl==0.9.6, a related fix is in process #1885

@SparrowZheyuan18
Copy link
Author

solved! thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants