Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions docs/source/en/quantization/autoround.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
<!-- Copyright 2026 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# AutoRound

[AutoRound](https://github.com/intel/auto-round) is an advanced quantization toolkit. It achieves high accuracy at ultra-low bit widths (2-4 bits) with minimal tuning by leveraging sign-gradient descent and providing broad hardware compatibility. See our papers [SignRoundV1](https://arxiv.org/pdf/2309.05516) and [SignRoundV2](https://arxiv.org/abs/2512.04746) for more details.


Install `auto-round`(version ≥ 0.13.0):

```bash
pip install "auto-round>=0.13.0"
```

To use the Marlin kernel for faster CUDA inference, install `gptqmodel`:

```bash
pip install "gptqmodel>=5.8.0"
```

## Load a quantized model

Load a pre-quantized AutoRound model by passing [`AutoRoundConfig`] to [`~ModelMixin.from_pretrained`]. The method works with any model that loads via [Accelerate(https://hf.co/docs/accelerate/index) and has `torch.nn.Linear` layers.

```python
import torch
from diffusers import ZImageTransformer2DModel, ZImagePipeline, AutoRoundConfig

model_id = "INCModel/Z-Image-W4A16-AutoRound"

quantization_config = AutoRoundConfig(backend="marlin")
Copy link
Copy Markdown

@wenhuach21 wenhuach21 Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in case there is a better backend in the future, we'd better not to explicitly code like this. Besides, If users have install gptqmodel, we will use marlin. Otherwise, we will remind user to install it.

transformer = ZImageTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map="cuda",
)

pipe = ZImagePipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=torch.bfloat16,
device_map="cuda",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device->auto

)

image = pipe("a cat holding a sign that says hello").images[0]
image.save("output.png")
```

> [!NOTE]
> AutoRound in Diffusers only supports loading *pre-quantized* models. To quantize a model from scratch, use the [AutoRound CLI or Python API](https://github.com/intel/auto-round) directly, then load the result with Diffusers.

## Backends

AutoRound supports multiple inference backends. The backend controls which kernel handles dequantization during the forward pass. Set the `backend` parameter in [`AutoRoundConfig`] to choose one:

| Backend | Value | Device | Requirements | Notes |
|---------|-------|--------|--------------|-------|
| **Auto** | `"auto"` | Any | — | Default. Automatically selects the best available backend. |
| **PyTorch** | `"torch"` | CPU / CUDA | — | Pure PyTorch implementation. Broadest compatibility. |
| **Triton** | `"tritonv2"` | CUDA | `triton` | Triton-based kernel for GPU inference. |
| **ExllamaV2** | `"exllamav2"` | CUDA | `gptqmodel>=5.8.0` | Good CUDA performance via the ExllamaV2 kernel. |
| **Marlin** | `"marlin"` | CUDA | `gptqmodel>=5.8.0` | Best CUDA performance via the Marlin kernel. |


```python
from diffusers import AutoRoundConfig

# Auto-select (default)
config = AutoRoundConfig()

# Explicit Triton backend for CUDA
config = AutoRoundConfig(backend="tritonv2")

# Marlin backend for best CUDA performance (requires gptqmodel>=5.8.0)
config = AutoRoundConfig(backend="marlin")

# Marlin backend for best CUDA performance (requires gptqmodel>=5.8.0)
config = AutoRoundConfig(backend="exllamav2")

# PyTorch backend for CPU/CUDA inference
config = AutoRoundConfig(backend="torch")
```


## Quantization configurations

AutoRound focuses on weight-only quantization. The primary configuration is W4A16 (4-bit weights, 16-bit activations), with flexibility in group size and symmetry:

| Configuration | `bits` | `group_size` | `sym` | Description |
|--------------|--------|-------------|-------|-------------|
| W4G128 asymmetric | `4` | `128` | `False` | Default. Good balance of accuracy and compression. |
| W4G128 symmetric | `4` | `128` | `True` | Faster dequantization, small accuracy trade-off. |
| W4G32 asymmetric | `4` | `32` | `False` | Higher accuracy at the cost of more metadata. |

## Save and load

<hfoptions id="save-and-load">
<hfoption id="save">

```python
from auto_round import AutoRound
autoround = AutoRound(
tiny_z_image_model_path,
num_inference_steps=3,
guidance_scale=7.5,
dataset="coco2014,
)
autoround.quantize_and_save("Z-Image-W4A16-AutoRound")
```

</hfoption>
<hfoption id="load">

```python
import torch
from diffusers import ZImageTransformer2DModel, ZImagePipeline

model_id = "INCModel/Z-Image-W4A16-AutoRound"

# The inference backend will be automatically selected.
pipe = ZImagePipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cuda",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about just setting the device_map to "auto"

)

image = pipe("a cat holding a sign that says hello").images[0]
image.save("output.png")
```

</hfoption>
</hfoptions>

## Resources

- [Pre-quantized AutoRound models on the Hub](https://huggingface.co/models?search=autoround)
21 changes: 21 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_accelerate_available,
is_auto_round_available,
is_bitsandbytes_available,
is_flax_available,
is_gguf_available,
Expand Down Expand Up @@ -122,6 +123,18 @@
else:
_import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")

try:
if not is_auto_round_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_auto_round_objects

_import_structure["utils.dummy_auto_round_objects"] = [
name for name in dir(dummy_auto_round_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("AutoRoundConfig")

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -945,6 +958,14 @@
else:
from .quantizers.quantization_config import NVIDIAModelOptConfig

try:
if not is_auto_round_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_auto_round_objects import *
else:
from .quantizers.quantization_config import AutoRoundConfig

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down
19 changes: 18 additions & 1 deletion src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

import warnings

from .autoround import AutoRoundQuantizer
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .modelopt import NVIDIAModelOptQuantizer
from .quantization_config import (
AutoRoundConfig,
BitsAndBytesConfig,
GGUFQuantizationConfig,
NVIDIAModelOptConfig,
Expand All @@ -41,6 +43,7 @@
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"modelopt": NVIDIAModelOptQuantizer,
"auto-round": AutoRoundQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -50,6 +53,7 @@
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"modelopt": NVIDIAModelOptConfig,
"auto-round": AutoRoundConfig,
}


Expand Down Expand Up @@ -136,13 +140,26 @@ def merge_quantization_configs(
)
else:
warning_msg = ""

if isinstance(quantization_config, dict):
existing_fields = set(quantization_config.keys())
quantization_config = cls.from_dict(quantization_config)
else:
existing_fields = set(quantization_config.__dict__.keys())

if isinstance(quantization_config, NVIDIAModelOptConfig):
quantization_config.check_model_patching()

if quantization_config_from_args is not None:
# Only override fields that the user explicitly set.
for key, value in quantization_config_from_args.__dict__.items():
if key not in existing_fields:
# Field does not exist in the model's quantization_config, add it.
setattr(quantization_config, key, value)
warning_msg += (
f" Field `{key}` from `quantization_config_from_args` is not present in the model's "
f"`quantization_config`. Adding it with value: {value!r}."
)

if warning_msg != "":
warnings.warn(warning_msg)

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/quantizers/autoround/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .autoround_quantizer import AutoRoundQuantizer
128 changes: 128 additions & 0 deletions src/diffusers/quantizers/autoround/autoround_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ...utils import (
is_auto_round_available,
is_torch_available,
logging,
)
from ..base import DiffusersQuantizer


if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin


if is_torch_available():
import torch

logger = logging.get_logger(__name__)


class AutoRoundQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for AutoRound (https://github.com/intel/auto-round).

AutoRound is a weight-only quantization method that uses sign gradient descent to jointly optimize
rounding values and min-max ranges for weights. It supports W4A16 (4-bit weight, 16-bit activation)
quantization for efficient inference.

This quantizer only supports loading pre-quantized AutoRound models. On-the-fly quantization
(calibration) is not supported through this interface.
"""

# AutoRound requires data calibration — we only support loading pre-quantized checkpoints.
requires_calibration = True
required_packages = ["auto_round"]

def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

def validate_environment(self, *args, **kwargs):
"""
Validates that the auto-round library (>= 0.5) is installed and captures the device_map
for later use during model conversion.
"""
self.device_map = kwargs.get("device_map", None)
if not is_auto_round_available():
raise ImportError(
"Loading an AutoRound quantized model requires the auto-round library "
"(`pip install 'auto-round>=0.5'`)"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.10 ?

)

def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: list[str] = [],
**kwargs,
):
"""
Replaces target nn.Linear layers with AutoRound's quantized QuantLinear layers before
weights are loaded from the checkpoint.

Uses `auto_round.inference.convert_model.convert_hf_model` which:
- Inspects the model architecture and the quantization config (bits, group_size, sym, backend).
- Replaces eligible nn.Linear modules with the appropriate QuantLinear variant
(the packed-weight layer that stores qweight, scales, qzeros).
- Returns the converted model and a set of used backend names.

`infer_target_device` resolves the device_map into a single target device string
that AutoRound uses to select the correct kernel backend (e.g. "cuda", "cpu").
"""
from auto_round.inference.convert_model import convert_hf_model, infer_target_device

if self.pre_quantized:
target_device = infer_target_device(self.device_map)
model, used_backends = convert_hf_model(model, target_device)
self.used_backends = used_backends

def _process_model_after_weight_loading(self, model, **kwargs):
"""
Finalizes the model after all quantized weights (qweight, scales, qzeros, etc.) have
been loaded into the QuantLinear layers.

Uses `auto_round.inference.convert_model.post_init` which:
- Performs backend-specific finalization (e.g. repacking weights into the kernel's
expected memory layout, moving buffers to the correct device).
- Freezes quantized parameters (requires_grad=False).
- Prepares the model for inference.

Raises ValueError if the model is not pre-quantized, since AutoRound does not support
on-the-fly quantization through this loading path.
"""
if self.pre_quantized:
from auto_round.inference.convert_model import post_init

post_init(model, self.used_backends)
else:
raise ValueError(
"AutoRound quantizer in diffusers only supports pre-quantized models. "
"Please provide a model that has already been quantized with AutoRound."
)
return model

@property
def is_trainable(self) -> bool:
"""AutoRound W4A16 pre-quantized models do not support training."""
return False

@property
def is_serializable(self):
"""AutoRound quantized models can be serialized (the quantization config may be
updated by the backend, e.g. for GPTQ/AWQ-compatible formats)."""
return True

Loading
Loading