diff --git a/docs/source/en/quantization/autoround.md b/docs/source/en/quantization/autoround.md new file mode 100644 index 000000000000..cc3b3693e381 --- /dev/null +++ b/docs/source/en/quantization/autoround.md @@ -0,0 +1,146 @@ + + +# AutoRound + +[AutoRound](https://github.com/intel/auto-round) is an advanced quantization toolkit. It achieves high accuracy at ultra-low bit widths (2-4 bits) with minimal tuning by leveraging sign-gradient descent and providing broad hardware compatibility. See our papers [SignRoundV1](https://arxiv.org/pdf/2309.05516) and [SignRoundV2](https://arxiv.org/abs/2512.04746) for more details. + + +Install `auto-round`(version ≥ 0.13.0): + +```bash +pip install "auto-round>=0.13.0" +``` + +To use the Marlin kernel for faster CUDA inference, install `gptqmodel`: + +```bash +pip install "gptqmodel>=5.8.0" +``` + +## Load a quantized model + +Load a pre-quantized AutoRound model by passing [`AutoRoundConfig`] to [`~ModelMixin.from_pretrained`]. The method works with any model that loads via [Accelerate(https://hf.co/docs/accelerate/index) and has `torch.nn.Linear` layers. + +```python +import torch +from diffusers import ZImageTransformer2DModel, ZImagePipeline, AutoRoundConfig + +model_id = "INCModel/Z-Image-W4A16-AutoRound" + +quantization_config = AutoRoundConfig(backend="marlin") +transformer = ZImageTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + device_map="cuda", +) + +pipe = ZImagePipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=torch.bfloat16, + device_map="cuda", +) + +image = pipe("a cat holding a sign that says hello").images[0] +image.save("output.png") +``` + +> [!NOTE] +> AutoRound in Diffusers only supports loading *pre-quantized* models. To quantize a model from scratch, use the [AutoRound CLI or Python API](https://github.com/intel/auto-round) directly, then load the result with Diffusers. + +## Backends + +AutoRound supports multiple inference backends. The backend controls which kernel handles dequantization during the forward pass. Set the `backend` parameter in [`AutoRoundConfig`] to choose one: + +| Backend | Value | Device | Requirements | Notes | +|---------|-------|--------|--------------|-------| +| **Auto** | `"auto"` | Any | — | Default. Automatically selects the best available backend. | +| **PyTorch** | `"torch"` | CPU / CUDA | — | Pure PyTorch implementation. Broadest compatibility. | +| **Triton** | `"tritonv2"` | CUDA | `triton` | Triton-based kernel for GPU inference. | +| **ExllamaV2** | `"exllamav2"` | CUDA | `gptqmodel>=5.8.0` | Good CUDA performance via the ExllamaV2 kernel. | +| **Marlin** | `"marlin"` | CUDA | `gptqmodel>=5.8.0` | Best CUDA performance via the Marlin kernel. | + + +```python +from diffusers import AutoRoundConfig + +# Auto-select (default) +config = AutoRoundConfig() + +# Explicit Triton backend for CUDA +config = AutoRoundConfig(backend="tritonv2") + +# Marlin backend for best CUDA performance (requires gptqmodel>=5.8.0) +config = AutoRoundConfig(backend="marlin") + +# Marlin backend for best CUDA performance (requires gptqmodel>=5.8.0) +config = AutoRoundConfig(backend="exllamav2") + +# PyTorch backend for CPU/CUDA inference +config = AutoRoundConfig(backend="torch") +``` + + +## Quantization configurations + +AutoRound focuses on weight-only quantization. The primary configuration is W4A16 (4-bit weights, 16-bit activations), with flexibility in group size and symmetry: + +| Configuration | `bits` | `group_size` | `sym` | Description | +|--------------|--------|-------------|-------|-------------| +| W4G128 asymmetric | `4` | `128` | `False` | Default. Good balance of accuracy and compression. | +| W4G128 symmetric | `4` | `128` | `True` | Faster dequantization, small accuracy trade-off. | +| W4G32 asymmetric | `4` | `32` | `False` | Higher accuracy at the cost of more metadata. | + +## Save and load + + + + +```python +from auto_round import AutoRound +autoround = AutoRound( + tiny_z_image_model_path, + num_inference_steps=3, + guidance_scale=7.5, + dataset="coco2014, +) +autoround.quantize_and_save("Z-Image-W4A16-AutoRound") +``` + + + + +```python +import torch +from diffusers import ZImageTransformer2DModel, ZImagePipeline + +model_id = "INCModel/Z-Image-W4A16-AutoRound" + +# The inference backend will be automatically selected. +pipe = ZImagePipeline.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="cuda", +) + +image = pipe("a cat holding a sign that says hello").images[0] +image.save("output.png") +``` + + + + +## Resources + +- [Pre-quantized AutoRound models on the Hub](https://huggingface.co/models?search=autoround) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2cbfd6e29305..4bce624764e1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,6 +7,7 @@ OptionalDependencyNotAvailable, _LazyModule, is_accelerate_available, + is_auto_round_available, is_bitsandbytes_available, is_flax_available, is_gguf_available, @@ -122,6 +123,18 @@ else: _import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig") +try: + if not is_auto_round_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_auto_round_objects + + _import_structure["utils.dummy_auto_round_objects"] = [ + name for name in dir(dummy_auto_round_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("AutoRoundConfig") + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -945,6 +958,14 @@ else: from .quantizers.quantization_config import NVIDIAModelOptConfig + try: + if not is_auto_round_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_auto_round_objects import * + else: + from .quantizers.quantization_config import AutoRoundConfig + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 6cd24c459c9d..fae2dfab7327 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -18,10 +18,12 @@ import warnings +from .autoround import AutoRoundQuantizer from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer from .gguf import GGUFQuantizer from .modelopt import NVIDIAModelOptQuantizer from .quantization_config import ( + AutoRoundConfig, BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, @@ -41,6 +43,7 @@ "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, "modelopt": NVIDIAModelOptQuantizer, + "auto-round": AutoRoundQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -50,6 +53,7 @@ "quanto": QuantoConfig, "torchao": TorchAoConfig, "modelopt": NVIDIAModelOptConfig, + "auto-round": AutoRoundConfig, } @@ -136,13 +140,26 @@ def merge_quantization_configs( ) else: warning_msg = "" - if isinstance(quantization_config, dict): + existing_fields = set(quantization_config.keys()) quantization_config = cls.from_dict(quantization_config) + else: + existing_fields = set(quantization_config.__dict__.keys()) if isinstance(quantization_config, NVIDIAModelOptConfig): quantization_config.check_model_patching() + if quantization_config_from_args is not None: + # Only override fields that the user explicitly set. + for key, value in quantization_config_from_args.__dict__.items(): + if key not in existing_fields: + # Field does not exist in the model's quantization_config, add it. + setattr(quantization_config, key, value) + warning_msg += ( + f" Field `{key}` from `quantization_config_from_args` is not present in the model's " + f"`quantization_config`. Adding it with value: {value!r}." + ) + if warning_msg != "": warnings.warn(warning_msg) diff --git a/src/diffusers/quantizers/autoround/__init__.py b/src/diffusers/quantizers/autoround/__init__.py new file mode 100644 index 000000000000..2fe2083d4a5f --- /dev/null +++ b/src/diffusers/quantizers/autoround/__init__.py @@ -0,0 +1 @@ +from .autoround_quantizer import AutoRoundQuantizer diff --git a/src/diffusers/quantizers/autoround/autoround_quantizer.py b/src/diffusers/quantizers/autoround/autoround_quantizer.py new file mode 100644 index 000000000000..f64c328f0261 --- /dev/null +++ b/src/diffusers/quantizers/autoround/autoround_quantizer.py @@ -0,0 +1,128 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + is_auto_round_available, + is_torch_available, + logging, +) +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class AutoRoundQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for AutoRound (https://github.com/intel/auto-round). + + AutoRound is a weight-only quantization method that uses sign gradient descent to jointly optimize + rounding values and min-max ranges for weights. It supports W4A16 (4-bit weight, 16-bit activation) + quantization for efficient inference. + + This quantizer only supports loading pre-quantized AutoRound models. On-the-fly quantization + (calibration) is not supported through this interface. + """ + + # AutoRound requires data calibration — we only support loading pre-quantized checkpoints. + requires_calibration = True + required_packages = ["auto_round"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + """ + Validates that the auto-round library (>= 0.5) is installed and captures the device_map + for later use during model conversion. + """ + self.device_map = kwargs.get("device_map", None) + if not is_auto_round_available(): + raise ImportError( + "Loading an AutoRound quantized model requires the auto-round library " + "(`pip install 'auto-round>=0.5'`)" + ) + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: list[str] = [], + **kwargs, + ): + """ + Replaces target nn.Linear layers with AutoRound's quantized QuantLinear layers before + weights are loaded from the checkpoint. + + Uses `auto_round.inference.convert_model.convert_hf_model` which: + - Inspects the model architecture and the quantization config (bits, group_size, sym, backend). + - Replaces eligible nn.Linear modules with the appropriate QuantLinear variant + (the packed-weight layer that stores qweight, scales, qzeros). + - Returns the converted model and a set of used backend names. + + `infer_target_device` resolves the device_map into a single target device string + that AutoRound uses to select the correct kernel backend (e.g. "cuda", "cpu"). + """ + from auto_round.inference.convert_model import convert_hf_model, infer_target_device + + if self.pre_quantized: + target_device = infer_target_device(self.device_map) + model, used_backends = convert_hf_model(model, target_device) + self.used_backends = used_backends + + def _process_model_after_weight_loading(self, model, **kwargs): + """ + Finalizes the model after all quantized weights (qweight, scales, qzeros, etc.) have + been loaded into the QuantLinear layers. + + Uses `auto_round.inference.convert_model.post_init` which: + - Performs backend-specific finalization (e.g. repacking weights into the kernel's + expected memory layout, moving buffers to the correct device). + - Freezes quantized parameters (requires_grad=False). + - Prepares the model for inference. + + Raises ValueError if the model is not pre-quantized, since AutoRound does not support + on-the-fly quantization through this loading path. + """ + if self.pre_quantized: + from auto_round.inference.convert_model import post_init + + post_init(model, self.used_backends) + else: + raise ValueError( + "AutoRound quantizer in diffusers only supports pre-quantized models. " + "Please provide a model that has already been quantized with AutoRound." + ) + return model + + @property + def is_trainable(self) -> bool: + """AutoRound W4A16 pre-quantized models do not support training.""" + return False + + @property + def is_serializable(self): + """AutoRound quantized models can be serialized (the quantization config may be + updated by the backend, e.g. for GPTQ/AWQ-compatible formats).""" + return True + diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index c3d829fde8cf..6f0bbd7bf5c7 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -48,6 +48,7 @@ class QuantizationMethod(str, Enum): TORCHAO = "torchao" QUANTO = "quanto" MODELOPT = "modelopt" + AUTOROUND = "auto-round" @dataclass @@ -749,3 +750,73 @@ def get_config_from_quant_type(self) -> dict[str, Any]: ) return BASE_CONFIG + + +@dataclass +class AutoRoundConfig(QuantizationConfigMixin): + """Configuration class for AutoRound quantization. + + AutoRound is a weight-only quantization algorithm that uses sign gradient descent to + jointly optimize weight rounding and min-max values. This config targets the W4A16 + (4-bit weights, 16-bit activations) setting. + + Reference: https://github.com/intel/auto-round + + Args: + bits (`int`, *optional*, defaults to `4`): + The number of bits to quantize weights to. For W4A16 this should be 4. + group_size (`int`, *optional*, defaults to `128`): + The group size for weight quantization. Weights in each group share the same + scale and zero-point. Common choices: 32, 64, 128, -1 (per-channel). + sym (`bool`, *optional*, defaults to `True`): + Whether to use symmetric quantization (zero-point fixed at 0) or asymmetric + quantization (zero-point is learned). + backend (`str`, *optional*, defaults to `"auto"`): + The backend kernel to use for quantized inference. Available backends: + - `"auto"`: Automatically select the best available backend for the current device. + - `"auto_round:torch_zp"`: Pure PyTorch kernel — works on CPU and CUDA. + - `"auto_round:tritonv2_zp"`: Triton-based kernel — requires CUDA. + - `"gptqmodel:marlin_zp"`: Marlin kernel via GPTQModel — requires CUDA and + `gptqmodel>=5.8.0`. Offers the best CUDA inference performance. + kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments forwarded to AutoRound (e.g. `iters`, `seqlen`, + `batch_size`, `lr`, `minmax_lr` for calibration when quantizing from scratch). + """ + + def __init__( + self, + bits: int = 4, + group_size: int = 128, + sym: bool = True, + backend: str = "auto", + **kwargs, + ) -> None: + self.quant_method = QuantizationMethod.AUTOROUND + self.bits = bits + self.group_size = group_size + self.sym = sym + self.backend = backend + for k, v in kwargs.items(): + setattr(self, k, v) + + def to_dict(self) -> dict: + """Serialize the config to a JSON-compatible dict. + + Output: A dict containing all config fields. The `quant_method` is stored as + its string value so it can be round-tripped through JSON. + """ + output = super().to_dict() + output["quant_method"] = output["quant_method"].value + return output + + @classmethod + def from_dict(cls, config_dict: dict, return_unused_kwargs: bool = False, **kwargs): + """Instantiate an AutoRoundConfig from a dictionary. + + Input: config_dict with keys like bits, group_size, sym, etc. + Output: An AutoRoundConfig instance (and optionally unused kwargs). + """ + # Filter out keys that are not constructor parameters + # (e.g. quant_method is set automatically) + config_dict = {k: v for k, v in config_dict.items() if k != "quant_method"} + return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cf18cacbe535..849ed04a22ac 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -88,6 +88,7 @@ is_hpu_available, is_inflect_available, is_invisible_watermark_available, + is_auto_round_available, is_kernels_available, is_kernels_version, is_kornia_available, diff --git a/src/diffusers/utils/dummy_auto_round_objects.py b/src/diffusers/utils/dummy_auto_round_objects.py new file mode 100644 index 000000000000..be7a6b8403cb --- /dev/null +++ b/src/diffusers/utils/dummy_auto_round_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class AutoRoundConfig(metaclass=DummyObject): + _backends = ["auto_round"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["auto_round"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["auto_round"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["auto_round"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 64e3e54887f5..5bcee256b918 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -230,6 +230,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True) _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) +_auto_round_available, _auto_round_version = _is_package_available("auto_round") _flashpack_available, _flashpack_version = _is_package_available("flashpack") _av_available, _av_version = _is_package_available("av") @@ -378,6 +379,10 @@ def is_nvidia_modelopt_available(): return _nvidia_modelopt_available +def is_auto_round_available(): + return _auto_round_available + + def is_timm_available(): return _timm_available diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 619a37034949..2ea9d039e4f0 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -33,6 +33,7 @@ from .import_utils import ( BACKENDS_MAPPING, is_accelerate_available, + is_auto_round_available, is_bitsandbytes_available, is_compel_available, is_flax_available, @@ -654,6 +655,19 @@ def decorator(test_case): return decorator +def require_auto_round_version_greater_or_equal(auto_round_version): + def decorator(test_case): + correct_auto_round_version = is_auto_round_available() and version.parse( + version.parse(importlib.metadata.version("auto_round")).base_version + ) >= version.parse(auto_round_version) + return unittest.skipUnless( + correct_auto_round_version, + f"Test requires auto-round with version greater than {auto_round_version}.", + )(test_case) + + return decorator + + def require_kernels_version_greater_or_equal(kernels_version): def decorator(test_case): correct_kernels_version = is_kernels_available() and version.parse( diff --git a/tests/quantization/auto_round/__init__.py b/tests/quantization/auto_round/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/auto_round/test_autoround.py b/tests/quantization/auto_round/test_autoround.py new file mode 100644 index 000000000000..7b53cd0fdaef --- /dev/null +++ b/tests/quantization/auto_round/test_autoround.py @@ -0,0 +1,463 @@ +import gc +import tempfile +import unittest +import warnings + +from diffusers import AutoRoundConfig, ZImageTransformer2DModel, ZImagePipeline +from diffusers.quantizers.auto import DiffusersAutoQuantizer +from diffusers.quantizers.quantization_config import QuantizationMethod +from diffusers.utils import is_auto_round_available, is_torch_available +from diffusers.utils.testing_utils import ( + backend_empty_cache, + backend_reset_peak_memory_stats, + enable_full_determinism, + nightly, + numpy_cosine_similarity_distance, + require_accelerate, + require_big_accelerator, + require_auto_round_version_greater_or_equal, + require_torch_cuda_compatibility, + torch_device, +) + + +if is_torch_available(): + import torch + + from ..utils import get_memory_consumption_stat + + +def _is_gptqmodel_available(min_version="5.8.0"): + """Check if gptqmodel is installed with a minimum version.""" + try: + import importlib.metadata + + from packaging import version + + gptqmodel_version = importlib.metadata.version("gptqmodel") + return version.parse(gptqmodel_version) >= version.parse(min_version) + except importlib.metadata.PackageNotFoundError: + return False + + +enable_full_determinism() + + +@nightly +@require_big_accelerator +@require_accelerate +@require_auto_round_version_greater_or_equal("0.13.0") +class AutoRoundBaseTesterMixin: + """Base test mixin for AutoRound quantized models. + + AutoRound is a weight-only quantization method (W4A16). It supports multiple inference + backends depending on the hardware: + - CPU: `auto_round:torch_zp` backend + - CUDA: `auto_round:tritonv2_zp` backend + - CUDA + GPTQModel>=5.8.0: `gptqmodel:marlin_zp` backend (best performance) + + When `backend="auto"`, AutoRound selects the best available backend automatically. + + Key differences from ModelOpt tests: + - Only pre-quantized model loading is supported (no on-the-fly quantization). + - `is_trainable` returns False, so no LoRA training test. + - No `test_dtype_assignment` (AutoRound doesn't restrict dtype changes). + - `requires_calibration = True` means we always load pre-quantized checkpoints. + """ + + # TODO: Replace with a real tiny AutoRound-quantized checkpoint on the Hub. + # This should be a small model that has been quantized with AutoRound and uploaded + # in the standard format (qweight, scales, qzeros, g_idx). + model_id = "INCModel/Z-Image-tiny-for-testing-W4A16-AutoRound" + model_cls = ZImageTransformer2DModel + pipeline_cls = ZImagePipeline + torch_dtype = torch.bfloat16 + expected_memory_reduction = 0.0 + _test_torch_compile = False + + def setUp(self): + backend_reset_peak_memory_stats(torch_device) + backend_empty_cache(torch_device) + gc.collect() + + def tearDown(self): + backend_reset_peak_memory_stats(torch_device) + backend_empty_cache(torch_device) + gc.collect() + + def get_dummy_init_kwargs(self): + """Returns the default AutoRoundConfig kwargs for W4A16 quantization. + + Subclasses override this to specify backend, group_size, sym, etc. + """ + return { + "bits": 4, + "group_size": 128, + "sym": False, + } + + def get_dummy_model_init_kwargs(self): + """Returns kwargs for model_cls.from_pretrained() with AutoRound quantization.""" + return { + "pretrained_model_name_or_path": self.model_id, + "torch_dtype": self.torch_dtype, + "quantization_config": AutoRoundConfig(**self.get_dummy_init_kwargs()), + "subfolder": "transformer", + } + + def get_dummy_inputs(self): + """Creates dummy inputs matching ZImageTransformer2DModel.forward() signature. + + ZImageTransformer2DModel expects: + - x: list of (C, F, H, W) tensors, one per batch item + - t: 1-D timestep tensor of shape (batch_size,) + - cap_feats: list of (seq_len, cap_feat_dim) tensors, one per batch item + + Dimensions are chosen to match the tiny test checkpoint + (in_channels=16, cap_feat_dim=512, patch_size=2, f_patch_size=1). + """ + batch_size = 1 + in_channels = 16 # matches tiny model config + cap_feat_dim = 512 # matches tiny model config + height = width = 8 # must be divisible by patch_size=2 + frames = 1 # must be divisible by f_patch_size=1 + seq_len = 16 # caption token count (will be padded to multiple of 32) + + torch.manual_seed(0) + x = [ + torch.randn((in_channels, frames, height, width)).to(torch_device, dtype=self.torch_dtype) + for _ in range(batch_size) + ] + cap_feats = [ + torch.randn((seq_len, cap_feat_dim)).to(torch_device, dtype=self.torch_dtype) + for _ in range(batch_size) + ] + t = torch.tensor([0.5] * batch_size).to(torch_device, dtype=self.torch_dtype) + + return {"x": x, "cap_feats": cap_feats, "t": t} + + def test_autoround_memory_usage(self): + """Compare peak memory between unquantized and AutoRound-quantized model. + + The quantized model should use significantly less memory due to 4-bit weight packing. + `expected_memory_reduction` defines the minimum ratio (unquantized / quantized). + """ + inputs = self.get_dummy_inputs() + # x and cap_feats are lists of tensors; move each element individually. + inputs = { + k: [t.to(device=torch_device, dtype=self.torch_dtype) for t in v] + if isinstance(v, list) + else v.to(device=torch_device, dtype=self.torch_dtype) + for k, v in inputs.items() + if not isinstance(v, bool) + } + + unquantized_model = self.model_cls.from_pretrained( + self.model_id, torch_dtype=self.torch_dtype, subfolder="transformer" + ) + unquantized_model.to(torch_device) + unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs) + + quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + quantized_model.to(torch_device) + quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs) + + assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction + + + def test_serialization(self): + """Test round-trip save and load of an AutoRound quantized model.""" + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + inputs = self.get_dummy_inputs() + + model.to(torch_device) + with torch.no_grad(): + model_output = model(**inputs) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + saved_model = self.model_cls.from_pretrained( + tmp_dir, + torch_dtype=self.torch_dtype, + ) + + saved_model.to(torch_device) + with torch.no_grad(): + saved_model_output = saved_model(**inputs) + + # model_output.sample is a list of per-item tensors + for out, saved_out in zip(model_output.sample, saved_model_output.sample): + assert torch.allclose(out, saved_out, rtol=1e-5, atol=1e-5) + + def test_torch_compile(self): + """Test that the quantized model works with torch.compile.""" + if not self._test_torch_compile: + return + + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False) + + model.to(torch_device) + with torch.no_grad(): + model_output = model(**self.get_dummy_inputs()).sample + + compiled_model.to(torch_device) + with torch.no_grad(): + compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample + + # model_output is a list of per-item tensors; stack for comparison + model_output = torch.stack([o.detach().float().cpu() for o in model_output]).numpy() + compiled_model_output = torch.stack([o.detach().float().cpu() for o in compiled_model_output]).numpy() + + max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten()) + assert max_diff < 1e-3 + + def test_model_cpu_offload(self): + """Test that the quantized model works with pipeline CPU offload.""" + init_kwargs = self.get_dummy_init_kwargs() + transformer = self.model_cls.from_pretrained( + self.model_id, + quantization_config=AutoRoundConfig(**init_kwargs), + subfolder="transformer", + torch_dtype=self.torch_dtype, + ) + pipe = self.pipeline_cls.from_pretrained(self.model_id, transformer=transformer, torch_dtype=self.torch_dtype) + pipe.enable_model_cpu_offload(device=torch_device) + _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) + + +# ============================================================================ +# Backend: auto (auto-select best available backend) +# ============================================================================ + + +class AutoRoundW4G128AsymAutoBackendTest(AutoRoundBaseTesterMixin, unittest.TestCase): + """W4A16, group_size=128, asymmetric, backend='auto' (default — auto-selects best backend).""" + + expected_memory_reduction = 0.55 + + def get_dummy_init_kwargs(self): + return { + "bits": 4, + "group_size": 128, + "sym": False, + "backend": "auto", + } + + +class AutoRoundW4G128SymAutoBackendTest(AutoRoundBaseTesterMixin, unittest.TestCase): + """W4A16, group_size=128, symmetric, backend='auto'.""" + + expected_memory_reduction = 0.55 + + def get_dummy_init_kwargs(self): + return { + "bits": 4, + "group_size": 128, + "sym": True, + "backend": "auto", + } + + +class AutoRoundW4G32AsymAutoBackendTest(AutoRoundBaseTesterMixin, unittest.TestCase): + """W4A16, group_size=32, asymmetric, backend='auto' (finer granularity).""" + + expected_memory_reduction = 0.50 + + def get_dummy_init_kwargs(self): + return { + "bits": 4, + "group_size": 32, + "sym": False, + "backend": "auto", + } + + +# ============================================================================ +# Backend: auto_round:tritonv2_zp (CUDA, Triton-based kernel) +# ============================================================================ + + +@require_torch_cuda_compatibility(7.0) +class AutoRoundW4G128AsymTritonTest(AutoRoundBaseTesterMixin, unittest.TestCase): + """W4A16, group_size=128, asymmetric, backend='auto_round:tritonv2_zp' (CUDA Triton kernel).""" + + expected_memory_reduction = 0.55 + + def get_dummy_init_kwargs(self): + return { + "bits": 4, + "group_size": 128, + "sym": False, + "backend": "auto_round:tritonv2_zp", + } + + +@require_torch_cuda_compatibility(7.0) +class AutoRoundW4G128SymTritonTest(AutoRoundBaseTesterMixin, unittest.TestCase): + """W4A16, group_size=128, symmetric, backend='auto_round:tritonv2_zp'.""" + + expected_memory_reduction = 0.55 + + def get_dummy_init_kwargs(self): + return { + "bits": 4, + "group_size": 128, + "sym": True, + "backend": "auto_round:tritonv2_zp", + } + + +# ============================================================================ +# Backend: gptqmodel:marlin_zp (CUDA, requires GPTQModel>=5.8.0, best perf) +# ============================================================================ + + +@unittest.skipUnless(_is_gptqmodel_available("5.8.0"), "Test requires gptqmodel>=5.8.0") +@require_torch_cuda_compatibility(8.0) +class AutoRoundW4G128AsymMarlinTest(AutoRoundBaseTesterMixin, unittest.TestCase): + """W4A16, group_size=128, asymmetric, backend='gptqmodel:marlin_zp' (best CUDA performance).""" + + _test_torch_compile = True + expected_memory_reduction = 0.55 + + def get_dummy_init_kwargs(self): + return { + "bits": 4, + "group_size": 128, + "sym": False, + "backend": "gptqmodel:marlin_zp", + } + + +@unittest.skipUnless(_is_gptqmodel_available("5.8.0"), "Test requires gptqmodel>=5.8.0") +@require_torch_cuda_compatibility(8.0) +class AutoRoundW4G128SymMarlinTest(AutoRoundBaseTesterMixin, unittest.TestCase): + """W4A16, group_size=128, symmetric, backend='gptqmodel:marlin_zp'.""" + + _test_torch_compile = True + expected_memory_reduction = 0.55 + + def get_dummy_init_kwargs(self): + return { + "bits": 4, + "group_size": 128, + "sym": True, + "backend": "gptqmodel:marlin_zp", + } + + +# ============================================================================ +# Backend: auto_round:torch_zp (CPU, pure PyTorch kernel) +# ============================================================================ + + +class AutoRoundW4G128AsymTorchCPUTest(AutoRoundBaseTesterMixin, unittest.TestCase): + """W4A16, group_size=128, asymmetric, backend='auto_round:torch_zp' (CPU).""" + + expected_memory_reduction = 0.50 + + def get_dummy_init_kwargs(self): + return { + "bits": 4, + "group_size": 128, + "sym": False, + "backend": "auto_round:torch_zp", + } + + +# ============================================================================ +# Unit tests: AutoRoundConfig (no hardware required) +# ============================================================================ + + +class AutoRoundConfigTest(unittest.TestCase): + """Unit tests for AutoRoundConfig — no GPU / nightly decorator needed.""" + + def test_defaults(self): + cfg = AutoRoundConfig() + self.assertEqual(cfg.bits, 4) + self.assertEqual(cfg.group_size, 128) + self.assertTrue(cfg.sym) + self.assertEqual(cfg.backend, "auto") + self.assertEqual(cfg.quant_method, QuantizationMethod.AUTOROUND) + + def test_backend_values(self): + """All documented backend strings are stored correctly.""" + for backend in ("auto", "torch", "tritonv2", "marlin", "exllamav2"): + self.assertEqual(AutoRoundConfig(backend=backend).backend, backend) + + def test_to_dict_round_trip(self): + """to_dict → from_dict preserves all fields including backend and extra kwargs.""" + cfg = AutoRoundConfig(bits=4, group_size=32, sym=False, backend="gptqmodel:marlin_zp", + packing_format="auto_round:auto_gptq") + restored = AutoRoundConfig.from_dict(cfg.to_dict()) + self.assertEqual(restored.bits, cfg.bits) + self.assertEqual(restored.group_size, cfg.group_size) + self.assertEqual(restored.sym, cfg.sym) + self.assertEqual(restored.backend, cfg.backend) + self.assertEqual(restored.packing_format, cfg.packing_format) + self.assertEqual(restored.to_dict()["quant_method"], "auto-round") + + +# ============================================================================ +# Unit tests: DiffusersAutoQuantizer.merge_quantization_configs (no hardware) +# ============================================================================ + + +class MergeQuantizationConfigsTest(unittest.TestCase): + """Tests for the merge logic in DiffusersAutoQuantizer.merge_quantization_configs. + + Key behaviours under test: + 1. New fields in quantization_config_from_args (e.g. `backend`) are forwarded to + the merged config when they are absent from the model's saved config. + 2. Fields already present in the model's saved config are NOT overridden. + 3. A warning is emitted when quantization_config_from_args is provided. + 4. No warning when quantization_config_from_args is None. + """ + + def _model_config_dict(self, **overrides): + """Simulate a minimal saved AutoRound quantization_config dict (no 'backend' key).""" + base = { + "quant_method": "auto-round", + "bits": 4, + "group_size": 128, + "sym": True, + "autoround_version": "0.13.0", + "packing_format": "auto_round:auto_gptq", + } + base.update(overrides) + return base + + def test_new_fields_from_args_are_forwarded(self): + """Fields absent from the model config (backend, or arbitrary kwargs) are added from args.""" + for backend in ("marlin", "triton", "torch", "exllamav2"): # tritonv2 equals triton + with self.subTest(backend=backend): + merged = DiffusersAutoQuantizer.merge_quantization_configs( + self._model_config_dict(), AutoRoundConfig(backend=backend) + ) + self.assertEqual(merged.backend, backend) + + def test_existing_fields_not_overridden(self): + """Fields already in model config are NOT overridden; absent fields (backend) ARE added.""" + args_cfg = AutoRoundConfig(bits=2, group_size=32, sym=False, backend="torch") + merged = DiffusersAutoQuantizer.merge_quantization_configs(self._model_config_dict(), args_cfg) + + self.assertEqual(merged.bits, 4) # model value kept + self.assertEqual(merged.group_size, 128) # model value kept + self.assertTrue(merged.sym) # model value kept + self.assertEqual(merged.backend, "torch") # new field added + + def test_warning_behaviour(self): + """Warning emitted with args; no warning without args.""" + model_cfg = self._model_config_dict() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + DiffusersAutoQuantizer.merge_quantization_configs(model_cfg, AutoRoundConfig(backend="auto")) + self.assertTrue(any("quantization_config" in str(x.message) for x in w)) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + DiffusersAutoQuantizer.merge_quantization_configs(model_cfg, None) + self.assertFalse(any("quantization_config" in str(x.message).lower() for x in w))