Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/diffusers/models/transformers/transformer_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,11 @@ def get_mod_params(
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
) -> tuple[torch.Tensor, ...]:
num_ada_params = scale_shift_table.shape[0]
ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape(
# Cast to temb's dtype at the use site (matching the original implementation):
# checkpoints store the scale_shift tables in fp32 alongside bf16 weights, so
# without the cast the fp32 tables promote the modulated hidden states and the
# following linear layers fail on mixed dtypes.
ada_values = scale_shift_table[None, None].to(device=temb.device, dtype=temb.dtype) + temb.reshape(
batch_size, temb.shape[1], num_ada_params, -1
)
ada_params = ada_values.unbind(dim=2)
Expand Down Expand Up @@ -1620,14 +1624,19 @@ def forward(
)

# 6. Output layers (including unpatchification)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
scale_shift_values = (
self.scale_shift_table[None, None].to(embedded_timestep.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)

audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None]
audio_scale_shift_values = (
self.audio_scale_shift_table[None, None].to(audio_embedded_timestep.dtype)
+ audio_embedded_timestep[:, :, None]
)
audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1]

audio_hidden_states = self.audio_norm_out(audio_hidden_states)
Expand Down
18 changes: 12 additions & 6 deletions src/diffusers/pipelines/ltx2/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,19 @@ def forward(
if binary_attn_mask.ndim == 4:
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]

# Replace padding positions with learned registers using vectorized masking
mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1]
# Move the valid tokens to the front in their original order and fill the tail
# with registers indexed by absolute position, matching the original LTX
# implementation (`_replace_padded_with_learnable_registers`). A stable argsort
# of the inverted mask gathers valid tokens first while preserving their order.
order = torch.argsort(1 - binary_attn_mask, dim=1, stable=True) # [B, L]
front_aligned = torch.gather(
hidden_states, 1, order.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
)
num_valid = binary_attn_mask.sum(dim=1, keepdim=True) # [B, 1]
positions = torch.arange(seq_len, device=hidden_states.device).unsqueeze(0) # [1, L]
front_mask = (positions < num_valid).unsqueeze(-1) # [B, L, 1]
registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D]
hidden_states = mask * hidden_states + (1 - mask) * registers_expanded

# Flip sequence: embeddings move to front, registers to back (from left padding layout)
hidden_states = torch.flip(hidden_states, dims=[1])
hidden_states = torch.where(front_mask, front_aligned, registers_expanded.to(hidden_states.dtype))

# Overwrite attention_mask with an all-zeros mask if using registers.
attention_mask = torch.zeros_like(attention_mask)
Expand Down
25 changes: 25 additions & 0 deletions tests/models/transformers/test_models_transformer_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,31 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestLTX2Transformer(LTX2TransformerTesterConfig, ModelTesterMixin):
"""Core model tests for LTX2 Video Transformer."""

def test_fp32_scale_shift_tables_match_uniform_dtype(self):
# Published LTX-2 checkpoints store the AdaLN scale_shift tables in fp32
# alongside bf16 weights. The tables are cast to the activation dtype at
# the use site (as in the original implementation), so a natively loaded
# mixed-dtype model must run and produce the same outputs as one whose
# tables were flattened to the weight dtype at load time.
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict()).to(torch.bfloat16).to(torch_device).eval()
inputs = {
key: value.to(torch.bfloat16) if isinstance(value, torch.Tensor) and value.is_floating_point() else value
for key, value in self.get_dummy_inputs().items()
}

with torch.no_grad():
reference = model(**inputs)

for name, param in model.named_parameters():
if "scale_shift_table" in name:
param.data = param.data.float()
with torch.no_grad():
mixed = model(**inputs)

assert torch.equal(reference[0], mixed[0])
assert torch.equal(reference[1], mixed[1])


class TestLTX2TransformerMemory(LTX2TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for LTX2 Video Transformer."""
Expand Down
Loading