-
Notifications
You must be signed in to change notification settings - Fork 7k
Improve TorchAO quantization test coverage and XPU support #13530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bef284c
ca507a8
c51708e
6df4b31
8a9013d
81e7015
4e4e759
8180979
d210d4a
0ba8682
98aba52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -142,6 +142,14 @@ def _is_module_quantized(self, module): | |
| except (AssertionError, AttributeError): | ||
| return False | ||
|
|
||
| def _get_dummy_inputs_for_model(self, model): | ||
| inputs = self.get_dummy_inputs() | ||
| model_dtype = next(model.parameters()).dtype | ||
| return { | ||
| k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v | ||
| for k, v in inputs.items() | ||
| } | ||
|
|
||
| def _load_unquantized_model(self): | ||
| kwargs = getattr(self, "pretrained_model_kwargs", {}) | ||
| return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) | ||
|
|
@@ -174,7 +182,7 @@ def _test_quantization_inference(self, config_kwargs): | |
| model_quantized = self._create_quantized_model(config_kwargs) | ||
| model_quantized.to(torch_device) | ||
|
|
||
| inputs = self.get_dummy_inputs() | ||
| inputs = self._get_dummy_inputs_for_model(model_quantized) | ||
| output = model_quantized(**inputs, return_dict=False)[0] | ||
|
|
||
| assert output is not None, "Model output is None" | ||
|
|
@@ -222,7 +230,8 @@ def _test_quantization_lora_inference(self, config_kwargs): | |
| # Move LoRA adapter weights to device (they default to CPU) | ||
| model.to(torch_device) | ||
|
|
||
| inputs = self.get_dummy_inputs() | ||
| inputs = self._get_dummy_inputs_for_model(model) | ||
|
|
||
| output = model(**inputs, return_dict=False)[0] | ||
|
|
||
| assert output is not None, "Model output is None with LoRA" | ||
|
|
@@ -236,7 +245,8 @@ def _test_quantization_serialization(self, config_kwargs, tmp_path): | |
|
|
||
| model_loaded = self.model_class.from_pretrained(str(tmp_path)) | ||
|
|
||
| inputs = self.get_dummy_inputs() | ||
| inputs = self._get_dummy_inputs_for_model(model_loaded) | ||
|
|
||
| output = model_loaded(**inputs, return_dict=False)[0] | ||
| assert not torch.isnan(output).any(), "Loaded model output contains NaN" | ||
|
|
||
|
|
@@ -334,7 +344,8 @@ def _test_quantization_device_map(self, config_kwargs): | |
| assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute" | ||
| assert model.hf_device_map is not None, "hf_device_map should not be None" | ||
|
|
||
| inputs = self.get_dummy_inputs() | ||
| inputs = self._get_dummy_inputs_for_model(model) | ||
|
|
||
| output = model(**inputs, return_dict=False)[0] | ||
| assert output is not None, "Model output is None" | ||
| assert not torch.isnan(output).any(), "Model output contains NaN" | ||
|
|
@@ -359,7 +370,12 @@ def _test_dequantize(self, config_kwargs): | |
| if isinstance(module, torch.nn.Linear): | ||
| assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()" | ||
|
|
||
| <<<<<<< torchao | ||
| # Get model dtype from first parameter | ||
| inputs = self._get_dummy_inputs_for_model(model) | ||
| ======= | ||
| inputs = self.get_dummy_inputs() | ||
| >>>>>>> main | ||
| output = model(**inputs, return_dict=False)[0] | ||
| assert output is not None, "Model output is None after dequantization" | ||
| assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" | ||
|
|
@@ -405,9 +421,10 @@ def _test_quantization_training(self, config_kwargs): | |
| pytest.skip("No attention layers found in model for adapter training test") | ||
|
|
||
| # Step 3: run forward and backward pass | ||
| inputs = self.get_dummy_inputs() | ||
| inputs = self._get_dummy_inputs_for_model(model) | ||
|
|
||
| with torch.amp.autocast(torch_device, dtype=torch.float16): | ||
| # Use bfloat16 instead of float16 to avoid gradient underflow with quantized layers | ||
| with torch.amp.autocast(torch_device, dtype=torch.bfloat16): | ||
| out = model(**inputs, return_dict=False)[0] | ||
| out.norm().backward() | ||
|
|
||
|
|
@@ -587,8 +604,17 @@ def test_bnb_keep_modules_in_fp32(self): | |
| f"Module {name} should be uint8 but is {module.weight.dtype}" | ||
| ) | ||
|
|
||
| <<<<<<< torchao | ||
| inputs = self._get_dummy_inputs_for_model(model) | ||
|
|
||
| _ = model(**inputs) | ||
| finally: | ||
| if original_fp32_modules is not None: | ||
| self.model_class._keep_in_fp32_modules = original_fp32_modules | ||
| ======= | ||
| inputs = self.get_dummy_inputs() | ||
| _ = model(**inputs) | ||
| >>>>>>> main | ||
|
|
||
| def test_bnb_modules_to_not_convert(self): | ||
| """Test that modules_to_not_convert parameter works correctly.""" | ||
|
|
@@ -805,6 +831,10 @@ class TorchAoConfigMixin: | |
| @staticmethod | ||
| def _get_quant_config(config_name): | ||
| config_cls = getattr(_torchao_quantization, config_name) | ||
| # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU | ||
| if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu": | ||
| return TorchAoConfig(config_cls(int4_packing_format="plain_int32")) | ||
|
|
||
| return TorchAoConfig(config_cls()) | ||
|
|
||
| def _create_quantized_model(self, config_name, **extra_kwargs): | ||
|
|
@@ -816,11 +846,12 @@ def _create_quantized_model(self, config_name, **extra_kwargs): | |
| return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) | ||
|
|
||
| def _verify_if_layer_quantized(self, name, module, config_kwargs): | ||
| assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" | ||
| from torchao.utils import TorchAOBaseTensor | ||
|
|
||
|
|
||
| # int4wo requires CUDA-specific ops (_convert_weight_to_int4pack) | ||
| _int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA") | ||
| assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" | ||
| assert isinstance(module.weight, TorchAOBaseTensor), ( | ||
| f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor" | ||
| ) | ||
|
|
||
|
|
||
| @is_torchao | ||
|
|
@@ -848,7 +879,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin): | |
| @pytest.mark.parametrize( | ||
| "quant_type", | ||
| [ | ||
| pytest.param("int4wo", marks=_int4wo_skip), | ||
| "int4wo", | ||
| "int8wo", | ||
| "int8dq", | ||
| ], | ||
|
|
@@ -860,7 +891,7 @@ def test_torchao_quantization_num_parameters(self, quant_type): | |
| @pytest.mark.parametrize( | ||
| "quant_type", | ||
| [ | ||
| pytest.param("int4wo", marks=_int4wo_skip), | ||
| "int4wo", | ||
| "int8wo", | ||
| "int8dq", | ||
| ], | ||
|
|
@@ -875,7 +906,7 @@ def test_torchao_quantization_memory_footprint(self, quant_type): | |
| @pytest.mark.parametrize( | ||
| "quant_type", | ||
| [ | ||
| pytest.param("int4wo", marks=_int4wo_skip), | ||
| "int4wo", | ||
| "int8wo", | ||
| "int8dq", | ||
| ], | ||
|
|
@@ -902,7 +933,8 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path): | |
|
|
||
| model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device)) | ||
|
|
||
| inputs = self.get_dummy_inputs() | ||
| inputs = self._get_dummy_inputs_for_model(model_loaded) | ||
|
|
||
| output = model_loaded(**inputs, return_dict=False)[0] | ||
| assert not torch.isnan(output).any(), "Loaded model output contains NaN" | ||
|
|
||
|
|
@@ -1159,6 +1191,14 @@ class QuantizationCompileTesterMixin: | |
| - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass | ||
| """ | ||
|
|
||
| def _get_dummy_inputs_for_model(self, model): | ||
| inputs = self.get_dummy_inputs() | ||
| model_dtype = next(model.parameters()).dtype | ||
| return { | ||
| k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v | ||
| for k, v in inputs.items() | ||
| } | ||
|
Comment on lines
+1194
to
+1200
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to override?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QuantizationCompileTesterMixin is an independent mixin that doesn't inherit from QuantizationTesterMixin. Test classes may use either one or both, so the method needs to be defined in both places. Alternatively, I can extract it into a shared base class or a standalone utility function to avoid code duplication. Let me know which approach you prefer. Please review this change in #13539 |
||
|
|
||
| def setup_method(self): | ||
| gc.collect() | ||
| backend_empty_cache(torch_device) | ||
|
|
@@ -1184,7 +1224,8 @@ def _test_torch_compile(self, config_kwargs): | |
| model = torch.compile(model, fullgraph=True) | ||
|
|
||
| with torch._dynamo.config.patch(error_on_recompile=True): | ||
| inputs = self.get_dummy_inputs() | ||
| inputs = self._get_dummy_inputs_for_model(model) | ||
|
|
||
| output = model(**inputs, return_dict=False)[0] | ||
| assert output is not None, "Model output is None" | ||
| assert not torch.isnan(output).any(), "Model output contains NaN" | ||
|
|
@@ -1215,7 +1256,8 @@ def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False | |
| model.enable_group_offload(**group_offload_kwargs) | ||
| model = torch.compile(model) | ||
|
|
||
| inputs = self.get_dummy_inputs() | ||
| inputs = self._get_dummy_inputs_for_model(model) | ||
|
|
||
| output = model(**inputs, return_dict=False)[0] | ||
| assert output is not None, "Model output is None" | ||
| assert not torch.isnan(output).any(), "Model output contains NaN" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -224,7 +224,7 @@ def get_dummy_inputs(self): | |
| """Override to provide inputs matching the tiny Wan Animate model dimensions.""" | ||
| return { | ||
| "hidden_states": randn_tensor( | ||
| (1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype | ||
| (1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explain the changes.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's for avoiding OOM, details see: #13541. Please let me know if you want comments in the code. |
||
| ), | ||
| "encoder_hidden_states": randn_tensor( | ||
| (1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype | ||
|
|
@@ -233,10 +233,10 @@ def get_dummy_inputs(self): | |
| (1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype | ||
| ), | ||
| "pose_hidden_states": randn_tensor( | ||
| (1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype | ||
| (1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype | ||
| ), | ||
| "face_pixel_values": randn_tensor( | ||
| (1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype | ||
| (1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype | ||
| ), | ||
| "timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype), | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't have dequantize here in this PR right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please review the change here: #13538