diff --git a/docs/source/en/api/models/motif_video_transformer_3d.md b/docs/source/en/api/models/motif_video_transformer_3d.md new file mode 100644 index 000000000000..46344f841b10 --- /dev/null +++ b/docs/source/en/api/models/motif_video_transformer_3d.md @@ -0,0 +1,32 @@ + + +# MotifVideoTransformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in Motif-Video by the Motif Technologies Team. + +The model uses a three-stage architecture with 12 dual-stream + 16 single-stream + 8 DDT decoder layers and rotary positional embeddings (RoPE) for video generation. + +The model can be loaded with the following code snippet. + +```python +from diffusers import MotifVideoTransformer3DModel + +transformer = MotifVideoTransformer3DModel.from_pretrained("MotifTechnologies/Motif-Video-2B", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## MotifVideoTransformer3DModel + +[[autodoc]] MotifVideoTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/motif_video.md b/docs/source/en/api/pipelines/motif_video.md new file mode 100644 index 000000000000..2ec549b9cfd4 --- /dev/null +++ b/docs/source/en/api/pipelines/motif_video.md @@ -0,0 +1,147 @@ + + +# Motif-Video + +[Technical Report](https://arxiv.org/abs/2604.16503) + +Motif-Video is a 2B parameter diffusion transformer designed for text-to-video and image-to-video generation. It features a three-stage architecture with 12 dual-stream + 16 single-stream + 8 DDT decoder layers, Shared Cross-Attention for stable text-video alignment under long video sequences, T5Gemma2 text encoder, and rectified flow matching for velocity prediction. + +

+ Motif-Video architecture +

