diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 1aab0b240148..3a4f59c5fbea 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -407,7 +407,8 @@ def _test_quantization_training(self, config_kwargs): # Step 3: run forward and backward pass inputs = self.get_dummy_inputs() - with torch.amp.autocast(torch_device, dtype=torch.float16): + # Use bfloat16 instead of float16 to avoid gradient underflow with quantized layers + with torch.amp.autocast(torch_device, dtype=torch.bfloat16): out = model(**inputs, return_dict=False)[0] out.norm().backward()