Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
hardware: 'tpu'
skip_jax_distributed_system: False
attention: 'flash'
a2v_attention_kernel: 'flash'
a2v_attention_kernel: 'dot_product'
v2a_attention_kernel: 'dot_product'
Comment thread
prishajain1 marked this conversation as resolved.
attention_sharding_uniform: True
precision: 'bf16'

# For scanning transformer layers
scan_layers: True

# For scanning diffusion loop
scan_diffusion_loop: True

names_which_can_be_saved: []
names_which_can_be_offloaded: []
remat_policy: "NONE"
Expand Down
8 changes: 4 additions & 4 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,11 @@ def _tpu_flash_attention(
) -> jax.Array:
"""TPU Flash Attention"""

block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
num_context_shards = mesh.shape["context"]
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)

q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
Expand Down Expand Up @@ -892,7 +892,7 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "mlp")),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
)
self.act = get_activation(activation_fn)
Expand All @@ -904,8 +904,8 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", None)),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
)

def __call__(self, hidden_states: Array) -> Array:
Expand Down
24 changes: 21 additions & 3 deletions src/maxdiffusion/models/ltx2/attention_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax.numpy as jnp
from ... import common_types
from ..attention_flax import NNXAttentionOp
from maxdiffusion.tpu_utils import get_tpu_type, TpuType

Array = common_types.Array
Mesh = common_types.Mesh
Expand Down Expand Up @@ -349,23 +350,40 @@ def __init__(
rope_type: str = "interleaved",
flash_block_sizes: BlockSizes = None,
flash_min_seq_length: int = 4096,
qkv_sharding_spec: Optional[tuple] = None,
out_sharding_spec: Optional[tuple] = None,
out_bias_sharding_spec: Optional[tuple] = None,
):
self.heads = heads
self.rope_type = rope_type
self.dim_head = dim_head
self.inner_dim = dim_head * heads
self.dropout_rate = dropout

# Auto-detect hardware for sharding specs if not overridden
tpu_type = get_tpu_type()
is_ironwood = tpu_type == TpuType.TPU_7X

# Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated)
# to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions.
# This has currently only been tested on Trillium (v6e) and Ironwood (v7x).
if qkv_sharding_spec is None:
Comment thread
prishajain1 marked this conversation as resolved.
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
if out_sharding_spec is None:
out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed")
if out_bias_sharding_spec is None:
out_bias_sharding_spec = (None,) if is_ironwood else ("embed",)

# 1. Define Partitioned Initializers (Logical Axes)
# Q, K, V kernels: [in_features (embed), out_features (heads)]
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec)
# Q, K, V biases: [out_features (heads)]
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))

# Out kernel: [in_features (heads), out_features (embed)]
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec)
# Out bias: [out_features (embed)]
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec)

# Norm scales
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ def __init__(self, in_channels: int, mid_channels: int = 1024, scale: float = 2.
in_channels, (num**2) * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), rngs=rngs
)
self.pixel_shuffle = PixelShuffleND(dims=2, upscale_factors=(num, num))
self.blur = BlurDownsample(dims=2, stride=den)
self.blur_down = BlurDownsample(dims=2, stride=den)

def __call__(self, x: jax.Array) -> jax.Array:
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur(x)
x = self.blur_down(x)
return x


Expand Down
Loading
Loading