Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,17 @@ def is_trainable(self):
@property
def is_compileable(self) -> bool:
return True

def _dequantize(self, model):
from torchao.utils import TorchAOBaseTensor

for name, module in model.named_modules():
if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor):
device = module.weight.device
dequantized_weight = module.weight.dequantize().to(device)
module.weight = nn.Parameter(dequantized_weight)
# Reset extra_repr if it was overridden
if hasattr(module.extra_repr, "__func__") and module.extra_repr.__func__ is not nn.Linear.extra_repr:
module.extra_repr = types.MethodType(nn.Linear.extra_repr, module)

return model
5 changes: 5 additions & 0 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,12 @@ def _create_quantized_model(self, config_name, **extra_kwargs):
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)

def _verify_if_layer_quantized(self, name, module, config_kwargs):
from torchao.utils import TorchAOBaseTensor

assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
assert isinstance(module.weight, TorchAOBaseTensor), (
f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor"
)


# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
Expand Down
Loading