Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Here `gather_idx < 2` represents `is_first_all2all`. During the first all2all, `uneven_head_all2all` will be called if either `num_heads % seq_world_size != 0` or `get_num_kv_heads() is None`. During the second all2all, it'll return return `uneven_head_all2all` if and only if `get_num_kv_heads() is None` which is always set during the first uneven all2all. This means that there will no longer be issue where `uneven_head_all2all ` is returned for the second all2all because of `num_heads % seq_world_size != 0`. Fixes: #6774 --------- Co-authored-by: Logan Adams <[email protected]>
- Loading branch information