[mps] Fix NaN in Attention.get_attention_scores when attention_mask is None#13701
[mps] Fix NaN in Attention.get_attention_scores when attention_mask is None#13701Shreyas-jk wants to merge 2 commits into
Conversation
|
Ping gently, this is open since May 8. Tiny MPS-based fix for the baddbmm beta=0 NaN in get_attention_scores (similar to #2643), including a regression test that does not pass without the fix, and passes with it. This is simply rebased onto main as-is, and is ready to go. Could a maintainer trigger CI or review when time permits? @yiyixuxu @sayakpaul |
2daec10 to
3670137
Compare
pcuenca
left a comment
There was a problem hiding this comment.
looks good to me, will defer to the team for additional opinions
|
also cc @Isalia20 for info |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thanks for tagging @pcuenca It is already fixed on nightly pytorch/pytorch#187522 |
| if query.device.type == "mps": | ||
| # MPS' baddbmm does not short-circuit on beta=0, so an | ||
| # uninitialized input from torch.empty() can propagate NaN. | ||
| baddbmm_input = torch.zeros( | ||
| query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device | ||
| ) | ||
| else: | ||
| baddbmm_input = torch.empty( | ||
| query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device | ||
| ) |
There was a problem hiding this comment.
Should this be gated on a future PyTorch version?
#13701 (comment)
Because when the nightly is eventually in a stable version, then we shouldn't have to do zeros() right?
| self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3)) | ||
|
|
||
|
|
||
| class GetAttentionScoresMPSTests(unittest.TestCase): |
There was a problem hiding this comment.
- Should be implemented using
pytest. - Should have a note or some
xfaildecorator indicating that we won't need it after the nightly is in a stable.
There was a problem hiding this comment.
Thank you @sayakpaul! I agree in principle, but since the MPS branch in question allocates and zeros out a full (batch, q_len, k_len) tensor each time, once pytorch#187522 is included in a stable release then it will be clear that using an empty() tensor is better.
Now the issue with this approach is that the fix is in nightly right now (@Isalia20), meaning there is no stable release to guard against yet; users with the current stable PyTorch will trigger the NaN issue and so the guard is necessary. There are two possible ways to do this:
- Guard the MPS branch on
is_torch_version("<", "<target>"), so that when the fix is released it stops doing anything. Do you know which release pytorch#187522 is targeted to? If it's tagged, then I'll gate on that specific version. - If there is not yet a target version for the release, leave the guard unconditional for now and add a TODO reference to pytorch#187522, and add the guard when we know the release version.
In either case I'll move the test over to pytest. As for the idea about making the test an xfail, it seems that with the fix the test passes, so it self-deactivates.
Does that match what you had in mind?
There was a problem hiding this comment.
Guard the MPS branch on is_torch_version("<", ""), so that when the fix is released it stops doing anything. Do you know which release pytorch#187522 is targeted to? If it's tagged, then I'll gate on that specific version.
We could do pytest.mark.xfails(is_torch_version(">", current_stable_version)?
There was a problem hiding this comment.
Confirmed pytorch#187522 is merged to main (by malfet), but PyTorch didn't set a milestone on it. Since 2.12 has already shipped, it lands in 2.13 so I'll gate on is_torch_version("<", "2.13.0"). Shout if you'd rather wait for the tag, but the fix is on main so 2.13 is where it surfaces.
There was a problem hiding this comment.
I doubt it will land on 2.13. PR submissions for 2.13 was closed on:
M2: All PRs landed in PyTorch repo / Feature Submission Closed (5/6/26)
and that PR was merged after it, so please guard it on 2.14
There was a problem hiding this comment.
Good catch, thanks. Didn't realize it missed the 2.13 feature window. Guarding on
is_torch_version("<", "2.14.0") and skipping the test at >=2.14.0 to match. Pushing
the pytest port + gate shortly.
There was a problem hiding this comment.
Done, gated the MPS workaround for is_torch_version("<", "2.14.0") and converted the test to pytest with a skipif condition at >=2.14.0 (an xfail condition will XPASS when the upstream patch is applied, thus skipif works better). A new commit on top;, thank you @Isalia20 for the 2.14 fix.
…7522), port test to pytest Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
What does this PR do?
Fixes #11229
Attention.get_attention_scoresallocatesbaddbmm_inputwithtorch.empty()and usesbeta=0, relying onbaddbmmto ignore the uninitialized input. The MPSbaddbmmkernel does not short-circuit onbeta=0, so any NaN/Inf in the uninitialized memory propagates through0 * NaN = NaNand poisons the attention output. CUDA happens to mask this because its allocator typically returns zero-initialized memory.This change uses
torch.zerosinstead oftorch.emptyonly on MPS, leaving the CUDA / CPU / XPU paths unchanged so they don't pay the extra fill cost.In real workloads this surfaces as black/NaN images from
StableDiffusionXLPipelinewithenable_attention_slicing()on Apple Silicon + fp16, which is the standard memory-saving path on Macs with limited unified memory.Reproduction
Minimal repro on M-series MPS (without the fix): 30/30 trials produce NaN. With the fix: 0/30. CPU baseline: 0/5. Verified on M4 MacBook Pro, torch 2.11.0, fp16 and fp32.
The added test
GetAttentionScoresMPSTests.test_no_nan_when_attention_mask_is_none_on_mpsreproduces the bug deterministically (fails 20/20 without the fix, passes 20/20 with it) by polluting the MPS allocator pool with NaN-filled tensors before each call.Before submitting
Who can review?
@pcuenca (MPS / Apple Silicon maintainer per the PR template)