[ExecuTorch][WebGPU] 2D-fold mul + permute dispatch (lift 65535 1D cap)#20651
[ExecuTorch][WebGPU] 2D-fold mul + permute dispatch (lift 65535 1D cap)#20651JulianCloudNTH wants to merge 6 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20651
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 PendingAs of commit 91b2d97 with merge base 124b677 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude Review thoroughly. Check for stale code / bugs. |
|
Claude finished @JulianCloudNTH's task in 2m 47s —— View job Review: WebGPU 2D-fold for
|
Stack from ghstack (oldest at bottom):
Lift the 65535 workgroup-per-dim cap for
mulandpermuteso they run at any numel.mul.Tensorandpermutestill usedcompute_1d_workgroup_count, which throws oncenumel / wg_size > 65535— hit by a realistic Llama-3.2-1B LoRA layer (mulover[2048, 8192]= 262k workgroups;permuteof[2048, 2048]= 65536).add/sub/div/fill/sdpaalready use the 2D fold; this bringsmul+permutein line.Key changes:
mul/BinaryOp.cpp,permute/Permute.cpp—compute_1d_workgroup_count→compute_2d_workgroup_count(returnsutils::WgCount); dispatch + resize hook now set bothworkgroup_count_xandworkgroup_count_y.binary_mul.wgsl,permute.wgsl—maintakes@builtin(num_workgroups); flat indexgid.x + gid.y * (num_workgroups.x * wg_size)(regenerated*_wgsl.h).Mirrors the landed
addop fold (runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}).Co-authored-with: Claude Code.
@exported-using-ghexport
Differential Revision: D110149677
Differential Revision: D110149677