feat: support ring attention with arbitrary KV sequence lengths#13545
feat: support ring attention with arbitrary KV sequence lengths#13545songh11 wants to merge 5 commits intohuggingface:mainfrom
Conversation
|
Thanks for your PR. Could you also provide some visual output examples and comparison with other backends we support (latency, memory overhead, etc.)? |
@sayakpaul Thanks for the review! I just pushed two follow-up commits:
FLUX.1-dev benchmarks on 4 × RTX 4090 (48GB)
Visual outputsI'll attach side-by-side generated images (same prompt + seed: single-GPU baseline vs. |
|
Thanks for the updates -- they are very comprehensive. It seems like ulysses_anything is still the best but users who might prefer ring would want to use ring_anything for non-divisible shapes? Is that the motivation here? |
Thanks for the review! You're spot on — the primary motivation is to provide a seamless path for users on a ring topology to handle non-divisible sequence lengths, similar to how ulysses_anything works. While ulysses/ulysses_anything is often the go-to, I believe ring_anything is a valuable addition for a few key reasons:
So while ulysses/ulysses_anything remains the recommended choice for most, while ring_anything covers head-constrained or ring-first setups—focusing on correctness today, with clear headroom for P2P overlap in the next step. |
|
Your reasoning is spot on! Let's include them in the docs too. |
|
@claude could you do a review here? |
|
Claude finished @sayakpaul's task in 2m 29s —— View job Review of PR #13545: Ring Anything Attention
Overall this is a well-structured PR that cleanly mirrors the existing patterns. The core algorithm (pad → all-gather → slice → online-softmax merge) is correct, and the config validation + hook routing are consistent with Duplicate hook branches should be merged
if self.parallel_config.ulysses_anything:
return PartitionAnythingSharder.shard_anything(...)
elif self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(...)Since the bodies are identical, these should be merged: if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(...)Same for the Incomplete docstring comment for
|
|
@sayakpaul Thanks! Addressed the feedback by merging duplicate hooks, adding the exclusivity check, and polishing comments. Pushed in the latest commit. |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for the updates. I left some further comments.
| _parallel_config: "ParallelConfig" | None = None, | ||
| ): | ||
| # Ring attention for arbitrary sequence lengths. | ||
| if attn_mask is not None: |
There was a problem hiding this comment.
Seems like a pretty big limitation no? This would make it incompatible with models like QwenImage, right?
There was a problem hiding this comment.
Yeah, that's fair. For this PR I'm keeping ring_anything scoped to the attn_mask=None path, which covers the FLUX/Wan cases I tested. QwenImage masks should be supportable, but I'd prefer to add that in a follow-up with proper validation.
| pad_shape[1] = pad_len | ||
| return torch.cat([t, torch.zeros(pad_shape, dtype=t.dtype, device=t.device)], dim=1) | ||
|
|
||
| key_padded = pad_to_s_max(key) |
There was a problem hiding this comment.
Would add a small explainer comment.
There was a problem hiding this comment.
Added a comment here.
| for i in range(world_size): | ||
| if i > 0: | ||
| true_seq_len = all_kv_seq_lens[next_rank] | ||
| kv = kv_buffer[next_rank] | ||
| # Reshape to padded shape, then slice to true sequence length | ||
| key = kv[:kv_padded_numel].reshape_as(key_padded)[:, :true_seq_len] | ||
| value = kv[kv_padded_numel:].reshape_as(value_padded)[:, :true_seq_len] | ||
| next_rank = (next_rank + 1) % world_size | ||
| else: | ||
| # i == 0: use local (unpadded) key/value | ||
| key = key_padded[:, :kv_seq_len] | ||
| value = value_padded[:, :kv_seq_len] | ||
|
|
||
| out, lse = forward_op( | ||
| ctx, | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| True, |
There was a problem hiding this comment.
@claude can we use torch.where here for a better conditional flow graph?
There was a problem hiding this comment.
Thanks for pointing this out. I tried removing the branch, but it regressed on PCIe since the local chunk has to be read back from the all-gather buffer. Keeping the local fast path for now, but happy to revisit if there’s a cleaner way to express it.
|
|
||
| > [!TIP] | ||
| > To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in [`init_process_group`](https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group). | ||
| > |
| > ```py | ||
| > import torch.distributed as dist | ||
| > | ||
| > dist.init_process_group(backend="cpu:gloo,cuda:nccl") | ||
| > ``` |
There was a problem hiding this comment.
We shouldn't have ">" in here no?
|
|
||
| The default [Ring Attention](https://huggingface.co/papers/2310.01889) requires the sequence length of hidden states to be evenly divisible across the ring degree. [Ring Anything Attention](https://github.com/huggingface/diffusers/pull/13545#issuecomment-4302195582) is a variant of Ring Attention that supports arbitrary (non-evenly divisible) sequence lengths. It pads each rank's local KV to the global maximum sequence length, all-gathers the padded KV buffer, and slices back to each rank's true length before running attention. | ||
|
|
||
| [`ContextParallelConfig`] supports Ring Anything Attention by specifying both `ring_degree` and `ring_anything`. Please note that Ring Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with `ring_degree` set to bigger than 1 and `ring_anything=True` to [`~ModelMixin.enable_parallelism`]. |
There was a problem hiding this comment.
| [`ContextParallelConfig`] supports Ring Anything Attention by specifying both `ring_degree` and `ring_anything`. Please note that Ring Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with `ring_degree` set to bigger than 1 and `ring_anything=True` to [`~ModelMixin.enable_parallelism`]. | |
| Ring Anything Attention is not supported by Unified Attention. Set `ring_degree > 1` and `ring_anything=True` to enable Ring Anything Attention. |
| ``` | ||
|
|
||
| > [!TIP] | ||
| > To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in [`init_process_group`](https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group). |
There was a problem hiding this comment.
| > To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in [`init_process_group`](https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group). | |
| > Add the `gloo` backend to [init_process_group](https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) to avoid multiple forced CUDA syncs from H2D and D2H transfers. |
| > [!NOTE] | ||
| > Backward is not implemented yet; this mode is currently inference-only. | ||
| > `attn_mask` must be `None`; non-None attention masks are not supported. |
There was a problem hiding this comment.
This might be better after the first paragraph
| > [!NOTE] | |
| > Backward is not implemented yet; this mode is currently inference-only. | |
| > `attn_mask` must be `None`; non-None attention masks are not supported. | |
| > [!NOTE] | |
| > Ring Anything Attention only currently supports inference and non-`None` attention masks aren't supported. `attn_mask` must be `None`. |
| > Backward is not implemented yet; this mode is currently inference-only. | ||
| > `attn_mask` must be `None`; non-None attention masks are not supported. | ||
|
|
||
| We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention, Ulysses Anything Attention, and Ring Anything Attention on a node of 4 RTX 4090 (48GB) GPUs. The results are summarized as follows: |
There was a problem hiding this comment.
| We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention, Ulysses Anything Attention, and Ring Anything Attention on a node of 4 RTX 4090 (48GB) GPUs. The results are summarized as follows: | |
| See the benchmarks below on a node of 4 RTX 4090 (48GB) GPUs. |







What does this PR do?
Adds a new "Ring Anything" context-parallel attention mode that supports arbitrary
(non-evenly divisible) KV sequence lengths across ring-degree workers.
Motivation
Existing
TemplatedRingAttentionrequires KV to be equipartitioned across ranks,which is impractical for real-world workloads where per-rank sequence lengths can
differ (e.g., variable-length prompts, packed batches, token pruning). This PR
mirrors the existing
ulysses_anythingdesign but applies it to the ring path.Changes
ContextParallelConfig: addring_anythingflag with validation(
ring_degree > 1andulysses_degree == 1).TemplatedRingAnythingAttention: new autograd Function that_templated_context_parallel_attention: dispatch to the new class whenring_anythingis enabled.ContextParallelSplitHook: route throughPartitionAnythingSharder.shard_anythingwhen
ring_anythingis set.Reproducible example
Launch
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@yiyixuxu @asomoza @sayakpaul