Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ jobs:
- name: PyTest
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready:
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready
# if: github.ref != 'refs/heads/main'
# permissions:
# checks: read
# pull-requests: write
# needs: build
# uses: ./.github/workflows/AddLabel.yml
# uses: ./.github/workflows/AddLabel.yml
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2026/04/16`**: Support for Tokamax Ring Attention kernel is now added.
- **`2026/03/31`**: Wan2.2 SenCache inference is now supported for T2V and I2V (up to 1.4x speedup)
- **`2026/03/25`**: Wan2.1 and Wan2.2 Magcache inference is now supported
- **`2026/03/25`**: LTX-2 Video Inference is now supported
Expand Down Expand Up @@ -623,6 +624,24 @@ To generate images, run the following command:
...
```

### Ring Attention
We added ring attention support for Wan models. Below are the stats for one `720p` (81 frames) video generation (with CFG DP):
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
| -- | -- | -- | -- | -- | -- |
| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | 264.2 |
| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | **252.4** |
| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | 212.7 |
| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | **201.7** |

| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
| -- | -- | -- | -- | -- | -- |
| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | 146.6 |
| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | **137.2** |
| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **117.8** |
| v7x-16 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context8-tp1 | 137.5 |

(* There are some known stability issues for ring attention on 16 TPUs, please use `tokamax_flash` attention instead.)

## Flux

First make sure you have permissions to access the Flux repos in Huggingface.
Expand Down
16 changes: 15 additions & 1 deletion src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1 # default to total_device * 2 // (dp)

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand All @@ -60,7 +61,7 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
Expand Down Expand Up @@ -180,6 +181,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'context'],
]
vae_logical_axis_rules: [
['activation_batch', 'redundant'],
['activation_length', 'vae_spatial'],
['activation_heads', null],
['activation_kv_length', null],
['embed', null],
['heads', null],
['norm', null],
['conv_batch', 'redundant'],
['out_channels', 'vae_spatial'],
['conv_out', 'vae_spatial'],
['conv_in', 'vae_spatial'],
]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
Expand Down
15 changes: 14 additions & 1 deletion src/maxdiffusion/configs/base_wan_1_3b.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -157,6 +157,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'context'],
]
vae_logical_axis_rules: [
['activation_batch', 'redundant'],
['activation_length', 'vae_spatial'],
['activation_heads', null],
['activation_kv_length', null],
['embed', null],
['heads', null],
['norm', null],
['conv_batch', 'redundant'],
['out_channels', 'vae_spatial'],
['conv_out', 'vae_spatial'],
['conv_in', 'vae_spatial'],
]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
Expand Down
14 changes: 14 additions & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: 1

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand Down Expand Up @@ -168,6 +169,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'context'],
]
vae_logical_axis_rules: [
['activation_batch', 'redundant'],
['activation_length', 'vae_spatial'],
['activation_heads', null],
['activation_kv_length', null],
['embed', null],
['heads', null],
['norm', null],
['conv_batch', 'redundant'],
['out_channels', 'vae_spatial'],
['conv_out', 'vae_spatial'],
['conv_in', 'vae_spatial'],
]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
Expand Down
13 changes: 13 additions & 0 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'context'],
]
vae_logical_axis_rules: [
['activation_batch', 'redundant'],
['activation_length', 'vae_spatial'],
['activation_heads', null],
['activation_kv_length', null],
['embed', null],
['heads', null],
['norm', null],
['conv_batch', 'redundant'],
['out_channels', 'vae_spatial'],
['conv_out', 'vae_spatial'],
['conv_in', 'vae_spatial'],
]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
Expand Down
13 changes: 13 additions & 0 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'context'],
]
vae_logical_axis_rules: [
['activation_batch', 'redundant'],
['activation_length', 'vae_spatial'],
['activation_heads', null],
['activation_kv_length', null],
['embed', null],
['heads', null],
['norm', null],
['conv_batch', 'redundant'],
['out_channels', 'vae_spatial'],
['conv_out', 'vae_spatial'],
['conv_in', 'vae_spatial'],
]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def load_config(
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
Expand Down
Empty file.
15 changes: 15 additions & 0 deletions src/maxdiffusion/kernels/splash_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Splash Attention kernels."""
Loading
Loading