diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 84ffb67bfd6a..5340c3884789 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -195,7 +195,7 @@ def forward( self, x: torch.Tensor, emb: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]