Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST)
webgpu_dynamic_shape_test test/native/test_dynamic_shape.cpp
)
target_link_libraries(webgpu_dynamic_shape_test PRIVATE GTest::gtest)

# Device-free fold unit test (gtest_main provides main; no device needed).
add_webgpu_native_test(
webgpu_dispatch_2d_test test/native/test_dispatch_2d.cpp
)
target_link_libraries(
webgpu_dispatch_2d_test PRIVATE GTest::gtest GTest::gtest_main
)
endif()
add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp)
endif()
4 changes: 3 additions & 1 deletion backends/webgpu/scripts/test_webgpu_native_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ cmake \
"${EXECUTORCH_ROOT}"

# ── Build + run every native test target that exists in this tree ────────────
TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test)
TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test webgpu_dispatch_2d_test)
BIN_DIR="${BUILD_DIR}/backends/webgpu"

# Which targets are defined depends on which diffs are landed (native_test +
Expand Down Expand Up @@ -212,6 +212,8 @@ if [[ "${INDEX_OK}" == "1" && -x "${BIN_DIR}/webgpu_index_test" ]]; then
"${BIN_DIR}/webgpu_index_test" "${INDEX_DIR}"
fi
[[ -x "${BIN_DIR}/webgpu_scratch_buffer_test" ]] && "${BIN_DIR}/webgpu_scratch_buffer_test"
# Device-free: pure 2D workgroup-count fold unit test (no .pte, no GPU).
[[ -x "${BIN_DIR}/webgpu_dispatch_2d_test" ]] && "${BIN_DIR}/webgpu_dispatch_2d_test"

echo "=== WebGPU native tests on Dawn: all run targets passed ==="

Expand Down
60 changes: 60 additions & 0 deletions backends/webgpu/test/native/test_dispatch_2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Device-free unit test for the pure 2D workgroup-count fold that lifts the
// 65535 per-dim dispatch cap. Exercises the fold arithmetic only — no GPU.

#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>

#include <gtest/gtest.h>

#include <cmath>
#include <cstdint>

using executorch::backends::webgpu::utils::fold_workgroup_count_2d;
using executorch::backends::webgpu::utils::WgCount;

namespace {

constexpr uint32_t kMax = 65535u;

// count <= max -> {count, 1}: the 1D fast path, byte-identical to the old path.
TEST(DispatchFold, FastPath1D) {
for (uint32_t count : {1u, kMax - 1u, kMax}) {
const WgCount got = fold_workgroup_count_2d(count, kMax, "test");
EXPECT_EQ(got.x, count);
EXPECT_EQ(got.y, 1u);
}
}

// count > max -> near-square {x, y}: fits the per-dim cap, covers every
// workgroup, and stays near-square so few invocations are inactive (launched -
// count is O(sqrt(count)); a flat {max, div_up} split would idle up to ~half).
TEST(DispatchFold, NearSquareFold) {
// Includes prefill-scale QK counts (Hq*ceil(S/4)*ceil(ctx/4)/wg) that fold:
// 131072 = S=2048 (32*512*512/64); 2097152 = large-S stress.
for (uint32_t count :
{kMax + 1u, 2u * kMax, 2u * kMax + 1u, 131072u, 2097152u}) {
const WgCount got = fold_workgroup_count_2d(count, kMax, "test");
const uint64_t launched = static_cast<uint64_t>(got.x) * got.y;
const uint32_t root =
static_cast<uint32_t>(std::ceil(std::sqrt(static_cast<double>(count))));
EXPECT_LE(got.x, kMax) << "count=" << count;
EXPECT_LE(got.y, kMax) << "count=" << count;
EXPECT_GE(launched, count) << "count=" << count;
EXPECT_LT(launched - count, 2ull * root)
<< "count=" << count << " launched=" << launched;
}
}

// count > max^2 needs a 3rd dispatch dimension -> throws (out of scope).
TEST(DispatchFold, ThrowsWhenNeeds3rdDimension) {
EXPECT_ANY_THROW(fold_workgroup_count_2d(kMax * kMax + 1u, kMax, "test"));
}

} // namespace
3 changes: 3 additions & 0 deletions backends/webgpu/test/ops/test_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class SdpaConfig:
SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127),
# D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load.
SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0),
# 2D-dispatch cap (>65535 wg): S=512 folds QK; S=2048 folds QK+softmax+AV (cap+1).
SdpaConfig("llama1b_prefill_512", 32, 8, 64, 512, 512, 0),
SdpaConfig("llama1b_prefill_2048", 32, 8, 64, 2048, 2048, 0),
]


Expand Down
12 changes: 12 additions & 0 deletions backends/webgpu/test/test_webgpu_native.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,18 @@ static const SdpaConfig kSdpaConfigs[] = {
16.0f,
/*required=*/false,
/*expect_reject=*/true},
// 2D-dispatch cap (>65535 wg): S=512 folds QK; S=2048 folds QK+softmax+AV
// (cap+1).
{"llama1b_prefill_512", 32, 8, 64, 512, 512, 0, 16.0f, /*required=*/true},
{"llama1b_prefill_2048",
32,
8,
64,
2048,
2048,
0,
16.0f,
/*required=*/true},
};

// Ramp denominator; mirror of test_sdpa.py::_RAMP_DENOM (keep in sync).
Expand Down
Loading