Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO阶段学习率更新异常 #424

Closed
hannlp opened this issue Aug 9, 2023 · 15 comments
Closed

PPO阶段学习率更新异常 #424

hannlp opened this issue Aug 9, 2023 · 15 comments
Labels
solved This problem has been already solved

Comments

@hannlp
Copy link
Contributor

hannlp commented Aug 9, 2023

在第三阶段使用--lr_scheduler_type cosine,训练过程中学习率变化如下:
image
但是,默认的num cycles是0.5,所以曲线不应该会有上下波动,下图是我认为正常的变化曲线
image
请问这有可能是哪里的bug呢?

一些超参:
--per_device_train_batch_size 1
--gradient_accumulation_steps 16
export CUDA_VISIBLE_DEVICES=0,1,2,3
accelerate launch

@hiyouga
Copy link
Owner

hiyouga commented Aug 10, 2023

我用单卡测试是正常的,怀疑是多卡问题

--per_device_train_batch_size 1
--gradient_accumulation_steps 4

image

@hannlp
Copy link
Contributor Author

hannlp commented Aug 11, 2023

@hiyouga 作者单卡测试时是直接用python启动的吗?我也曾用单卡测试过,不过也是通过accelerate启动的:export CUDA_VISIBLE_DEVICES=0 accelerate launch,还是会出现上下波动的情况,不知道是不是accelerate的原因

@hiyouga
Copy link
Owner

hiyouga commented Aug 11, 2023

好的,我再测试一下

@mmbwf
Copy link
Contributor

mmbwf commented Sep 27, 2023

Hi, @hannlp, there are many reasons why the learning rate fluctuates. The main reason is AcceleratedScheduler set step_with_optimizer=True, split_batches=False default. So the lr_scheduler will step with optimizer and will do num_processes steps per training step.

def step(self, *args, **kwargs):
        if not self.step_with_optimizer:
            # No link between scheduler and optimizer -> just step
            self.scheduler.step(*args, **kwargs)
            return

        # Otherwise, first make sure the optimizer was stepped.
        if not self.gradient_state.sync_gradients:
            if self.gradient_state.adjust_scheduler:
                self.scheduler._step_count += 1
            return

        for opt in self.optimizers:
            if opt.step_was_skipped:
                return
        if self.split_batches:
            # Split batches -> the training dataloader batch size is not changed so one step per training step
            self.scheduler.step(*args, **kwargs)
        else:
            # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
            # num_processes steps per training step
            num_processes = AcceleratorState().num_processes
            for _ in range(num_processes):
                # Special case when using OneCycle and `drop_last` was not used
                if hasattr(self.scheduler, "total_steps"):
                    if self.scheduler._step_count <= self.scheduler.total_steps:
                        self.scheduler.step(*args, **kwargs)
                else:
                    self.scheduler.step(*args, **kwargs)

After trl==0.5.0, they refactor the whole ppo step logic, several new concepts of batch size have been introduced.

