diff --git a/docs/source/en/api/models/motif_video_transformer_3d.md b/docs/source/en/api/models/motif_video_transformer_3d.md
new file mode 100644
index 000000000000..46344f841b10
--- /dev/null
+++ b/docs/source/en/api/models/motif_video_transformer_3d.md
@@ -0,0 +1,32 @@
+
+
+# MotifVideoTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in Motif-Video by the Motif Technologies Team.
+
+The model uses a three-stage architecture with 12 dual-stream + 16 single-stream + 8 DDT decoder layers and rotary positional embeddings (RoPE) for video generation.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import MotifVideoTransformer3DModel
+
+transformer = MotifVideoTransformer3DModel.from_pretrained("MotifTechnologies/Motif-Video-2B", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## MotifVideoTransformer3DModel
+
+[[autodoc]] MotifVideoTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/motif_video.md b/docs/source/en/api/pipelines/motif_video.md
new file mode 100644
index 000000000000..2ec549b9cfd4
--- /dev/null
+++ b/docs/source/en/api/pipelines/motif_video.md
@@ -0,0 +1,147 @@
+
+
+# Motif-Video
+
+[Technical Report](https://arxiv.org/abs/2604.16503)
+
+Motif-Video is a 2B parameter diffusion transformer designed for text-to-video and image-to-video generation. It features a three-stage architecture with 12 dual-stream + 16 single-stream + 8 DDT decoder layers, Shared Cross-Attention for stable text-video alignment under long video sequences, T5Gemma2 text encoder, and rectified flow matching for velocity prediction.
+
+
+
+
+
+## Text-to-Video Generation
+
+Use `MotifVideoPipeline` for text-to-video generation:
+
+```python
+import torch
+from diffusers import AdaptiveProjectedGuidance, MotifVideoPipeline
+from diffusers.utils import export_to_video
+
+guider = AdaptiveProjectedGuidance(
+ guidance_scale=8.0,
+ adaptive_projected_guidance_rescale=12.0,
+ adaptive_projected_guidance_momentum=0.1,
+ use_original_formulation=True,
+ normalization_dims="spatial",
+)
+
+pipe = MotifVideoPipeline.from_pretrained(
+ "MotifTechnologies/Motif-Video-2B",
+ torch_dtype=torch.bfloat16,
+ guider=guider,
+)
+pipe.to("cuda")
+
+prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair."
+negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=1280,
+ height=736,
+ num_frames=121,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(video, "output.mp4", fps=24)
+```
+
+## Image-to-Video Generation
+
+Use `MotifVideoImage2VideoPipeline` for image-to-video generation:
+
+```python
+import torch
+from diffusers import AdaptiveProjectedGuidance, MotifVideoImage2VideoPipeline
+from diffusers.utils import export_to_video, load_image
+
+guider = AdaptiveProjectedGuidance(
+ guidance_scale=8.0,
+ adaptive_projected_guidance_rescale=12.0,
+ adaptive_projected_guidance_momentum=0.1,
+ use_original_formulation=True,
+ normalization_dims="spatial",
+)
+
+pipe = MotifVideoImage2VideoPipeline.from_pretrained(
+ "MotifTechnologies/Motif-Video-2B",
+ torch_dtype=torch.bfloat16,
+ guider=guider,
+)
+pipe.to("cuda")
+
+image = load_image("input_image.png")
+prompt = "A cinematic scene with vivid colors."
+negative_prompt = "worst quality, blurry, jittery, distorted"
+
+video = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=1280,
+ height=736,
+ num_frames=121,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(video, "i2v_output.mp4", fps=24)
+```
+
+### Memory-efficient Inference
+
+For GPUs with less than 30GB VRAM (e.g., RTX 4090), use model CPU offloading:
+
+```bash
+export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+```
+
+```python
+import torch
+from diffusers import AdaptiveProjectedGuidance, MotifVideoPipeline
+from diffusers.utils import export_to_video
+
+guider = AdaptiveProjectedGuidance(
+ guidance_scale=8.0,
+ adaptive_projected_guidance_rescale=12.0,
+ adaptive_projected_guidance_momentum=0.1,
+ use_original_formulation=True,
+ normalization_dims="spatial",
+)
+
+pipe = MotifVideoPipeline.from_pretrained(
+ "MotifTechnologies/Motif-Video-2B",
+ torch_dtype=torch.bfloat16,
+ guider=guider,
+)
+pipe.enable_model_cpu_offload()
+
+prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair."
+negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=1280,
+ height=736,
+ num_frames=121,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(video, "output.mp4", fps=24)
+```
+
+## MotifVideoPipeline
+
+[[autodoc]] MotifVideoPipeline
+ - all
+ - __call__
+
+## MotifVideoImage2VideoPipeline
+
+[[autodoc]] MotifVideoImage2VideoPipeline
+ - all
+ - __call__
+
+## MotifVideoPipelineOutput
+
+[[autodoc]] pipelines.motif_video.pipeline_output.MotifVideoPipelineOutput
\ No newline at end of file
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 2cbfd6e29305..ebd0c9e8ffde 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -263,6 +263,7 @@
"LuminaNextDiT2DModel",
"MochiTransformer3DModel",
"ModelMixin",
+ "MotifVideoTransformer3DModel",
"MotionAdapter",
"MultiAdapter",
"MultiControlNetModel",
@@ -625,6 +626,9 @@
"MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
"MochiPipeline",
+ "MotifVideoImage2VideoPipeline",
+ "MotifVideoPipeline",
+ "MotifVideoPipelineOutput",
"MusicLDMPipeline",
"NucleusMoEImagePipeline",
"OmniGenPipeline",
@@ -1073,6 +1077,7 @@
LuminaNextDiT2DModel,
MochiTransformer3DModel,
ModelMixin,
+ MotifVideoTransformer3DModel,
MotionAdapter,
MultiAdapter,
MultiControlNetModel,
@@ -1410,6 +1415,9 @@
MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
MochiPipeline,
+ MotifVideoImage2VideoPipeline,
+ MotifVideoPipeline,
+ MotifVideoPipelineOutput,
MusicLDMPipeline,
NucleusMoEImagePipeline,
OmniGenPipeline,
diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py
index b210cb3e67aa..55d38a0d32e7 100644
--- a/src/diffusers/guiders/adaptive_projected_guidance.py
+++ b/src/diffusers/guiders/adaptive_projected_guidance.py
@@ -48,6 +48,12 @@ class AdaptiveProjectedGuidance(BaseGuidance):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ normalization_dims (`str` or `list[int]` or `None`, defaults to `None`):
+ Dimensions to normalize over for the guidance computation. Can be:
+ - `None` (default): Normalize over all non-batch dimensions (e.g., [C, H, W] for 4D, [C, T, H, W] for 5D)
+ - `"spatial"`: Spatial-only normalization - normalize over [C, H, W] per frame for 5D tensors, standard for
+ 4D
+ - `list[int]`: Custom dimensions to normalize over (e.g., `[-1, -2, -4]` for [W, H, C])
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
@@ -65,6 +71,7 @@ def __init__(
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
+ normalization_dims: str | list[int] | None = None,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
@@ -77,6 +84,7 @@ def __init__(
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
+ self.normalization_dims = normalization_dims
self.momentum_buffer = None
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
@@ -117,6 +125,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = No
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
+ self.normalization_dims,
)
if self.guidance_rescale > 0.0:
@@ -210,9 +219,25 @@ def normalized_guidance(
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
+ normalization_dims: str | list[int] | None = None,
):
diff = pred_cond - pred_uncond
- dim = [-i for i in range(1, len(diff.shape))]
+
+ # Determine normalization dimensions
+ if normalization_dims == "spatial":
+ # Spatial-only normalization: normalize over [C, H, W] per frame for 5D tensors
+ if len(diff.shape) == 5:
+ # [B, C, T, H, W] -> normalize over W(-1), H(-2), C(-4), skip T(-3)
+ dim = [-1, -2, -4]
+ else:
+ # [B, C, H, W] -> standard behavior
+ dim = [-i for i in range(1, len(diff.shape))]
+ elif normalization_dims is None:
+ # Default: normalize over all non-batch dimensions
+ dim = [-i for i in range(1, len(diff.shape))]
+ else:
+ # Custom dimensions provided by user
+ dim = normalization_dims
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff --git a/src/diffusers/guiders/adaptive_projected_guidance_mix.py b/src/diffusers/guiders/adaptive_projected_guidance_mix.py
index 559e30d2aabe..d9187672f157 100644
--- a/src/diffusers/guiders/adaptive_projected_guidance_mix.py
+++ b/src/diffusers/guiders/adaptive_projected_guidance_mix.py
@@ -48,6 +48,12 @@ class AdaptiveProjectedMixGuidance(BaseGuidance):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ normalization_dims (`str` or `list[int]` or `None`, defaults to `None`):
+ Dimensions to normalize over for the guidance computation. Can be:
+ - `None` (default): Normalize over all non-batch dimensions (e.g., [C, H, W] for 4D, [C, T, H, W] for 5D)
+ - `"spatial"`: Spatial-only normalization - normalize over [C, H, W] per frame for 5D tensors, standard for
+ 4D
+ - `list[int]`: Custom dimensions to normalize over (e.g., `[-1, -2, -4]` for [W, H, C])
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
stop (`float`, defaults to `1.0`):
@@ -71,6 +77,7 @@ def __init__(
adaptive_projected_guidance_rescale: float = 10.0,
eta: float = 0.0,
use_original_formulation: bool = False,
+ normalization_dims: str | list[int] | None = None,
start: float = 0.0,
stop: float = 1.0,
adaptive_projected_guidance_start_step: int = 5,
@@ -86,6 +93,7 @@ def __init__(
self.eta = eta
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
self.use_original_formulation = use_original_formulation
+ self.normalization_dims = normalization_dims
self.momentum_buffer = None
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
@@ -138,6 +146,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = No
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
+ self.normalization_dims,
)
if self.guidance_rescale > 0.0:
@@ -269,6 +278,7 @@ def normalized_guidance(
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
+ normalization_dims: str | list[int] | None = None,
):
if momentum_buffer is not None:
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
@@ -276,7 +286,21 @@ def normalized_guidance(
else:
diff = pred_cond - pred_uncond
- dim = [-i for i in range(1, len(diff.shape))]
+ # Determine normalization dimensions
+ if normalization_dims == "spatial":
+ # Spatial-only normalization: normalize over [C, H, W] per frame for 5D tensors
+ if len(diff.shape) == 5:
+ # [B, C, T, H, W] -> normalize over W(-1), H(-2), C(-4), skip T(-3)
+ dim = [-1, -2, -4]
+ else:
+ # [B, C, H, W] -> standard behavior
+ dim = [-i for i in range(1, len(diff.shape))]
+ elif normalization_dims is None:
+ # Default: normalize over all non-batch dimensions
+ dim = [-i for i in range(1, len(diff.shape))]
+ else:
+ # Custom dimensions provided by user
+ dim = normalization_dims
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index c7bb2de4437a..43fc8d897fe6 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -21,7 +21,11 @@
from typing_extensions import Self
from .. import __version__
-from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map
+from ..models.model_loading_utils import (
+ _caching_allocator_warmup,
+ _determine_device_map,
+ _expand_device_map,
+)
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
@@ -194,6 +198,10 @@
"checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
"default_subfolder": "audio_vae",
},
+ "MotifVideoTransformer3DModel": {
+ "checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
+ "default_subfolder": "transformer",
+ },
}
@@ -336,7 +344,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No
disable_mmap = kwargs.pop("disable_mmap", False)
device_map = kwargs.pop("device_map", None)
- user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "single_file",
+ "framework": "pytorch",
+ }
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if quantization_config is not None:
user_agent["quant"] = quantization_config.quant_method.value
@@ -393,7 +405,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
diffusers_model_config = config_mapping_fn(
- original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
+ original_config=original_config,
+ checkpoint=checkpoint,
+ **config_mapping_kwargs,
)
else:
if config is not None:
@@ -465,7 +479,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No
if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
- config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
+ config=diffusers_model_config,
+ checkpoint=checkpoint,
+ **checkpoint_mapping_kwargs,
)
else:
diffusers_format_checkpoint = checkpoint
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index ba9b7810e054..d9a9611b6fd3 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -119,6 +119,7 @@
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
+ _import_structure["transformers.transformer_motif_video"] = ["MotifVideoTransformer3DModel"]
_import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"]
@@ -243,6 +244,7 @@
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
+ MotifVideoTransformer3DModel,
NucleusMoEImageTransformer2DModel,
OmniGenTransformer2DModel,
OvisImageTransformer2DModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index d4ac6ff4301e..d79c84fba784 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -42,6 +42,7 @@
from .transformer_ltx2 import LTX2VideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
+ from .transformer_motif_video import MotifVideoTransformer3DModel
from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_ovis_image import OvisImageTransformer2DModel
diff --git a/src/diffusers/models/transformers/transformer_motif_video.py b/src/diffusers/models/transformers/transformer_motif_video.py
new file mode 100644
index 000000000000..0de30c2925a8
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_motif_video.py
@@ -0,0 +1,1014 @@
+# Copyright 2026 Motif Technologies and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import FeedForward
+from ..attention_processor import Attention, AttentionProcessor
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ PixArtAlphaTextProjection,
+ TimestepEmbedding,
+ Timesteps,
+ apply_rotary_emb,
+ get_1d_rotary_pos_embed,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import (
+ AdaLayerNormContinuous,
+ AdaLayerNormZero,
+ AdaLayerNormZeroSingle,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class CrossAttentionInputs:
+ """Inputs for cross-attention mode where query is pre-projected externally.
+
+ Args:
+ query: Pre-projected query tensor [B, L, D]
+ key: Key tensor for attention [B, L, D]
+ value: Value tensor for attention [B, L, D]
+ """
+
+ query: torch.Tensor
+ key: torch.Tensor
+ value: torch.Tensor
+
+
+class MotifVideoAttnProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "MotifVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ cross_attn_inputs: Optional[CrossAttentionInputs] = None,
+ ) -> torch.Tensor:
+ """
+ Process attention with support for cross-attention mode.
+
+ Args:
+ attn: Attention layer
+ hidden_states: Input hidden states [B, L, D]
+ encoder_hidden_states: Optional encoder states [B, E, D]
+ attention_mask: Optional attention mask [B, 1, 1, N]
+ image_rotary_emb: Optional rotary embeddings
+ cross_attn_inputs: Optional pre-projected cross-attention inputs
+
+ Returns:
+ Tuple of (hidden_states, encoder_hidden_states)
+ """
+ if cross_attn_inputs is not None:
+ return self._handle_cross_attention_mode(attn, cross_attn_inputs, attention_mask, image_rotary_emb)
+
+ # Standard attention mode
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ query, key, value = self._project_qkv(attn, hidden_states)
+ query, key = self._normalize_qk(attn, query, key)
+ query, key = self._apply_rope(attn, query, key, encoder_hidden_states, image_rotary_emb)
+ query, key, value = self._add_encoder_conditioning(attn, query, key, value, encoder_hidden_states)
+ hidden_states = self._compute_attention(query, key, value, attention_mask)
+ return self._project_output(attn, hidden_states, encoder_hidden_states)
+
+ def _handle_cross_attention_mode(
+ self,
+ attn: Attention,
+ cross_attn_inputs: CrossAttentionInputs,
+ attention_mask: Optional[torch.Tensor],
+ image_rotary_emb: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, None]:
+ """Handle cross-attention mode with pre-projected query.
+
+ Query is already projected externally (cross_attn_query_proj + norm), so we skip to_q and only apply reshape +
+ norm_q + RoPE. K/V use to_k/to_v as normal.
+ """
+ query = cross_attn_inputs.query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = attn.to_k(cross_attn_inputs.key)
+ value = attn.to_v(cross_attn_inputs.value)
+
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+ return hidden_states, None
+
+ def _project_qkv(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Project hidden states to Q, K, V and reshape for attention."""
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ return query, key, value
+
+ def _normalize_qk(
+ self,
+ attn: Attention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Apply QK normalization if present."""
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+ return query, key
+
+ def _apply_rope(
+ self,
+ attn: Attention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor],
+ image_rotary_emb: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Apply rotary positional embeddings to query and key."""
+ if image_rotary_emb is not None:
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ split_idx = -encoder_hidden_states.shape[1]
+ query = torch.cat(
+ [
+ apply_rotary_emb(query[:, :, :split_idx], image_rotary_emb),
+ query[:, :, split_idx:],
+ ],
+ dim=2,
+ )
+ key = torch.cat(
+ [
+ apply_rotary_emb(key[:, :, :split_idx], image_rotary_emb),
+ key[:, :, split_idx:],
+ ],
+ dim=2,
+ )
+ else:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ return query, key
+
+ def _add_encoder_conditioning(
+ self,
+ attn: Attention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Add encoder conditioning QKV projections and normalization."""
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([query, encoder_query], dim=2)
+ key = torch.cat([key, encoder_key], dim=2)
+ value = torch.cat([value, encoder_value], dim=2)
+
+ return query, key, value
+
+ def _compute_attention(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ """Compute scaled dot-product attention."""
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+ return hidden_states
+
+ def _project_output(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Apply output projections and split encoder states."""
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
+ )
+
+ if getattr(attn, "to_out", None) is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if getattr(attn, "to_add_out", None) is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+
+class MotifVideoPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: Union[int, Tuple[int, int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ super().__init__()
+
+ patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
+ return hidden_states
+
+
+class MotifVideoAdaNorm(nn.Module):
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
+ super().__init__()
+
+ out_features = out_features or 2 * in_features
+ self.linear = nn.Linear(in_features, out_features)
+ self.nonlinearity = nn.SiLU()
+
+ def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ temb = self.linear(self.nonlinearity(temb))
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
+ return gate_msa, gate_mlp
+
+
+class MotifVideoConditionEmbedding(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ ):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ ) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep)
+ compute_dtype = next(
+ (p.dtype for p in self.timestep_embedder.parameters() if p.is_floating_point()),
+ torch.float32, # safe fallback
+ )
+ conditioning = self.timestep_embedder(timesteps_proj.to(compute_dtype)) # (N, D)
+
+ return conditioning
+
+
+class MotifVideoRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int,
+ patch_size_t: int,
+ rope_dim: List[int],
+ theta: float = 256.0,
+ ):
+ """
+ Rotary Positional Embedding (RoPE) for video latents.
+
+ Args:
+ patch_size (`int`): Spatial patch size.
+ patch_size_t (`int`): Temporal patch size.
+ rope_dim (`List[int]`): Dimensions for RoPE across [Time, Height, Width] axes.
+ theta (`float`, *optional*, defaults to 256.0): Base frequency for rotary embeddings.
+ """
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.rope_dim = rope_dim
+ self.theta = theta
+
+ def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ rope_sizes = [
+ num_frames // self.patch_size_t,
+ height // self.patch_size,
+ width // self.patch_size,
+ ]
+
+ axes_grids = []
+ for i in range(3):
+ grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
+ axes_grids.append(grid)
+ grid = torch.meshgrid(*axes_grids, indexing="ij")
+ grid = torch.stack(grid, dim=0)
+
+ freqs = []
+ for i in range(3):
+ freq = get_1d_rotary_pos_embed(
+ dim=self.rope_dim[i],
+ pos=grid[i].reshape(-1),
+ theta=self.theta,
+ use_real=True,
+ freqs_dtype=torch.float64,
+ )
+ freqs.append(freq)
+
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1)
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1)
+ return freqs_cos, freqs_sin
+
+
+class MotifVideoImageProjection(nn.Module):
+ def __init__(self, in_features: int, hidden_size: int):
+ super().__init__()
+ self.norm_in = nn.LayerNorm(in_features)
+ self.linear_1 = nn.Linear(in_features, in_features)
+ self.act_fn = nn.GELU()
+ self.linear_2 = nn.Linear(in_features, hidden_size)
+ self.norm_out = nn.LayerNorm(hidden_size)
+
+ def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm_in(image_embeds)
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ hidden_states = self.norm_out(hidden_states)
+ return hidden_states
+
+
+class MotifVideoSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 4.0,
+ qk_norm: str = "rms_norm",
+ norm_type: str = "layer_norm",
+ enable_text_cross_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+ mlp_dim = int(hidden_size * mlp_ratio)
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ bias=True,
+ processor=MotifVideoAttnProcessor2_0(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ self.enable_text_cross_attention = enable_text_cross_attention
+ if enable_text_cross_attention:
+ self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
+ self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
+ self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
+ nn.init.zeros_(self.cross_attn_out_proj.weight)
+ nn.init.zeros_(self.cross_attn_out_proj.bias)
+
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type=norm_type)
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ image_embed_seq_len: int = 0,
+ ) -> torch.Tensor:
+ video_tokens = hidden_states.shape[1]
+ encoder_seq_length = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ residual = hidden_states
+
+ # 1. Input normalization
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ norm_hidden_states, norm_encoder_hidden_states = (
+ norm_hidden_states[:, :-encoder_seq_length, :],
+ norm_hidden_states[:, -encoder_seq_length:, :],
+ )
+
+ # 2. Attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
+ if self.enable_text_cross_attention:
+ txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
+ text_mask = None
+ if attention_mask is not None:
+ text_mask = attention_mask[:, :, :, video_tokens + image_embed_seq_len :]
+ cross_q = self.cross_attn_query_proj(attn_output)
+ cross_output, _ = self.attn(
+ hidden_states=cross_q,
+ cross_attn_inputs=CrossAttentionInputs(
+ query=cross_q,
+ key=txt_kv,
+ value=txt_kv,
+ ),
+ attention_mask=text_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ attn_output = attn_output + self.cross_attn_out_proj(cross_output)
+
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
+
+ # 3. Modulation and residual connection
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
+ hidden_states = hidden_states + residual
+
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, :-encoder_seq_length, :],
+ hidden_states[:, -encoder_seq_length:, :],
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class MotifVideoTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float,
+ qk_norm: str = "rms_norm",
+ norm_type: str = "layer_norm",
+ enable_text_cross_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type=norm_type)
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type=norm_type)
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ added_kv_proj_dim=hidden_size,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ context_pre_only=False,
+ bias=True,
+ processor=MotifVideoAttnProcessor2_0(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+
+ self.enable_text_cross_attention = enable_text_cross_attention
+ if enable_text_cross_attention:
+ self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
+ self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
+ self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
+ nn.init.zeros_(self.cross_attn_out_proj.weight)
+ nn.init.zeros_(self.cross_attn_out_proj.bias)
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ image_embed_seq_len: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Input normalization
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # 2. Joint attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # 3. Modulation and residual connection
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
+
+ # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
+ if self.enable_text_cross_attention:
+ txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
+ text_mask = None
+ if attention_mask is not None:
+ text_mask = attention_mask[:, :, :, hidden_states.shape[1] + image_embed_seq_len :]
+ cross_q = self.cross_attn_query_proj(attn_output)
+ cross_output, _ = self.attn(
+ hidden_states=cross_q,
+ cross_attn_inputs=CrossAttentionInputs(
+ query=cross_q,
+ key=txt_kv,
+ value=txt_kv,
+ ),
+ attention_mask=text_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + self.cross_attn_out_proj(cross_output)
+
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ # 4. Feed-forward
+ ff_output = self.ff(norm_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return hidden_states, encoder_hidden_states
+
+
+TransformerBlockRegistry.register(
+ model_class=MotifVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+)
+TransformerBlockRegistry.register(
+ model_class=MotifVideoSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+)
+
+
+class MotifVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ r"""
+ A Transformer model for video-like data used in the Motif-Video model.
+
+ Args:
+ in_channels (`int`, defaults to `33`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of dual-stream blocks to use.
+ num_single_layers (`int`, defaults to `40`):
+ The number of layers of single-stream blocks to use.
+ num_decoder_layers (`int`, defaults to `0`):
+ The number of decoder layers in single-stream blocks.
+ mlp_ratio (`float`, defaults to `4.0`):
+ The ratio of the hidden layer size to the input size in the feedforward network.
+ patch_size (`int`, defaults to `2`):
+ The size of the spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of the temporal patches to use in the patch embedding layer.
+ qk_norm (`str`, defaults to `rms_norm`):
+ The normalization to use for the query and key projections in the attention layers.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ image_embed_dim (`int`, *optional*):
+ Input dimension of image embeddings from a vision encoder. If provided, enables image conditioning.
+ rope_theta (`float`, defaults to `256.0`):
+ The value of theta to use in the RoPE layer.
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions of the axes to use in the RoPE layer.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
+ _no_split_modules = [
+ "MotifVideoTransformerBlock",
+ "MotifVideoSingleTransformerBlock",
+ "MotifVideoPatchEmbed",
+ ]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 33,
+ out_channels: int = 16,
+ num_attention_heads: int = 24,
+ attention_head_dim: int = 128,
+ num_layers: int = 20,
+ num_single_layers: int = 40,
+ num_decoder_layers: int = 0,
+ mlp_ratio: float = 4.0,
+ patch_size: int = 2,
+ patch_size_t: int = 1,
+ qk_norm: str = "rms_norm",
+ norm_type: str = "layer_norm",
+ text_embed_dim: int = 4096,
+ image_embed_dim: int | None = None,
+ rope_theta: float = 256.0,
+ rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
+ enable_text_cross_attention_dual: bool = False,
+ enable_text_cross_attention_single: bool = False,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Latent and condition embedders
+ self.x_embedder = MotifVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
+ self.context_embedder = PixArtAlphaTextProjection(in_features=text_embed_dim, hidden_size=inner_dim)
+
+ # First frame conditioning: Image conditioning embedders
+ self.image_embed_dim = image_embed_dim
+ if image_embed_dim is not None:
+ self.image_embedder = MotifVideoImageProjection(in_features=image_embed_dim, hidden_size=inner_dim)
+
+ self.time_text_embed = MotifVideoConditionEmbedding(inner_dim)
+
+ # 2. RoPE
+ self.rope = MotifVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
+
+ # Cross-attention config
+ self.enable_text_cross_attention_dual = enable_text_cross_attention_dual
+ self.enable_text_cross_attention_single = enable_text_cross_attention_single
+
+ # 3. Dual stream transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ MotifVideoTransformerBlock(
+ num_attention_heads,
+ attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ norm_type=norm_type,
+ enable_text_cross_attention=enable_text_cross_attention_dual,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Single stream transformer blocks
+ # Encoder blocks get cross-attention; decoder blocks do not (no text stream in decoder)
+ num_encoder_single = num_single_layers - num_decoder_layers
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ MotifVideoSingleTransformerBlock(
+ num_attention_heads,
+ attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ norm_type=norm_type,
+ enable_text_cross_attention=enable_text_cross_attention_single
+ if i < num_encoder_single
+ else False,
+ )
+ for i in range(num_single_layers)
+ ]
+ )
+
+ # 5. Output projection
+ self.norm_out = AdaLayerNormContinuous(
+ inner_dim,
+ inner_dim,
+ elementwise_affine=False,
+ eps=1e-6,
+ norm_type=norm_type,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
+
+ # Verify cross-attention config matches actual block state.
+ # Catches silent misconfiguration (e.g. checkpoint config with renamed keys).
+ for i, block in enumerate(self.transformer_blocks):
+ if block.enable_text_cross_attention != enable_text_cross_attention_dual:
+ raise ValueError(
+ f"transformer_blocks[{i}].enable_text_cross_attention="
+ f"{block.enable_text_cross_attention}, expected {enable_text_cross_attention_dual}. "
+ f"Check checkpoint config.json key names match __init__ parameters."
+ )
+ for i, block in enumerate(self.single_transformer_blocks):
+ expected = enable_text_cross_attention_single if i < num_encoder_single else False
+ if block.enable_text_cross_attention != expected:
+ raise ValueError(
+ f"single_transformer_blocks[{i}].enable_text_cross_attention="
+ f"{block.enable_text_cross_attention}, expected {expected}. "
+ f"Check checkpoint config.json key names match __init__ parameters."
+ )
+
+ self.gradient_checkpointing = False
+ self.num_decoder_layers = num_decoder_layers
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def _maybe_gradient_checkpoint_block(self, block, *args):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ return self._gradient_checkpointing_func(block, *args)
+ return block(*args)
+
+ def _create_attention_mask(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ attention_mask = F.pad(
+ encoder_attention_mask.to(torch.bool),
+ (hidden_states.shape[1], 0),
+ value=True,
+ )
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
+ return attention_mask
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor | None = None,
+ image_embeds: torch.Tensor | None = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Forward pass of the MotifVideoTransformer3DModel.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input latent tensor of shape `(batch_size, channels, num_frames, height, width)`.
+ timestep (`torch.LongTensor`):
+ Diffusion timesteps of shape `(batch_size,)`.
+ encoder_hidden_states (`torch.Tensor`):
+ Text conditioning of shape `(batch_size, sequence_length, embed_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ Mask for text conditioning of shape `(batch_size, sequence_length)`.
+ image_embeds (`torch.Tensor`, *optional*):
+ Image embeddings from vision encoder of shape `(batch_size, num_tokens, embed_dim)`.
+ attention_kwargs (`dict`, *optional*):
+ Additional arguments for attention processors.
+ return_dict (`bool`, defaults to `True`):
+ Whether to return a [`~models.modeling_outputs.Transformer2DModelOutput`].
+
+ Returns:
+ [`~models.modeling_outputs.Transformer2DModelOutput`] or `tuple`:
+ The predicted samples.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, _, num_frames, height, width = hidden_states.shape
+ p, p_t = self.config.patch_size, self.config.patch_size_t
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p
+ post_patch_width = width // p
+
+ # 1. RoPE
+ image_rotary_emb = self.rope(hidden_states, timestep=timestep)
+
+ # 2. Conditional embeddings
+ temb = self.time_text_embed(timestep)
+ hidden_states = self.x_embedder(hidden_states)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ # First frame conditioning: Image embeddings from vision encoder
+ if image_embeds is not None:
+ image_embeds = self.image_embedder(image_embeds)
+ encoder_hidden_states = torch.cat([image_embeds, encoder_hidden_states], dim=1)
+ if encoder_attention_mask is not None:
+ image_mask = torch.ones(
+ image_embeds.shape[0],
+ image_embeds.shape[1],
+ device=encoder_attention_mask.device,
+ dtype=encoder_attention_mask.dtype,
+ )
+ encoder_attention_mask = torch.cat([image_mask, encoder_attention_mask], dim=1)
+
+ # image_embed_seq_len: used by cross-attention blocks to slice text from encoder_hidden_states
+ image_embed_seq_len = image_embeds.shape[1] if image_embeds is not None else 0
+
+ decoder_hidden_states = hidden_states.clone()
+
+ if encoder_attention_mask is not None:
+ attention_mask = self._create_attention_mask(
+ hidden_states=hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ attention_mask = None
+
+ # 3. Dual stream transformer blocks
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ image_embed_seq_len,
+ )
+
+ # 4. Single stream transformer blocks (Encoder)
+ single_transformer_blocks = self.single_transformer_blocks
+
+ for block in single_transformer_blocks[: len(single_transformer_blocks) - self.num_decoder_layers]:
+ hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ image_embed_seq_len,
+ )
+
+ # 5. Single stream transformer blocks (Decoder)
+ if self.num_decoder_layers > 0:
+ encoder_hidden_states = hidden_states
+ attention_mask = None
+
+ for block in single_transformer_blocks[-self.num_decoder_layers :]:
+ decoder_hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
+ block,
+ decoder_hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ )
+
+ hidden_states = decoder_hidden_states
+
+ # 6. Output projection
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size,
+ post_patch_num_frames,
+ post_patch_height,
+ post_patch_width,
+ -1,
+ p_t,
+ p,
+ p,
+ )
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer2DModelOutput(
+ sample=hidden_states,
+ )
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index ae1849a587e8..6c71ab9dd3f8 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -340,6 +340,12 @@
]
)
_import_structure["mochi"] = ["MochiPipeline"]
+ _import_structure["motif_video"] = [
+ "MotifVideoPipeline",
+ "MotifVideoImage2VideoPipeline",
+ "MotifVideoPipelineOutput",
+ ]
+ _import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["ernie_image"] = ["ErnieImagePipeline"]
_import_structure["ovis_image"] = ["OvisImagePipeline"]
@@ -778,6 +784,12 @@
MarigoldNormalsPipeline,
)
from .mochi import MochiPipeline
+ from .motif_video import (
+ MotifVideoImage2VideoPipeline,
+ MotifVideoPipeline,
+ MotifVideoPipelineOutput,
+ )
+ from .musicldm import MusicLDMPipeline
from .nucleusmoe_image import NucleusMoEImagePipeline
from .omnigen import OmniGenPipeline
from .ovis_image import OvisImagePipeline
diff --git a/src/diffusers/pipelines/motif_video/__init__.py b/src/diffusers/pipelines/motif_video/__init__.py
new file mode 100644
index 000000000000..ee1d7c72ee65
--- /dev/null
+++ b/src/diffusers/pipelines/motif_video/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_motif_video"] = ["MotifVideoPipeline"]
+ _import_structure["pipeline_motif_video_image2video"] = ["MotifVideoImage2VideoPipeline"]
+ _import_structure["pipeline_output"] = ["MotifVideoPipelineOutput"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_motif_video import MotifVideoPipeline
+ from .pipeline_motif_video_image2video import MotifVideoImage2VideoPipeline
+ from .pipeline_output import MotifVideoPipelineOutput
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/motif_video/pipeline_motif_video.py b/src/diffusers/pipelines/motif_video/pipeline_motif_video.py
new file mode 100644
index 000000000000..099a55e23d90
--- /dev/null
+++ b/src/diffusers/pipelines/motif_video/pipeline_motif_video.py
@@ -0,0 +1,868 @@
+# Copyright 2026 Motif Technologies, Inc. and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ...utils import is_transformers_version
+
+
+# Check transformers version before importing T5Gemma2Encoder
+if not is_transformers_version(">=", "5.1.0"):
+ import transformers
+
+ raise ImportError(
+ f"MotifVideoPipeline requires transformers>=5.1.0. "
+ f"Found: {transformers.__version__}. "
+ "Please upgrade transformers: pip install transformers --upgrade"
+ )
+
+from transformers import BatchEncoding, PreTrainedTokenizerBase, SiglipImageProcessor, T5Gemma2Encoder
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...guiders import AdaptiveProjectedGuidance, ClassifierFreeGuidance, SkipLayerGuidance
+from ...models import AutoencoderKLWan
+from ...models.transformers import MotifVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import MotifVideoPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import MotifVideoPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Load the Motif-Video pipeline
+ >>> motif_video_model_id = "MotifTechnologies/Motif-Video-2B"
+ >>> pipe = MotifVideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> video = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=640,
+ ... height=352,
+ ... num_frames=65,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+def get_linear_quadratic_sigmas(
+ num_inference_steps: int,
+ linear_quadratic_emulating_steps: int = 250,
+) -> np.ndarray:
+ """
+ Compute a linear-quadratic sigma schedule for flow matching.
+
+ This schedule combines:
+ - First half: Linear interpolation from high noise to medium noise (slow denoising)
+ - Second half: Quadratic interpolation from medium noise to clean (faster denoising)
+
+ Convention:
+ - sigma=1.0 represents pure noise
+ - sigma=0.0 represents clean image
+ - Output sigmas are in descending order (1.0 → ~0)
+
+ Args:
+ num_inference_steps: Total number of denoising steps (must be even).
+ linear_quadratic_emulating_steps: Controls the slope of linear interpolation.
+ Higher values result in gentler slope in the first half.
+
+ Returns:
+ np.ndarray: Array of sigma values with shape (num_inference_steps,).
+ The scheduler will append a terminal 0.
+
+ Raises:
+ ValueError: If num_inference_steps is not even.
+ """
+ if num_inference_steps % 2 != 0:
+ raise ValueError(
+ f"num_inference_steps must be even for linear-quadratic schedule, but got {num_inference_steps}"
+ )
+
+ steps = num_inference_steps
+ N = linear_quadratic_emulating_steps
+ half_steps = steps // 2
+
+ # First half: linear interpolation from 1 toward 0
+ linear_part = np.linspace(1.0, 0.0, N + 1)[:half_steps]
+
+ # Second half: quadratic interpolation
+ x = np.linspace(0.0, 1.0, half_steps + 1)
+ scale_factor = half_steps / N - 1
+ quadratic_part = x**2 * scale_factor - scale_factor
+
+ # Concatenate and exclude the last 0 (scheduler appends terminal 0)
+ sigmas = np.concatenate([linear_part, quadratic_part])
+ sigmas = sigmas[:-1]
+
+ return sigmas.astype(np.float32)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ use_linear_quadratic_schedule: bool = False,
+ linear_quadratic_emulating_steps: int = 250,
+ **kwargs,
+):
+ r"""
+ Retrieve timesteps from the scheduler.
+
+ Args:
+ scheduler: The noise scheduler to use.
+ num_inference_steps: Number of denoising steps.
+ device: Device to place timesteps on.
+ timesteps: Custom timestep values (mutually exclusive with sigmas).
+ sigmas: Custom sigma values (mutually exclusive with timesteps).
+ use_linear_quadratic_schedule: If True, use linear-quadratic sigma schedule.
+ linear_quadratic_emulating_steps: Controls the linear portion slope.
+ **kwargs: Additional arguments passed to scheduler.set_timesteps().
+
+ Returns:
+ Tuple of (timesteps, num_inference_steps).
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+
+ if use_linear_quadratic_schedule:
+ if sigmas is not None:
+ raise ValueError(
+ "Cannot use both `sigmas` and `use_linear_quadratic_schedule`. "
+ "The linear-quadratic schedule computes sigmas automatically."
+ )
+ if num_inference_steps is None:
+ raise ValueError("`num_inference_steps` must be provided when using `use_linear_quadratic_schedule`.")
+ sigmas = get_linear_quadratic_sigmas(
+ num_inference_steps=num_inference_steps,
+ linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
+ )
+
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class MotifVideoPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using Motif-Video.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`MotifVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5Gemma2Encoder`]):
+ Primary text encoder for encoding text prompts into embeddings.
+ tokenizer ([`PreTrainedTokenizerBase`]):
+ Tokenizer corresponding to the primary text encoder.
+ guider ([`ClassifierFreeGuidance`] or [`SkipLayerGuidance`] or [`AdaptiveProjectedGuidance`]):
+ The guidance method to use. Can be `ClassifierFreeGuidance`, `SkipLayerGuidance`, or
+ `AdaptiveProjectedGuidance`. For video generation with `AdaptiveProjectedGuidance`, use
+ `normalization_dims="spatial"` for spatial-only normalization that preserves temporal quality by
+ normalizing over [C, H, W] per frame instead of collapsing the temporal dimension.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = ["feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLWan,
+ text_encoder: T5Gemma2Encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ transformer: MotifVideoTransformer3DModel,
+ guider: Union[ClassifierFreeGuidance, SkipLayerGuidance, AdaptiveProjectedGuidance] = None,
+ feature_extractor: Optional[SiglipImageProcessor] = None,
+ ):
+ super().__init__()
+
+ if guider is None:
+ guider = ClassifierFreeGuidance()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ guider=guider,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512
+ )
+
+ def _get_default_embeds(
+ self,
+ text_encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ prompt: Union[str, List[str]],
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ dtype = dtype or text_encoder.dtype
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_inputs = BatchEncoding(
+ {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()}
+ )
+
+ prompt_embeds = text_encoder(**text_inputs)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, text_inputs.attention_mask
+
+ def _get_prompt_embeds(
+ self,
+ text_encoder: T5Gemma2Encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ prompt_embeds_kwargs = {
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "prompt": prompt,
+ "max_sequence_length": max_sequence_length,
+ "device": device,
+ "dtype": dtype,
+ }
+ prompt_embeds, prompt_attention_mask = self._get_default_embeds(**prompt_embeds_kwargs)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be encoded.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos to generate per prompt.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ max_sequence_length (`int`, defaults to 512):
+ Maximum sequence length for the tokenizer.
+ device (`torch.device`, *optional*):
+ Device to place tensors on.
+ dtype (`torch.dtype`, *optional*):
+ Data type for tensors.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ seq_len = prompt_embeds.shape[1]
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.bool()
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
+
+ return (
+ prompt_embeds,
+ prompt_attention_mask,
+ )
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ batch_size,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % self.vae_scale_factor_spatial != 0 or width % self.vae_scale_factor_spatial != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor_spatial} but are {height} and {width}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None:
+ if not isinstance(negative_prompt, (str, list)):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size:
+ raise ValueError(
+ f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ def _prepare_negative_prompt(
+ self,
+ negative_prompt: Optional[Union[str, List[str]]],
+ batch_size: int,
+ ) -> List[str]:
+ """Prepare negative_prompt to match batch_size."""
+ if negative_prompt is None:
+ return [""] * batch_size
+ if isinstance(negative_prompt, str):
+ return [negative_prompt] * batch_size
+ return negative_prompt
+
+ @staticmethod
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) / latents_std
+ return latents
+
+ @staticmethod
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std + latents_mean
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 736,
+ width: int = 1280,
+ num_frames: int = 121,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 736,
+ width: int = 1280,
+ num_frames: int = 121,
+ frame_rate: int = 25,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ use_linear_quadratic_schedule: bool = True,
+ linear_quadratic_emulating_steps: int = 250,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ use_attention_mask: bool = True,
+ vae_batch_size: int | None = None,
+ ):
+ r"""
+ The call function to the pipeline for text-to-video generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance.
+ height (`int`, defaults to `352`):
+ The height in pixels of the generated video.
+ width (`int`, defaults to `640`):
+ The width in pixels of the generated video.
+ num_frames (`int`, defaults to `65`):
+ The number of video frames to generate.
+ frame_rate (`int`, defaults to `25`):
+ Frame rate for the output video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process.
+ use_linear_quadratic_schedule (`bool`, defaults to `True`):
+ Whether to use a linear-quadratic sigma schedule instead of the default linear schedule. Requires
+ `num_inference_steps` to be even.
+ linear_quadratic_emulating_steps (`int`, defaults to `250`):
+ Controls the slope of linear interpolation in the first half of the linear-quadratic schedule. Only
+ used when `use_linear_quadratic_schedule=True`.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ PyTorch Generator object(s) for deterministic generation.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `"pil"`, `"np"`, or `"latent"`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~MotifVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ Arguments passed to the attention processor.
+ callback_on_step_end (`Callable`, *optional*):
+ A function or subclass of `PipelineCallback` or `MultiPipelineCallbacks` called at the end of each
+ denoising step.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, defaults to `512`):
+ Maximum sequence length for the tokenizer.
+ use_attention_mask (`bool`, defaults to `True`):
+ Whether to use attention masks for text embeddings.
+ vae_batch_size (`int`, *optional*):
+ Batch size for VAE decoding. If provided and latents batch size is larger, VAE decoding will be done in
+ chunks.
+
+ Examples:
+
+ Returns:
+ [`~MotifVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~MotifVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list of generated video frames.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 2. Check inputs
+ self.check_inputs(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ batch_size=batch_size,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ prompt_embeds, prompt_attention_mask = self.encode_prompt(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.guider._enabled and self.guider.num_conditions > 1:
+ negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
+ negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # 4. Prepare latents
+ num_channels_latents = self.vae.config.z_dim
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ self.transformer.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ packed_latent_height = latent_height // self.transformer_spatial_patch_size
+ packed_latent_width = latent_width // self.transformer_spatial_patch_size
+ packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size
+ video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width
+
+ if use_linear_quadratic_schedule:
+ sigmas = None
+ else:
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ use_linear_quadratic_schedule=use_linear_quadratic_schedule,
+ linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
+ mu=mu,
+ )
+
+ # Prepare conditioning tensors (T2V mode: no first-frame conditioning)
+ batch_size, latent_channels, latent_num_frames, latent_height, latent_width = latents.shape
+ latent_condition = torch.zeros(
+ batch_size,
+ latent_channels,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ device=latents.device,
+ dtype=latents.dtype,
+ )
+ latent_mask = torch.zeros(
+ batch_size,
+ 1,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ device=latents.device,
+ dtype=latents.dtype,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ # Concatenate current latents with conditioning: [latents | latent_condition | latent_mask]
+ hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1)
+
+ timestep = t.expand(latents.shape[0])
+
+ # Guider: collect model inputs
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ }
+ if use_attention_mask:
+ guider_inputs["encoder_attention_mask"] = (
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+ guider_state = self.guider.prepare_inputs(guider_inputs)
+
+ # Sigma injection for guiders that support sigma-based gating
+ if hasattr(self.guider, "_current_sigma") and hasattr(self.scheduler, "sigmas"):
+ self.guider._current_sigma = float(self.scheduler.sigmas[i])
+
+ for guider_state_batch in guider_state:
+ self.guider.prepare_models(self.transformer)
+
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ context_name = getattr(guider_state_batch, self.guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ noise_pred = self.transformer(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0].clone()
+
+ guider_state_batch.noise_pred = noise_pred
+ self.guider.cleanup_models(self.transformer)
+
+ noise_pred = self.guider(guider_state)[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ if "negative_prompt_embeds" in callback_outputs:
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds")
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std)
+ if vae_batch_size is not None and latents.shape[0] > vae_batch_size:
+ video_chunks = []
+ for i in range(0, latents.shape[0], vae_batch_size):
+ chunk = latents[i : i + vae_batch_size]
+ video_chunks.append(self.vae.decode(chunk, return_dict=False)[0])
+ video = torch.cat(video_chunks, dim=0)
+ del video_chunks
+ else:
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return MotifVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py b/src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py
new file mode 100644
index 000000000000..e8bf594da0a7
--- /dev/null
+++ b/src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py
@@ -0,0 +1,921 @@
+# Copyright 2026 Motif Technologies, Inc. and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ...utils import is_transformers_version
+
+
+# Check transformers version before importing T5Gemma2Encoder
+if not is_transformers_version(">=", "5.1.0"):
+ import transformers
+
+ raise ImportError(
+ f"MotifVideoImage2VideoPipeline requires transformers>=5.1.0. "
+ f"Found: {transformers.__version__}. "
+ "Please upgrade transformers: pip install transformers --upgrade"
+ )
+
+from transformers import BatchEncoding, PreTrainedTokenizerBase, SiglipImageProcessor, T5Gemma2Encoder
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...guiders import AdaptiveProjectedGuidance, ClassifierFreeGuidance, SkipLayerGuidance
+from ...image_processor import PipelineImageInput
+from ...models import AutoencoderKLWan
+from ...models.transformers import MotifVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import MotifVideoPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import MotifVideoImage2VideoPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> # Load the Motif-Video image-to-video pipeline
+ >>> motif_video_model_id = "MotifTechnologies/Motif-Video-2B"
+ >>> pipe = MotifVideoImage2VideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> # Load an image
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.png"
+ ... )
+
+ >>> prompt = "An astronaut is walking on the moon surface, kicking up dust with each step"
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> video = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=640,
+ ... height=352,
+ ... num_frames=65,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.motif_video.pipeline_motif_video.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.motif_video.pipeline_motif_video.get_linear_quadratic_sigmas
+def get_linear_quadratic_sigmas(
+ num_inference_steps: int,
+ linear_quadratic_emulating_steps: int = 250,
+) -> np.ndarray:
+ if num_inference_steps % 2 != 0:
+ raise ValueError(
+ f"num_inference_steps must be even for linear-quadratic schedule, but got {num_inference_steps}"
+ )
+
+ steps = num_inference_steps
+ N = linear_quadratic_emulating_steps
+ half_steps = steps // 2
+
+ linear_part = np.linspace(1.0, 0.0, N + 1)[:half_steps]
+
+ x = np.linspace(0.0, 1.0, half_steps + 1)
+ scale_factor = half_steps / N - 1
+ quadratic_part = x**2 * scale_factor - scale_factor
+
+ sigmas = np.concatenate([linear_part, quadratic_part])
+ sigmas = sigmas[:-1]
+
+ return sigmas.astype(np.float32)
+
+
+# Copied from diffusers.pipelines.motif_video.pipeline_motif_video.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ use_linear_quadratic_schedule: bool = False,
+ linear_quadratic_emulating_steps: int = 250,
+ **kwargs,
+):
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+
+ if use_linear_quadratic_schedule:
+ if sigmas is not None:
+ raise ValueError(
+ "Cannot use both `sigmas` and `use_linear_quadratic_schedule`. "
+ "The linear-quadratic schedule computes sigmas automatically."
+ )
+ if num_inference_steps is None:
+ raise ValueError("`num_inference_steps` must be provided when using `use_linear_quadratic_schedule`.")
+ sigmas = get_linear_quadratic_sigmas(
+ num_inference_steps=num_inference_steps,
+ linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
+ )
+
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class MotifVideoImage2VideoPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image-to-video generation using Motif-Video with first frame conditioning.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`MotifVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5Gemma2Encoder`]):
+ Primary text encoder for encoding text prompts into embeddings.
+ tokenizer ([`PreTrainedTokenizerBase`]):
+ Tokenizer corresponding to the primary text encoder.
+ feature_extractor ([`SiglipImageProcessor`]):
+ Image processor for the SigLIP vision encoder.
+ guider ([`ClassifierFreeGuidance`] or [`SkipLayerGuidance`] or [`AdaptiveProjectedGuidance`]):
+ The guidance method to use. Can be `ClassifierFreeGuidance`, `SkipLayerGuidance`, or
+ `AdaptiveProjectedGuidance`. For video generation with `AdaptiveProjectedGuidance`, use
+ `normalization_dims="spatial"` for spatial-only normalization that preserves temporal quality by
+ normalizing over [C, H, W] per frame instead of collapsing the temporal dimension.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLWan,
+ text_encoder: T5Gemma2Encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ transformer: MotifVideoTransformer3DModel,
+ feature_extractor: SiglipImageProcessor,
+ guider: Optional[Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, SkipLayerGuidance]] = None,
+ ):
+ super().__init__()
+
+ if guider is None:
+ guider = ClassifierFreeGuidance()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ guider=guider,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512
+ )
+
+ def _get_default_embeds(
+ self,
+ text_encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ prompt: Union[str, List[str]],
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ dtype = dtype or text_encoder.dtype
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_inputs = BatchEncoding(
+ {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()}
+ )
+
+ prompt_embeds = text_encoder(**text_inputs)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, text_inputs.attention_mask
+
+ def _get_prompt_embeds(
+ self,
+ text_encoder: T5Gemma2Encoder,
+ tokenizer: PreTrainedTokenizerBase,
+ prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ prompt_embeds_kwargs = {
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "prompt": prompt,
+ "max_sequence_length": max_sequence_length,
+ "device": device,
+ "dtype": dtype,
+ }
+ prompt_embeds, prompt_attention_mask = self._get_default_embeds(**prompt_embeds_kwargs)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be encoded.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos to generate per prompt.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ max_sequence_length (`int`, defaults to 512):
+ Maximum sequence length for the tokenizer.
+ device (`torch.device`, *optional*):
+ Device to place tensors on.
+ dtype (`torch.dtype`, *optional*):
+ Data type for tensors.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ seq_len = prompt_embeds.shape[1]
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.bool()
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
+
+ return (
+ prompt_embeds,
+ prompt_attention_mask,
+ )
+
+ @property
+ def vision_encoder(self):
+ """Get the vision encoder from T5Gemma2."""
+ return self.text_encoder.vision_tower
+
+ @staticmethod
+ def _get_image_embeds(
+ image_encoder,
+ feature_extractor: SiglipImageProcessor,
+ image,
+ device: torch.device,
+ ) -> torch.Tensor:
+ """Helper to encode single image with SigLIP."""
+ image_encoder_dtype = next(image_encoder.parameters()).dtype
+
+ if isinstance(image, torch.Tensor):
+ image = feature_extractor.preprocess(
+ images=image.float(),
+ do_resize=True,
+ do_rescale=False,
+ do_normalize=True,
+ do_convert_rgb=True,
+ return_tensors="pt",
+ )
+ else:
+ image = feature_extractor.preprocess(
+ images=image,
+ do_resize=True,
+ do_rescale=False,
+ do_normalize=True,
+ do_convert_rgb=True,
+ return_tensors="pt",
+ )
+
+ image = image.to(device=device, dtype=image_encoder_dtype)
+ return image_encoder(**image).last_hidden_state
+
+ def _prepare_first_frame_conditioning(
+ self,
+ video: torch.Tensor,
+ latents: torch.Tensor,
+ use_conditioning: bool,
+ generator: Optional[torch.Generator] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Prepare first frame conditioning tensors.
+
+ For I2V mode:
+ 1. Extract and VAE-encode first frame from video
+ 2. Create latent_condition with first frame latents at frame 0
+ 3. Create latent_mask with 1.0 at frame 0
+ 4. Get image_embeds from vision encoder
+
+ For T2V mode:
+ 1. Return zeros for latent_condition and latent_mask, None for image_embeds
+
+ Args:
+ video: Input video tensor [batch_size, frames, channels, height, width] in [-1, 1]
+ latents: Latents [batch_size, channels, num_frames, height, width]
+ use_conditioning: Whether to use first-frame conditioning (True for I2V)
+ generator: Optional random number generator
+
+ Returns:
+ Tuple of (latent_condition, latent_mask, image_embeds).
+ """
+ batch_size, latent_channels, latent_num_frames, latent_height, latent_width = latents.shape
+ device = latents.device
+ dtype = latents.dtype
+
+ use_conditioning = use_conditioning and (latent_num_frames > 1)
+
+ latent_condition = torch.zeros(
+ batch_size, latent_channels, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype
+ )
+ latent_mask = torch.zeros(
+ batch_size, 1, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype
+ )
+ image_embeds = None
+
+ if use_conditioning:
+ with torch.no_grad():
+ # video shape: [B, F, C, H, W] -> [B, C, F, H, W] for VAE
+ first_frame_latents = self.vae.encode(video[:, 0:1].permute(0, 2, 1, 3, 4)).latent_dist.sample(
+ generator=generator
+ )
+ first_frame_latents = self._normalize_latents(
+ latents=first_frame_latents,
+ latents_mean=self.vae.config.latents_mean,
+ latents_std=self.vae.config.latents_std,
+ )
+
+ latent_condition = first_frame_latents.repeat(1, 1, latent_num_frames, 1, 1)
+ latent_condition[:, :, 1:, :, :] = 0
+
+ latent_mask[:, :, 0] = 1.0
+
+ first_frame_vision = video[:, 0] # [B, C, H, W]
+ first_frame_vision = ((first_frame_vision + 1) / 2).clamp(0, 1)
+
+ with torch.no_grad():
+ image_embeds = self._get_image_embeds(
+ image_encoder=self.vision_encoder,
+ feature_extractor=self.feature_extractor,
+ image=first_frame_vision,
+ device=device,
+ )
+
+ return latent_condition, latent_mask, image_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ batch_size,
+ image=None,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % self.vae_scale_factor_spatial != 0 or width % self.vae_scale_factor_spatial != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor_spatial} but are {height} and {width}."
+ )
+
+ if image is not None:
+ if isinstance(image, torch.Tensor):
+ if image.dim() != 4:
+ raise ValueError(f"`image` must be a 4D tensor [B, C, H, W], got {image.dim()}D")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None:
+ if not isinstance(negative_prompt, (str, list)):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size:
+ raise ValueError(
+ f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape, "
+ f"got {prompt_embeds.shape} and {negative_prompt_embeds.shape}."
+ )
+
+ def _prepare_negative_prompt(
+ self,
+ negative_prompt: Optional[Union[str, List[str]]],
+ batch_size: int,
+ ) -> List[str]:
+ """Prepare negative_prompt to match batch_size."""
+ if negative_prompt is None:
+ return [""] * batch_size
+ if isinstance(negative_prompt, str):
+ return [negative_prompt] * batch_size
+ return negative_prompt
+
+ @staticmethod
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) / latents_std
+ return latents
+
+ @staticmethod
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std + latents_mean
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 736,
+ width: int = 1280,
+ num_frames: int = 121,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 736,
+ width: int = 1280,
+ num_frames: int = 121,
+ frame_rate: int = 25,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ use_linear_quadratic_schedule: bool = True,
+ linear_quadratic_emulating_steps: int = 250,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for image-to-video generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to use as the first frame for video generation.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the video generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation.
+ height (`int`, defaults to `352`):
+ The height in pixels of the generated video.
+ width (`int`, defaults to `640`):
+ The width in pixels of the generated video.
+ num_frames (`int`, defaults to `65`):
+ The number of video frames to generate.
+ frame_rate (`int`, defaults to `25`):
+ Frame rate for the output video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process.
+ use_linear_quadratic_schedule (`bool`, defaults to `True`):
+ Whether to use a linear-quadratic sigma schedule.
+ linear_quadratic_emulating_steps (`int`, defaults to `250`):
+ Controls the slope of linear interpolation in the first half.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ PyTorch Generator object(s) for deterministic generation.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~MotifVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ Arguments passed to the attention processor.
+ callback_on_step_end (`Callable`, *optional*):
+ A function or subclass of `PipelineCallback` called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, defaults to `512`):
+ Maximum sequence length for the tokenizer.
+
+ Examples:
+
+ Returns:
+ [`~MotifVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~MotifVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list of generated video frames.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 2. Check inputs
+ self.check_inputs(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ batch_size=batch_size,
+ image=image,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ device = self._execution_device
+
+ # 3. Preprocess image
+ if latents is None:
+ if isinstance(image, torch.Tensor):
+ image = image.to(device=device, dtype=self.transformer.dtype)
+ if image.min() >= 0 and image.max() <= 1:
+ image = image * 2 - 1
+ image = image.clamp(-1, 1)
+ else:
+ image = self.video_processor.preprocess(image, height=height, width=width)
+ image = image.to(device=device, dtype=self.transformer.dtype)
+ video = image.unsqueeze(1) # [B, 1, C, H, W]
+
+ # 4. Prepare latents
+ num_channels_latents = self.vae.config.z_dim
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ self.transformer.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare text embeddings
+ prompt_embeds, prompt_attention_mask = self.encode_prompt(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # 6. First frame conditioning
+ latent_condition, latent_mask, image_embeds = self._prepare_first_frame_conditioning(
+ video,
+ latents,
+ use_conditioning=True,
+ generator=generator,
+ )
+
+ # Repeat conditioning tensors for each generation per prompt
+ if num_videos_per_prompt > 1:
+ latent_condition = latent_condition.repeat_interleave(num_videos_per_prompt, dim=0)
+ latent_mask = latent_mask.repeat_interleave(num_videos_per_prompt, dim=0)
+ if image_embeds is not None:
+ image_embeds = image_embeds.repeat_interleave(num_videos_per_prompt, dim=0)
+
+ if self.guider._enabled and self.guider.num_conditions > 1:
+ negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
+ negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # 7. Prepare timesteps
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ packed_latent_height = latent_height // self.transformer_spatial_patch_size
+ packed_latent_width = latent_width // self.transformer_spatial_patch_size
+ packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size
+ video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width
+
+ if use_linear_quadratic_schedule:
+ sigmas = None
+ else:
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ use_linear_quadratic_schedule=use_linear_quadratic_schedule,
+ linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
+ mu=mu,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 8. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ # Concatenate: [latents | latent_condition | latent_mask]
+ hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1)
+
+ timestep = t.expand(latents.shape[0])
+
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ "encoder_attention_mask": (prompt_attention_mask, negative_prompt_attention_mask),
+ }
+
+ self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+ guider_state = self.guider.prepare_inputs(guider_inputs)
+
+ for guider_state_batch in guider_state:
+ self.guider.prepare_models(self.transformer)
+
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ context_name = getattr(guider_state_batch, self.guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ noise_pred = self.transformer(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ image_embeds=image_embeds,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0].clone()
+
+ guider_state_batch.noise_pred = noise_pred
+ self.guider.cleanup_models(self.transformer)
+
+ noise_pred = self.guider(guider_state)[0]
+
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ if "negative_prompt_embeds" in callback_outputs:
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds")
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std)
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return MotifVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/motif_video/pipeline_output.py b/src/diffusers/pipelines/motif_video/pipeline_output.py
new file mode 100644
index 000000000000..aa0b2b83b323
--- /dev/null
+++ b/src/diffusers/pipelines/motif_video/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class MotifVideoPipelineOutput(BaseOutput):
+ r"""
+ Output class for Motif-Video pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 738e079eba9b..656b0dc5effa 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1515,6 +1515,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class MotifVideoTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index c95c56789e37..46c015cc3db6 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -2657,6 +2657,51 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class MotifVideoImage2VideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class MotifVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class MotifVideoPipelineOutput(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class MusicLDMPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/pipelines/motif_video/__init__.py b/tests/pipelines/motif_video/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/motif_video/test_motif_video.py b/tests/pipelines/motif_video/test_motif_video.py
new file mode 100644
index 000000000000..00a821717ab7
--- /dev/null
+++ b/tests/pipelines/motif_video/test_motif_video.py
@@ -0,0 +1,125 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, MotifVideoPipeline
+from diffusers.models.transformers.transformer_motif_video import MotifVideoTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class MotifVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = MotifVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ # Use tiny-random-t5 as a stand-in for T5Gemma2Model's encoder
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = MotifVideoTransformer3DModel(
+ in_channels=33,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=12,
+ num_layers=1,
+ num_single_layers=1,
+ mlp_ratio=4.0,
+ patch_size=1,
+ patch_size_t=1,
+ qk_norm="rms_norm",
+ text_embed_dim=32,
+ rope_axes_dim=(4, 4, 4),
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "A test video",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+
+ @unittest.skip("MotifVideo uses guider pattern instead of guidance_scale")
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/motif_video/test_motif_video_image2video.py b/tests/pipelines/motif_video/test_motif_video_image2video.py
new file mode 100644
index 000000000000..36b36e12805d
--- /dev/null
+++ b/tests/pipelines/motif_video/test_motif_video_image2video.py
@@ -0,0 +1,135 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, MotifVideoImage2VideoPipeline
+from diffusers.models.transformers.transformer_motif_video import MotifVideoTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_BATCH_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class MotifVideoImage2VideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = MotifVideoImage2VideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} | {"image"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS | {"image"}
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = MotifVideoTransformer3DModel(
+ in_channels=33,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=12,
+ num_layers=1,
+ num_single_layers=1,
+ mlp_ratio=4.0,
+ patch_size=1,
+ patch_size_t=1,
+ qk_norm="rms_norm",
+ text_embed_dim=32,
+ image_embed_dim=4,
+ rope_axes_dim=(4, 4, 4),
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "feature_extractor": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = Image.new("RGB", (16, 16))
+
+ inputs = {
+ "image": image,
+ "prompt": "A test video",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+
+ @unittest.skip("MotifVideo uses guider pattern instead of guidance_scale")
+ def test_inference_batch_single_identical(self):
+ pass