Skip to content

LTX2 Performance tuning#385

Draft
prishajain1 wants to merge 20 commits intomainfrom
prisha/ltx2_opt
Draft

LTX2 Performance tuning#385
prishajain1 wants to merge 20 commits intomainfrom
prisha/ltx2_opt

Conversation

@prishajain1
Copy link
Copy Markdown
Collaborator

@prishajain1 prishajain1 commented Apr 19, 2026

This PR adds features which lead to performance gains for LTX-2 model, along with a fix for the broken LTX-2 Upsampler in main

Performance enhancement features:

  • Sharding fix in NNXSImpleFeedForward : Data is sharded along the sequence dimension, each device holds a subset of tokens, but full feature channels. Because the input data had replicated features but the weights expected sharded features on the same physical axis (context), XLA was forced to insert an All-Gather on the sequence dimension to resolve the layout conflict, resulting in high wasted time. With our fix:

    • Overall % of wasted time in all-gathers went from 52.56% to 38.07%
    • Generation time per video dropped from 20s to 16.7s
  • QKV Projection Sharding fix: The profiling showed that the input data was being All-Gathered along the sequence dimension triggered by the QKV Projection step. Because the weights were sharded on the dimension that needed to be summed over (features), a single device could not complete the matrix multiplication using only its local shard of the data. To resolve this, XLA automatically inserted an All-Gather to replicate the sequence dimension across all devices before performing the multiplication. We changed the weight sharding in attention_ltx2.py to remove sharding on the input feature dimension.

    • Overall % of wasted time in all-gathers went from 38.07% to 19.39%
    • Generation time per video dropped from 16.7 to 13.84s
  • Batching in text encoder: With CFG enabled, we see two passes of text encoder: one each for positive and negative prompts. If Classifier-Free Guidance is enabled, we concatenate the positive prompt and negative prompt and instead of doing two passes of text encoder, we do a single pass.

    • Text encoder time reduced from 3.54s to 3.06s
    • Generation time per video dropped from 13.84s to 13.38s
  • JITting Diffusion Loop: The current implementation uses a Python for loop to iterate over diffusion timesteps. This created a "Python Dispatch Wall," resulting in some idle time between consecutive forward passes while the TPU waited for the host CPU to dispatch the next step. We refactored the entire denoising loop to use nnx.scan.

    • The total diffusion time across 40 steps dropped from 7.84s to 7.28s
    • Generation time per video dropped from 13.38s to 12.5s

LTX2 Upsampler fix:

  • The current LTX2 Upsampler pipeline raises ValueError : blur_down is the name of the submodule in the PyTorch state dict from the Hugging Face checkpoint. In the original PyTorch model, that layer was named blur_down, but in the MaxDiffusion Flax implementation, it was named blur. Because our weight loader didn't rename it, nnx.update tried to update a non-existent blur_down attribute.

Results

Version Execution Time Status
Current Main 20.01s Video Link
After Fix 12.50s Video Link

We also tested WAN I2V pipelines to ensure no regressions are caused there. No quality regression or increased latency was observed.

@prishajain1 prishajain1 requested a review from entrpn as a code owner April 19, 2026 07:11
@github-actions
Copy link
Copy Markdown

@prishajain1 prishajain1 marked this pull request as draft April 19, 2026 09:00
Comment thread src/maxdiffusion/configs/ltx2_video.yml
Comment on lines 365 to 368
# Out kernel: [in_features (heads), out_features (embed)]
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
# Out bias: [out_features (embed)]
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since qkv_kernel is changed to (None, "heads"), should we do the same for the out_kernel? Maybe change it to ("heads", None) and out_bias to (None,)

Comment on lines +1652 to +1656
final_carry, _ = nnx.scan(
scan_body,
in_axes=(nnx.Carry, 0, None),
out_axes=(nnx.Carry, 0),
)(initial_carry, timesteps_jax, transformer)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a config param scan_layers (which regulates transformer blocks, not the iterative pipeline loop) as well as this nnx.scan diffusion loop. Could this change confuse what the scan_layers config controls for developers? Perhaps we can add a comment to the scan_layers config that this controls only transformer blocks

@entrpn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants