Skip to content

Cast LTX2 scale_shift tables to the activation dtype at use site#13918

Open
Boffee wants to merge 2 commits into
huggingface:mainfrom
Boffee:ltx2-scale-shift-table-dtype
Open

Cast LTX2 scale_shift tables to the activation dtype at use site#13918
Boffee wants to merge 2 commits into
huggingface:mainfrom
Boffee:ltx2-scale-shift-table-dtype

Conversation

@Boffee

@Boffee Boffee commented Jun 11, 2026

Copy link
Copy Markdown

What does this PR do?

Fixes a mixed-dtype crash in LTX2VideoTransformer3DModel when a checkpoint is loaded with its native per-tensor dtypes.

The published LTX-2 checkpoints (e.g. Lightricks/LTX-2.3, diffusers/LTX-2.3-Diffusers, and the fp8 variants) store the AdaLN scale_shift_table parameters in fp32 alongside bf16 weights — in diffusers/LTX-2.3-Diffusers/transformer that's 290 fp32 tensors next to 3,896 bf16 ones. The reference implementation casts the tables to the activation dtype at every use site:

The diffusers port casts device only (scale_shift_table[None, None].to(temb.device)), so a natively-loaded mixed-dtype model promotes the modulated hidden states to fp32, and the next linear layer raises:

RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

Repro (tiny config from the test suite):

model = LTX2VideoTransformer3DModel(**tiny_init).to(torch.bfloat16).eval()
for name, p in model.named_parameters():
    if "scale_shift_table" in name:
        p.data = p.data.float()  # as stored in the published checkpoints
model(**inputs)  # RuntimeError

This PR restores the use-site dtype cast in the three places the tables are consumed (get_mod_params, which all block-level tables flow through, and the two output-layer sites). Behavior for existing users is unchanged: with uniform-dtype models the cast is a no-op, and a mixed-dtype model's outputs are bit-identical to one whose tables were flattened to the weight dtype at load time (cast-then-add is the same computation either way). A regression test asserting exactly that equality is included.

Related precedent: the Wan family handles the same situation by declaring _keep_in_fp32_modules = [..., "scale_shift_table", ...] with compensating casts in the forward. For LTX2 the reference implementation computes the modulation in the activation dtype, so the minimal use-site cast matches the original numerics without opting into fp32 modulation.

Before submitting

  • Did you read the contributor guideline?
  • Did you write any new necessary tests? (test_fp32_scale_shift_tables_match_uniform_dtype)
  • make style / ruff clean on the changed files; tests/models/transformers/test_models_transformer_ltx2.py::TestLTX2Transformer passes (15 passed, 2 skipped).

Who can review?

@a-r-r-o-w @yiyixuxu

🤖 Generated with Claude Code

Published LTX-2 checkpoints store the AdaLN scale_shift tables in fp32
alongside bf16 weights. The original implementation casts the tables to
the activation dtype at every use site, but the diffusers port casts
device only, so loading a checkpoint with its native dtypes promotes
the modulated hidden states to fp32 and the following linear layers
raise "mat1 and mat2 must have the same dtype". Restore the use-site
dtype cast; outputs are bit-identical to a model whose tables were
flattened to the weight dtype at load time.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@github-actions github-actions Bot added size/S PR with diff < 50 LOC models tests and removed size/S PR with diff < 50 LOC labels Jun 11, 2026
…tation

The connector replaced left-padding positions with the tiled registers and
then flipped the whole sequence, which put the prompt tokens at the front in
reversed order and the register tile reversed within each block. The original
LTX implementation (ltx-core _replace_padded_with_learnable_registers, also
matched by ComfyUI) front-aligns the valid tokens in their original order and
fills the tail with registers indexed by absolute position.

Since the connector blocks apply RoPE, the reversed layout produces
off-distribution embeddings; short prompts (e.g. negative prompts, whose
context is mostly registers) are hit hardest, which manifests as overblown
CFG: at cfg > 1 (or CFG++ samplers at cfg 1) the unconditional branch is
computed from a mostly-register context with scrambled positions.

Replace the fill+flip with a stable-argsort gather (valid tokens to the
front, order preserved, per batch row) and fill the tail with the
absolute-position register tile.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@github-actions github-actions Bot added pipelines size/M PR with diff < 200 LOC labels Jun 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant