Fix Ideogram4MRoPE collapsing under torch.autocast (compute rotary in float32)#13922
Fix Ideogram4MRoPE collapsing under torch.autocast (compute rotary in float32)#13922HaozheZhang6 wants to merge 3 commits into
Ideogram4MRoPE collapsing under torch.autocast (compute rotary in float32)#13922Conversation
…y in float32)
Ideogram4 builds image-token positions as IMAGE_POSITION_OFFSET (65536) + (t, h, w).
`Ideogram4MRoPE.forward` casts its operands to float32, but the rotary matmul (and
cos/sin) is on autocast's downcast list, so under torch.autocast("cuda", bfloat16) —
common in training and pipeline code — it runs in bfloat16 anyway. bfloat16's step at
65536 is 512, so every image position in a <=512 grid rounds to the same value: all
image tokens get identical rotary embeddings, spatial information is lost, and the
decoded image degenerates to a flat color.
Wrap the frequency computation in torch.autocast(enabled=False) so the rotary
embeddings are always computed in float32, matching how transformers guards its RoPE
modules. Added a regression test that fails on main and passes with the fix.
Fixes huggingface#13920
|
before committing that (and thereby closing my report), please consider that other modules might be affected, just not as bad. bfloat16 becomes inaccurate for integers starting 257.0 (which is rounded to 256.0). that's within the range of text token ids |
|
You're right — confirmed bf16 rounds 257→256, 259→260, so text positions past 256 lose precision in any RoPE that matmuls raw position ids under autocast. Ideogram4 is just the pathological case: the 65536 offset collapses a whole ≤512-wide grid onto a single value, where the others degrade gradually instead of all-at-once. I'd checked the other diffusers transformers — Ideogram4 is the only RoPE with a large position offset, so the only catastrophic one — but the gradual loss you describe is real for the rest. I can extend the same |
| # IMAGE_POSITION_OFFSET (65536), so an ambient autocast would otherwise run the matmul and | ||
| # cos/sin in bfloat16, rounding every image position to the same value and collapsing the | ||
| # rotary embeddings (all spatial information is lost). | ||
| with torch.autocast(device_type=position_ids.device.type, enabled=False): |
There was a problem hiding this comment.
We don't use autocast within our modeling implementation like this.
There was a problem hiding this comment.
Good catch — dropped the autocast guard and compute the freqs in float64 instead, which autocast doesn't downcast (matching the float64 rope path Flux uses). The autocast and float32 paths come out bit-identical (max|Δ| = 0), and the regression test still passes.
Per review: replace the torch.autocast(enabled=False) guard with a float64 computation, which autocast does not downcast — matching the float64 rope path used elsewhere (Flux). The autocast and float32 paths stay bit-identical (max|delta|=0).
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
Fixes #13920
Ideogram4MRoPEproduces collapsed rotary embeddings undertorch.autocast, so denoising inside an autocast context (common in training, and when users wrap pipeline calls) renders a flat single-color image.Root cause
Image-token positions are
IMAGE_POSITION_OFFSET (65536) + (t, h, w).Ideogram4MRoPE.forwardcasts its operands to float32, but the frequency matmul is on autocast's downcast list, so undertorch.autocast("cuda", torch.bfloat16)it executes in bfloat16 anyway. bfloat16's representable step at 65536 is 512, so every image position in a ≤512-wide grid rounds to the same value — all image tokens get identical rotary embeddings, spatial information is lost, and sampling degenerates to a flat field.Reproduced with the weight-free snippet from the issue (
max |cos_autocast − cos_fp32| ≈ 1.93, distinct positions become equal).Fix
Wrap the frequency computation in
torch.autocast(device_type=..., enabled=False)so the rotary embeddings are always computed in float32 regardless of an ambient autocast — the same guardtransformersapplies to its RoPE modules. After the fix the autocast and float32 paths are bit-identical (max |Δ| = 0.0).Scope is
Ideogram4MRoPE, the catastrophic case (others noted in the issue are far milder without the 65536 offset). Happy to extend the same guard to the sibling RoPE modules in a follow-up if you'd like.Tests
Added
test_ideogram4_mrope_is_autocast_invariant— it fails onmain(collapsed positions) and passes with the fix. Full file green:Before submitting
Ideogram4MRoPEbreaks undertorch.autocast: all image positions collapse, producing flat single-color images #13920Who can review?
@DN6 @sayakpaul