diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index c9a341df40c5..2aade154327f 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -371,6 +371,40 @@ We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulys From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention. + +### Ring Anything Attention + +The default Ring Attention requires the sequence length of hidden states to be evenly divisible across the ring degree. Ring Anything Attention is a variant of Ring Attention that supports arbitrary (non-evenly divisible) sequence lengths. It pads each rank's local KV to the global maximum sequence length, all-gathers the padded KV buffer, and slices back to each rank's true length before running attention. + +[`ContextParallelConfig`] supports Ring Anything Attention by specifying both `ring_degree` and `ring_anything`. Please note that Ring Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with `ring_degree` set to bigger than 1 and `ring_anything=True` to [`~ModelMixin.enable_parallelism`]. + +```py +pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2, ring_anything=True)) +``` + +> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency. + +> [!NOTE] +> Backward is not implemented yet; this mode is currently inference-only. +> `attn_mask` must be `None`; non-None attention masks are not supported. + +We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention, Ulysses Anything Attention, and Ring Anything Attention on a node of 4 RTX 4090 (48GB) GPUs. The results are summarized as follows: + +| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)| +|--------------------|------------------|-------------|------------------|------------| +| ulysses | 259.07 | 3.86 | 33.83 | 1024x1024 | +| ring | 338.98 | 2.95 | 33.83 | 1024x1024 | +| unified_balanced | 321.54 | 3.11 | 33.83 | 1024x1024 | +| ulysses_anything | 259.07 | 3.86 | 33.83 | 1024x1024 | +| ring_anything | 340.14 | 2.94 | 33.83 | 1024x1024 | +| ulysses | failed | failed | failed | 1008x1008 | +| ring | failed | failed | failed | 1008x1008 | +| unified_balanced | failed | failed | failed | 1008x1008 | +| ulysses_anything | 253.16 | 3.95 | 33.75 | 1008x1008 | +| ring_anything | 335.57 | 2.98 | 33.75 | 1008x1008 | + +From the above table, Ring Anything Attention offers compatibility with arbitrary sequence lengths while maintaining performance comparable to the standard Ring Attention. + ### parallel_config Pass `parallel_config` during model initialization to enable context parallelism. diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index f6ab623a1865..cfc812509a01 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -210,7 +210,7 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> ) return x else: - if self.parallel_config.ulysses_anything: + if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything: return PartitionAnythingSharder.shard_anything( x, cp_input.split_dim, self.parallel_config._flattened_mesh ) @@ -239,7 +239,7 @@ def post_forward(self, module, output): for i, cpm in enumerate(self.metadata): if cpm is None: continue - if self.parallel_config.ulysses_anything: + if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything: output[i] = PartitionAnythingSharder.unshard_anything( output[i], cpm.gather_dim, self.parallel_config._flattened_mesh ) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 8573c01ca4c7..3afb8a04c4e3 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -64,6 +64,9 @@ class ContextParallelConfig: Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and `ring_degree` must be 1. + ring_anything (`bool`, *optional*, defaults to `False`): + Whether to enable "Ring Anything" mode, which supports arbitrary sequence lengths. When enabled, `ring_degree` + must be greater than 1 and `ulysses_degree` must be 1. mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of creating a new one. This is useful when combining context parallelism with other parallelism strategies @@ -82,6 +85,8 @@ class ContextParallelConfig: # Whether to enable ulysses anything attention to support # any sequence lengths and any head numbers. ulysses_anything: bool = False + # Whether to enable ring anything attention to support any sequence lengths. + ring_anything: bool = False _rank: int = None _world_size: int = None @@ -114,6 +119,13 @@ def __post_init__(self): raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.") if self.ring_degree > 1: raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.") + if self.ring_anything: + if self.ring_degree == 1: + raise ValueError("ring_degree must be greater than 1 for ring_anything to be enabled.") + if self.ulysses_degree > 1: + raise ValueError("ring_anything cannot be enabled when ulysses_degree > 1.") + if self.ulysses_anything and self.ring_anything: + raise ValueError("ulysses_anything and ring_anything cannot both be enabled.") @property def mesh_shape(self) -> tuple[int, int]: diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index b3bd55db48dd..e656ea6c711f 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2076,6 +2076,119 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None +class TemplatedRingAnythingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None, + dropout_p: float, + is_causal: bool, + scale: float | None, + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: "ParallelConfig" | None = None, + ): + # Ring attention for arbitrary sequence lengths. + if attn_mask is not None: + raise ValueError( + "TemplatedRingAnythingAttention does not support non-None attn_mask: " + "non-uniform sequence lengths across ranks make cross-rank mask slicing ambiguous." + ) + ring_mesh = _parallel_config.context_parallel_config._ring_mesh + group = ring_mesh.get_group() + rank = _parallel_config.context_parallel_config._ring_local_rank + world_size = _parallel_config.context_parallel_config.ring_degree + next_rank = (rank + 1) % world_size + prev_out = prev_lse = None + + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx.q_shape = query.shape + ctx.kv_shape = key.shape + ctx._parallel_config = _parallel_config + + kv_seq_len = key.shape[1] # local S_KV (may differ across ranks) + all_kv_seq_lens = gather_size_by_comm(kv_seq_len, group) + s_max = max(all_kv_seq_lens) + + # Padding is applied on the sequence dimension (dim=1) at the end. + def pad_to_s_max(t: torch.Tensor) -> torch.Tensor: + pad_len = s_max - t.shape[1] + if pad_len == 0: + return t + pad_shape = list(t.shape) + pad_shape[1] = pad_len + return torch.cat([t, torch.zeros(pad_shape, dtype=t.dtype, device=t.device)], dim=1) + + key_padded = pad_to_s_max(key) + value_padded = pad_to_s_max(value) + + kv_buffer = torch.cat([key_padded.flatten(), value_padded.flatten()]).contiguous() + kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=group) + kv_buffer = kv_buffer.chunk(world_size) + + # numel per-rank in the padded layout + kv_padded_numel = key_padded.numel() + + for i in range(world_size): + if i > 0: + true_seq_len = all_kv_seq_lens[next_rank] + kv = kv_buffer[next_rank] + # Reshape to padded shape, then slice to true sequence length + key = kv[:kv_padded_numel].reshape_as(key_padded)[:, :true_seq_len] + value = kv[kv_padded_numel:].reshape_as(value_padded)[:, :true_seq_len] + next_rank = (next_rank + 1) % world_size + else: + # i == 0: use local (unpadded) key/value + key = key_padded[:, :kv_seq_len] + value = value_padded[:, :kv_seq_len] + + out, lse = forward_op( + ctx, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + True, + _save_ctx=i == 0, + _parallel_config=_parallel_config, + ) + + if _parallel_config.context_parallel_config.convert_to_fp32: + out = out.to(torch.float32) + lse = lse.to(torch.float32) + + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) + if prev_out is not None: + out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) + lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) + prev_out = out + prev_lse = lse + + out = out.to(query.dtype) + lse = lse.squeeze(-1) + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + raise NotImplementedError("Backward pass for Ring Anything Attention in diffusers is not implemented yet.") + + class TemplatedUlyssesAnythingAttention(torch.autograd.Function): @staticmethod def forward( @@ -2254,20 +2367,36 @@ def _templated_context_parallel_attention( _parallel_config, ) elif _parallel_config.context_parallel_config.ring_degree > 1: - return TemplatedRingAttention.apply( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - forward_op, - backward_op, - _parallel_config, - ) + if _parallel_config.context_parallel_config.ring_anything: + return TemplatedRingAnythingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + else: + return TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) elif _parallel_config.context_parallel_config.ulysses_degree > 1: if _parallel_config.context_parallel_config.ulysses_anything: # For Any sequence lengths and Any head num support