Skip to content

Commit

Permalink
dynamic vit gradient_checkpointing (#2071)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Sep 20, 2024
1 parent bea9867 commit 206c391
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 7 deletions.
6 changes: 4 additions & 2 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from .accelerator import ta_accelerate
from .tuner import prepare_model
from .utils import (TEMPLATE_MAPPING, LazyLLMDataset, PtArguments, RLHFArguments, SftArguments, Template, dataset_map,
get_dataset, get_model_tokenizer, get_template, get_time_info, print_example, set_generation_config,
sort_by_max_length, stat_dataset)
dynamic_vit_gradient_checkpointing, get_dataset, get_model_tokenizer, get_template, get_time_info,
print_example, set_generation_config, sort_by_max_length, stat_dataset)

logger = get_logger()

Expand Down Expand Up @@ -239,6 +239,8 @@ def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None):
model.label_names = label_names
model.return_loss = return_loss

if args.is_multimodal and args.gradient_checkpointing and args.vit_use_gc:
dynamic_vit_gradient_checkpointing(model, args.model_type)
# Preparing LoRA
model, callbacks = prepare_model(model, args)

Expand Down
11 changes: 6 additions & 5 deletions swift/llm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
ModelList, UsageInfo, XRequestConfig, random_uuid)
from .template import (DEFAULT_SYSTEM, TEMPLATE_MAPPING, History, KTOTemplateMixin, Prompt, RLHFTemplateMixin,
StopWords, Template, TemplateType, get_env_args, get_template, register_template)
from .utils import (LazyLLMDataset, LLMDataset, dataset_map, download_dataset, find_all_linears, find_embedding,
find_ln, get_max_model_len, get_time_info, history_to_messages, inference, inference_stream,
is_lmdeploy_available, is_megatron_available, is_quant_model, is_vllm_available,
limit_history_length, messages_join_observation, messages_to_history, print_example,
safe_tokenizer_decode, set_generation_config, sort_by_max_length, stat_dataset, to_device)
from .utils import (LazyLLMDataset, LLMDataset, dataset_map, download_dataset, dynamic_vit_gradient_checkpointing,
find_all_linears, find_embedding, find_ln, get_max_model_len, get_time_info, history_to_messages,
inference, inference_stream, is_lmdeploy_available, is_megatron_available, is_quant_model,
is_vllm_available, limit_history_length, messages_join_observation, messages_to_history,
print_example, safe_tokenizer_decode, set_generation_config, sort_by_max_length, stat_dataset,
to_device)

logger = get_logger()

Expand Down
1 change: 1 addition & 0 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ class SftArguments(ArgumentsBase):
use_liger: bool = False

gradient_checkpointing: Optional[bool] = None
vit_use_gc: bool = True # vit use gradient_checkpointing
# e.g. 'default-zero3', 'default-zero2', 'ds_config/zero2.json', 'zero2-offload', 'zero3-offload'
deepspeed: Optional[str] = None
batch_size: int = 1
Expand Down
54 changes: 54 additions & 0 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from queue import Empty, Queue
from tempfile import TemporaryDirectory
from threading import Thread
from types import MethodType
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Union

import accelerate
Expand All @@ -18,6 +19,8 @@
import requests
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
import transformers
from datasets import Dataset as HfDataset
from datasets import IterableDataset as HfIterableDataset
Expand Down Expand Up @@ -421,6 +424,57 @@ def find_ln(model: Module) -> List[str]:
return list(module_names)


def _find_module_list(vision_tower) -> Optional[nn.ModuleList]:
module_lists = []
for m in vision_tower.modules():
if hasattr(m, 'gradient_checkpointing'):
return
if isinstance(m, nn.ModuleList) and len(m) >= 10:
module_lists.append(m)
if module_lists is not None:
return max(module_lists, key=lambda x: len(x))


def _add_gradient_checkpointing(module_list):

def _new_forward(self, *args, **kwargs):
layer_ret = torch.utils.checkpoint.checkpoint(self.__old_forward, *args, **kwargs)
return layer_ret

for module in module_list:
if hasattr(module, '_old_forward'): # device_map
__old_forward = module._old_forward
module._old_forward = MethodType(_new_forward, module)
else:
__old_forward = module.forward
module.forward = MethodType(_new_forward, module)
module.__old_forward = __old_forward


def deep_getattr(model, attr: str):
attrs = attr.split('.')
for a in attrs:
model = getattr(model, a)
return model


def dynamic_vit_gradient_checkpointing(model, model_type: str) -> None:
from swift.utils.module_mapping import MODEL_KEYS_MAPPING
from .model import MODEL_MAPPING
model_info = MODEL_MAPPING[model_type]
lora_target_modules = model_info.get('lora_target_modules')

if not isinstance(lora_target_modules, str):
return
vision_tower_list = MODEL_KEYS_MAPPING[lora_target_modules].vision_tower
for vision_tower_name in vision_tower_list:
vision_tower = deep_getattr(model, vision_tower_name)
module_list = _find_module_list(vision_tower)
if module_list is None:
continue
_add_gradient_checkpointing(module_list)


def find_embedding(model: Module) -> List[str]:
return _find_layers(model, torch.nn.Embedding)

Expand Down

0 comments on commit 206c391

Please sign in to comment.