diff --git a/src/diffusers/models/transformers/transformer_ideogram4.py b/src/diffusers/models/transformers/transformer_ideogram4.py index 121118e3bd80..03cc6c84a051 100644 --- a/src/diffusers/models/transformers/transformer_ideogram4.py +++ b/src/diffusers/models/transformers/transformer_ideogram4.py @@ -70,8 +70,12 @@ def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso raise ValueError(f"`position_ids` must have shape (B, L, 3), got {tuple(position_ids.shape)}.") batch_size, seq_len, _ = position_ids.shape - pos = position_ids.permute(2, 0, 1).to(dtype=torch.float32) - inv_freq = self.inv_freq.to(dtype=torch.float32)[None, None, :, None].expand(3, batch_size, -1, 1) + # Rotary frequencies are computed in float64: Ideogram4's image positions start at + # IMAGE_POSITION_OFFSET (65536), which float32 cannot represent distinctly once an ambient + # autocast runs the matmul/cos/sin in bfloat16, collapsing every image position to the same + # embedding. float64 is not downcast by autocast, matching the float64 rope path Flux uses. + pos = position_ids.permute(2, 0, 1).to(dtype=torch.float64) + inv_freq = self.inv_freq.to(dtype=torch.float64)[None, None, :, None].expand(3, batch_size, -1, 1) freqs = inv_freq @ pos.unsqueeze(2) freqs = freqs.transpose(2, 3) # (3, B, L, inv_freq_size) @@ -83,7 +87,7 @@ def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso freqs_t[..., idx] = freqs[axis][..., idx] emb = torch.cat((freqs_t, freqs_t), dim=-1) - return emb.cos(), emb.sin() + return emb.cos().float(), emb.sin().float() class Ideogram4AttnProcessor: diff --git a/tests/models/transformers/test_models_transformer_ideogram4.py b/tests/models/transformers/test_models_transformer_ideogram4.py index 31592ada64bc..d8e7318d501d 100644 --- a/tests/models/transformers/test_models_transformer_ideogram4.py +++ b/tests/models/transformers/test_models_transformer_ideogram4.py @@ -21,6 +21,7 @@ IMAGE_POSITION_OFFSET, LLM_TOKEN_INDICATOR, OUTPUT_IMAGE_INDICATOR, + Ideogram4MRoPE, ) from diffusers.utils.torch_utils import randn_tensor @@ -164,3 +165,20 @@ def test_gradient_checkpointing_is_applied(self): class TestIdeogram4TransformerAttention(Ideogram4TransformerTesterConfig, AttentionTesterMixin): """Attention processor tests for Ideogram 4 Transformer.""" + + +def test_ideogram4_mrope_is_autocast_invariant(): + # Ideogram4's image positions start at IMAGE_POSITION_OFFSET (65536), so the rotary matmul must + # run in float32: under an ambient autocast it would otherwise execute in bfloat16 and round every + # image position to the same value, collapsing all spatial information (the decoded image goes flat). + rope = Ideogram4MRoPE(head_dim=256, base=5_000_000, mrope_section=(24, 20, 20)).to(torch_device) + position_ids = torch.tensor([[[0, 0, 0], [0, 0, 1], [0, 63, 63]]], device=torch_device) + IMAGE_POSITION_OFFSET + + cos_ref, sin_ref = rope(position_ids) + with torch.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16): + cos_ac, sin_ac = rope(position_ids) + + # Distinct image positions must keep distinct embeddings, identical to the float32 computation. + assert not torch.equal(cos_ac[0, 0], cos_ac[0, 1]) + assert torch.equal(cos_ac, cos_ref) + assert torch.equal(sin_ac, sin_ref)