diff --git a/examples/models/stable_diffusion_3_5_large/__init__.py b/examples/models/stable_diffusion_3_5_large/__init__.py new file mode 100644 index 00000000000..1dcb55fa679 --- /dev/null +++ b/examples/models/stable_diffusion_3_5_large/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .export_sd3_5_large import parse_component, parse_dtype, SD35LargeExporter +from .model import ( + MODEL_ID, + SD3CLIPTextEncoderWrapper, + SD3T5TextEncoderWrapper, + SD3TransformerWrapper, + SD3VAEDecoderWrapper, + StableDiffusion3ModelLoader, + StableDiffusionComponent, +) + +__all__ = [ + "MODEL_ID", + "SD35LargeExporter", + "SD3CLIPTextEncoderWrapper", + "SD3T5TextEncoderWrapper", + "SD3TransformerWrapper", + "SD3VAEDecoderWrapper", + "StableDiffusion3ModelLoader", + "StableDiffusionComponent", + "parse_component", + "parse_dtype", +] diff --git a/examples/models/stable_diffusion_3_5_large/export_sd3_5_large.py b/examples/models/stable_diffusion_3_5_large/export_sd3_5_large.py new file mode 100644 index 00000000000..519b26ce85f --- /dev/null +++ b/examples/models/stable_diffusion_3_5_large/export_sd3_5_large.py @@ -0,0 +1,445 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import subprocess +import sys +import zipfile +from pathlib import Path +from typing import Optional + +import torch + +try: + from .model import MODEL_ID, StableDiffusion3ModelLoader, StableDiffusionComponent +except ImportError: + from model import MODEL_ID, StableDiffusion3ModelLoader, StableDiffusionComponent +from torch.export import export, save + + +logging.basicConfig(level=logging.INFO) +logging.getLogger("httpx").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) + + +COMPONENTS = ( + StableDiffusionComponent.TEXT_ENCODER, + StableDiffusionComponent.TEXT_ENCODER_2, + StableDiffusionComponent.TEXT_ENCODER_3, + StableDiffusionComponent.TRANSFORMER, + StableDiffusionComponent.VAE_DECODER, +) +MEMORY_SENSITIVE_COMPONENTS = { + StableDiffusionComponent.TRANSFORMER, + StableDiffusionComponent.VAE_DECODER, +} +LATENT_SIZE_RETRIES = (64, 32, 16, 8) + +DTYPE_BY_NAME: dict[str, torch.dtype] = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, +} +NAME_BY_DTYPE: dict[torch.dtype, str] = { + dtype: name for name, dtype in DTYPE_BY_NAME.items() +} + + +def parse_dtype(dtype: str) -> torch.dtype: + """Parse a dtype string into a torch.dtype. + + Supported values are: fp16, bf16, fp32. + """ + try: + return DTYPE_BY_NAME[dtype] + except KeyError as e: + raise ValueError(f"Unsupported dtype: {dtype}") from e + + +def dtype_name(dtype: torch.dtype) -> str: + """Convert a torch dtype to the corresponding CLI dtype name.""" + try: + return NAME_BY_DTYPE[dtype] + except KeyError as e: + raise ValueError(f"Unsupported dtype: {dtype}") from e + + +def parse_component(component: str) -> StableDiffusionComponent: + """Parse a component string into a StableDiffusionComponent enum. + + Raises argparse.ArgumentTypeError for unsupported components. + """ + try: + return StableDiffusionComponent(component) + except ValueError as e: + raise argparse.ArgumentTypeError(f"Unsupported component: {component}") from e + + +def build_export_command( + component: StableDiffusionComponent, + model_id: str, + text_encoder_id: Optional[str], + text_encoder_2_id: Optional[str], + dtype: torch.dtype, + max_sequence_length: int, + latent_size: Optional[int], + output_dir: Path, +) -> list[str]: + """Build the command used to export one component in a child process.""" + command = [ + sys.executable, + str(Path(__file__).resolve()), + "--model-id", + model_id, + "--component", + component.value, + "--dtype", + dtype_name(dtype), + "--max-sequence-length", + str(max_sequence_length), + "--output-dir", + str(output_dir), + "--in-process", + ] + if latent_size is not None: + command.extend(["--latent-size", str(latent_size)]) + if text_encoder_id is not None: + command.extend(["--text-encoder-id", text_encoder_id]) + if text_encoder_2_id is not None: + command.extend(["--text-encoder-2-id", text_encoder_2_id]) + return command + + +def is_valid_export_file(path: Path) -> bool: + """Return true when an existing export is a non-empty readable .pt2 archive.""" + return path.is_file() and path.stat().st_size > 0 and zipfile.is_zipfile(path) + + +def latent_size_attempts( + component: StableDiffusionComponent, + latent_size: Optional[int], +) -> tuple[Optional[int], ...]: + """Return latent-size attempts for a component export.""" + if component not in MEMORY_SENSITIVE_COMPONENTS: + return (latent_size,) + + attempts = [latent_size] + retry_sizes = ( + size + for size in LATENT_SIZE_RETRIES + if latent_size is None or size < latent_size + ) + attempts.extend(retry_sizes) + + deduped_attempts = [] + for attempt in attempts: + if attempt not in deduped_attempts: + deduped_attempts.append(attempt) + return tuple(deduped_attempts) + + +def remove_invalid_export_file(path: Path) -> None: + """Remove a stale partial export if it exists and is not a valid .pt2 archive.""" + if path.exists() and not is_valid_export_file(path): + logger.warning("Removing invalid export: %s", path) + path.unlink() + + +def run_component_export_subprocess( + component: StableDiffusionComponent, + model_id: str, + text_encoder_id: Optional[str], + text_encoder_2_id: Optional[str], + dtype: torch.dtype, + max_sequence_length: int, + latent_size: Optional[int], + output_dir: Path, +) -> None: + """Export a component in a child process, retrying smaller latent sizes on OOM.""" + attempted_latent_sizes = [] + last_error = None + for attempt_latent_size in latent_size_attempts(component, latent_size): + attempted_latent_sizes.append(attempt_latent_size) + remove_invalid_export_file(output_dir / f"{component.value}.pt2") + if attempt_latent_size != latent_size: + logger.info( + "Retrying %s with --latent-size %s", + component.value, + attempt_latent_size, + ) + + try: + subprocess.run( + build_export_command( + component=component, + model_id=model_id, + text_encoder_id=text_encoder_id, + text_encoder_2_id=text_encoder_2_id, + dtype=dtype, + max_sequence_length=max_sequence_length, + latent_size=attempt_latent_size, + output_dir=output_dir, + ), + check=True, + ) + return + except subprocess.CalledProcessError as e: + last_error = e + if e.returncode >= 0: + raise + logger.warning( + "Exporting %s was killed by signal %s with latent size %s", + component.value, + -e.returncode, + attempt_latent_size, + ) + + attempted = ", ".join( + "default" if size is None else str(size) for size in attempted_latent_sizes + ) + raise RuntimeError( + f"Exporting {component.value} was killed after trying latent sizes: " + f"{attempted}. Try reducing --max-sequence-length or exporting this " + "component on a machine with more memory." + ) from last_error + + +def export_all_in_subprocesses( + model_id: str, + text_encoder_id: Optional[str], + text_encoder_2_id: Optional[str], + dtype: torch.dtype, + max_sequence_length: int, + latent_size: Optional[int], + output_dir: Path, + skip_existing: bool, +) -> list[Path]: + """Export each SD3 component in a fresh Python process to release memory.""" + output_paths = [] + for component in COMPONENTS: + output_path = output_dir / f"{component.value}.pt2" + if skip_existing and is_valid_export_file(output_path): + logger.info("Skipping %s; %s already exists", component.value, output_path) + output_paths.append(output_path) + continue + if skip_existing and output_path.exists(): + logger.warning( + "Re-exporting %s because %s is not a valid export", + component.value, + output_path, + ) + + logger.info("Exporting %s in a fresh process", component.value) + run_component_export_subprocess( + component=component, + model_id=model_id, + text_encoder_id=text_encoder_id, + text_encoder_2_id=text_encoder_2_id, + dtype=dtype, + max_sequence_length=max_sequence_length, + latent_size=latent_size, + output_dir=output_dir, + ) + output_paths.append(output_path) + + return output_paths + + +class SD35LargeExporter: + """Export wrapped SD3 components, defaulting to Stable Diffusion 3.5 Large.""" + + def __init__( + self, + model_id: str = MODEL_ID, + text_encoder_id: Optional[str] = None, + text_encoder_2_id: Optional[str] = None, + dtype: torch.dtype = torch.float16, + max_sequence_length: int = 256, + latent_size: Optional[int] = None, + ): + self.max_sequence_length = max_sequence_length + self.latent_size = latent_size + self.model_loader = StableDiffusion3ModelLoader( + model_id=model_id, + text_encoder_id=text_encoder_id, + text_encoder_2_id=text_encoder_2_id, + dtype=dtype, + ) + + def load_models(self) -> bool: + """Load all configured SD3 components.""" + return self.model_loader.load_models() + + def load_component(self, component: StableDiffusionComponent) -> bool: + """Load only the model component needed for a single export.""" + return self.model_loader.load_models([component]) + + def _component_model(self, component: StableDiffusionComponent) -> torch.nn.Module: + if component == StableDiffusionComponent.TEXT_ENCODER: + return self.model_loader.get_text_encoder_wrapper() + if component == StableDiffusionComponent.TEXT_ENCODER_2: + return self.model_loader.get_text_encoder_2_wrapper() + if component == StableDiffusionComponent.TEXT_ENCODER_3: + return self.model_loader.get_text_encoder_3_wrapper() + if component == StableDiffusionComponent.TRANSFORMER: + return self.model_loader.get_transformer_wrapper() + if component == StableDiffusionComponent.VAE_DECODER: + return self.model_loader.get_vae_decoder_wrapper() + raise ValueError(f"Unsupported SD3.5 component: {component.value}") + + def export_component( + self, + component: StableDiffusionComponent, + output_dir: Path, + ) -> Path: + """Export a single SD3 component to a .pt2 file.""" + dummy_inputs = self.model_loader.get_dummy_inputs( + max_sequence_length=self.max_sequence_length, + latent_size=self.latent_size, + ) + if component not in dummy_inputs: + raise ValueError(f"No dummy inputs are available for {component.value}") + + model = self._component_model(component).eval() + component_inputs = dummy_inputs[component] + logger.info("Exporting %s", component.value) + exported_program = export(model, component_inputs) + + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"{component.value}.pt2" + save(exported_program, output_path) + logger.info("Saved %s", output_path) + return output_path + + def export_all(self, output_dir: Path) -> list[Path]: + """Export all wrapped SD3 components.""" + return [ + self.export_component(component, output_dir) for component in COMPONENTS + ] + + +def main() -> None: + """Parse command line arguments and export the requested SD3 components.""" + parser = argparse.ArgumentParser( + description=( + "Export wrapped SD3 components, defaulting to Stable Diffusion 3.5 " + "Large." + ) + ) + parser.add_argument( + "--model-id", + default=MODEL_ID, + help="HuggingFace model id to load.", + ) + parser.add_argument( + "--text-encoder-id", + default=None, + help=( + "HuggingFace CLIP text encoder repo id. Defaults to the " + "text_encoder subfolder of --model-id." + ), + ) + parser.add_argument( + "--text-encoder-2-id", + default=None, + help=( + "HuggingFace CLIP text encoder repo id. Defaults to the " + "text_encoder_2 subfolder of --model-id." + ), + ) + parser.add_argument( + "--component", + default="all", + choices=("all", *(component.value for component in COMPONENTS)), + help=( + "Component to export: all, text_encoder, text_encoder_2, " + "text_encoder_3, transformer, or vae_decoder." + ), + ) + parser.add_argument( + "--dtype", + choices=("fp16", "bf16", "fp32"), + default="fp16", + help="Model dtype used while loading and exporting.", + ) + parser.add_argument( + "--max-sequence-length", + type=int, + default=256, + help="T5 sequence length used for text_encoder_3 and transformer inputs.", + ) + parser.add_argument( + "--latent-size", + type=int, + default=None, + help=( + "Latent height/width used for transformer and VAE decoder dummy " + "inputs. Defaults to the transformer sample size, or 128 for a " + "standalone VAE decoder export." + ), + ) + parser.add_argument( + "--no-skip-existing", + action="store_true", + help="Re-export components even if the destination .pt2 file already exists.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("sd3.5-large-exported"), + help="Directory where exported .pt2 files are written.", + ) + parser.add_argument( + "--in-process", + action="store_true", + help=argparse.SUPPRESS, + ) + args = parser.parse_args() + dtype = parse_dtype(args.dtype) + + if args.component == "all": + export_all_in_subprocesses( + model_id=args.model_id, + text_encoder_id=args.text_encoder_id, + text_encoder_2_id=args.text_encoder_2_id, + dtype=dtype, + max_sequence_length=args.max_sequence_length, + latent_size=args.latent_size, + output_dir=args.output_dir, + skip_existing=not args.no_skip_existing, + ) + return + + component = parse_component(args.component) + if not args.in_process and component in MEMORY_SENSITIVE_COMPONENTS: + run_component_export_subprocess( + component=component, + model_id=args.model_id, + text_encoder_id=args.text_encoder_id, + text_encoder_2_id=args.text_encoder_2_id, + dtype=dtype, + max_sequence_length=args.max_sequence_length, + latent_size=args.latent_size, + output_dir=args.output_dir, + ) + return + + exporter = SD35LargeExporter( + model_id=args.model_id, + text_encoder_id=args.text_encoder_id, + text_encoder_2_id=args.text_encoder_2_id, + dtype=dtype, + max_sequence_length=args.max_sequence_length, + latent_size=args.latent_size, + ) + if not exporter.load_component(component): + raise RuntimeError("Failed to load SD3.5 Large models") + + exporter.export_component(component, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/models/stable_diffusion_3_5_large/model.py b/examples/models/stable_diffusion_3_5_large/model.py new file mode 100644 index 00000000000..2badae5e1f1 --- /dev/null +++ b/examples/models/stable_diffusion_3_5_large/model.py @@ -0,0 +1,415 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections.abc import Iterable +from enum import Enum +from typing import Any, Optional + +import torch +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers import SD3Transformer2DModel +from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel + + +logger = logging.getLogger(__name__) + +MODEL_ID = "stabilityai/stable-diffusion-3.5-large" + + +def _subfolder_for_model_repo( + repo_id: str, model_id: str, subfolder: str +) -> str | None: + return subfolder if repo_id == model_id else None + + +def _optional_subfolder_kwargs(subfolder: str | None) -> dict[str, str]: + return {"subfolder": subfolder} if subfolder else {} + + +class StableDiffusionComponent(Enum): + """Stable Diffusion component names used by this exporter.""" + + TEXT_ENCODER = "text_encoder" + TEXT_ENCODER_2 = "text_encoder_2" + TEXT_ENCODER_3 = "text_encoder_3" + TRANSFORMER = "transformer" + VAE_DECODER = "vae_decoder" + + +class SD3CLIPTextEncoderWrapper(torch.nn.Module): + """Wrapper for SD3 CLIP text encoders.""" + + def __init__(self, text_encoder, clip_skip: Optional[int] = None): + super().__init__() + self.text_encoder = text_encoder + self.clip_skip = clip_skip + + def forward(self, input_ids): + """Forward pass for CLIP text encoder.""" + output = self.text_encoder( + input_ids, output_hidden_states=True, return_dict=True + ) + hidden_state_index = -2 if self.clip_skip is None else -(self.clip_skip + 2) + return output.hidden_states[hidden_state_index], output[0] + + +class SD3T5TextEncoderWrapper(torch.nn.Module): + """Wrapper for SD3 T5 text encoder.""" + + def __init__(self, text_encoder): + super().__init__() + self.text_encoder = text_encoder + + def forward(self, input_ids): + """Forward pass for T5 text encoder.""" + output = self.text_encoder(input_ids, return_dict=True) + return output.last_hidden_state + + +class SD3TransformerWrapper(torch.nn.Module): + """Wrapper for SD3 transformer denoiser that extracts sample tensor.""" + + def __init__(self, transformer): + super().__init__() + self.transformer = transformer + + def forward( + self, + latents, + timestep, + encoder_hidden_states, + pooled_projections, + ): + """Forward pass through the transformer denoiser. + + Args: + latents: Input latent tensor + timestep: Timestep for denoising + encoder_hidden_states: Hidden states from text encoder + pooled_projections: Pooled projection embeddings + + Returns: + Sample output tensor from transformer + """ + output = self.transformer( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + pooled_projections=pooled_projections, + return_dict=True, + ) + return output.sample + + +class SD3VAEDecoderWrapper(torch.nn.Module): + """Wrapper for SD3 VAE decoder with scaling, shift, and normalization.""" + + def __init__(self, vae): + super().__init__() + self.vae = vae + + def forward(self, latents): + """Decode latents to image using VAE decoder with scaling and normalization.""" + latents = latents / self.vae.config.scaling_factor + shift_factor = getattr(self.vae.config, "shift_factor", None) + if shift_factor is not None: + latents = latents + shift_factor + image = self.vae.decode(latents, return_dict=True).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image + + +class StableDiffusion3ModelLoader: + """Load SD3 components and construct export wrappers locally.""" + + def __init__( + self, + model_id: str = MODEL_ID, + text_encoder_id: Optional[str] = None, + text_encoder_2_id: Optional[str] = None, + dtype: torch.dtype = torch.float16, + ): + self.model_id = model_id + self.text_encoder_id = text_encoder_id or model_id + self.text_encoder_2_id = text_encoder_2_id or model_id + self.dtype = dtype + self.text_encoder: Any = None + self.text_encoder_2: Any = None + self.text_encoder_3: Any = None + self.transformer: Any = None + self.vae: Any = None + self.tokenizer: Any = None + self.tokenizer_2: Any = None + + def _load_tokenizer(self) -> None: + if self.tokenizer is not None: + return + + self.tokenizer = CLIPTokenizer.from_pretrained( + self.text_encoder_id, + **_optional_subfolder_kwargs( + _subfolder_for_model_repo( + self.text_encoder_id, + self.model_id, + "tokenizer", + ) + ), + ) + + def _load_text_encoder(self) -> None: + if self.text_encoder is not None: + return + + text_encoder_subfolder = _subfolder_for_model_repo( + self.text_encoder_id, + self.model_id, + "text_encoder", + ) + logger.info( + "Loading CLIP text encoder: %s%s (dtype: %s)", + self.text_encoder_id, + f"/{text_encoder_subfolder}" if text_encoder_subfolder else "", + self.dtype, + ) + self._load_tokenizer() + self.text_encoder = CLIPTextModelWithProjection.from_pretrained( + self.text_encoder_id, + torch_dtype=self.dtype, + **_optional_subfolder_kwargs(text_encoder_subfolder), + ) + + def _load_text_encoder_2(self) -> None: + if self.text_encoder_2 is not None: + return + + tokenizer_2_subfolder = _subfolder_for_model_repo( + self.text_encoder_2_id, + self.model_id, + "tokenizer_2", + ) + text_encoder_2_subfolder = _subfolder_for_model_repo( + self.text_encoder_2_id, + self.model_id, + "text_encoder_2", + ) + logger.info( + "Loading CLIP text encoder 2: %s%s (dtype: %s)", + self.text_encoder_2_id, + f"/{text_encoder_2_subfolder}" if text_encoder_2_subfolder else "", + self.dtype, + ) + self.tokenizer_2 = CLIPTokenizer.from_pretrained( + self.text_encoder_2_id, + **_optional_subfolder_kwargs(tokenizer_2_subfolder), + ) + self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + self.text_encoder_2_id, + torch_dtype=self.dtype, + **_optional_subfolder_kwargs(text_encoder_2_subfolder), + ) + + def _load_text_encoder_3(self) -> None: + if self.text_encoder_3 is not None: + return + + logger.info( + "Loading T5 text encoder: %s/text_encoder_3 (dtype: %s)", + self.model_id, + self.dtype, + ) + self.text_encoder_3 = T5EncoderModel.from_pretrained( + self.model_id, + subfolder="text_encoder_3", + torch_dtype=self.dtype, + ) + + def _load_transformer(self) -> None: + if self.transformer is not None: + return + + logger.info( + "Loading SD3 transformer: %s/transformer (dtype: %s)", + self.model_id, + self.dtype, + ) + self._load_tokenizer() + self.transformer = SD3Transformer2DModel.from_pretrained( + self.model_id, + subfolder="transformer", + torch_dtype=self.dtype, + ) + + def _load_vae(self) -> None: + if self.vae is not None: + return + + logger.info( + "Loading VAE: %s/vae (dtype: %s)", + self.model_id, + self.dtype, + ) + self.vae = AutoencoderKL.from_pretrained( + self.model_id, + subfolder="vae", + torch_dtype=self.dtype, + ) + + def load_models( + self, components: Optional[Iterable[StableDiffusionComponent]] = None + ) -> bool: + """Load the requested SD3 components.""" + try: + requested_components = set( + StableDiffusionComponent if components is None else components + ) + if StableDiffusionComponent.TEXT_ENCODER in requested_components: + self._load_text_encoder() + if StableDiffusionComponent.TEXT_ENCODER_2 in requested_components: + self._load_text_encoder_2() + if StableDiffusionComponent.TEXT_ENCODER_3 in requested_components: + self._load_text_encoder_3() + if StableDiffusionComponent.TRANSFORMER in requested_components: + self._load_transformer() + if StableDiffusionComponent.VAE_DECODER in requested_components: + self._load_vae() + + for model in ( + self.text_encoder, + self.text_encoder_2, + self.text_encoder_3, + self.transformer, + self.vae, + ): + if model is not None: + model.to(dtype=self.dtype) + model.eval() + + logger.info("Successfully loaded requested SD3 model components") + return True + except (OSError, ValueError, RuntimeError, ImportError) as e: + logger.exception("Failed to load SD3 models: %s", e) + return False + + def get_text_encoder_wrapper( + self, clip_skip: Optional[int] = None + ) -> SD3CLIPTextEncoderWrapper: + """Get wrapped first CLIP text encoder ready for export.""" + if self.text_encoder is None: + raise ValueError("Models not loaded. Call load_models() first.") + return SD3CLIPTextEncoderWrapper(self.text_encoder, clip_skip=clip_skip) + + def get_text_encoder_2_wrapper( + self, clip_skip: Optional[int] = None + ) -> SD3CLIPTextEncoderWrapper: + """Get wrapped second CLIP text encoder ready for export.""" + if self.text_encoder_2 is None: + raise ValueError("Models not loaded. Call load_models() first.") + return SD3CLIPTextEncoderWrapper(self.text_encoder_2, clip_skip=clip_skip) + + def get_text_encoder_3_wrapper(self) -> SD3T5TextEncoderWrapper: + """Get wrapped T5 text encoder ready for export.""" + if self.text_encoder_3 is None: + raise ValueError("Models not loaded. Call load_models() first.") + return SD3T5TextEncoderWrapper(self.text_encoder_3) + + def get_transformer_wrapper(self) -> SD3TransformerWrapper: + """Get wrapped SD3 transformer ready for export.""" + if self.transformer is None: + raise ValueError("Models not loaded. Call load_models() first.") + return SD3TransformerWrapper(self.transformer) + + def get_vae_decoder_wrapper(self) -> SD3VAEDecoderWrapper: + """Get wrapped SD3 VAE decoder ready for export.""" + if self.vae is None: + raise ValueError("Models not loaded. Call load_models() first.") + return SD3VAEDecoderWrapper(self.vae) + + def get_dummy_inputs( + self, + max_sequence_length: int = 256, + latent_size: Optional[int] = None, + ) -> dict[StableDiffusionComponent, tuple[Any, ...]]: + """Get dummy inputs for each wrapped SD3 component.""" + if not any( + ( + self.text_encoder, + self.text_encoder_2, + self.text_encoder_3, + self.transformer, + self.vae, + ) + ): + raise ValueError("Models not loaded. Call load_models() first.") + + batch_size = 1 + tokenizer_max_length = ( + self.tokenizer.model_max_length if self.tokenizer is not None else 77 + ) + text_seq_len = tokenizer_max_length + max_sequence_length + + dummy_inputs: dict[StableDiffusionComponent, tuple[Any, ...]] = {} + if self.text_encoder is not None: + dummy_inputs[StableDiffusionComponent.TEXT_ENCODER] = ( + torch.randn(batch_size, tokenizer_max_length) + .abs() + .round() + .to(dtype=torch.long), + ) + if self.text_encoder_2 is not None: + dummy_inputs[StableDiffusionComponent.TEXT_ENCODER_2] = ( + torch.randn(batch_size, tokenizer_max_length) + .abs() + .round() + .to(dtype=torch.long), + ) + if self.text_encoder_3 is not None: + dummy_inputs[StableDiffusionComponent.TEXT_ENCODER_3] = ( + torch.randn(batch_size, max_sequence_length) + .abs() + .round() + .to(dtype=torch.long), + ) + + if self.transformer is not None: + latent_channels = self.transformer.config.in_channels + transformer_latent_size = latent_size or self.transformer.config.sample_size + joint_attention_dim = self.transformer.config.joint_attention_dim + pooled_projection_dim = self.transformer.config.pooled_projection_dim + dummy_inputs[StableDiffusionComponent.TRANSFORMER] = ( + torch.randn( + batch_size, + latent_channels, + transformer_latent_size, + transformer_latent_size, + dtype=self.dtype, + ), + torch.tensor([1.0], dtype=torch.float32), + torch.randn( + batch_size, + text_seq_len, + joint_attention_dim, + dtype=self.dtype, + ), + torch.randn(batch_size, pooled_projection_dim, dtype=self.dtype), + ) + + if self.vae is not None: + vae_latent_size = latent_size or ( + self.transformer.config.sample_size + if self.transformer is not None + else 128 + ) + vae_latent_channels = getattr(self.vae.config, "latent_channels", 16) + dummy_inputs[StableDiffusionComponent.VAE_DECODER] = ( + torch.randn( + batch_size, + vae_latent_channels, + vae_latent_size, + vae_latent_size, + dtype=self.dtype, + ), + ) + return dummy_inputs