Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,17 @@ def is_trainable(self):
@property
def is_compileable(self) -> bool:
return True

def _dequantize(self, model):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have dequantize here in this PR right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please review the change here: #13538

from torchao.utils import TorchAOBaseTensor

for name, module in model.named_modules():
if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor):
device = module.weight.device
dequantized_weight = module.weight.dequantize().to(device)
module.weight = nn.Parameter(dequantized_weight)
# Reset extra_repr if it was overridden
if hasattr(module.extra_repr, "__func__") and module.extra_repr.__func__ is not nn.Linear.extra_repr:
module.extra_repr = types.MethodType(nn.Linear.extra_repr, module)

return model
74 changes: 58 additions & 16 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def _is_module_quantized(self, module):
except (AssertionError, AttributeError):
return False

def _get_dummy_inputs_for_model(self, model):
inputs = self.get_dummy_inputs()
model_dtype = next(model.parameters()).dtype
return {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}

def _load_unquantized_model(self):
kwargs = getattr(self, "pretrained_model_kwargs", {})
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
Expand Down Expand Up @@ -174,7 +182,7 @@ def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
model_quantized.to(torch_device)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model_quantized)
output = model_quantized(**inputs, return_dict=False)[0]

assert output is not None, "Model output is None"
Expand Down Expand Up @@ -222,7 +230,8 @@ def _test_quantization_lora_inference(self, config_kwargs):
# Move LoRA adapter weights to device (they default to CPU)
model.to(torch_device)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]

assert output is not None, "Model output is None with LoRA"
Expand All @@ -236,7 +245,8 @@ def _test_quantization_serialization(self, config_kwargs, tmp_path):

model_loaded = self.model_class.from_pretrained(str(tmp_path))

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model_loaded)

output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"

Expand Down Expand Up @@ -334,7 +344,8 @@ def _test_quantization_device_map(self, config_kwargs):
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
assert model.hf_device_map is not None, "hf_device_map should not be None"

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Expand All @@ -359,7 +370,12 @@ def _test_dequantize(self, config_kwargs):
if isinstance(module, torch.nn.Linear):
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"

<<<<<<< torchao
# Get model dtype from first parameter
inputs = self._get_dummy_inputs_for_model(model)
=======
inputs = self.get_dummy_inputs()
>>>>>>> main
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None after dequantization"
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
Expand Down Expand Up @@ -405,9 +421,10 @@ def _test_quantization_training(self, config_kwargs):
pytest.skip("No attention layers found in model for adapter training test")

# Step 3: run forward and backward pass
inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

with torch.amp.autocast(torch_device, dtype=torch.float16):
# Use bfloat16 instead of float16 to avoid gradient underflow with quantized layers
with torch.amp.autocast(torch_device, dtype=torch.bfloat16):
out = model(**inputs, return_dict=False)[0]
out.norm().backward()

Expand Down Expand Up @@ -587,8 +604,17 @@ def test_bnb_keep_modules_in_fp32(self):
f"Module {name} should be uint8 but is {module.weight.dtype}"
)

<<<<<<< torchao
inputs = self._get_dummy_inputs_for_model(model)

_ = model(**inputs)
finally:
if original_fp32_modules is not None:
self.model_class._keep_in_fp32_modules = original_fp32_modules
=======
inputs = self.get_dummy_inputs()
_ = model(**inputs)
>>>>>>> main

def test_bnb_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
Expand Down Expand Up @@ -805,6 +831,10 @@ class TorchAoConfigMixin:
@staticmethod
def _get_quant_config(config_name):
config_cls = getattr(_torchao_quantization, config_name)
# TorchAO int4 quantization requires plain_int32 packing format on Intel XPU
if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu":
return TorchAoConfig(config_cls(int4_packing_format="plain_int32"))

return TorchAoConfig(config_cls())

def _create_quantized_model(self, config_name, **extra_kwargs):
Expand All @@ -816,11 +846,12 @@ def _create_quantized_model(self, config_name, **extra_kwargs):
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)

def _verify_if_layer_quantized(self, name, module, config_kwargs):
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
from torchao.utils import TorchAOBaseTensor


# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
_int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA")
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
assert isinstance(module.weight, TorchAOBaseTensor), (
f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor"
)


@is_torchao
Expand Down Expand Up @@ -848,7 +879,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
@pytest.mark.parametrize(
"quant_type",
[
pytest.param("int4wo", marks=_int4wo_skip),
"int4wo",
"int8wo",
"int8dq",
],
Expand All @@ -860,7 +891,7 @@ def test_torchao_quantization_num_parameters(self, quant_type):
@pytest.mark.parametrize(
"quant_type",
[
pytest.param("int4wo", marks=_int4wo_skip),
"int4wo",
"int8wo",
"int8dq",
],
Expand All @@ -875,7 +906,7 @@ def test_torchao_quantization_memory_footprint(self, quant_type):
@pytest.mark.parametrize(
"quant_type",
[
pytest.param("int4wo", marks=_int4wo_skip),
"int4wo",
"int8wo",
"int8dq",
],
Expand All @@ -902,7 +933,8 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path):

model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device))

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model_loaded)

output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"

Expand Down Expand Up @@ -1159,6 +1191,14 @@ class QuantizationCompileTesterMixin:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""

def _get_dummy_inputs_for_model(self, model):
inputs = self.get_dummy_inputs()
model_dtype = next(model.parameters()).dtype
return {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
Comment on lines +1194 to +1200
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to override?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QuantizationCompileTesterMixin is an independent mixin that doesn't inherit from QuantizationTesterMixin. Test classes may use either one or both, so the method needs to be defined in both places.

Alternatively, I can extract it into a shared base class or a standalone utility function to avoid code duplication. Let me know which approach you prefer. Please review this change in #13539


def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
Expand All @@ -1184,7 +1224,8 @@ def _test_torch_compile(self, config_kwargs):
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True):
inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Expand Down Expand Up @@ -1215,7 +1256,8 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False
model.enable_group_offload(**group_offload_kwargs)
model = torch.compile(model)

inputs = self.get_dummy_inputs()
inputs = self._get_dummy_inputs_for_model(model)

output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
(1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain the changes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for avoiding OOM, details see: #13541. Please let me know if you want comments in the code.

),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
Expand All @@ -233,10 +233,10 @@ def get_dummy_inputs(self):
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
(1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
(1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
Expand Down
Loading