for i in range(ppo_epochs):
    for j in range(batch_size // backward_batch_size):
        for k in range(backward_batch_size // mini_batch_size):
            with self.accelerator.accumulate(self.model):
                ...
                self.accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

...

lr_scehduler.step()

where, backward_batch_size = mini_batch_size * gradient_accumulation_steps, batch_size = n * backward_batch_size.
So, the total training steps calculated in this repository might not match the actual training steps. Only when batch_size=backward_batch_size and ppo_epochs=1, they are equal.

ppo_config = PPOConfig(
        model_name=model_args.model_name_or_path,
        learning_rate=training_args.learning_rate,
        mini_batch_size=training_args.per_device_train_batch_size,
        batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        ppo_epochs=1,
        max_grad_norm=training_args.max_grad_norm,
        seed=training_args.seed,
        optimize_cuda_cache=True,
    )

...

total_train_batch_size = (
        training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
    )
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
        training_args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
        num_training_steps=num_training_steps
    )

If you want change batch_size or ppo_epochs, you can set step_with_optimizer=False, then lr_scehduler only step when you call it.

ppo_config = PPOConfig(
        model_name=model_args.model_name_or_path,
        learning_rate=training_args.learning_rate,
        mini_batch_size=training_args.per_device_train_batch_size,
        batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        accelerator_kwargs={"step_scheduler_with_optimizer": False},
        ppo_epochs=4,
        max_grad_norm=training_args.max_grad_norm,
        seed=training_args.seed,
        optimize_cuda_cache=True,
    )

@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Sep 27, 2023
@hiyouga
Copy link
Owner

hiyouga commented Sep 27, 2023

@mmbwf Thanks for pointing out it! It has been fixed in 35fa947

@hannlp
Copy link
Contributor Author

hannlp commented Nov 9, 2023

The original intention of gradient accumulation is to simulate a larger batch_size with limited GPU memory by accumulating operations. For example, if a GPU card holds 'per_device_train_batch_size' samples, and the gradient accumulation is set as 'gradient_accumulation_steps', it can simulate a batch size of 'per_device_train_batch_size * gradient_accumulation_steps'. However, in the code, the 'batch_size' in PPOConfig is 'training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps', which means a GPU card directly holds so many samples. This doesn't seem to align with the original purpose of gradient accumulation. It feels like the more changes are made to TRL, the more problems arise. @mmbwf

@hiyouga
Copy link
Owner

hiyouga commented Nov 9, 2023

@hannlp TRL uses mini_batch_size to indicate the batch size in each forward process. Therefore, the total batch size is mini_batch_size * gradient_accumulation_steps with gradient accumulation [1]. The batch_size in PPOConfig only represents the number of examples in a single ppo_trainer.step() call.

@mmbwf
Copy link
Contributor

mmbwf commented Nov 10, 2023

Hi, @hannlp, @hiyouga is right. backward_batch_size = mini_batch_size * gradient_accumulation_steps is the real number of samples optimized in an optimizer.step() call.

The following is a quote from the original author's explanation, I only made some adjustments to align with the current variable names. For further details, please refer to this PR huggingface/trl#546 (comment).

I guess this is more of a stylistic thing. We really have three levels of "batch sizes":

  1. batch_size is the amount of rollout data, e,g., 8 data points
  2. backward_batch_size is the amount of data that you are actually doing a zero_grad(), loss.backward(), and optimizer.step(), e.g., 4 data points
  3. mini_batch_size if the backward_batch_size is too large to fit in memory, so we zero_grad(), partition the backward_batch and do multiple loss.backward() and then optimizer.step().

Having the terminology like this makes it clear what the real backward_batch_size is, so that we do not confound it with gradient_accumulation_steps.

From the user's perspective, one set of hyperparameters should always reliably reproduce similar learning curves. By the new API design in this PR, batch_size=8, backward_batch_size=4 will reliably reproduce similar learning curves regardless of how many gradient_accumulation_steps there is, so when doing experiment management / analysis we can just group / filter by batch_size and backward_batch_size. The existing implementation, however, would require us to calculate what's the "real size of the data we perform an optimizer.step() on".

@hannlp
Copy link
Contributor Author

hannlp commented Nov 11, 2023

@mmbwf @hiyouga Thank you for your answers! But I have another question: In my environment(A100 40G), why does the setting of per_device_train_batch_size=8, gradient_accumulation_steps=2 not cause OOM (Out of Memory), but the setting of per_device_train_batch_size=4, gradient_accumulation_steps=16 does cause OOM? Logically, shouldn't the number of samples on a GPU be determined by per_device_train_batch_size? This was also the cause of my previous question. Also, under the current circumstances, how can we simulate a larger batch size?

@mmbwf
Copy link
Contributor

mmbwf commented Nov 13, 2023

Hi, @hannlp, the cause of OOM is not in the PPO training stage, but in the stage model generates responses and rewards. In this inference stage, the total batch size is training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps.

So, the memory cannot support a large batch for inference. If you want to increase the batch size, you can split the large batch into mini-batches for inference then merge.

The following is a rough implementation, just for reference:

# Get inputs
queries, responses, rewards = [], [], []
for mini_batch_start in range(0, self.config.batch_size, self.config.mini_batch_size):
    mini_batch_end = mini_batch_start + self.config.mini_batch_size
    mini_batch_queries, mini_batch_responses = self.get_inputs(batch[mini_batch_start:mini_batch_end])
    queries.extend(mini_batch_queries)
    responses.extend(mini_batch_responses)

    self.tokenizer.padding_side = "right" # change padding side for rewards inference.
    mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
    rewards.extend(mini_batch_rewards)
    self.tokenizer.padding_side = "left" # change padding side for responses inference.

@hannlp
Copy link
Contributor Author

hannlp commented Nov 13, 2023

@mmbwf You're absolutely right, thank you so much for your patient response and code examples! Hope you have a blast in your life!

hiyouga added a commit that referenced this issue Nov 13, 2023
@hiyouga
Copy link
Owner

hiyouga commented Nov 13, 2023

a more accurate version in 87390ae @hannlp

@luyuntao92
Copy link

@hiyouga 你好,更新之后的代码PPO阶段,多卡学习率变成了一个周期的cosine,而不是1/4周期

@luyuntao92
Copy link

@hiyouga 你好,更新之后的代码PPO阶段,多卡学习率变成了一个周期的cosine,而不是1/4周期

image
ppo_epochs默认不为1的情况,是不是总的training_steps要*ppo_epochs

@hiyouga
Copy link
Owner

hiyouga commented Jan 4, 2024

@luyuntao92 理论上不能这么改,单卡会出问题,请尝试用 accelerate 而非 Deepspeed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

4 participants