Skip to content

[mps] Fix NaN in Attention.get_attention_scores when attention_mask is None#13701

Open
Shreyas-jk wants to merge 2 commits into
huggingface:mainfrom
Shreyas-jk:fix/mps-attention-scores-nan-11229
Open

[mps] Fix NaN in Attention.get_attention_scores when attention_mask is None#13701
Shreyas-jk wants to merge 2 commits into
huggingface:mainfrom
Shreyas-jk:fix/mps-attention-scores-nan-11229

Conversation

@Shreyas-jk

@Shreyas-jk Shreyas-jk commented May 8, 2026

Copy link
Copy Markdown

What does this PR do?

Fixes #11229

Attention.get_attention_scores allocates baddbmm_input with torch.empty() and uses beta=0, relying on baddbmm to ignore the uninitialized input. The MPS baddbmm kernel does not short-circuit on beta=0, so any NaN/Inf in the uninitialized memory propagates through 0 * NaN = NaN and poisons the attention output. CUDA happens to mask this because its allocator typically returns zero-initialized memory.

This change uses torch.zeros instead of torch.empty only on MPS, leaving the CUDA / CPU / XPU paths unchanged so they don't pay the extra fill cost.

In real workloads this surfaces as black/NaN images from StableDiffusionXLPipeline with enable_attention_slicing() on Apple Silicon + fp16, which is the standard memory-saving path on Macs with limited unified memory.

Reproduction

Minimal repro on M-series MPS (without the fix): 30/30 trials produce NaN. With the fix: 0/30. CPU baseline: 0/5. Verified on M4 MacBook Pro, torch 2.11.0, fp16 and fp32.

The added test GetAttentionScoresMPSTests.test_no_nan_when_attention_mask_is_none_on_mps reproduces the bug deterministically (fails 20/20 without the fix, passes 20/20 with it) by polluting the MPS allocator pool with NaN-filled tensors before each call.

Before submitting

Who can review?

@pcuenca (MPS / Apple Silicon maintainer per the PR template)

@Shreyas-jk

Copy link
Copy Markdown
Author

Ping gently, this is open since May 8. Tiny MPS-based fix for the baddbmm beta=0 NaN in get_attention_scores (similar to #2643), including a regression test that does not pass without the fix, and passes with it. This is simply rebased onto main as-is, and is ready to go. Could a maintainer trigger CI or review when time permits? @yiyixuxu @sayakpaul

@Shreyas-jk Shreyas-jk force-pushed the fix/mps-attention-scores-nan-11229 branch from 2daec10 to 3670137 Compare June 25, 2026 22:17

@pcuenca pcuenca left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, will defer to the team for additional opinions

@pcuenca pcuenca requested a review from yiyixuxu June 26, 2026 08:39
@pcuenca

pcuenca commented Jun 26, 2026

Copy link
Copy Markdown
Member

also cc @Isalia20 for info

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Isalia20

Copy link
Copy Markdown

Thanks for tagging @pcuenca It is already fixed on nightly pytorch/pytorch#187522

Comment on lines +679 to +688
if query.device.type == "mps":
# MPS' baddbmm does not short-circuit on beta=0, so an
# uninitialized input from torch.empty() can propagate NaN.
baddbmm_input = torch.zeros(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
else:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be gated on a future PyTorch version?
#13701 (comment)

Because when the nightly is eventually in a stable version, then we shouldn't have to do zeros() right?

self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3))


class GetAttentionScoresMPSTests(unittest.TestCase):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Should be implemented using pytest.
  2. Should have a note or some xfail decorator indicating that we won't need it after the nightly is in a stable.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @sayakpaul! I agree in principle, but since the MPS branch in question allocates and zeros out a full (batch, q_len, k_len) tensor each time, once pytorch#187522 is included in a stable release then it will be clear that using an empty() tensor is better.

Now the issue with this approach is that the fix is in nightly right now (@Isalia20), meaning there is no stable release to guard against yet; users with the current stable PyTorch will trigger the NaN issue and so the guard is necessary. There are two possible ways to do this:

  1. Guard the MPS branch on is_torch_version("<", "<target>"), so that when the fix is released it stops doing anything. Do you know which release pytorch#187522 is targeted to? If it's tagged, then I'll gate on that specific version.
  2. If there is not yet a target version for the release, leave the guard unconditional for now and add a TODO reference to pytorch#187522, and add the guard when we know the release version.

In either case I'll move the test over to pytest. As for the idea about making the test an xfail, it seems that with the fix the test passes, so it self-deactivates.

Does that match what you had in mind?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard the MPS branch on is_torch_version("<", ""), so that when the fix is released it stops doing anything. Do you know which release pytorch#187522 is targeted to? If it's tagged, then I'll gate on that specific version.

We could do pytest.mark.xfails(is_torch_version(">", current_stable_version)?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed pytorch#187522 is merged to main (by malfet), but PyTorch didn't set a milestone on it. Since 2.12 has already shipped, it lands in 2.13 so I'll gate on is_torch_version("<", "2.13.0"). Shout if you'd rather wait for the tag, but the fix is on main so 2.13 is where it surfaces.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it will land on 2.13. PR submissions for 2.13 was closed on:
M2: All PRs landed in PyTorch repo / Feature Submission Closed (5/6/26)
and that PR was merged after it, so please guard it on 2.14

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks. Didn't realize it missed the 2.13 feature window. Guarding on
is_torch_version("<", "2.14.0") and skipping the test at >=2.14.0 to match. Pushing
the pytest port + gate shortly.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, gated the MPS workaround for is_torch_version("<", "2.14.0") and converted the test to pytest with a skipif condition at >=2.14.0 (an xfail condition will XPASS when the upstream patch is applied, thus skipif works better). A new commit on top;, thank you @Isalia20 for the 2.14 fix.

…7522), port test to pytest

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

enable_attention_slicing give NaN results for SDXL on MPS

5 participants