Cast LTX2 scale_shift tables to the activation dtype at use site#13918
Open
Boffee wants to merge 2 commits into
Open
Cast LTX2 scale_shift tables to the activation dtype at use site#13918Boffee wants to merge 2 commits into
Boffee wants to merge 2 commits into
Conversation
Published LTX-2 checkpoints store the AdaLN scale_shift tables in fp32 alongside bf16 weights. The original implementation casts the tables to the activation dtype at every use site, but the diffusers port casts device only, so loading a checkpoint with its native dtypes promotes the modulated hidden states to fp32 and the following linear layers raise "mat1 and mat2 must have the same dtype". Restore the use-site dtype cast; outputs are bit-identical to a model whose tables were flattened to the weight dtype at load time. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…tation The connector replaced left-padding positions with the tiled registers and then flipped the whole sequence, which put the prompt tokens at the front in reversed order and the register tile reversed within each block. The original LTX implementation (ltx-core _replace_padded_with_learnable_registers, also matched by ComfyUI) front-aligns the valid tokens in their original order and fills the tail with registers indexed by absolute position. Since the connector blocks apply RoPE, the reversed layout produces off-distribution embeddings; short prompts (e.g. negative prompts, whose context is mostly registers) are hit hardest, which manifests as overblown CFG: at cfg > 1 (or CFG++ samplers at cfg 1) the unconditional branch is computed from a mostly-register context with scrambled positions. Replace the fill+flip with a stable-argsort gather (valid tokens to the front, order preserved, per batch row) and fill the tail with the absolute-position register tile. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes a mixed-dtype crash in
LTX2VideoTransformer3DModelwhen a checkpoint is loaded with its native per-tensor dtypes.The published LTX-2 checkpoints (e.g.
Lightricks/LTX-2.3,diffusers/LTX-2.3-Diffusers, and the fp8 variants) store the AdaLNscale_shift_tableparameters in fp32 alongside bf16 weights — indiffusers/LTX-2.3-Diffusers/transformerthat's 290 fp32 tensors next to 3,896 bf16 ones. The reference implementation casts the tables to the activation dtype at every use site:transformer.pyget_ada_values:scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)model.pyoutput layer:scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)The diffusers port casts device only (
scale_shift_table[None, None].to(temb.device)), so a natively-loaded mixed-dtype model promotes the modulated hidden states to fp32, and the next linear layer raises:Repro (tiny config from the test suite):
This PR restores the use-site dtype cast in the three places the tables are consumed (
get_mod_params, which all block-level tables flow through, and the two output-layer sites). Behavior for existing users is unchanged: with uniform-dtype models the cast is a no-op, and a mixed-dtype model's outputs are bit-identical to one whose tables were flattened to the weight dtype at load time (cast-then-add is the same computation either way). A regression test asserting exactly that equality is included.Related precedent: the Wan family handles the same situation by declaring
_keep_in_fp32_modules = [..., "scale_shift_table", ...]with compensating casts in the forward. For LTX2 the reference implementation computes the modulation in the activation dtype, so the minimal use-site cast matches the original numerics without opting into fp32 modulation.Before submitting
test_fp32_scale_shift_tables_match_uniform_dtype)make style/ ruff clean on the changed files;tests/models/transformers/test_models_transformer_ltx2.py::TestLTX2Transformerpasses (15 passed, 2 skipped).Who can review?
@a-r-r-o-w @yiyixuxu
🤖 Generated with Claude Code