Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,26 @@ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None)
return args[index], False, index


def _get_cp_pad_state(parallel_config: ContextParallelConfig) -> dict[int, int]:
if not hasattr(parallel_config, "_cp_pad_state"):
parallel_config._cp_pad_state = {}
return parallel_config._cp_pad_state


def _pad_tensor_for_context_parallel(
tensor: torch.Tensor, dim: int, world_size: int, pad_value: float | int
) -> torch.Tensor:
seq_len = tensor.size(dim)
pad_len = (world_size - seq_len % world_size) % world_size
if pad_len == 0:
return tensor

pad_width = [0] * (2 * tensor.dim())
pad_idx = tensor.dim() - 1 - dim
pad_width[2 * pad_idx + 1] = pad_len
return torch.nn.functional.pad(tensor, tuple(pad_width), mode="constant", value=pad_value)


def apply_context_parallel(
module: torch.nn.Module,
parallel_config: ContextParallelConfig,
Expand Down Expand Up @@ -156,7 +176,7 @@ def pre_forward(self, module, *args, **kwargs):
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
# the output instead of input for a particular layer by setting split_output=True
if isinstance(input_val, torch.Tensor):
input_val = self._prepare_cp_input(input_val, cpm)
input_val = self._prepare_cp_input(input_val, cpm, name)
elif isinstance(input_val, (list, tuple)):
if len(input_val) != len(cpm):
raise ValueError(
Expand Down Expand Up @@ -198,23 +218,32 @@ def post_forward(self, module, output):
if index >= len(output):
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
current_output = output[index]
current_output = self._prepare_cp_input(current_output, cpm)
current_output = self._prepare_cp_input(current_output, cpm, str(index))
output[index] = current_output

return output[0] if is_tensor else tuple(output)

def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput, name: str = "") -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
logger.warning_once(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
)
return x
else:
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(
x, cp_input.split_dim, self.parallel_config._flattened_mesh
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)

mesh = self.parallel_config._flattened_mesh
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(x, cp_input.split_dim, mesh)

dim = cp_input.split_dim
world_size = mesh.size()
seq_len = x.size(dim)
if world_size > 1 and seq_len % world_size != 0:
pad_value = 0 if "mask" in name.lower() else 0.0
x = _pad_tensor_for_context_parallel(x, dim, world_size, pad_value)
pad_state = _get_cp_pad_state(self.parallel_config)
pad_state.setdefault(dim, seq_len)

return EquipartitionSharder.shard(x, dim, mesh)


class ContextParallelGatherHook(ModelHook):
Expand All @@ -240,13 +269,18 @@ def post_forward(self, module, output):
if cpm is None:
continue
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
output[i] = PartitionAnythingSharder.unshard_anything(
x = PartitionAnythingSharder.unshard_anything(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
else:
output[i] = EquipartitionSharder.unshard(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
x = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)

pad_state = _get_cp_pad_state(self.parallel_config)
original_s = pad_state.pop(cpm.gather_dim, None)
if original_s is not None:
x = x.narrow(cpm.gather_dim, 0, original_s)

output[i] = x

return output[0] if is_tensor else tuple(output)

Expand Down
82 changes: 82 additions & 0 deletions tests/hooks/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.

import gc
import unittest
from unittest.mock import patch

import pytest
import torch

from diffusers.hooks import HookRegistry, ModelHook
from diffusers.hooks.context_parallel import ContextParallelGatherHook, ContextParallelSplitHook
from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger

Expand Down Expand Up @@ -372,3 +375,82 @@ def test_invocation_order_stateful_last(self):
.replace("\n", "")
)
assert output == expected_invocation_order_log


class _DummyMesh:
def __init__(self, size: int):
self._size = size

def size(self):
return self._size


class _DummyParallelConfig:
def __init__(self, mesh_size: int):
self._flattened_mesh = _DummyMesh(mesh_size)
self.ulysses_anything = False
self.ring_anything = False


class _DummyCPInput:
def __init__(self, split_dim: int, expected_dims: int | None = None, split_output: bool = False):
self.split_dim = split_dim
self.expected_dims = expected_dims
self.split_output = split_output


class _DummyCPOutput:
def __init__(self, gather_dim: int, expected_dims: int | None = None):
self.gather_dim = gather_dim
self.expected_dims = expected_dims


class ContextParallelHooksTests(unittest.TestCase):
def setUp(self):
self.parallel_config = _DummyParallelConfig(mesh_size=3)
self.hook = ContextParallelSplitHook(metadata={}, parallel_config=self.parallel_config)
self.module = DummyModel(in_features=1, hidden_features=1, out_features=1, num_layers=1)
self.hook.initialize_hook(self.module)

def test_prepare_cp_input_pads_hidden_states_and_stores_original(self):
x = torch.randn(1, 7, 16)
cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False)

with patch.object(self.EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")

self.assertEqual(out.shape[1], 9)
self.assertEqual(self.parallel_config._cp_pad_state[1], 7)

def test_prepare_cp_input_pads_mask_with_zeros(self):
mask = torch.ones(1, 7, dtype=torch.long)
cp_input = _DummyCPInput(split_dim=1, expected_dims=2, split_output=False)

with patch.object(self.EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
out_mask = self.hook._prepare_cp_input(mask, cp_input, name="encoder_hidden_states_mask")

self.assertEqual(out_mask.shape[1], 9)
self.assertTrue(torch.equal(out_mask[:, -2:], torch.zeros(1, 2, dtype=torch.long)))

def test_prepare_cp_input_no_pad_when_divisible(self):
x = torch.randn(1, 6, 16)
cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False)

with patch.object(self.EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")

self.assertEqual(out.shape[1], 6)
self.assertNotIn(1, getattr(self.parallel_config, "_cp_pad_state", {}))

def test_gather_hook_trims_padded_output(self):
gather_hook = ContextParallelGatherHook(
metadata=[_DummyCPOutput(gather_dim=1, expected_dims=3)], parallel_config=self.parallel_config
)
self.parallel_config._cp_pad_state = {1: 7}

padded_output = torch.randn(1, 9, 16)
with patch.object(self.EquipartitionSharder, "unshard", side_effect=lambda t, dim, mesh: t):
result = gather_hook.post_forward(self.module, padded_output)

self.assertEqual(result.shape[1], 7)
self.assertNotIn(1, self.parallel_config._cp_pad_state)
Loading