Draft
Conversation
…d, heads) to (None, heads)
…o prisha/ltx2_opt
Perseus14
reviewed
Apr 19, 2026
Comment on lines
365
to
368
| # Out kernel: [in_features (heads), out_features (embed)] | ||
| out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")) | ||
| # Out bias: [out_features (embed)] | ||
| out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) |
Collaborator
There was a problem hiding this comment.
Since qkv_kernel is changed to (None, "heads"), should we do the same for the out_kernel? Maybe change it to ("heads", None) and out_bias to (None,)
Comment on lines
+1652
to
+1656
| final_carry, _ = nnx.scan( | ||
| scan_body, | ||
| in_axes=(nnx.Carry, 0, None), | ||
| out_axes=(nnx.Carry, 0), | ||
| )(initial_carry, timesteps_jax, transformer) |
Collaborator
There was a problem hiding this comment.
We have a config param scan_layers (which regulates transformer blocks, not the iterative pipeline loop) as well as this nnx.scan diffusion loop. Could this change confuse what the scan_layers config controls for developers? Perhaps we can add a comment to the scan_layers config that this controls only transformer blocks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds features which lead to performance gains for LTX-2 model, along with a fix for the broken LTX-2 Upsampler in main
Performance enhancement features:
Sharding fix in
NNXSImpleFeedForward: Data is sharded along the sequence dimension, each device holds a subset of tokens, but full feature channels. Because the input data had replicated features but the weights expected sharded features on the same physical axis (context), XLA was forced to insert an All-Gather on the sequence dimension to resolve the layout conflict, resulting in high wasted time. With our fix:QKV Projection Sharding fix: The profiling showed that the input data was being All-Gathered along the sequence dimension triggered by the QKV Projection step. Because the weights were sharded on the dimension that needed to be summed over (features), a single device could not complete the matrix multiplication using only its local shard of the data. To resolve this, XLA automatically inserted an All-Gather to replicate the sequence dimension across all devices before performing the multiplication. We changed the weight sharding in
attention_ltx2.pyto remove sharding on the input feature dimension.Batching in text encoder: With CFG enabled, we see two passes of text encoder: one each for positive and negative prompts. If Classifier-Free Guidance is enabled, we concatenate the positive prompt and negative prompt and instead of doing two passes of text encoder, we do a single pass.
JITting Diffusion Loop: The current implementation uses a Python for loop to iterate over diffusion timesteps. This created a "Python Dispatch Wall," resulting in some idle time between consecutive forward passes while the TPU waited for the host CPU to dispatch the next step. We refactored the entire denoising loop to use nnx.scan.
LTX2 Upsampler fix:
Results
We also tested WAN I2V pipelines to ensure no regressions are caused there. No quality regression or increased latency was observed.