Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train and infer scripts for emu3_gen #2610

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/infer/pt/all_to_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# 18GiB
CUDA_VISIBLE_DEVICES=0 \
swift infer \
--model BAAI/Emu3-Gen \
--infer_backend pt \
--stream False \
--use_chat_template False \
--top_k 2048 \
--max_new_tokens 40960
26 changes: 26 additions & 0 deletions examples/train/all_to_all/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# 70 GiB * 2
nproc_per_node=2
NPROC_PER_NODE=$nproc_per_node \
CUDA_VISIBLE_DEVICES=0,2 \
max_position_embeddings=10240 \
image_area=518400 \
swift sft \
--model BAAI/Emu3-Gen \
--train_type lora \
--dataset swift/TextCaps#40 \
--loss_scale react \
--tools_prompt react_zh \
--torch_dtype bfloat16 \
--num_train_epochs 10 \
--per_device_train_batch_size 1 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 4 \
--warmup_ratio 0.03 \
--eval_steps 500 \
--save_steps 500 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 1024 \
--max_steps 1000 \
--weight_decay 0.1 \
--gradient_checkpointing_kwargs '{"use_reentrant": false}'
3 changes: 2 additions & 1 deletion swift/llm/template/template/emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
prompt = inputs.to_history()['response']
image = self.smart_resize(inputs.images[0].convert('RGB'))
with torch.no_grad():
image = self.processor.image_processor(image, return_tensors='pt')['pixel_values'].cuda()
image = self.processor.image_processor(
image, return_tensors='pt')['pixel_values'].to(device=self.processor.vision_tokenizer.device)
image_token_ids = self.processor.vision_tokenizer.encode(image).squeeze(0)
encoded = self._process_prompt_train(prompt, image_token_ids)
else:
Expand Down
Loading