Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/diffusers/models/transformers/transformer_ideogram4.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso
raise ValueError(f"`position_ids` must have shape (B, L, 3), got {tuple(position_ids.shape)}.")
batch_size, seq_len, _ = position_ids.shape

pos = position_ids.permute(2, 0, 1).to(dtype=torch.float32)
inv_freq = self.inv_freq.to(dtype=torch.float32)[None, None, :, None].expand(3, batch_size, -1, 1)
# Rotary frequencies are computed in float64: Ideogram4's image positions start at
# IMAGE_POSITION_OFFSET (65536), which float32 cannot represent distinctly once an ambient
# autocast runs the matmul/cos/sin in bfloat16, collapsing every image position to the same
# embedding. float64 is not downcast by autocast, matching the float64 rope path Flux uses.
pos = position_ids.permute(2, 0, 1).to(dtype=torch.float64)
inv_freq = self.inv_freq.to(dtype=torch.float64)[None, None, :, None].expand(3, batch_size, -1, 1)
freqs = inv_freq @ pos.unsqueeze(2)
freqs = freqs.transpose(2, 3) # (3, B, L, inv_freq_size)

Expand All @@ -83,7 +87,7 @@ def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso
freqs_t[..., idx] = freqs[axis][..., idx]

emb = torch.cat((freqs_t, freqs_t), dim=-1)
return emb.cos(), emb.sin()
return emb.cos().float(), emb.sin().float()


class Ideogram4AttnProcessor:
Expand Down
18 changes: 18 additions & 0 deletions tests/models/transformers/test_models_transformer_ideogram4.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
IMAGE_POSITION_OFFSET,
LLM_TOKEN_INDICATOR,
OUTPUT_IMAGE_INDICATOR,
Ideogram4MRoPE,
)
from diffusers.utils.torch_utils import randn_tensor

Expand Down Expand Up @@ -164,3 +165,20 @@ def test_gradient_checkpointing_is_applied(self):

class TestIdeogram4TransformerAttention(Ideogram4TransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Ideogram 4 Transformer."""


def test_ideogram4_mrope_is_autocast_invariant():
# Ideogram4's image positions start at IMAGE_POSITION_OFFSET (65536), so the rotary matmul must
# run in float32: under an ambient autocast it would otherwise execute in bfloat16 and round every
# image position to the same value, collapsing all spatial information (the decoded image goes flat).
rope = Ideogram4MRoPE(head_dim=256, base=5_000_000, mrope_section=(24, 20, 20)).to(torch_device)
position_ids = torch.tensor([[[0, 0, 0], [0, 0, 1], [0, 63, 63]]], device=torch_device) + IMAGE_POSITION_OFFSET

cos_ref, sin_ref = rope(position_ids)
with torch.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
cos_ac, sin_ac = rope(position_ids)

# Distinct image positions must keep distinct embeddings, identical to the float32 computation.
assert not torch.equal(cos_ac[0, 0], cos_ac[0, 1])
assert torch.equal(cos_ac, cos_ref)
assert torch.equal(sin_ac, sin_ref)
Loading