[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20583
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20583
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled JobAs of commit a0b5f02 with merge base 124b677 ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review and check for any areas or opportunities for modularization |
|
Claude finished @JulianCloudNTH's task in 3m 55s —— View job Review: WebGPU 2D compute dispatch (#20583)
Solid, well-scoped change. The 🔴 Blocking: lint failure (CI red)
lintrunner -aModularization opportunities1. The graph.add_dispatch({pipeline, bind_group, workgroup_count.x, "", workgroup_count.y});
// ^^ kernel_name placeholderA field that's logically adjacent to
2. inline void set_workgroup_count(WebGPUDispatch& d, utils::WgCount c) {
d.workgroup_count_x = c.x;
d.workgroup_count_y = c.y;
}3. 4. Minor: reuse 5. Shader-side duplication (noting, not necessarily actionable). The thread-form reconstruction Nits / questions
None of the modularization items are blockers — only the lint fix is required to turn CI green. Items 1–3 would meaningfully reduce the positional-init fragility if you want to fold them in. |
99d36d6
into
gh/JulianCloudNTH/75/base
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
… (prefill path) Pull Request resolved: #20583 **Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.** **Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill. **Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`. - **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`. - **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive). **Implementation**: - `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path) - `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)` - Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform) - `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK - Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste) **Constraints**: - `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path - Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path - Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585) Co-authored-with: Claude Code. ghstack-source-id: 399812920 @exported-using-ghexport Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
Stack from ghstack (oldest at bottom):
Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.
Problem: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (
maxComputeWorkgroupsPerDimension, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill.Solution: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from
@builtin(num_workgroups).compute_1d_workgroup_countthrows ifcount > limit; dispatch(count, 1, 1).compute_2d_workgroup_countreturns{count, 1}(fast path) or a near-square{x, y}(x = ceil(sqrt(count))clamped tolimit,y = div_up(count, x)); dispatch(x, y, 1). A flat{limit, div_up(count, limit)}split would idle up to ~half the launched workgroups whencountjust exceedslimit; the near-square split holds the waste toO(sqrt(count))(e.g. 65536 ->{256, 256}, 0 inactive).Implementation:
WgCount+ purefold_workgroup_count_2d+compute_2d_workgroup_countinWebGPUUtils.h(device-free, unit-testable;queried_max_workgroupsfactored out of the 1D path)WebGPUDispatch.workgroup_count_y(default 1, declared last so existing aggregate inits are unchanged); bothdispatchWorkgroupscalls + the profiling record pass(x, y, 1)idx = gid.x + gid.y*(num_workgroups.x*wg_size)(QK/AV/add); row-formrow_idx = wid.x + wid.y*num_workgroups.x(softmax — keeps avalidpredicate, not an early return, soworkgroupBarrier()s stay uniform)Sdpa.cpp: QK/softmax/AV counts via the 2D helper; the dynamic-input_posresize hook recomputes both x and y for QKget_2d_grid_dimspacks whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste)Constraints:
y=1fast path keeps every non-folded dispatch byte-identical to the prior 1D pathrms_norm/embedding/lm_head/update_cacheare row/token-indexed and never hit the cap, so they keep the 1D pathuint32element guard fires first at S~11585)Co-authored-with: Claude Code.
@exported-using-ghexport
Differential Revision: D109517684
Differential Revision: D109517684