Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,40 @@ We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulys

From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.


### Ring Anything Attention

The default Ring Attention requires the sequence length of hidden states to be evenly divisible across the ring degree. Ring Anything Attention is a variant of Ring Attention that supports arbitrary (non-evenly divisible) sequence lengths. It pads each rank's local KV to the global maximum sequence length, all-gathers the padded KV buffer, and slices back to each rank's true length before running attention.

[`ContextParallelConfig`] supports Ring Anything Attention by specifying both `ring_degree` and `ring_anything`. Please note that Ring Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with `ring_degree` set to bigger than 1 and `ring_anything=True` to [`~ModelMixin.enable_parallelism`].

```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2, ring_anything=True))
```

> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.

> [!NOTE]
> Backward is not implemented yet; this mode is currently inference-only.
> `attn_mask` must be `None`; non-None attention masks are not supported.

We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention, Ulysses Anything Attention, and Ring Anything Attention on a node of 4 RTX 4090 (48GB) GPUs. The results are summarized as follows:

| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
|--------------------|------------------|-------------|------------------|------------|
| ulysses | 259.07 | 3.86 | 33.83 | 1024x1024 |
| ring | 338.98 | 2.95 | 33.83 | 1024x1024 |
| unified_balanced | 321.54 | 3.11 | 33.83 | 1024x1024 |
| ulysses_anything | 259.07 | 3.86 | 33.83 | 1024x1024 |
| ring_anything | 340.14 | 2.94 | 33.83 | 1024x1024 |
| ulysses | failed | failed | failed | 1008x1008 |
| ring | failed | failed | failed | 1008x1008 |
| unified_balanced | failed | failed | failed | 1008x1008 |
| ulysses_anything | 253.16 | 3.95 | 33.75 | 1008x1008 |
| ring_anything | 335.57 | 2.98 | 33.75 | 1008x1008 |

From the above table, Ring Anything Attention offers compatibility with arbitrary sequence lengths while maintaining performance comparable to the standard Ring Attention.

### parallel_config

Pass `parallel_config` during model initialization to enable context parallelism.
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
)
return x
else:
if self.parallel_config.ulysses_anything:
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(
x, cp_input.split_dim, self.parallel_config._flattened_mesh
)
Expand Down Expand Up @@ -239,7 +239,7 @@ def post_forward(self, module, output):
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
if self.parallel_config.ulysses_anything:
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
output[i] = PartitionAnythingSharder.unshard_anything(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class ContextParallelConfig:
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
ring_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ring Anything" mode, which supports arbitrary sequence lengths. When enabled, `ring_degree`
must be greater than 1 and `ulysses_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
Expand All @@ -82,6 +85,8 @@ class ContextParallelConfig:
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
# Whether to enable ring anything attention to support any sequence lengths.
ring_anything: bool = False

_rank: int = None
_world_size: int = None
Expand Down Expand Up @@ -114,6 +119,13 @@ def __post_init__(self):
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
if self.ring_degree > 1:
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")
if self.ring_anything:
if self.ring_degree == 1:
raise ValueError("ring_degree must be greater than 1 for ring_anything to be enabled.")
if self.ulysses_degree > 1:
raise ValueError("ring_anything cannot be enabled when ulysses_degree > 1.")
if self.ulysses_anything and self.ring_anything:
raise ValueError("ulysses_anything and ring_anything cannot both be enabled.")

@property
def mesh_shape(self) -> tuple[int, int]:
Expand Down
157 changes: 143 additions & 14 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,6 +2076,119 @@ def backward(
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None


class TemplatedRingAnythingAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None,
dropout_p: float,
is_causal: bool,
scale: float | None,
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: "ParallelConfig" | None = None,
):
# Ring attention for arbitrary sequence lengths.
if attn_mask is not None:
raise ValueError(
"TemplatedRingAnythingAttention does not support non-None attn_mask: "
"non-uniform sequence lengths across ranks make cross-rank mask slicing ambiguous."
)
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
group = ring_mesh.get_group()
rank = _parallel_config.context_parallel_config._ring_local_rank
world_size = _parallel_config.context_parallel_config.ring_degree
next_rank = (rank + 1) % world_size
prev_out = prev_lse = None

ctx.forward_op = forward_op
ctx.backward_op = backward_op
ctx.q_shape = query.shape
ctx.kv_shape = key.shape
ctx._parallel_config = _parallel_config

kv_seq_len = key.shape[1] # local S_KV (may differ across ranks)
all_kv_seq_lens = gather_size_by_comm(kv_seq_len, group)
s_max = max(all_kv_seq_lens)

# Padding is applied on the sequence dimension (dim=1) at the end.
def pad_to_s_max(t: torch.Tensor) -> torch.Tensor:
pad_len = s_max - t.shape[1]
if pad_len == 0:
return t
pad_shape = list(t.shape)
pad_shape[1] = pad_len
return torch.cat([t, torch.zeros(pad_shape, dtype=t.dtype, device=t.device)], dim=1)

key_padded = pad_to_s_max(key)
value_padded = pad_to_s_max(value)

kv_buffer = torch.cat([key_padded.flatten(), value_padded.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=group)
kv_buffer = kv_buffer.chunk(world_size)

# numel per-rank in the padded layout
kv_padded_numel = key_padded.numel()

for i in range(world_size):
if i > 0:
true_seq_len = all_kv_seq_lens[next_rank]
kv = kv_buffer[next_rank]
# Reshape to padded shape, then slice to true sequence length
key = kv[:kv_padded_numel].reshape_as(key_padded)[:, :true_seq_len]
value = kv[kv_padded_numel:].reshape_as(value_padded)[:, :true_seq_len]
next_rank = (next_rank + 1) % world_size
else:
# i == 0: use local (unpadded) key/value
key = key_padded[:, :kv_seq_len]
value = value_padded[:, :kv_seq_len]

out, lse = forward_op(
ctx,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
True,
_save_ctx=i == 0,
_parallel_config=_parallel_config,
)

if _parallel_config.context_parallel_config.convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)

if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse

out = out.to(query.dtype)
lse = lse.squeeze(-1)

return (out, lse) if return_lse else out

@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
):
raise NotImplementedError("Backward pass for Ring Anything Attention in diffusers is not implemented yet.")


class TemplatedUlyssesAnythingAttention(torch.autograd.Function):
@staticmethod
def forward(
Expand Down Expand Up @@ -2254,20 +2367,36 @@ def _templated_context_parallel_attention(
_parallel_config,
)
elif _parallel_config.context_parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
if _parallel_config.context_parallel_config.ring_anything:
return TemplatedRingAnythingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
else:
return TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
if _parallel_config.context_parallel_config.ulysses_anything:
# For Any sequence lengths and Any head num support
Expand Down
Loading