Skip to content

[mps] Fix NaN in Attention.get_attention_scores when attention_mask is None#13701

Open
Shreyas-jk wants to merge 1 commit into
huggingface:mainfrom
Shreyas-jk:fix/mps-attention-scores-nan-11229
Open

[mps] Fix NaN in Attention.get_attention_scores when attention_mask is None#13701
Shreyas-jk wants to merge 1 commit into
huggingface:mainfrom
Shreyas-jk:fix/mps-attention-scores-nan-11229

Conversation

@Shreyas-jk

@Shreyas-jk Shreyas-jk commented May 8, 2026

Copy link
Copy Markdown

What does this PR do?

Fixes #11229

Attention.get_attention_scores allocates baddbmm_input with torch.empty() and uses beta=0, relying on baddbmm to ignore the uninitialized input. The MPS baddbmm kernel does not short-circuit on beta=0, so any NaN/Inf in the uninitialized memory propagates through 0 * NaN = NaN and poisons the attention output. CUDA happens to mask this because its allocator typically returns zero-initialized memory.

This change uses torch.zeros instead of torch.empty only on MPS, leaving the CUDA / CPU / XPU paths unchanged so they don't pay the extra fill cost.

In real workloads this surfaces as black/NaN images from StableDiffusionXLPipeline with enable_attention_slicing() on Apple Silicon + fp16, which is the standard memory-saving path on Macs with limited unified memory.

Reproduction

Minimal repro on M-series MPS (without the fix): 30/30 trials produce NaN. With the fix: 0/30. CPU baseline: 0/5. Verified on M4 MacBook Pro, torch 2.11.0, fp16 and fp32.

The added test GetAttentionScoresMPSTests.test_no_nan_when_attention_mask_is_none_on_mps reproduces the bug deterministically (fails 20/20 without the fix, passes 20/20 with it) by polluting the MPS allocator pool with NaN-filled tensors before each call.

Before submitting

Who can review?

@pcuenca (MPS / Apple Silicon maintainer per the PR template)

@Shreyas-jk

Copy link
Copy Markdown
Author

Ping gently, this is open since May 8. Tiny MPS-based fix for the baddbmm beta=0 NaN in get_attention_scores (similar to #2643), including a regression test that does not pass without the fix, and passes with it. This is simply rebased onto main as-is, and is ready to go. Could a maintainer trigger CI or review when time permits? @yiyixuxu @sayakpaul

@Shreyas-jk Shreyas-jk force-pushed the fix/mps-attention-scores-nan-11229 branch from 2daec10 to 3670137 Compare June 25, 2026 22:17

@pcuenca pcuenca left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, will defer to the team for additional opinions

@pcuenca pcuenca requested a review from yiyixuxu June 26, 2026 08:39
@pcuenca

pcuenca commented Jun 26, 2026

Copy link
Copy Markdown
Member

also cc @Isalia20 for info

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Isalia20

Copy link
Copy Markdown

Thanks for tagging @pcuenca It is already fixed on nightly pytorch/pytorch#187522

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

enable_attention_slicing give NaN results for SDXL on MPS

4 participants