+ +## Text-to-Video Generation + +Use `MotifVideoPipeline` for text-to-video generation: + +```python +import torch +from diffusers import AdaptiveProjectedGuidance, MotifVideoPipeline +from diffusers.utils import export_to_video + +guider = AdaptiveProjectedGuidance( + guidance_scale=8.0, + adaptive_projected_guidance_rescale=12.0, + adaptive_projected_guidance_momentum=0.1, + use_original_formulation=True, + normalization_dims="spatial", +) + +pipe = MotifVideoPipeline.from_pretrained( + "MotifTechnologies/Motif-Video-2B", + torch_dtype=torch.bfloat16, + guider=guider, +) +pipe.to("cuda") + +prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair." +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=1280, + height=736, + num_frames=121, + num_inference_steps=50, +).frames[0] +export_to_video(video, "output.mp4", fps=24) +``` + +## Image-to-Video Generation + +Use `MotifVideoImage2VideoPipeline` for image-to-video generation: + +```python +import torch +from diffusers import AdaptiveProjectedGuidance, MotifVideoImage2VideoPipeline +from diffusers.utils import export_to_video, load_image + +guider = AdaptiveProjectedGuidance( + guidance_scale=8.0, + adaptive_projected_guidance_rescale=12.0, + adaptive_projected_guidance_momentum=0.1, + use_original_formulation=True, + normalization_dims="spatial", +) + +pipe = MotifVideoImage2VideoPipeline.from_pretrained( + "MotifTechnologies/Motif-Video-2B", + torch_dtype=torch.bfloat16, + guider=guider, +) +pipe.to("cuda") + +image = load_image("input_image.png") +prompt = "A cinematic scene with vivid colors." +negative_prompt = "worst quality, blurry, jittery, distorted" + +video = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + width=1280, + height=736, + num_frames=121, + num_inference_steps=50, +).frames[0] +export_to_video(video, "i2v_output.mp4", fps=24) +``` + +### Memory-efficient Inference + +For GPUs with less than 30GB VRAM (e.g., RTX 4090), use model CPU offloading: + +```bash +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +``` + +```python +import torch +from diffusers import AdaptiveProjectedGuidance, MotifVideoPipeline +from diffusers.utils import export_to_video + +guider = AdaptiveProjectedGuidance( + guidance_scale=8.0, + adaptive_projected_guidance_rescale=12.0, + adaptive_projected_guidance_momentum=0.1, + use_original_formulation=True, + normalization_dims="spatial", +) + +pipe = MotifVideoPipeline.from_pretrained( + "MotifTechnologies/Motif-Video-2B", + torch_dtype=torch.bfloat16, + guider=guider, +) +pipe.enable_model_cpu_offload() + +prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair." +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=1280, + height=736, + num_frames=121, + num_inference_steps=50, +).frames[0] +export_to_video(video, "output.mp4", fps=24) +``` + +## MotifVideoPipeline + +[[autodoc]] MotifVideoPipeline + - all + - __call__ + +## MotifVideoImage2VideoPipeline + +[[autodoc]] MotifVideoImage2VideoPipeline + - all + - __call__ + +## MotifVideoPipelineOutput + +[[autodoc]] pipelines.motif_video.pipeline_output.MotifVideoPipelineOutput \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2cbfd6e29305..ebd0c9e8ffde 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -263,6 +263,7 @@ "LuminaNextDiT2DModel", "MochiTransformer3DModel", "ModelMixin", + "MotifVideoTransformer3DModel", "MotionAdapter", "MultiAdapter", "MultiControlNetModel", @@ -625,6 +626,9 @@ "MarigoldIntrinsicsPipeline", "MarigoldNormalsPipeline", "MochiPipeline", + "MotifVideoImage2VideoPipeline", + "MotifVideoPipeline", + "MotifVideoPipelineOutput", "MusicLDMPipeline", "NucleusMoEImagePipeline", "OmniGenPipeline", @@ -1073,6 +1077,7 @@ LuminaNextDiT2DModel, MochiTransformer3DModel, ModelMixin, + MotifVideoTransformer3DModel, MotionAdapter, MultiAdapter, MultiControlNetModel, @@ -1410,6 +1415,9 @@ MarigoldIntrinsicsPipeline, MarigoldNormalsPipeline, MochiPipeline, + MotifVideoImage2VideoPipeline, + MotifVideoPipeline, + MotifVideoPipelineOutput, MusicLDMPipeline, NucleusMoEImagePipeline, OmniGenPipeline, diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index b210cb3e67aa..55d38a0d32e7 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -48,6 +48,12 @@ class AdaptiveProjectedGuidance(BaseGuidance): Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. See [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + normalization_dims (`str` or `list[int]` or `None`, defaults to `None`): + Dimensions to normalize over for the guidance computation. Can be: + - `None` (default): Normalize over all non-batch dimensions (e.g., [C, H, W] for 4D, [C, T, H, W] for 5D) + - `"spatial"`: Spatial-only normalization - normalize over [C, H, W] per frame for 5D tensors, standard for + 4D + - `list[int]`: Custom dimensions to normalize over (e.g., `[-1, -2, -4]` for [W, H, C]) start (`float`, defaults to `0.0`): The fraction of the total number of denoising steps after which guidance starts. stop (`float`, defaults to `1.0`): @@ -65,6 +71,7 @@ def __init__( eta: float = 1.0, guidance_rescale: float = 0.0, use_original_formulation: bool = False, + normalization_dims: str | list[int] | None = None, start: float = 0.0, stop: float = 1.0, enabled: bool = True, @@ -77,6 +84,7 @@ def __init__( self.eta = eta self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation + self.normalization_dims = normalization_dims self.momentum_buffer = None def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]: @@ -117,6 +125,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = No self.eta, self.adaptive_projected_guidance_rescale, self.use_original_formulation, + self.normalization_dims, ) if self.guidance_rescale > 0.0: @@ -210,9 +219,25 @@ def normalized_guidance( eta: float = 1.0, norm_threshold: float = 0.0, use_original_formulation: bool = False, + normalization_dims: str | list[int] | None = None, ): diff = pred_cond - pred_uncond - dim = [-i for i in range(1, len(diff.shape))] + + # Determine normalization dimensions + if normalization_dims == "spatial": + # Spatial-only normalization: normalize over [C, H, W] per frame for 5D tensors + if len(diff.shape) == 5: + # [B, C, T, H, W] -> normalize over W(-1), H(-2), C(-4), skip T(-3) + dim = [-1, -2, -4] + else: + # [B, C, H, W] -> standard behavior + dim = [-i for i in range(1, len(diff.shape))] + elif normalization_dims is None: + # Default: normalize over all non-batch dimensions + dim = [-i for i in range(1, len(diff.shape))] + else: + # Custom dimensions provided by user + dim = normalization_dims if momentum_buffer is not None: momentum_buffer.update(diff) diff --git a/src/diffusers/guiders/adaptive_projected_guidance_mix.py b/src/diffusers/guiders/adaptive_projected_guidance_mix.py index 559e30d2aabe..d9187672f157 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance_mix.py +++ b/src/diffusers/guiders/adaptive_projected_guidance_mix.py @@ -48,6 +48,12 @@ class AdaptiveProjectedMixGuidance(BaseGuidance): Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. See [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + normalization_dims (`str` or `list[int]` or `None`, defaults to `None`): + Dimensions to normalize over for the guidance computation. Can be: + - `None` (default): Normalize over all non-batch dimensions (e.g., [C, H, W] for 4D, [C, T, H, W] for 5D) + - `"spatial"`: Spatial-only normalization - normalize over [C, H, W] per frame for 5D tensors, standard for + 4D + - `list[int]`: Custom dimensions to normalize over (e.g., `[-1, -2, -4]` for [W, H, C]) start (`float`, defaults to `0.0`): The fraction of the total number of denoising steps after which the classifier-free guidance starts. stop (`float`, defaults to `1.0`): @@ -71,6 +77,7 @@ def __init__( adaptive_projected_guidance_rescale: float = 10.0, eta: float = 0.0, use_original_formulation: bool = False, + normalization_dims: str | list[int] | None = None, start: float = 0.0, stop: float = 1.0, adaptive_projected_guidance_start_step: int = 5, @@ -86,6 +93,7 @@ def __init__( self.eta = eta self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step self.use_original_formulation = use_original_formulation + self.normalization_dims = normalization_dims self.momentum_buffer = None def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]: @@ -138,6 +146,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = No self.eta, self.adaptive_projected_guidance_rescale, self.use_original_formulation, + self.normalization_dims, ) if self.guidance_rescale > 0.0: @@ -269,6 +278,7 @@ def normalized_guidance( eta: float = 1.0, norm_threshold: float = 0.0, use_original_formulation: bool = False, + normalization_dims: str | list[int] | None = None, ): if momentum_buffer is not None: update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer) @@ -276,7 +286,21 @@ def normalized_guidance( else: diff = pred_cond - pred_uncond - dim = [-i for i in range(1, len(diff.shape))] + # Determine normalization dimensions + if normalization_dims == "spatial": + # Spatial-only normalization: normalize over [C, H, W] per frame for 5D tensors + if len(diff.shape) == 5: + # [B, C, T, H, W] -> normalize over W(-1), H(-2), C(-4), skip T(-3) + dim = [-1, -2, -4] + else: + # [B, C, H, W] -> standard behavior + dim = [-i for i in range(1, len(diff.shape))] + elif normalization_dims is None: + # Default: normalize over all non-batch dimensions + dim = [-i for i in range(1, len(diff.shape))] + else: + # Custom dimensions provided by user + dim = normalization_dims if norm_threshold > 0: ones = torch.ones_like(diff) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index c7bb2de4437a..43fc8d897fe6 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -21,7 +21,11 @@ from typing_extensions import Self from .. import __version__ -from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map +from ..models.model_loading_utils import ( + _caching_allocator_warmup, + _determine_device_map, + _expand_device_map, +) from ..quantizers import DiffusersAutoQuantizer from ..utils import deprecate, is_accelerate_available, is_torch_version, logging from ..utils.torch_utils import empty_device_cache @@ -194,6 +198,10 @@ "checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers, "default_subfolder": "audio_vae", }, + "MotifVideoTransformer3DModel": { + "checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint, + "default_subfolder": "transformer", + }, } @@ -336,7 +344,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No disable_mmap = kwargs.pop("disable_mmap", False) device_map = kwargs.pop("device_map", None) - user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"} + user_agent = { + "diffusers": __version__, + "file_type": "single_file", + "framework": "pytorch", + } # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` if quantization_config is not None: user_agent["quant"] = quantization_config.quant_method.value @@ -393,7 +405,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs) diffusers_model_config = config_mapping_fn( - original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs + original_config=original_config, + checkpoint=checkpoint, + **config_mapping_kwargs, ) else: if config is not None: @@ -465,7 +479,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = No if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint): diffusers_format_checkpoint = checkpoint_mapping_fn( - config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs + config=diffusers_model_config, + checkpoint=checkpoint, + **checkpoint_mapping_kwargs, ) else: diffusers_format_checkpoint = checkpoint diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index ba9b7810e054..d9a9611b6fd3 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -119,6 +119,7 @@ _import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] + _import_structure["transformers.transformer_motif_video"] = ["MotifVideoTransformer3DModel"] _import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] @@ -243,6 +244,7 @@ Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, + MotifVideoTransformer3DModel, NucleusMoEImageTransformer2DModel, OmniGenTransformer2DModel, OvisImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d4ac6ff4301e..d79c84fba784 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -42,6 +42,7 @@ from .transformer_ltx2 import LTX2VideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel + from .transformer_motif_video import MotifVideoTransformer3DModel from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_ovis_image import OvisImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_motif_video.py b/src/diffusers/models/transformers/transformer_motif_video.py new file mode 100644 index 000000000000..0de30c2925a8 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_motif_video.py @@ -0,0 +1,1014 @@ +# Copyright 2026 Motif Technologies and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import FeedForward +from ..attention_processor import Attention, AttentionProcessor +from ..cache_utils import CacheMixin +from ..embeddings import ( + PixArtAlphaTextProjection, + TimestepEmbedding, + Timesteps, + apply_rotary_emb, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class CrossAttentionInputs: + """Inputs for cross-attention mode where query is pre-projected externally. + + Args: + query: Pre-projected query tensor [B, L, D] + key: Key tensor for attention [B, L, D] + value: Value tensor for attention [B, L, D] + """ + + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + + +class MotifVideoAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "MotifVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + cross_attn_inputs: Optional[CrossAttentionInputs] = None, + ) -> torch.Tensor: + """ + Process attention with support for cross-attention mode. + + Args: + attn: Attention layer + hidden_states: Input hidden states [B, L, D] + encoder_hidden_states: Optional encoder states [B, E, D] + attention_mask: Optional attention mask [B, 1, 1, N] + image_rotary_emb: Optional rotary embeddings + cross_attn_inputs: Optional pre-projected cross-attention inputs + + Returns: + Tuple of (hidden_states, encoder_hidden_states) + """ + if cross_attn_inputs is not None: + return self._handle_cross_attention_mode(attn, cross_attn_inputs, attention_mask, image_rotary_emb) + + # Standard attention mode + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + query, key, value = self._project_qkv(attn, hidden_states) + query, key = self._normalize_qk(attn, query, key) + query, key = self._apply_rope(attn, query, key, encoder_hidden_states, image_rotary_emb) + query, key, value = self._add_encoder_conditioning(attn, query, key, value, encoder_hidden_states) + hidden_states = self._compute_attention(query, key, value, attention_mask) + return self._project_output(attn, hidden_states, encoder_hidden_states) + + def _handle_cross_attention_mode( + self, + attn: Attention, + cross_attn_inputs: CrossAttentionInputs, + attention_mask: Optional[torch.Tensor], + image_rotary_emb: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, None]: + """Handle cross-attention mode with pre-projected query. + + Query is already projected externally (cross_attn_query_proj + norm), so we skip to_q and only apply reshape + + norm_q + RoPE. K/V use to_k/to_v as normal. + """ + query = cross_attn_inputs.query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = attn.to_k(cross_attn_inputs.key) + value = attn.to_v(cross_attn_inputs.value) + + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + return hidden_states, None + + def _project_qkv( + self, + attn: Attention, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Project hidden states to Q, K, V and reshape for attention.""" + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + return query, key, value + + def _normalize_qk( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply QK normalization if present.""" + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + return query, key + + def _apply_rope( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + image_rotary_emb: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary positional embeddings to query and key.""" + if image_rotary_emb is not None: + if attn.add_q_proj is None and encoder_hidden_states is not None: + split_idx = -encoder_hidden_states.shape[1] + query = torch.cat( + [ + apply_rotary_emb(query[:, :, :split_idx], image_rotary_emb), + query[:, :, split_idx:], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, :split_idx], image_rotary_emb), + key[:, :, split_idx:], + ], + dim=2, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + return query, key + + def _add_encoder_conditioning( + self, + attn: Attention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Add encoder conditioning QKV projections and normalization.""" + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + return query, key, value + + def _compute_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + """Compute scaled dot-product attention.""" + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + return hidden_states + + def _project_output( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Apply output projections and split encoder states.""" + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class MotifVideoPatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states + + +class MotifVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class MotifVideoConditionEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + ): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, + timestep: torch.Tensor, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + compute_dtype = next( + (p.dtype for p in self.timestep_embedder.parameters() if p.is_floating_point()), + torch.float32, # safe fallback + ) + conditioning = self.timestep_embedder(timesteps_proj.to(compute_dtype)) # (N, D) + + return conditioning + + +class MotifVideoRotaryPosEmbed(nn.Module): + def __init__( + self, + patch_size: int, + patch_size_t: int, + rope_dim: List[int], + theta: float = 256.0, + ): + """ + Rotary Positional Embedding (RoPE) for video latents. + + Args: + patch_size (`int`): Spatial patch size. + patch_size_t (`int`): Temporal patch size. + rope_dim (`List[int]`): Dimensions for RoPE across [Time, Height, Width] axes. + theta (`float`, *optional*, defaults to 256.0): Base frequency for rotary embeddings. + """ + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + rope_sizes = [ + num_frames // self.patch_size_t, + height // self.patch_size, + width // self.patch_size, + ] + + axes_grids = [] + for i in range(3): + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + grid = torch.meshgrid(*axes_grids, indexing="ij") + grid = torch.stack(grid, dim=0) + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed( + dim=self.rope_dim[i], + pos=grid[i].reshape(-1), + theta=self.theta, + use_real=True, + freqs_dtype=torch.float64, + ) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) + return freqs_cos, freqs_sin + + +class MotifVideoImageProjection(nn.Module): + def __init__(self, in_features: int, hidden_size: int): + super().__init__() + self.norm_in = nn.LayerNorm(in_features) + self.linear_1 = nn.Linear(in_features, in_features) + self.act_fn = nn.GELU() + self.linear_2 = nn.Linear(in_features, hidden_size) + self.norm_out = nn.LayerNorm(hidden_size) + + def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm_in(image_embeds) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.norm_out(hidden_states) + return hidden_states + + +class MotifVideoSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + norm_type: str = "layer_norm", + enable_text_cross_attention: bool = False, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=MotifVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.enable_text_cross_attention = enable_text_cross_attention + if enable_text_cross_attention: + self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6) + self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size) + nn.init.zeros_(self.cross_attn_out_proj.weight) + nn.init.zeros_(self.cross_attn_out_proj.bias) + + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type=norm_type) + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + image_embed_seq_len: int = 0, + ) -> torch.Tensor: + video_tokens = hidden_states.shape[1] + encoder_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-encoder_seq_length, :], + norm_hidden_states[:, -encoder_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights + if self.enable_text_cross_attention: + txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :] + text_mask = None + if attention_mask is not None: + text_mask = attention_mask[:, :, :, video_tokens + image_embed_seq_len :] + cross_q = self.cross_attn_query_proj(attn_output) + cross_output, _ = self.attn( + hidden_states=cross_q, + cross_attn_inputs=CrossAttentionInputs( + query=cross_q, + key=txt_kv, + value=txt_kv, + ), + attention_mask=text_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = attn_output + self.cross_attn_out_proj(cross_output) + + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-encoder_seq_length, :], + hidden_states[:, -encoder_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class MotifVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + norm_type: str = "layer_norm", + enable_text_cross_attention: bool = False, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type=norm_type) + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type=norm_type) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=MotifVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.enable_text_cross_attention = enable_text_cross_attention + if enable_text_cross_attention: + self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size) + self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6) + self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size) + nn.init.zeros_(self.cross_attn_out_proj.weight) + nn.init.zeros_(self.cross_attn_out_proj.bias) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + image_embed_seq_len: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + + # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights + if self.enable_text_cross_attention: + txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :] + text_mask = None + if attention_mask is not None: + text_mask = attention_mask[:, :, :, hidden_states.shape[1] + image_embed_seq_len :] + cross_q = self.cross_attn_query_proj(attn_output) + cross_output, _ = self.attn( + hidden_states=cross_q, + cross_attn_inputs=CrossAttentionInputs( + query=cross_q, + key=txt_kv, + value=txt_kv, + ), + attention_mask=text_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.cross_attn_out_proj(cross_output) + + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return hidden_states, encoder_hidden_states + + +TransformerBlockRegistry.register( + model_class=MotifVideoTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), +) +TransformerBlockRegistry.register( + model_class=MotifVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), +) + + +class MotifVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): + r""" + A Transformer model for video-like data used in the Motif-Video model. + + Args: + in_channels (`int`, defaults to `33`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of dual-stream blocks to use. + num_single_layers (`int`, defaults to `40`): + The number of layers of single-stream blocks to use. + num_decoder_layers (`int`, defaults to `0`): + The number of decoder layers in single-stream blocks. + mlp_ratio (`float`, defaults to `4.0`): + The ratio of the hidden layer size to the input size in the feedforward network. + patch_size (`int`, defaults to `2`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the temporal patches to use in the patch embedding layer. + qk_norm (`str`, defaults to `rms_norm`): + The normalization to use for the query and key projections in the attention layers. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + image_embed_dim (`int`, *optional*): + Input dimension of image embeddings from a vision encoder. If provided, enables image conditioning. + rope_theta (`float`, defaults to `256.0`): + The value of theta to use in the RoPE layer. + rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions of the axes to use in the RoPE layer. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"] + _no_split_modules = [ + "MotifVideoTransformerBlock", + "MotifVideoSingleTransformerBlock", + "MotifVideoPatchEmbed", + ] + + @register_to_config + def __init__( + self, + in_channels: int = 33, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_decoder_layers: int = 0, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + norm_type: str = "layer_norm", + text_embed_dim: int = 4096, + image_embed_dim: int | None = None, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int, ...] = (16, 56, 56), + enable_text_cross_attention_dual: bool = False, + enable_text_cross_attention_single: bool = False, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = MotifVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = PixArtAlphaTextProjection(in_features=text_embed_dim, hidden_size=inner_dim) + + # First frame conditioning: Image conditioning embedders + self.image_embed_dim = image_embed_dim + if image_embed_dim is not None: + self.image_embedder = MotifVideoImageProjection(in_features=image_embed_dim, hidden_size=inner_dim) + + self.time_text_embed = MotifVideoConditionEmbedding(inner_dim) + + # 2. RoPE + self.rope = MotifVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # Cross-attention config + self.enable_text_cross_attention_dual = enable_text_cross_attention_dual + self.enable_text_cross_attention_single = enable_text_cross_attention_single + + # 3. Dual stream transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + MotifVideoTransformerBlock( + num_attention_heads, + attention_head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + norm_type=norm_type, + enable_text_cross_attention=enable_text_cross_attention_dual, + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + # Encoder blocks get cross-attention; decoder blocks do not (no text stream in decoder) + num_encoder_single = num_single_layers - num_decoder_layers + self.single_transformer_blocks = nn.ModuleList( + [ + MotifVideoSingleTransformerBlock( + num_attention_heads, + attention_head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + norm_type=norm_type, + enable_text_cross_attention=enable_text_cross_attention_single + if i < num_encoder_single + else False, + ) + for i in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous( + inner_dim, + inner_dim, + elementwise_affine=False, + eps=1e-6, + norm_type=norm_type, + ) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + # Verify cross-attention config matches actual block state. + # Catches silent misconfiguration (e.g. checkpoint config with renamed keys). + for i, block in enumerate(self.transformer_blocks): + if block.enable_text_cross_attention != enable_text_cross_attention_dual: + raise ValueError( + f"transformer_blocks[{i}].enable_text_cross_attention=" + f"{block.enable_text_cross_attention}, expected {enable_text_cross_attention_dual}. " + f"Check checkpoint config.json key names match __init__ parameters." + ) + for i, block in enumerate(self.single_transformer_blocks): + expected = enable_text_cross_attention_single if i < num_encoder_single else False + if block.enable_text_cross_attention != expected: + raise ValueError( + f"single_transformer_blocks[{i}].enable_text_cross_attention=" + f"{block.enable_text_cross_attention}, expected {expected}. " + f"Check checkpoint config.json key names match __init__ parameters." + ) + + self.gradient_checkpointing = False + self.num_decoder_layers = num_decoder_layers + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _maybe_gradient_checkpoint_block(self, block, *args): + if torch.is_grad_enabled() and self.gradient_checkpointing: + return self._gradient_checkpointing_func(block, *args) + return block(*args) + + def _create_attention_mask( + self, + hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + ) -> torch.Tensor: + attention_mask = F.pad( + encoder_attention_mask.to(torch.bool), + (hidden_states.shape[1], 0), + value=True, + ) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + return attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Forward pass of the MotifVideoTransformer3DModel. + + Args: + hidden_states (`torch.Tensor`): + Input latent tensor of shape `(batch_size, channels, num_frames, height, width)`. + timestep (`torch.LongTensor`): + Diffusion timesteps of shape `(batch_size,)`. + encoder_hidden_states (`torch.Tensor`): + Text conditioning of shape `(batch_size, sequence_length, embed_dim)`. + encoder_attention_mask (`torch.Tensor`): + Mask for text conditioning of shape `(batch_size, sequence_length)`. + image_embeds (`torch.Tensor`, *optional*): + Image embeddings from vision encoder of shape `(batch_size, num_tokens, embed_dim)`. + attention_kwargs (`dict`, *optional*): + Additional arguments for attention processors. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.modeling_outputs.Transformer2DModelOutput`]. + + Returns: + [`~models.modeling_outputs.Transformer2DModelOutput`] or `tuple`: + The predicted samples. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, _, num_frames, height, width = hidden_states.shape + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states, timestep=timestep) + + # 2. Conditional embeddings + temb = self.time_text_embed(timestep) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + # First frame conditioning: Image embeddings from vision encoder + if image_embeds is not None: + image_embeds = self.image_embedder(image_embeds) + encoder_hidden_states = torch.cat([image_embeds, encoder_hidden_states], dim=1) + if encoder_attention_mask is not None: + image_mask = torch.ones( + image_embeds.shape[0], + image_embeds.shape[1], + device=encoder_attention_mask.device, + dtype=encoder_attention_mask.dtype, + ) + encoder_attention_mask = torch.cat([image_mask, encoder_attention_mask], dim=1) + + # image_embed_seq_len: used by cross-attention blocks to slice text from encoder_hidden_states + image_embed_seq_len = image_embeds.shape[1] if image_embeds is not None else 0 + + decoder_hidden_states = hidden_states.clone() + + if encoder_attention_mask is not None: + attention_mask = self._create_attention_mask( + hidden_states=hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + else: + attention_mask = None + + # 3. Dual stream transformer blocks + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block( + block, + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + image_embed_seq_len, + ) + + # 4. Single stream transformer blocks (Encoder) + single_transformer_blocks = self.single_transformer_blocks + + for block in single_transformer_blocks[: len(single_transformer_blocks) - self.num_decoder_layers]: + hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block( + block, + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + image_embed_seq_len, + ) + + # 5. Single stream transformer blocks (Decoder) + if self.num_decoder_layers > 0: + encoder_hidden_states = hidden_states + attention_mask = None + + for block in single_transformer_blocks[-self.num_decoder_layers :]: + decoder_hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block( + block, + decoder_hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + ) + + hidden_states = decoder_hidden_states + + # 6. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + -1, + p_t, + p, + p, + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput( + sample=hidden_states, + ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ae1849a587e8..6c71ab9dd3f8 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -340,6 +340,12 @@ ] ) _import_structure["mochi"] = ["MochiPipeline"] + _import_structure["motif_video"] = [ + "MotifVideoPipeline", + "MotifVideoImage2VideoPipeline", + "MotifVideoPipelineOutput", + ] + _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["omnigen"] = ["OmniGenPipeline"] _import_structure["ernie_image"] = ["ErnieImagePipeline"] _import_structure["ovis_image"] = ["OvisImagePipeline"] @@ -778,6 +784,12 @@ MarigoldNormalsPipeline, ) from .mochi import MochiPipeline + from .motif_video import ( + MotifVideoImage2VideoPipeline, + MotifVideoPipeline, + MotifVideoPipelineOutput, + ) + from .musicldm import MusicLDMPipeline from .nucleusmoe_image import NucleusMoEImagePipeline from .omnigen import OmniGenPipeline from .ovis_image import OvisImagePipeline diff --git a/src/diffusers/pipelines/motif_video/__init__.py b/src/diffusers/pipelines/motif_video/__init__.py new file mode 100644 index 000000000000..ee1d7c72ee65 --- /dev/null +++ b/src/diffusers/pipelines/motif_video/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_motif_video"] = ["MotifVideoPipeline"] + _import_structure["pipeline_motif_video_image2video"] = ["MotifVideoImage2VideoPipeline"] + _import_structure["pipeline_output"] = ["MotifVideoPipelineOutput"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_motif_video import MotifVideoPipeline + from .pipeline_motif_video_image2video import MotifVideoImage2VideoPipeline + from .pipeline_output import MotifVideoPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/motif_video/pipeline_motif_video.py b/src/diffusers/pipelines/motif_video/pipeline_motif_video.py new file mode 100644 index 000000000000..099a55e23d90 --- /dev/null +++ b/src/diffusers/pipelines/motif_video/pipeline_motif_video.py @@ -0,0 +1,868 @@ +# Copyright 2026 Motif Technologies, Inc. and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...utils import is_transformers_version + + +# Check transformers version before importing T5Gemma2Encoder +if not is_transformers_version(">=", "5.1.0"): + import transformers + + raise ImportError( + f"MotifVideoPipeline requires transformers>=5.1.0. " + f"Found: {transformers.__version__}. " + "Please upgrade transformers: pip install transformers --upgrade" + ) + +from transformers import BatchEncoding, PreTrainedTokenizerBase, SiglipImageProcessor, T5Gemma2Encoder + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...guiders import AdaptiveProjectedGuidance, ClassifierFreeGuidance, SkipLayerGuidance +from ...models import AutoencoderKLWan +from ...models.transformers import MotifVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MotifVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import MotifVideoPipeline + >>> from diffusers.utils import export_to_video + + >>> # Load the Motif-Video pipeline + >>> motif_video_model_id = "MotifTechnologies/Motif-Video-2B" + >>> pipe = MotifVideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=640, + ... height=352, + ... num_frames=65, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def get_linear_quadratic_sigmas( + num_inference_steps: int, + linear_quadratic_emulating_steps: int = 250, +) -> np.ndarray: + """ + Compute a linear-quadratic sigma schedule for flow matching. + + This schedule combines: + - First half: Linear interpolation from high noise to medium noise (slow denoising) + - Second half: Quadratic interpolation from medium noise to clean (faster denoising) + + Convention: + - sigma=1.0 represents pure noise + - sigma=0.0 represents clean image + - Output sigmas are in descending order (1.0 → ~0) + + Args: + num_inference_steps: Total number of denoising steps (must be even). + linear_quadratic_emulating_steps: Controls the slope of linear interpolation. + Higher values result in gentler slope in the first half. + + Returns: + np.ndarray: Array of sigma values with shape (num_inference_steps,). + The scheduler will append a terminal 0. + + Raises: + ValueError: If num_inference_steps is not even. + """ + if num_inference_steps % 2 != 0: + raise ValueError( + f"num_inference_steps must be even for linear-quadratic schedule, but got {num_inference_steps}" + ) + + steps = num_inference_steps + N = linear_quadratic_emulating_steps + half_steps = steps // 2 + + # First half: linear interpolation from 1 toward 0 + linear_part = np.linspace(1.0, 0.0, N + 1)[:half_steps] + + # Second half: quadratic interpolation + x = np.linspace(0.0, 1.0, half_steps + 1) + scale_factor = half_steps / N - 1 + quadratic_part = x**2 * scale_factor - scale_factor + + # Concatenate and exclude the last 0 (scheduler appends terminal 0) + sigmas = np.concatenate([linear_part, quadratic_part]) + sigmas = sigmas[:-1] + + return sigmas.astype(np.float32) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + use_linear_quadratic_schedule: bool = False, + linear_quadratic_emulating_steps: int = 250, + **kwargs, +): + r""" + Retrieve timesteps from the scheduler. + + Args: + scheduler: The noise scheduler to use. + num_inference_steps: Number of denoising steps. + device: Device to place timesteps on. + timesteps: Custom timestep values (mutually exclusive with sigmas). + sigmas: Custom sigma values (mutually exclusive with timesteps). + use_linear_quadratic_schedule: If True, use linear-quadratic sigma schedule. + linear_quadratic_emulating_steps: Controls the linear portion slope. + **kwargs: Additional arguments passed to scheduler.set_timesteps(). + + Returns: + Tuple of (timesteps, num_inference_steps). + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + + if use_linear_quadratic_schedule: + if sigmas is not None: + raise ValueError( + "Cannot use both `sigmas` and `use_linear_quadratic_schedule`. " + "The linear-quadratic schedule computes sigmas automatically." + ) + if num_inference_steps is None: + raise ValueError("`num_inference_steps` must be provided when using `use_linear_quadratic_schedule`.") + sigmas = get_linear_quadratic_sigmas( + num_inference_steps=num_inference_steps, + linear_quadratic_emulating_steps=linear_quadratic_emulating_steps, + ) + + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class MotifVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Motif-Video. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`MotifVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5Gemma2Encoder`]): + Primary text encoder for encoding text prompts into embeddings. + tokenizer ([`PreTrainedTokenizerBase`]): + Tokenizer corresponding to the primary text encoder. + guider ([`ClassifierFreeGuidance`] or [`SkipLayerGuidance`] or [`AdaptiveProjectedGuidance`]): + The guidance method to use. Can be `ClassifierFreeGuidance`, `SkipLayerGuidance`, or + `AdaptiveProjectedGuidance`. For video generation with `AdaptiveProjectedGuidance`, use + `normalization_dims="spatial"` for spatial-only normalization that preserves temporal quality by + normalizing over [C, H, W] per frame instead of collapsing the temporal dimension. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = ["feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLWan, + text_encoder: T5Gemma2Encoder, + tokenizer: PreTrainedTokenizerBase, + transformer: MotifVideoTransformer3DModel, + guider: Union[ClassifierFreeGuidance, SkipLayerGuidance, AdaptiveProjectedGuidance] = None, + feature_extractor: Optional[SiglipImageProcessor] = None, + ): + super().__init__() + + if guider is None: + guider = ClassifierFreeGuidance() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + guider=guider, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512 + ) + + def _get_default_embeds( + self, + text_encoder, + tokenizer: PreTrainedTokenizerBase, + prompt: Union[str, List[str]], + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = dtype or text_encoder.dtype + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_inputs = BatchEncoding( + {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()} + ) + + prompt_embeds = text_encoder(**text_inputs)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, text_inputs.attention_mask + + def _get_prompt_embeds( + self, + text_encoder: T5Gemma2Encoder, + tokenizer: PreTrainedTokenizerBase, + prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt_embeds_kwargs = { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "prompt": prompt, + "max_sequence_length": max_sequence_length, + "device": device, + "dtype": dtype, + } + prompt_embeds, prompt_attention_mask = self._get_default_embeds(**prompt_embeds_kwargs) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to be encoded. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + max_sequence_length (`int`, defaults to 512): + Maximum sequence length for the tokenizer. + device (`torch.device`, *optional*): + Device to place tensors on. + dtype (`torch.dtype`, *optional*): + Data type for tensors. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + seq_len = prompt_embeds.shape[1] + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.bool() + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0) + + return ( + prompt_embeds, + prompt_attention_mask, + ) + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + batch_size, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % self.vae_scale_factor_spatial != 0 or width % self.vae_scale_factor_spatial != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor_spatial} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None: + if not isinstance(negative_prompt, (str, list)): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size: + raise ValueError( + f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def _prepare_negative_prompt( + self, + negative_prompt: Optional[Union[str, List[str]]], + batch_size: int, + ) -> List[str]: + """Prepare negative_prompt to match batch_size.""" + if negative_prompt is None: + return [""] * batch_size + if isinstance(negative_prompt, str): + return [negative_prompt] * batch_size + return negative_prompt + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor + ) -> torch.Tensor: + latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor + ) -> torch.Tensor: + latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std + latents_mean + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 736, + width: int = 1280, + num_frames: int = 121, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 736, + width: int = 1280, + num_frames: int = 121, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + use_linear_quadratic_schedule: bool = True, + linear_quadratic_emulating_steps: int = 250, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + use_attention_mask: bool = True, + vae_batch_size: int | None = None, + ): + r""" + The call function to the pipeline for text-to-video generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance. + height (`int`, defaults to `352`): + The height in pixels of the generated video. + width (`int`, defaults to `640`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `65`): + The number of video frames to generate. + frame_rate (`int`, defaults to `25`): + Frame rate for the output video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. + use_linear_quadratic_schedule (`bool`, defaults to `True`): + Whether to use a linear-quadratic sigma schedule instead of the default linear schedule. Requires + `num_inference_steps` to be even. + linear_quadratic_emulating_steps (`int`, defaults to `250`): + Controls the slope of linear interpolation in the first half of the linear-quadratic schedule. Only + used when `use_linear_quadratic_schedule=True`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + PyTorch Generator object(s) for deterministic generation. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `"pil"`, `"np"`, or `"latent"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~MotifVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Arguments passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + A function or subclass of `PipelineCallback` or `MultiPipelineCallbacks` called at the end of each + denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `512`): + Maximum sequence length for the tokenizer. + use_attention_mask (`bool`, defaults to `True`): + Whether to use attention masks for text embeddings. + vae_batch_size (`int`, *optional*): + Batch size for VAE decoding. If provided and latents batch size is larger, VAE decoding will be done in + chunks. + + Examples: + + Returns: + [`~MotifVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~MotifVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list of generated video frames. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 2. Check inputs + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + batch_size=batch_size, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + device = self._execution_device + + # 3. Prepare text embeddings + prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.guider._enabled and self.guider.num_conditions > 1: + negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size) + negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + + # 4. Prepare latents + num_channels_latents = self.vae.config.z_dim + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + self.transformer.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + packed_latent_height = latent_height // self.transformer_spatial_patch_size + packed_latent_width = latent_width // self.transformer_spatial_patch_size + packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size + video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width + + if use_linear_quadratic_schedule: + sigmas = None + else: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + use_linear_quadratic_schedule=use_linear_quadratic_schedule, + linear_quadratic_emulating_steps=linear_quadratic_emulating_steps, + mu=mu, + ) + + # Prepare conditioning tensors (T2V mode: no first-frame conditioning) + batch_size, latent_channels, latent_num_frames, latent_height, latent_width = latents.shape + latent_condition = torch.zeros( + batch_size, + latent_channels, + latent_num_frames, + latent_height, + latent_width, + device=latents.device, + dtype=latents.dtype, + ) + latent_mask = torch.zeros( + batch_size, + 1, + latent_num_frames, + latent_height, + latent_width, + device=latents.device, + dtype=latents.dtype, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Concatenate current latents with conditioning: [latents | latent_condition | latent_mask] + hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1) + + timestep = t.expand(latents.shape[0]) + + # Guider: collect model inputs + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + } + if use_attention_mask: + guider_inputs["encoder_attention_mask"] = ( + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + guider_state = self.guider.prepare_inputs(guider_inputs) + + # Sigma injection for guiders that support sigma-based gating + if hasattr(self.guider, "_current_sigma") and hasattr(self.scheduler, "sigmas"): + self.guider._current_sigma = float(self.scheduler.sigmas[i]) + + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + noise_pred = self.transformer( + hidden_states=hidden_states, + timestep=timestep, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0].clone() + + guider_state_batch.noise_pred = noise_pred + self.guider.cleanup_models(self.transformer) + + noise_pred = self.guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + if "negative_prompt_embeds" in callback_outputs: + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds") + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + video = latents + else: + latents = latents.to(self.vae.dtype) + latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std) + if vae_batch_size is not None and latents.shape[0] > vae_batch_size: + video_chunks = [] + for i in range(0, latents.shape[0], vae_batch_size): + chunk = latents[i : i + vae_batch_size] + video_chunks.append(self.vae.decode(chunk, return_dict=False)[0]) + video = torch.cat(video_chunks, dim=0) + del video_chunks + else: + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return MotifVideoPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py b/src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py new file mode 100644 index 000000000000..e8bf594da0a7 --- /dev/null +++ b/src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py @@ -0,0 +1,921 @@ +# Copyright 2026 Motif Technologies, Inc. and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...utils import is_transformers_version + + +# Check transformers version before importing T5Gemma2Encoder +if not is_transformers_version(">=", "5.1.0"): + import transformers + + raise ImportError( + f"MotifVideoImage2VideoPipeline requires transformers>=5.1.0. " + f"Found: {transformers.__version__}. " + "Please upgrade transformers: pip install transformers --upgrade" + ) + +from transformers import BatchEncoding, PreTrainedTokenizerBase, SiglipImageProcessor, T5Gemma2Encoder + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...guiders import AdaptiveProjectedGuidance, ClassifierFreeGuidance, SkipLayerGuidance +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan +from ...models.transformers import MotifVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MotifVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from PIL import Image + >>> from diffusers import MotifVideoImage2VideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> # Load the Motif-Video image-to-video pipeline + >>> motif_video_model_id = "MotifTechnologies/Motif-Video-2B" + >>> pipe = MotifVideoImage2VideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Load an image + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.png" + ... ) + + >>> prompt = "An astronaut is walking on the moon surface, kicking up dust with each step" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=640, + ... height=352, + ... num_frames=65, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.motif_video.pipeline_motif_video.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.motif_video.pipeline_motif_video.get_linear_quadratic_sigmas +def get_linear_quadratic_sigmas( + num_inference_steps: int, + linear_quadratic_emulating_steps: int = 250, +) -> np.ndarray: + if num_inference_steps % 2 != 0: + raise ValueError( + f"num_inference_steps must be even for linear-quadratic schedule, but got {num_inference_steps}" + ) + + steps = num_inference_steps + N = linear_quadratic_emulating_steps + half_steps = steps // 2 + + linear_part = np.linspace(1.0, 0.0, N + 1)[:half_steps] + + x = np.linspace(0.0, 1.0, half_steps + 1) + scale_factor = half_steps / N - 1 + quadratic_part = x**2 * scale_factor - scale_factor + + sigmas = np.concatenate([linear_part, quadratic_part]) + sigmas = sigmas[:-1] + + return sigmas.astype(np.float32) + + +# Copied from diffusers.pipelines.motif_video.pipeline_motif_video.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + use_linear_quadratic_schedule: bool = False, + linear_quadratic_emulating_steps: int = 250, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + + if use_linear_quadratic_schedule: + if sigmas is not None: + raise ValueError( + "Cannot use both `sigmas` and `use_linear_quadratic_schedule`. " + "The linear-quadratic schedule computes sigmas automatically." + ) + if num_inference_steps is None: + raise ValueError("`num_inference_steps` must be provided when using `use_linear_quadratic_schedule`.") + sigmas = get_linear_quadratic_sigmas( + num_inference_steps=num_inference_steps, + linear_quadratic_emulating_steps=linear_quadratic_emulating_steps, + ) + + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class MotifVideoImage2VideoPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using Motif-Video with first frame conditioning. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`MotifVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5Gemma2Encoder`]): + Primary text encoder for encoding text prompts into embeddings. + tokenizer ([`PreTrainedTokenizerBase`]): + Tokenizer corresponding to the primary text encoder. + feature_extractor ([`SiglipImageProcessor`]): + Image processor for the SigLIP vision encoder. + guider ([`ClassifierFreeGuidance`] or [`SkipLayerGuidance`] or [`AdaptiveProjectedGuidance`]): + The guidance method to use. Can be `ClassifierFreeGuidance`, `SkipLayerGuidance`, or + `AdaptiveProjectedGuidance`. For video generation with `AdaptiveProjectedGuidance`, use + `normalization_dims="spatial"` for spatial-only normalization that preserves temporal quality by + normalizing over [C, H, W] per frame instead of collapsing the temporal dimension. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLWan, + text_encoder: T5Gemma2Encoder, + tokenizer: PreTrainedTokenizerBase, + transformer: MotifVideoTransformer3DModel, + feature_extractor: SiglipImageProcessor, + guider: Optional[Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, SkipLayerGuidance]] = None, + ): + super().__init__() + + if guider is None: + guider = ClassifierFreeGuidance() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + feature_extractor=feature_extractor, + guider=guider, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512 + ) + + def _get_default_embeds( + self, + text_encoder, + tokenizer: PreTrainedTokenizerBase, + prompt: Union[str, List[str]], + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = dtype or text_encoder.dtype + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_inputs = BatchEncoding( + {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()} + ) + + prompt_embeds = text_encoder(**text_inputs)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, text_inputs.attention_mask + + def _get_prompt_embeds( + self, + text_encoder: T5Gemma2Encoder, + tokenizer: PreTrainedTokenizerBase, + prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt_embeds_kwargs = { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "prompt": prompt, + "max_sequence_length": max_sequence_length, + "device": device, + "dtype": dtype, + } + prompt_embeds, prompt_attention_mask = self._get_default_embeds(**prompt_embeds_kwargs) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to be encoded. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + max_sequence_length (`int`, defaults to 512): + Maximum sequence length for the tokenizer. + device (`torch.device`, *optional*): + Device to place tensors on. + dtype (`torch.dtype`, *optional*): + Data type for tensors. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + seq_len = prompt_embeds.shape[1] + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.bool() + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0) + + return ( + prompt_embeds, + prompt_attention_mask, + ) + + @property + def vision_encoder(self): + """Get the vision encoder from T5Gemma2.""" + return self.text_encoder.vision_tower + + @staticmethod + def _get_image_embeds( + image_encoder, + feature_extractor: SiglipImageProcessor, + image, + device: torch.device, + ) -> torch.Tensor: + """Helper to encode single image with SigLIP.""" + image_encoder_dtype = next(image_encoder.parameters()).dtype + + if isinstance(image, torch.Tensor): + image = feature_extractor.preprocess( + images=image.float(), + do_resize=True, + do_rescale=False, + do_normalize=True, + do_convert_rgb=True, + return_tensors="pt", + ) + else: + image = feature_extractor.preprocess( + images=image, + do_resize=True, + do_rescale=False, + do_normalize=True, + do_convert_rgb=True, + return_tensors="pt", + ) + + image = image.to(device=device, dtype=image_encoder_dtype) + return image_encoder(**image).last_hidden_state + + def _prepare_first_frame_conditioning( + self, + video: torch.Tensor, + latents: torch.Tensor, + use_conditioning: bool, + generator: Optional[torch.Generator] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Prepare first frame conditioning tensors. + + For I2V mode: + 1. Extract and VAE-encode first frame from video + 2. Create latent_condition with first frame latents at frame 0 + 3. Create latent_mask with 1.0 at frame 0 + 4. Get image_embeds from vision encoder + + For T2V mode: + 1. Return zeros for latent_condition and latent_mask, None for image_embeds + + Args: + video: Input video tensor [batch_size, frames, channels, height, width] in [-1, 1] + latents: Latents [batch_size, channels, num_frames, height, width] + use_conditioning: Whether to use first-frame conditioning (True for I2V) + generator: Optional random number generator + + Returns: + Tuple of (latent_condition, latent_mask, image_embeds). + """ + batch_size, latent_channels, latent_num_frames, latent_height, latent_width = latents.shape + device = latents.device + dtype = latents.dtype + + use_conditioning = use_conditioning and (latent_num_frames > 1) + + latent_condition = torch.zeros( + batch_size, latent_channels, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype + ) + latent_mask = torch.zeros( + batch_size, 1, latent_num_frames, latent_height, latent_width, device=device, dtype=dtype + ) + image_embeds = None + + if use_conditioning: + with torch.no_grad(): + # video shape: [B, F, C, H, W] -> [B, C, F, H, W] for VAE + first_frame_latents = self.vae.encode(video[:, 0:1].permute(0, 2, 1, 3, 4)).latent_dist.sample( + generator=generator + ) + first_frame_latents = self._normalize_latents( + latents=first_frame_latents, + latents_mean=self.vae.config.latents_mean, + latents_std=self.vae.config.latents_std, + ) + + latent_condition = first_frame_latents.repeat(1, 1, latent_num_frames, 1, 1) + latent_condition[:, :, 1:, :, :] = 0 + + latent_mask[:, :, 0] = 1.0 + + first_frame_vision = video[:, 0] # [B, C, H, W] + first_frame_vision = ((first_frame_vision + 1) / 2).clamp(0, 1) + + with torch.no_grad(): + image_embeds = self._get_image_embeds( + image_encoder=self.vision_encoder, + feature_extractor=self.feature_extractor, + image=first_frame_vision, + device=device, + ) + + return latent_condition, latent_mask, image_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + batch_size, + image=None, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % self.vae_scale_factor_spatial != 0 or width % self.vae_scale_factor_spatial != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor_spatial} but are {height} and {width}." + ) + + if image is not None: + if isinstance(image, torch.Tensor): + if image.dim() != 4: + raise ValueError(f"`image` must be a 4D tensor [B, C, H, W], got {image.dim()}D") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None: + if not isinstance(negative_prompt, (str, list)): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size: + raise ValueError( + f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape, " + f"got {prompt_embeds.shape} and {negative_prompt_embeds.shape}." + ) + + def _prepare_negative_prompt( + self, + negative_prompt: Optional[Union[str, List[str]]], + batch_size: int, + ) -> List[str]: + """Prepare negative_prompt to match batch_size.""" + if negative_prompt is None: + return [""] * batch_size + if isinstance(negative_prompt, str): + return [negative_prompt] * batch_size + return negative_prompt + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor + ) -> torch.Tensor: + latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor + ) -> torch.Tensor: + latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std + latents_mean + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 736, + width: int = 1280, + num_frames: int = 121, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 736, + width: int = 1280, + num_frames: int = 121, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + use_linear_quadratic_schedule: bool = True, + linear_quadratic_emulating_steps: int = 250, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for image-to-video generation. + + Args: + image (`PipelineImageInput`): + The input image to use as the first frame for video generation. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the video generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. + height (`int`, defaults to `352`): + The height in pixels of the generated video. + width (`int`, defaults to `640`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `65`): + The number of video frames to generate. + frame_rate (`int`, defaults to `25`): + Frame rate for the output video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. + use_linear_quadratic_schedule (`bool`, defaults to `True`): + Whether to use a linear-quadratic sigma schedule. + linear_quadratic_emulating_steps (`int`, defaults to `250`): + Controls the slope of linear interpolation in the first half. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + PyTorch Generator object(s) for deterministic generation. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~MotifVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Arguments passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + A function or subclass of `PipelineCallback` called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `512`): + Maximum sequence length for the tokenizer. + + Examples: + + Returns: + [`~MotifVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~MotifVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list of generated video frames. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 2. Check inputs + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + batch_size=batch_size, + image=image, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + device = self._execution_device + + # 3. Preprocess image + if latents is None: + if isinstance(image, torch.Tensor): + image = image.to(device=device, dtype=self.transformer.dtype) + if image.min() >= 0 and image.max() <= 1: + image = image * 2 - 1 + image = image.clamp(-1, 1) + else: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=self.transformer.dtype) + video = image.unsqueeze(1) # [B, 1, C, H, W] + + # 4. Prepare latents + num_channels_latents = self.vae.config.z_dim + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + self.transformer.dtype, + device, + generator, + latents, + ) + + # 5. Prepare text embeddings + prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + + # 6. First frame conditioning + latent_condition, latent_mask, image_embeds = self._prepare_first_frame_conditioning( + video, + latents, + use_conditioning=True, + generator=generator, + ) + + # Repeat conditioning tensors for each generation per prompt + if num_videos_per_prompt > 1: + latent_condition = latent_condition.repeat_interleave(num_videos_per_prompt, dim=0) + latent_mask = latent_mask.repeat_interleave(num_videos_per_prompt, dim=0) + if image_embeds is not None: + image_embeds = image_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + if self.guider._enabled and self.guider.num_conditions > 1: + negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size) + negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + + # 7. Prepare timesteps + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + packed_latent_height = latent_height // self.transformer_spatial_patch_size + packed_latent_width = latent_width // self.transformer_spatial_patch_size + packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size + video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width + + if use_linear_quadratic_schedule: + sigmas = None + else: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + use_linear_quadratic_schedule=use_linear_quadratic_schedule, + linear_quadratic_emulating_steps=linear_quadratic_emulating_steps, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 8. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Concatenate: [latents | latent_condition | latent_mask] + hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1) + + timestep = t.expand(latents.shape[0]) + + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_attention_mask, negative_prompt_attention_mask), + } + + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + guider_state = self.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + noise_pred = self.transformer( + hidden_states=hidden_states, + timestep=timestep, + image_embeds=image_embeds, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0].clone() + + guider_state_batch.noise_pred = noise_pred + self.guider.cleanup_models(self.transformer) + + noise_pred = self.guider(guider_state)[0] + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + if "negative_prompt_embeds" in callback_outputs: + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds") + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + video = latents + else: + latents = latents.to(self.vae.dtype) + latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std) + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return MotifVideoPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/motif_video/pipeline_output.py b/src/diffusers/pipelines/motif_video/pipeline_output.py new file mode 100644 index 000000000000..aa0b2b83b323 --- /dev/null +++ b/src/diffusers/pipelines/motif_video/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class MotifVideoPipelineOutput(BaseOutput): + r""" + Output class for Motif-Video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 738e079eba9b..656b0dc5effa 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1515,6 +1515,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class MotifVideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ModelMixin(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c95c56789e37..46c015cc3db6 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2657,6 +2657,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class MotifVideoImage2VideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class MotifVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class MotifVideoPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class MusicLDMPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/motif_video/__init__.py b/tests/pipelines/motif_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/motif_video/test_motif_video.py b/tests/pipelines/motif_video/test_motif_video.py new file mode 100644 index 000000000000..00a821717ab7 --- /dev/null +++ b/tests/pipelines/motif_video/test_motif_video.py @@ -0,0 +1,125 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, MotifVideoPipeline +from diffusers.models.transformers.transformer_motif_video import MotifVideoTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class MotifVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = MotifVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + # Use tiny-random-t5 as a stand-in for T5Gemma2Model's encoder + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = MotifVideoTransformer3DModel( + in_channels=33, + out_channels=16, + num_attention_heads=2, + attention_head_dim=12, + num_layers=1, + num_single_layers=1, + mlp_ratio=4.0, + patch_size=1, + patch_size_t=1, + qk_norm="rms_norm", + text_embed_dim=32, + rope_axes_dim=(4, 4, 4), + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "A test video", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + + @unittest.skip("MotifVideo uses guider pattern instead of guidance_scale") + def test_inference_batch_single_identical(self): + pass diff --git a/tests/pipelines/motif_video/test_motif_video_image2video.py b/tests/pipelines/motif_video/test_motif_video_image2video.py new file mode 100644 index 000000000000..36b36e12805d --- /dev/null +++ b/tests/pipelines/motif_video/test_motif_video_image2video.py @@ -0,0 +1,135 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, MotifVideoImage2VideoPipeline +from diffusers.models.transformers.transformer_motif_video import MotifVideoTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class MotifVideoImage2VideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = MotifVideoImage2VideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} | {"image"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS | {"image"} + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = MotifVideoTransformer3DModel( + in_channels=33, + out_channels=16, + num_attention_heads=2, + attention_head_dim=12, + num_layers=1, + num_single_layers=1, + mlp_ratio=4.0, + patch_size=1, + patch_size_t=1, + qk_norm="rms_norm", + text_embed_dim=32, + image_embed_dim=4, + rope_axes_dim=(4, 4, 4), + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = Image.new("RGB", (16, 16)) + + inputs = { + "image": image, + "prompt": "A test video", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + + @unittest.skip("MotifVideo uses guider pattern instead of guidance_scale") + def test_inference_batch_single_identical(self): + pass