[LTX2] resolve flash attention block size mismatch and missing config overrides#382
Merged
copybara-service[bot] merged 1 commit intomainfrom Apr 21, 2026
Merged
[LTX2] resolve flash attention block size mismatch and missing config overrides#382copybara-service[bot] merged 1 commit intomainfrom
copybara-service[bot] merged 1 commit intomainfrom
Conversation
prishajain1
reviewed
Apr 19, 2026
…nfig overrides This commit addresses two issues in the LTX-2 pipeline: 1. Pipeline Config Overrides: Fixed a bug in `ltx2_pipeline.py` where `a2v_attention_kernel` and `v2a_attention_kernel` configurations were ignored. The model previously hardcoded a fallback to "flash" because these values were not mapped from the user config to `ltx2_config`. 2. Flash Attention Padding Mismatch: Fixed a `ValueError` (e.g., `kv_block_size=126 should divide kv_seq_len=128`) in `attention_flax.py` that occurred for specific video frame counts. A previous fix padded sequences to satisfy `shard_map` context dimension requirements, but `_select_flash_block_sizes` was calculating block sizes based on the unpadded length. Moved the block size calculation to occur *after* `_reshape_data_for_flash` so that the dynamic `min()` bounds correctly align with the newly padded sequence lengths, keeping cross-attention optimizations intact and unit tests passing.
Perseus14
approved these changes
Apr 19, 2026
entrpn
approved these changes
Apr 21, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR fixes a crash occurring in LTX-2 when using specific frame counts (e.g., num_frames=121) with Flash Attention, and fixes a pipeline bug that prevented users from manually overriding the cross-attention kernels.
The Root Causes & Fixes:
Bug: Passing a2v_attention_kernel=dot_product via CLI or YAML had no effect. The pipeline was only mapping the main attention config, dropping the cross-attention kernel parameters before initializing the transformer.
Fix: Added mapping for a2v_attention_kernel and v2a_attention_kernel into ltx2_config inside create_sharded_logical_transformer so user overrides are respected.
Bug: Generating 121 frames results in 126 audio latent tokens. A previous PR correctly padded this sequence from 126 to 128 to satisfy shard_map context chunking requirements. However, _tpu_flash_attention was calling _select_flash_block_sizes before the padding occurred. Because of the min() bounds used for cross-attention optimization, the block size was calculated as 126. This resulted in passing a padded sequence of 128 to the Splash Attention kernel but telling it to use a block size of 126, crashing because 128 % 126 != 0.
Fix: Swapped the order of operations in _tpu_flash_attention. Sequences are now padded by _reshape_data_for_flash before block sizes are calculated. This ensures _select_flash_block_sizes sees the padded shape, correctly calculating a divisible block size without removing the memory optimizations needed for the cross-attention unit tests to pass.
Testing
Verified that running generation with num_frames=121 executes cleanly on TPU with Flash Attention enabled.
Verified pytest src/maxdiffusion/tests/attention_test.py passes.