diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 3a20dca88ecf..59387e41654e 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -376,3 +376,17 @@ def is_trainable(self): @property def is_compileable(self) -> bool: return True + + def _dequantize(self, model): + from torchao.utils import TorchAOBaseTensor + + for name, module in model.named_modules(): + if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor): + device = module.weight.device + dequantized_weight = module.weight.dequantize().to(device) + module.weight = nn.Parameter(dequantized_weight) + # Reset extra_repr if it was overridden + if hasattr(module.extra_repr, "__func__") and module.extra_repr.__func__ is not nn.Linear.extra_repr: + module.extra_repr = types.MethodType(nn.Linear.extra_repr, module) + + return model diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 1aab0b240148..dd94ccf4f324 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -142,6 +142,14 @@ def _is_module_quantized(self, module): except (AssertionError, AttributeError): return False + def _get_dummy_inputs_for_model(self, model): + inputs = self.get_dummy_inputs() + model_dtype = next(model.parameters()).dtype + return { + k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs.items() + } + def _load_unquantized_model(self): kwargs = getattr(self, "pretrained_model_kwargs", {}) return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) @@ -174,7 +182,7 @@ def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) model_quantized.to(torch_device) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model_quantized) output = model_quantized(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" @@ -222,7 +230,8 @@ def _test_quantization_lora_inference(self, config_kwargs): # Move LoRA adapter weights to device (they default to CPU) model.to(torch_device) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None with LoRA" @@ -236,7 +245,8 @@ def _test_quantization_serialization(self, config_kwargs, tmp_path): model_loaded = self.model_class.from_pretrained(str(tmp_path)) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model_loaded) + output = model_loaded(**inputs, return_dict=False)[0] assert not torch.isnan(output).any(), "Loaded model output contains NaN" @@ -334,7 +344,8 @@ def _test_quantization_device_map(self, config_kwargs): assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute" assert model.hf_device_map is not None, "hf_device_map should not be None" - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" @@ -359,7 +370,12 @@ def _test_dequantize(self, config_kwargs): if isinstance(module, torch.nn.Linear): assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()" +<<<<<<< torchao + # Get model dtype from first parameter + inputs = self._get_dummy_inputs_for_model(model) +======= inputs = self.get_dummy_inputs() +>>>>>>> main output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None after dequantization" assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" @@ -405,9 +421,10 @@ def _test_quantization_training(self, config_kwargs): pytest.skip("No attention layers found in model for adapter training test") # Step 3: run forward and backward pass - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) - with torch.amp.autocast(torch_device, dtype=torch.float16): + # Use bfloat16 instead of float16 to avoid gradient underflow with quantized layers + with torch.amp.autocast(torch_device, dtype=torch.bfloat16): out = model(**inputs, return_dict=False)[0] out.norm().backward() @@ -587,8 +604,17 @@ def test_bnb_keep_modules_in_fp32(self): f"Module {name} should be uint8 but is {module.weight.dtype}" ) +<<<<<<< torchao + inputs = self._get_dummy_inputs_for_model(model) + + _ = model(**inputs) + finally: + if original_fp32_modules is not None: + self.model_class._keep_in_fp32_modules = original_fp32_modules +======= inputs = self.get_dummy_inputs() _ = model(**inputs) +>>>>>>> main def test_bnb_modules_to_not_convert(self): """Test that modules_to_not_convert parameter works correctly.""" @@ -805,6 +831,10 @@ class TorchAoConfigMixin: @staticmethod def _get_quant_config(config_name): config_cls = getattr(_torchao_quantization, config_name) + # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU + if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu": + return TorchAoConfig(config_cls(int4_packing_format="plain_int32")) + return TorchAoConfig(config_cls()) def _create_quantized_model(self, config_name, **extra_kwargs): @@ -816,11 +846,12 @@ def _create_quantized_model(self, config_name, **extra_kwargs): return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) def _verify_if_layer_quantized(self, name, module, config_kwargs): - assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" + from torchao.utils import TorchAOBaseTensor - -# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack) -_int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA") + assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" + assert isinstance(module.weight, TorchAOBaseTensor), ( + f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor" + ) @is_torchao @@ -848,7 +879,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin): @pytest.mark.parametrize( "quant_type", [ - pytest.param("int4wo", marks=_int4wo_skip), + "int4wo", "int8wo", "int8dq", ], @@ -860,7 +891,7 @@ def test_torchao_quantization_num_parameters(self, quant_type): @pytest.mark.parametrize( "quant_type", [ - pytest.param("int4wo", marks=_int4wo_skip), + "int4wo", "int8wo", "int8dq", ], @@ -875,7 +906,7 @@ def test_torchao_quantization_memory_footprint(self, quant_type): @pytest.mark.parametrize( "quant_type", [ - pytest.param("int4wo", marks=_int4wo_skip), + "int4wo", "int8wo", "int8dq", ], @@ -902,7 +933,8 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path): model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device)) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model_loaded) + output = model_loaded(**inputs, return_dict=False)[0] assert not torch.isnan(output).any(), "Loaded model output contains NaN" @@ -1159,6 +1191,14 @@ class QuantizationCompileTesterMixin: - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass """ + def _get_dummy_inputs_for_model(self, model): + inputs = self.get_dummy_inputs() + model_dtype = next(model.parameters()).dtype + return { + k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs.items() + } + def setup_method(self): gc.collect() backend_empty_cache(torch_device) @@ -1184,7 +1224,8 @@ def _test_torch_compile(self, config_kwargs): model = torch.compile(model, fullgraph=True) with torch._dynamo.config.patch(error_on_recompile=True): - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" @@ -1215,7 +1256,8 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False model.enable_group_offload(**group_offload_kwargs) model = torch.compile(model) - inputs = self.get_dummy_inputs() + inputs = self._get_dummy_inputs_for_model(model) + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index df67e55c9b5d..94dab90dc20a 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -224,7 +224,7 @@ def get_dummy_inputs(self): """Override to provide inputs matching the tiny Wan Animate model dimensions.""" return { "hidden_states": randn_tensor( - (1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "encoder_hidden_states": randn_tensor( (1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype @@ -233,10 +233,10 @@ def get_dummy_inputs(self): (1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "pose_hidden_states": randn_tensor( - (1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "face_pixel_values": randn_tensor( - (1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype + (1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype ), "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype), }