Skip to content

[ExecuTorch][WebGPU] 2D-fold mul + permute dispatch (lift 65535 1D cap)#20651

Open
JulianCloudNTH wants to merge 6 commits into
gh/JulianCloudNTH/83/basefrom
gh/JulianCloudNTH/83/head
Open

[ExecuTorch][WebGPU] 2D-fold mul + permute dispatch (lift 65535 1D cap)#20651
JulianCloudNTH wants to merge 6 commits into
gh/JulianCloudNTH/83/basefrom
gh/JulianCloudNTH/83/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Lift the 65535 workgroup-per-dim cap for mul and permute so they run at any numel.

mul.Tensor and permute still used compute_1d_workgroup_count, which throws once numel / wg_size > 65535 — hit by a realistic Llama-3.2-1B LoRA layer (mul over [2048, 8192] = 262k workgroups; permute of [2048, 2048] = 65536). add/sub/div/fill/sdpa already use the 2D fold; this brings mul + permute in line.

Key changes:

  • mul/BinaryOp.cpp, permute/Permute.cppcompute_1d_workgroup_countcompute_2d_workgroup_count (returns utils::WgCount); dispatch + resize hook now set both workgroup_count_x and workgroup_count_y.
  • binary_mul.wgsl, permute.wgslmain takes @builtin(num_workgroups); flat index gid.x + gid.y * (num_workgroups.x * wg_size) (regenerated *_wgsl.h).

Mirrors the landed add op fold (runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}).

Co-authored-with: Claude Code.
@exported-using-ghexport

Differential Revision: D110149677

Differential Revision: D110149677

[ghstack-poisoned]
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 30, 2026
@pytorch-bot

pytorch-bot Bot commented Jun 30, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20651

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending

As of commit 91b2d97 with merge base 124b677 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

[ghstack-poisoned]
[ghstack-poisoned]
@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude Review thoroughly. Check for stale code / bugs.

@claude

claude Bot commented Jul 3, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 2m 47s —— View job


Review: WebGPU 2D-fold for mul + permute

  • Read PR diff and understand the change
  • Compare against the reference add op fold
  • Check compute_2d_workgroup_count / WgCount / add_dispatch signatures
  • Verify WGSL flat-index math is correct
  • Check for stale code / bugs
  • Post review

Reviewed against the landed add fold (ops/add/{BinaryOp.cpp,binary_add.wgsl}), WebGPUUtils.h (compute_2d_workgroup_count/fold_workgroup_count_2d/WgCount), and the dispatch path in WebGPUGraph.{h,cpp}. The change is correct and a faithful mirror of the reference — no bugs found. One stale comment worth fixing.

✅ Correctness verified

  • add_dispatch brace-init matches the struct. WebGPUDispatch field order is {pipeline, bind_group, workgroup_count_x, kernel_name, workgroup_count_y} (WebGPUGraph.h:46-51). Both call sites pass {pipeline, bind_group, workgroup_count.x, "mul"/"permute", workgroup_count.y} — positions line up exactly. add uses "" as the label; mul/permute supplying real bench labels is a minor improvement, not a bug.
  • Dispatch consumes .y. WebGPUGraph.cpp:818 and :898 both call wgpuComputePassEncoderDispatchWorkgroups(pass, workgroup_count_x, workgroup_count_y, 1), so the folded y is actually launched.
  • Flat-index math is exact. idx = gid.x + gid.y * (num_workgroups.x * wg_size). With grid (x, y, 1) and @workgroup_size(wg_size,1,1): gid.x ∈ [0, x·wg_size), gid.y ∈ [0, y), num_workgroups.x == x, so idx covers [0, x·y·wg_size), and x·y ≥ ceil(numel/wg_size). The idx >= out_meta.numel guard drops the tail. Byte-identical to binary_add.wgsl:17.
  • mul resize hook rewrites both workgroup_count_x and workgroup_count_y (BinaryOp.cpp:204-205), matching add:190-191. permute has no resize hook (it registers none), so nothing to update there — consistent with the op being static-shape only.
  • num_workgroups builtin is correctly added to both main signatures and both regenerated *_wgsl.h SHAs were updated.

🟡 Minor: stale comment

backends/webgpu/runtime/ops/mul/BinaryOp.cpp:37

// Rank guard (NCHW backend is <= 4 dims; 1D dispatch only).

The 1D dispatch only clause is now inaccurate — dispatch is 2D-folded. Suggest trimming to // Rank guard (NCHW backend is <= 4 dims). permute has no equivalent stale comment. Fix this →

💭 Observations (non-blocking, shared with the landed add)

  • gid.y * (num_workgroups.x * wg_size) is u32 arithmetic. In practice bounded by numel (a uint32_t), so no overflow for any realistic tensor; only a theoretical concern as numel → 2³². This is inherited verbatim from the reference add/sdpa fold, so it's not something this PR should change in isolation.
  • No test exercises the >65535 path. test_mul.py / test_permute.py exist but I didn't find a case at the numel that previously threw (mul over [2048, 8192], permute of [2048, 2048]). Adding one would lock in the fix and guard against regression — though if the add fold landed without such a case, matching that precedent is reasonable.

Verdict: LGTM. The one stale comment is optional cleanup; the logic is sound and consistent with the established pattern.
· branch gh/JulianCloudNTH/83/head

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants