From be17d15675be3a05145c13fbd787bd2ece94e3d5 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 26 May 2026 13:29:07 -0700 Subject: [PATCH 1/2] Update (base update) [ghstack-poisoned] --- .../operator_support/index_select_support.py | 14 +- .../operator_support/unfold_copy_support.py | 14 +- backends/arm/scripts/build_executorch.sh | 8 + backends/arm/scripts/pre-push | 2 +- backends/arm/test/ops/test_index_select.py | 32 +++ backends/arm/test/ops/test_unfold_copy.py | 24 ++ backends/cortex_m/CMakeLists.txt | 9 + .../ops/op_quantized_batch_matmul.cpp | 35 +-- backends/cortex_m/ops/op_quantized_conv2d.cpp | 34 +-- .../ops/op_quantized_depthwise_conv2d.cpp | 31 +- .../ops/op_quantized_transpose_conv2d.cpp | 44 +-- backends/cortex_m/ops/operators.py | 28 +- backends/cortex_m/ops/operators.yaml | 9 +- backends/cortex_m/passes/__init__.py | 1 + .../passes/convert_to_cortex_m_pass.py | 64 ++++- .../cortex_m/passes/scratch_buffer_sizes.py | 266 ++++++++++++++++++ backends/cortex_m/test/build_test_runner.sh | 4 +- .../ops_converters/add_tensor_converter.py | 42 ++- .../ops_converters/sub_tensor_converter.py | 40 ++- .../test_add_tensor_converter.py | 263 ++++++++++++++++- .../test_avg_pool2d_converter.py | 9 +- .../test_max_pool_2d_converter.py | 7 +- .../test_mul_tensor_converter.py | 5 - .../test_sub_tensor_converter.py | 260 ++++++++++++++++- backends/nxp/tests/models.py | 4 +- backends/nxp/tests/ops_aliases.py | 2 + 26 files changed, 1112 insertions(+), 139 deletions(-) create mode 100644 backends/cortex_m/passes/scratch_buffer_sizes.py diff --git a/backends/arm/operator_support/index_select_support.py b/backends/arm/operator_support/index_select_support.py index a3188e739c7..285b2cfe79f 100644 --- a/backends/arm/operator_support/index_select_support.py +++ b/backends/arm/operator_support/index_select_support.py @@ -77,8 +77,16 @@ def is_node_tosa_supported( f"{node.target}: dtype {values_dtype} requires INT profile.", ) return False - # fp16/fp32: either FP profile, or INT profile (via quantization) - elif values_dtype in (torch.float16, torch.float32): + # fp16/fp32/bf16: either FP profile, or INT profile (via quantization) + elif values_dtype in (torch.float16, torch.float32, torch.bfloat16): + if values_dtype == torch.bfloat16 and not tosa_spec.support_extension( + "bf16" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires bf16 extension.", + ) + return False if not (tosa_spec.support_float() or tosa_spec.support_integer()): self.reporter.report_reject( node, @@ -90,7 +98,7 @@ def is_node_tosa_supported( self.reporter.report_reject( node, f"{node.target}: unsupported values dtype {values_dtype}; " - "expected bool/int8/int16/int32/float16/float32.", + "expected bool/int8/int16/int32/float16/bfloat16/float32.", ) return False diff --git a/backends/arm/operator_support/unfold_copy_support.py b/backends/arm/operator_support/unfold_copy_support.py index bf6c1cad22e..ac9fc7d0ee3 100644 --- a/backends/arm/operator_support/unfold_copy_support.py +++ b/backends/arm/operator_support/unfold_copy_support.py @@ -84,8 +84,16 @@ def is_node_tosa_supported( f"{node.target}: dtype {values_dtype} requires INT profile.", ) return False - # fp16/fp32: either FP profile, or INT profile (via quantization) - elif values_dtype in (torch.float16, torch.float32): + # fp16/fp32/bf16: either FP profile, or INT profile (via quantization) + elif values_dtype in (torch.float16, torch.float32, torch.bfloat16): + if values_dtype == torch.bfloat16 and not tosa_spec.support_extension( + "bf16" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires bf16 extension.", + ) + return False if not (tosa_spec.support_float() or tosa_spec.support_integer()): self.reporter.report_reject( node, @@ -97,7 +105,7 @@ def is_node_tosa_supported( self.reporter.report_reject( node, f"{node.target}: unsupported values dtype {values_dtype}; " - "expected bool/int8/int16/int32/float16/float32.", + "expected bool/int8/int16/int32/float16/bfloat16/float32.", ) return False diff --git a/backends/arm/scripts/build_executorch.sh b/backends/arm/scripts/build_executorch.sh index 54d2091d1f4..5ac2674f964 100755 --- a/backends/arm/scripts/build_executorch.sh +++ b/backends/arm/scripts/build_executorch.sh @@ -7,6 +7,7 @@ # Optional parameter: # --build_type= "Release" | "Debug" | "RelWithDebInfo" | "UndefinedSanitizer" | "AddressSanitizer" # --etdump build with devtools-etdump support +# --cmake-args= Additional arguments passed to cmake configure set -eu @@ -24,6 +25,7 @@ build_type="Release" build_devtools=OFF build_with_etdump=OFF is_linux_musl=0 +extra_cmake_args=() target_cpu="" help() { @@ -33,6 +35,7 @@ help() { echo " --build_type= Build with Release, Debug, RelWithDebInfo, UndefinedSanitizer or AddressSanitizer, default is ${build_type}" echo " --devtools Build Devtools libs" echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" + echo " --cmake-args= Additional arguments passed to cmake configure" echo " --toolchain= Toolchain can be specified (arm-none-eabi-gcc, arm-zephyr-eabi-gcc, aarch64-linux-musl-gcc). Default: ${toolchain}" echo " --target_cpu= Override the toolchain's default TARGET_CPU (e.g. cortex-m4). Switching target_cpu reuses the same cmake-out dir, so clear ${et_build_root}/cmake-out first to avoid stale per-CPU artifacts. Default: unset (toolchain default)." exit 0 @@ -45,6 +48,10 @@ for arg in "$@"; do --build_type=*) build_type="${arg#*=}";; --devtools) build_devtools=ON ;; --etdump) build_with_etdump=ON ;; + --cmake-args=*) + # shellcheck disable=SC2206 + extra_cmake_args=(${arg#*=}) + ;; --toolchain=*) toolchain="${arg#*=}";; --target_cpu=*) target_cpu="${arg#*=}";; *) @@ -89,6 +96,7 @@ cmake_args=( -DEXECUTORCH_BUILD_DEVTOOLS=${build_devtools} -DEXECUTORCH_BUILD_ARM_ETDUMP=${build_with_etdump} -DEXECUTORCH_BAREMETAL_SKIP_INSTALL=OFF + "${extra_cmake_args[@]}" ) if [[ -n "${target_cpu}" ]]; then diff --git a/backends/arm/scripts/pre-push b/backends/arm/scripts/pre-push index 8e26463cd94..6aa32d07286 100755 --- a/backends/arm/scripts/pre-push +++ b/backends/arm/scripts/pre-push @@ -177,7 +177,7 @@ for COMMIT in ${COMMITS}; do for committed_file in "${license_files[@]}"; do # Skip files with certain extensions case "$committed_file" in - *.md|*.md.in|*.json|*.yml|*.yaml|*.cmake|*.patch|.gitignore|*.bzl) + *.md|*.md.in|*.json|*.yml|*.yaml|*.cmake|*.patch|.gitignore|*.bzl|BUCK|*/BUCK|TARGETS|*/TARGETS) echo -e "${INFO} Skipping license check for ${committed_file} (excluded extension)" continue ;; diff --git a/backends/arm/test/ops/test_index_select.py b/backends/arm/test/ops/test_index_select.py index bb5f0a92c51..4de19d30daf 100644 --- a/backends/arm/test/ops/test_index_select.py +++ b/backends/arm/test/ops/test_index_select.py @@ -61,6 +61,26 @@ def forward(self, input_: torch.Tensor, dim: int, index_: torch.Tensor): torch.tensor([3, 1], dtype=torch.int32), # [W=2] ), } +test_data_fp_bf16: dict[str, input_params] = { + # Rank-2: [K, C] -> index_select dim=0 => [W, C] + "test_bf16_rank2_dim0": ( + torch.tensor( + [[0.5, 1.25, 2.5], [3.5, 4.25, 5.75], [6.5, 7.25, 8.75]], + dtype=torch.bfloat16, + ), # [K=3, C=3] + 0, + torch.tensor([2, 0], dtype=torch.int32), # [W=2] + ), + # Rank-3: [N, K, C] -> index_select dim=-1 => [N, K, W] + "test_bf16_rank3_dim_neg1": ( + torch.tensor( + [[[0.5, 1.5], [2.5, 3.5]], [[4.5, 5.5], [6.5, 7.5]]], + dtype=torch.bfloat16, + ), # [N=2, K=2, C=2] + -1, + torch.tensor([1, 0], dtype=torch.int32), # [W=2] + ), +} # ---- INT profile: integer inputs + bool ---- test_data_int: dict[str, input_params] = { @@ -104,6 +124,18 @@ def test_index_select_tosa_FP(test_data: input_params): pipeline.run() +@common.parametrize("test_data", test_data_fp_bf16) +def test_index_select_tosa_FP_bf16(test_data: input_params): + pipeline = TosaPipelineFP[input_params]( + IndexSelect(), + test_data, + aten_op=IndexSelect.aten_op, + exir_op=IndexSelect.exir_op, + tosa_extensions=["bf16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_int | test_data_fp) def test_index_select_tosa_INT(test_data: input_params): # INT profile runs quantized, so we test both int inputs and float inputs here. diff --git a/backends/arm/test/ops/test_unfold_copy.py b/backends/arm/test/ops/test_unfold_copy.py index 2b502a9be10..baa4b7f64bc 100644 --- a/backends/arm/test/ops/test_unfold_copy.py +++ b/backends/arm/test/ops/test_unfold_copy.py @@ -120,6 +120,18 @@ def forward(self, input_: torch.Tensor, dim_: int, size_: int, step_: int): ), } +test_data_bf16: dict[str, input_params] = { + "test_bf16_2d_dim1": ( + torch.tensor( + [[0.1, 0.2, 0.3, 0.4, 0.5], [1.1, 1.2, 1.3, 1.4, 1.5]], + dtype=torch.bfloat16, + ), # [B=2, T=5] + 1, + 3, + 2, # U=(5-3)//2+1=2 -> [B=2, U=2, C=3] + ), +} + @common.parametrize("test_data", test_data_fp) def test_unfold_copy_tosa_FP(test_data: input_params): @@ -132,6 +144,18 @@ def test_unfold_copy_tosa_FP(test_data: input_params): pipeline.run() +@common.parametrize("test_data", test_data_bf16) +def test_unfold_copy_tosa_FP_bf16(test_data: input_params): + pipeline = TosaPipelineFP[input_params]( + UnfoldCopy(), + test_data, + aten_op=UnfoldCopy.aten_op, + exir_op=UnfoldCopy.exir_op, + tosa_extensions=["bf16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_int | test_data_fp) def test_unfold_copy_tosa_INT(test_data: input_params): pipeline = TosaPipelineINT[input_params]( diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index 876c65982e6..627406c1935 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -30,6 +30,10 @@ set(CMSIS_NN_LOCAL_PATH "" CACHE PATH "Path to existing local CMSIS-NN installation" ) +option(CORTEX_M_ENABLE_RUNTIME_CHECKS + "Enable additional Cortex-M runtime assertions and validation checks" + OFF +) # Try to find existing / local CMSIS-NN installation. This is useful for # debugging and testing with local changes. This is not common, as the CMSIS-NN @@ -107,6 +111,11 @@ target_link_libraries( PRIVATE executorch PRIVATE kernels_util_all_deps ) +target_compile_definitions( + cortex_m_kernels + PRIVATE + $<$:CORTEX_M_ENABLE_RUNTIME_CHECKS> +) # Include directories for cortex_m_kernels target_include_directories( diff --git a/backends/cortex_m/ops/op_quantized_batch_matmul.cpp b/backends/cortex_m/ops/op_quantized_batch_matmul.cpp index e6bc5a949ce..345753ca8fc 100644 --- a/backends/cortex_m/ops/op_quantized_batch_matmul.cpp +++ b/backends/cortex_m/ops/op_quantized_batch_matmul.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -71,6 +72,7 @@ Tensor& quantized_batch_matmul_out( int64_t output_offset, int64_t output_multiplier, int64_t output_shift, + const Tensor& scratch, Tensor& out) { if (!validate_batch_matmul_arguments(context, lhs, rhs_transposed, out)) { return out; @@ -100,25 +102,26 @@ Tensor& quantized_batch_matmul_out( quant_params.multiplier = static_cast(output_multiplier); quant_params.shift = static_cast(output_shift); - const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&out_dims); - cmsis_nn_context ctx; ctx.buf = nullptr; - ctx.size = 0; - - if (buf_size > 0) { - auto buffer_or_error = context.allocate_temp(buf_size); - if (!buffer_or_error.ok()) { - ET_LOG( - Error, - "quantized_batch_matmul: failed to allocate scratch buffer (%d bytes)", - buf_size); - context.fail(buffer_or_error.error()); - return out; - } - ctx.buf = buffer_or_error.get(); - ctx.size = buf_size; + ctx.size = scratch.nbytes(); + if (ctx.size > 0) { + ctx.buf = scratch.mutable_data_ptr(); + } + +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS + const int32_t runtime_buffer_bytes = + arm_fully_connected_s8_get_buffer_size(&out_dims); + if (ctx.size != static_cast(runtime_buffer_bytes)) { + ET_LOG( + Error, + "quantized_batch_matmul: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(ctx.size), + runtime_buffer_bytes); + context.fail(Error::Internal); + return out; } +#endif const arm_cmsis_nn_status status = arm_batch_matmul_s8( &ctx, diff --git a/backends/cortex_m/ops/op_quantized_conv2d.cpp b/backends/cortex_m/ops/op_quantized_conv2d.cpp index 7d4433690f6..8af374c03f8 100644 --- a/backends/cortex_m/ops/op_quantized_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_conv2d.cpp @@ -112,6 +112,7 @@ Tensor& quantized_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, Tensor& out) { if (!validate_conv2d_arguments( context, @@ -182,31 +183,30 @@ Tensor& quantized_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } - const int32_t buffer_bytes = arm_convolve_wrapper_s8_get_buffer_size( +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS + const int32_t runtime_buffer_bytes = arm_convolve_wrapper_s8_get_buffer_size( &conv_params, &input_dims, &filter_dims, &output_dims); - if (buffer_bytes < 0) { + if (runtime_buffer_bytes < 0) { ET_LOG( Error, "quantized_conv2d_out: CMSIS-NN buffer size calculation failed"); context.fail(Error::Internal); return out; } - if (buffer_bytes > 0) { - auto buffer_or_error = - context.allocate_temp(buffer_bytes, kCortexMMveAlignment); - if (!buffer_or_error.ok()) { - ET_LOG( - Error, - "quantized_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - static_cast(buffer_bytes), - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); - return out; - } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; + if (scratch.nbytes() != static_cast(runtime_buffer_bytes)) { + ET_LOG( + Error, + "quantized_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + static_cast(runtime_buffer_bytes)); + context.fail(Error::Internal); + return out; } +#endif const arm_cmsis_nn_status status = arm_convolve_wrapper_s8( &cmsis_context, diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index 8dec61e0af1..21d4f257501 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -150,6 +150,7 @@ Tensor& quantized_depthwise_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, Tensor& out) { if (!validate_depthwise_conv2d_arguments( context, @@ -220,32 +221,32 @@ Tensor& quantized_depthwise_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } - const int32_t buffer_bytes = arm_depthwise_conv_wrapper_s8_get_buffer_size( - &dw_conv_params, &input_dims, &filter_dims, &output_dims); - if (buffer_bytes < 0) { +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS + const int32_t runtime_buffer_bytes = + arm_depthwise_conv_wrapper_s8_get_buffer_size( + &dw_conv_params, &input_dims, &filter_dims, &output_dims); + if (runtime_buffer_bytes < 0) { ET_LOG( Error, "quantized_depthwise_conv2d_out: CMSIS-NN buffer size calculation failed"); context.fail(Error::Internal); return out; } - - auto buffer_or_error = context.allocate_temp( - static_cast(buffer_bytes), kCortexMMveAlignment); - if (!buffer_or_error.ok()) { + if (scratch.nbytes() != static_cast(runtime_buffer_bytes)) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - static_cast(buffer_bytes), - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); + "quantized_depthwise_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + static_cast(runtime_buffer_bytes)); + context.fail(Error::Internal); return out; } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; - +#endif const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8( &cmsis_context, &dw_conv_params, diff --git a/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp b/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp index e3f6135c7b9..d2b66b18802 100644 --- a/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -97,6 +98,8 @@ Tensor& quantized_transpose_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, + const Tensor& output_scratch, Tensor& out) { if (!validate_transpose_conv2d_arguments( context, @@ -179,44 +182,43 @@ Tensor& quantized_transpose_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } cmsis_nn_context output_context; output_context.buf = nullptr; - output_context.size = 0; - + output_context.size = output_scratch.nbytes(); + if (output_context.size > 0) { + output_context.buf = output_scratch.mutable_data_ptr(); + } +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS const int32_t buffer_bytes = arm_transpose_conv_s8_get_buffer_size( &transpose_conv_params, &input_dims, &filter_dims, &output_dims); - auto buffer_or_error = context.allocate_temp( - static_cast(buffer_bytes), kCortexMMveAlignment); - if (!buffer_or_error.ok()) { + if (scratch.nbytes() != static_cast(buffer_bytes)) { ET_LOG( Error, - "quantized_transpose_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - buffer_bytes, - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); + "quantized_transpose_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + buffer_bytes); + context.fail(Error::Internal); return out; } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; const int32_t output_buffer_bytes = arm_transpose_conv_s8_get_reverse_conv_buffer_size( &transpose_conv_params, &input_dims, &filter_dims); - auto output_buffer_or_error = context.allocate_temp( - static_cast(output_buffer_bytes), kCortexMMveAlignment); - if (!output_buffer_or_error.ok()) { + if (output_scratch.nbytes() != static_cast(output_buffer_bytes)) { ET_LOG( Error, - "quantized_transpose_conv2d_out: failed to allocate output scratch buffer (%d bytes, error %d)", - output_buffer_bytes, - static_cast(output_buffer_or_error.error())); - context.fail(output_buffer_or_error.error()); + "quantized_transpose_conv2d_out: output scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(output_scratch.nbytes()), + output_buffer_bytes); + context.fail(Error::Internal); return out; } - output_context.buf = output_buffer_or_error.get(); - output_context.size = output_buffer_bytes; +#endif const arm_cmsis_nn_status status = arm_transpose_conv_wrapper_s8( &cmsis_context, diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 2c35ed8730b..d4393bc7ada 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -271,13 +271,15 @@ def quantized_mul_impl( "quantized_batch_matmul(" "Tensor lhs, int lhs_zero_point, " "Tensor rhs_transposed, int rhs_zero_point, " - "int output_zero_point, int output_multiplier, int output_shift) -> Tensor" + "int output_zero_point, int output_multiplier, int output_shift, " + "Tensor scratch) -> Tensor" ) lib.define( "quantized_batch_matmul.out(" "Tensor lhs, int lhs_zero_point, " "Tensor rhs_transposed, int rhs_zero_point, " "int output_zero_point, int output_multiplier, int output_shift, " + "Tensor scratch, " "*, Tensor(a!) out) -> Tensor(a!)" ) @@ -291,6 +293,7 @@ def quantized_batch_matmul_meta( output_zero_point: int, output_multiplier: int, output_shift: int, + scratch: torch.Tensor, ) -> torch.Tensor: batch, lhs_rows, inner = lhs.shape batch_rhs, rhs_cols, inner_rhs = rhs_transposed.shape @@ -307,6 +310,7 @@ def quantized_batch_matmul_impl( output_zero_point: int, output_multiplier: int, output_shift: int, + scratch: torch.Tensor, ) -> torch.Tensor: # Offsets are negated zero points (CMSIS-NN convention) lhs_fp = lhs.to(torch.float32) + float(lhs_zero_point) @@ -638,7 +642,8 @@ def pad_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch" ") -> Tensor" ) @@ -657,6 +662,7 @@ def pad_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " "*, Tensor(a!) out" ") -> Tensor(a!)" ) @@ -733,6 +739,7 @@ def quantized_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -762,6 +769,7 @@ def quantized_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: if input.dim() != 4 or weight.dim() != 4: raise RuntimeError("quantized_conv2d expects 4D input and weight tensors") @@ -830,7 +838,8 @@ def quantized_conv2d_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch" ") -> Tensor" ) @@ -850,6 +859,7 @@ def quantized_conv2d_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " "*, Tensor(a!) out" ") -> Tensor(a!)" ) @@ -870,6 +880,7 @@ def quantized_depthwise_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -900,6 +911,7 @@ def quantized_depthwise_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: if input.dim() != 4 or weight.dim() != 4: raise RuntimeError( @@ -973,7 +985,9 @@ def quantized_depthwise_conv2d_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch, " + "Tensor output_scratch" ") -> Tensor" ) @@ -992,6 +1006,8 @@ def quantized_depthwise_conv2d_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " + "Tensor output_scratch, " "*, Tensor(a!) out) -> Tensor(a!)" ) @@ -1057,6 +1073,8 @@ def quantized_transpose_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, + output_scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -1095,6 +1113,8 @@ def quantized_transpose_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, + output_scratch: torch.Tensor, ) -> torch.Tensor: """ Reference implementation of quantized transposed convolution. diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index e0ebbfab868..8db109dea43 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -65,19 +65,20 @@ - arg_meta: null kernel_name: cortex_m::pad_out -- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: cortex_m::quantized_conv2d_out -- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) + +- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: cortex_m::quantized_depthwise_conv2d_out -- func: cortex_m::quantized_transpose_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_transpose_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, Tensor output_scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null @@ -94,7 +95,7 @@ - arg_meta: null kernel_name: cortex_m::quantized_max_pool2d_out -- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index 92179ec6654..c379461949f 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -33,6 +33,7 @@ def _ensure_cortex_m_dependencies() -> None: _ensure_cortex_m_dependencies() +from .cortex_m_pass import CortexMPass # noqa # usort: skip from .activation_fusion_pass import ActivationFusionPass # noqa from .clamp_hardswish_pass import ClampHardswishPass # noqa from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 418f6cd63ff..e61ddaf63bc 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -6,25 +6,32 @@ # LICENSE file in the root directory of this source tree. import executorch.backends.cortex_m.ops.operators # noqa +import executorch.exir as exir import torch import torch.fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor + +from executorch.backends.cortex_m.passes import CortexMPass from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot +from executorch.backends.cortex_m.passes.scratch_buffer_sizes import ( + required_cmsis_nn_buffer_sizes, +) from executorch.backends.transforms.utils import ( create_constant_placeholder, get_param_tensor, is_param_node, ) - -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes import make_alloc_node +from torch._subclasses.fake_tensor import FakeTensorMode + from torch.export.graph_signature import InputKind from torch.fx.passes.infra.pass_manager import PassResult -class ConvertToCortexMPass(XNNPACKPass): +class ConvertToCortexMPass(CortexMPass): """ Cortex-M backend pass for replacing supported quantized kernels with Cortex-M accelerated kernels. @@ -33,6 +40,15 @@ class ConvertToCortexMPass(XNNPACKPass): by call_operator. """ + def _create_uninitialized_alloc_node(self): + """Create an unitialized alloc node to be initialize at a later point.""" + with FakeTensorMode() as mode: + return make_alloc_node( + self.exported_program.graph_module, + mode.from_tensor(torch.empty(0)), + None, + ) + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): """ Computes the precomputed kernel sum term (bias optional) @@ -238,6 +254,9 @@ def _get_convolution_replacement(self, node): torch.tensor(quantized_shifts, dtype=torch.int32), ) + with node.graph.inserting_before(node): + scratch = self._create_uninitialized_alloc_node() + if use_depthwise_conv: # Compute depth_multiplier for depthwise convolution # For depthwise: output_channels = input_channels * depth_multiplier @@ -263,6 +282,7 @@ def _get_convolution_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, ) return exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default, new_args else: @@ -280,9 +300,36 @@ def _get_convolution_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, ) return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args + def _initialize_alloc_node_size(self, node: torch.fx.Node) -> None: + """For nodes with a registered buffer size function for node.target, set the buffer sizes + of the last n args, which should be exir.memory.alloc nodes. For nodes without a + registered function, do nothing. + """ + + scratch_buffer_sizes = required_cmsis_nn_buffer_sizes( + node, self.target_config.backend + ) + if scratch_buffer_sizes is None: + return + + # Assume that scratch_buffer_sizes are given from left to right in the call signature of node.target. + for i, scratch_buffer_size in enumerate(reversed(scratch_buffer_sizes)): + scratch_arg = node.args[-(i + 1)] + if ( + not isinstance(scratch_arg, torch.fx.Node) + or scratch_arg.target != exir.memory.alloc + ): + raise RuntimeError( + f"Expected scratch alloc node as final argument(s) for {node.target}, got {scratch_arg}." + ) + + # buffer size is given in bytes, always use uint8 as dtype. + scratch_arg.args = (((scratch_buffer_size,), torch.uint8),) + def _get_transpose_conv2d_replacement(self, node): """ Transform aten.convolution with transposed=True to cortex_m.quantized_transpose_conv2d @@ -363,6 +410,10 @@ def _get_transpose_conv2d_replacement(self, node): torch.tensor(quantized_shifts, dtype=torch.int32), ) + with node.graph.inserting_before(node): + scratch = self._create_uninitialized_alloc_node() + output_scratch = self._create_uninitialized_alloc_node() + new_args = ( x, weight_nhwc, @@ -377,6 +428,8 @@ def _get_transpose_conv2d_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, + output_scratch, ) return exir_ops.edge.cortex_m.quantized_transpose_conv2d.default, new_args @@ -415,6 +468,9 @@ def _get_bmm_replacement(self, node): args=(rhs_node, [0, 2, 1]), ) + with node.graph.inserting_before(node): + scratch = self._create_uninitialized_alloc_node() + args = ( lhs_node, -lhs_zp, @@ -423,6 +479,7 @@ def _get_bmm_replacement(self, node): output_zp, output_mult, output_shift, + scratch, ) return exir_ops.edge.cortex_m.quantized_batch_matmul.default, args @@ -459,6 +516,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: args=args, kwargs={}, ) + self._initialize_alloc_node_size(cortex_m_op) node.replace_all_uses_with(cortex_m_op) graph_module.graph.erase_node(node) diff --git a/backends/cortex_m/passes/scratch_buffer_sizes.py b/backends/cortex_m/passes/scratch_buffer_sizes.py new file mode 100644 index 00000000000..36f3f8bbc17 --- /dev/null +++ b/backends/cortex_m/passes/scratch_buffer_sizes.py @@ -0,0 +1,266 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable +from typing import Any, cast + +import cmsis_nn # type: ignore[import-not-found, import-untyped] +import executorch.backends.cortex_m.ops.operators # noqa + +import torch +import torch.fx + +from executorch.exir.dialects._ops import ops as exir_ops + +BufferSizeFunction = Callable[[cmsis_nn.Backend, torch.fx.Node], list[int]] + + +def _tensor_from_node(node: torch.fx.Node) -> torch.Tensor: + if "val" in node.meta: + return node.meta["val"] + elif node.op == "call_function": + args = ( + _tensor_from_node(arg) if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ) + return node.target(*args, **node.kwargs) # type: ignore[operator] + else: + raise RuntimeError("Encountered non-call_function without 'val' meta.") + + +def _shape_from_node(node: torch.fx.Node) -> torch.Size: + return _tensor_from_node(node).shape + + +def _get_common_conv_buffer_size_inputs( + conv_node: torch.fx.Node, + *, + stride_arg_idx: int = 3, + padding_arg_idx: int = 4, + dilation_arg_idx: int = 5, +) -> tuple[ + list[int], + list[int], + list[int], + list[int], + list[int], + list[int], +]: + x = cast(torch.fx.Node, conv_node.args[0]) + weight = cast(torch.fx.Node, conv_node.args[1]) + stride = cast(list[int], conv_node.args[stride_arg_idx]) + padding = cast(list[int], conv_node.args[padding_arg_idx]) + dilation = cast(list[int], conv_node.args[dilation_arg_idx]) + + # Input is NCHW (PyTorch); CMSIS-NN wants NHWC dims. + n, c_in, height, width = _shape_from_node(x) + + weight_shape = _shape_from_node(weight) + + # Output is NCHW; convert to NHWC dims. + out_n, out_c, out_h, out_w = _shape_from_node(conv_node) + + input_nhwc = [n, height, width, c_in] + output_nhwc = [out_n, out_h, out_w, out_c] + stride_hw = [int(stride[0]), int(stride[1])] + padding_hw = [int(padding[0]), int(padding[1])] + dilation_hw = [int(dilation[0]), int(dilation[1])] + + return ( + input_nhwc, + list(weight_shape), + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) + + +def cmsis_nn_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs(conv_node=conv_node) + input_offset = cast(int, conv_node.args[6]) + output_offset = cast(int, conv_node.args[7]) + output_qmin = cast(int, conv_node.args[10]) + output_qmax = cast(int, conv_node.args[11]) + + # Weight is in OHWI layout after conversion. + c_out, kernel_h, kernel_w, c_in = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, c_in] + + return [ + int( + cmsis_nn.convolve_wrapper_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ) + ] + + +def cmsis_nn_depthwise_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs(conv_node=conv_node) + depth_multiplier = cast(int, conv_node.args[6]) + input_offset = cast(int, conv_node.args[7]) + output_offset = cast(int, conv_node.args[8]) + output_qmin = cast(int, conv_node.args[11]) + output_qmax = cast(int, conv_node.args[12]) + + # Weight is in IHWO layout after conversion. + _, kernel_h, kernel_w, c_out = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, 1] + + return [ + int( + cmsis_nn.depthwise_conv_wrapper_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + ch_mult=depth_multiplier, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ) + ] + + +def cmsis_nn_batch_matmul_buffer_size( + backend: cmsis_nn.Backend, + matmul_node: torch.fx.Node, +) -> list[int]: + rhs_transposed = cast(torch.fx.Node, matmul_node.args[2]) + rhs_shape = _shape_from_node(rhs_transposed) + + _, rhs_cols, inner = rhs_shape + + return [ + int( + cmsis_nn.fully_connected_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + filter_nhwc=[inner, -1, -1, rhs_cols], # H and W values are unused. + ) + ) + ] + + +def cmsis_nn_transpose_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs( + conv_node=conv_node, + stride_arg_idx=3, + padding_arg_idx=4, + dilation_arg_idx=6, + ) + output_padding = cast(list[int], conv_node.args[5]) + input_offset = cast(int, conv_node.args[7]) + output_offset = cast(int, conv_node.args[8]) + output_qmin = cast(int, conv_node.args[11]) + output_qmax = cast(int, conv_node.args[12]) + c_out, kernel_h, kernel_w, kernel_c_in = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, kernel_c_in] + padding_offsets_hw = [int(output_padding[0]), int(output_padding[1])] + + return [ + int( + cmsis_nn.transpose_conv_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + padding_offsets_hw=padding_offsets_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ), + int( + cmsis_nn.transpose_conv_reverse_conv_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + padding_offsets_hw=padding_offsets_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ), + ] + + +_target_to_buffer_sizes_registry: dict[Any, BufferSizeFunction] = { + exir_ops.edge.cortex_m.quantized_conv2d.default: cmsis_nn_conv_buffer_size, + exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default: cmsis_nn_depthwise_conv_buffer_size, + exir_ops.edge.cortex_m.quantized_batch_matmul.default: cmsis_nn_batch_matmul_buffer_size, + exir_ops.edge.cortex_m.quantized_transpose_conv2d.default: cmsis_nn_transpose_conv_buffer_size, +} + + +def required_cmsis_nn_buffer_sizes( + node: torch.fx.Node, backend: cmsis_nn.Backend +) -> list[int] | None: + """Returns a sequence of scratch buffer sizes required by node, in bytes. + If no function is registered to compute this for the target of the node, return None. + """ + if node.target not in _target_to_buffer_sizes_registry: + return None + + buffer_size_function = _target_to_buffer_sizes_registry[node.target] + return buffer_size_function(backend, node) diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh index bdca1a21e7c..a67c5a907a4 100755 --- a/backends/cortex_m/test/build_test_runner.sh +++ b/backends/cortex_m/test/build_test_runner.sh @@ -28,7 +28,7 @@ fi script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") et_root_dir=$(realpath "${script_dir}/../../..") build_executorch="${et_root_dir}/backends/arm/scripts/build_executorch.sh" -${build_executorch} --devtools --target_cpu="${target_cpu}" +${build_executorch} --devtools --target_cpu="${target_cpu}" --cmake-args="-DCORTEX_M_ENABLE_RUNTIME_CHECKS=ON" # Build executor runner with selected aten ops and semi hosting build_dir="${et_root_dir}/arm_test" @@ -48,4 +48,4 @@ aten::unsqueeze_copy.out,\ aten::select_copy.int_out,\ aten::amax.out" -${build_executor_runner} --pte=semihosting --bundleio --target="${target}" --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" --extra_build_flags="-DET_ATOL=5.0 -DET_RTOL=1.0" +${build_executor_runner} --pte=semihosting --bundleio --target="${target}" --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" --extra_build_flags="-DET_ATOL=5.0 -DET_RTOL=1.0 -DET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=0" diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py index fd28b077b8a..673af19310f 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch + +from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, @@ -23,11 +26,33 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - if NodeConverter.uses_shape_broadcasting(node): - # Shape broadcasting may require the addition of `Transpose` ops during conversion. - return False + if custom_delegation_options.use_new_flow_neutron_c: + if not NodeConverter.at_least_one_input_shape_matches_the_output_shape( + node + ): + return False - return True + # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes + # Transpose is currently not supported for new flow + if any( + input_node.meta[NXP_NODE_FORMAT].is_channels_first() + for input_node in node.all_input_nodes + ) and NodeConverter._node_inputs_ranks_not_equal(node): + return False + + supported_types = [torch.int8, torch.uint8] + if not NodeConverter.uses_quantization_type_for_io( + node, supported_types, [0, 1], [0] + ): + return False + + return True + else: + if NodeConverter.uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False + + return True @staticmethod def _is_supported_in_IR( @@ -43,12 +68,13 @@ def _is_supported_in_IR( return True - # add.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1) def convert(self, node: Node): - """Convert 'add_tensor' operator to TFLite 'add'.""" + """Convert 'add_tensor' operator to NeutronIR 'Add'. + The ExecuTorch schema is: + add.Tensor(Tensor self, Tensor other, Scalar alpha=1) + """ self.assert_convertible(node) - t_op = self._create_tflite_op_with_io_tensors(node) - t_op.builtin_options = add_options.Add() + self.builder.append_operators([t_op]) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py index e97f4bf63c2..79dbcbcc012 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch + +from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, @@ -23,11 +26,33 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - if NodeConverter.uses_shape_broadcasting(node): - # Shape broadcasting may require the addition of `Transpose` ops during conversion. - return False + if custom_delegation_options.use_new_flow_neutron_c: + if not NodeConverter.at_least_one_input_shape_matches_the_output_shape( + node + ): + return False - return True + # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes + # Transpose is currently not supported for new flow + if any( + input_node.meta[NXP_NODE_FORMAT].is_channels_first() + for input_node in node.all_input_nodes + ) and NodeConverter._node_inputs_ranks_not_equal(node): + return False + + supported_types = [torch.int8, torch.uint8] + if not NodeConverter.uses_quantization_type_for_io( + node, supported_types, [0, 1], [0] + ): + return False + + return True + else: + if NodeConverter.uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False + + return True @staticmethod def _is_supported_in_IR( @@ -45,9 +70,12 @@ def _is_supported_in_IR( return True - # sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1) def convert(self, node: Node): - """Convert 'sub_tensor' operator to NeutronIR 'Sub'.""" + """Convert 'sub_tensor' operator to NeutronIR 'Sub'. + The ExecuTorch schema is: + sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) + """ + self.assert_convertible(node) t_op = self._create_tflite_op_with_io_tensors(node) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py index 1aa58ab5d95..4a656eb9517 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py @@ -1,7 +1,8 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import numpy as np import pytest import torch @@ -9,17 +10,29 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator +from executorch.backends.nxp.tests.executorch_pipeline import ( + ModelInputSpec, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier from executorch.backends.nxp.tests.models import ( AddTensorConvModule, AddTensorModule, AddTensorOneInputModule, ) +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + AddTensor, + Convolution, + ExecutorchDelegateCall, +) from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -92,20 +105,26 @@ def test_add_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): @pytest.mark.parametrize( - "input_shape", + "x_input_shape", [ pytest.param((1, 4, 8, 8), id="4D."), pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."), ], ) -def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat): +def test_add_tensor_w_conv_quant_conversion(mocker, x_input_shape, use_qat): model = AddTensorConvModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + n, c, h, w = x_input_shape + y_input_shape = (n, 8, h, w) + # Run conversion _ = to_quantized_edge_program( - model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + model, + [x_input_shape, y_input_shape], + use_qat=use_qat, + use_neutron_for_format_conversion=False, ) # Capture generated model @@ -114,7 +133,13 @@ def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat): # Capture converted program exported_program: ExportedProgram = converter_spy.call_args.args[1] - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + input_data_1 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data_2 = (np.random.random(y_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data = {0: input_data_1, 1: input_data_2} convert_run_compare( exported_program, @@ -149,7 +174,7 @@ def test_add_tensor_broadcasting_unsupported_quant_conversion( nodes = list(edge_program.graph.nodes) # Broadcast is not supported, node is not converted - assert nodes[6].target.__name__ == "aten.add.Tensor" # Add Tensor is not delegated. + assert nodes[6].target == AddTensor # Add Tensor is not delegated. # Capture converted program # exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -159,3 +184,227 @@ def test_add_tensor_broadcasting_unsupported_quant_conversion( # input_data = {0: x_input_data, 1: y_input_data} # # convert_run_compare(exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data) + + +class TestAddTensorNewNeutronFlow: + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param((2, 4, 3, 15), id="4D."), + pytest.param( + (6, 82), + id="2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 68, 7), + id="3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 4, 9, 11, 4), + id="5D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = AddTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AddTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param((2, 4, 3, 15), id="4D."), + pytest.param( + (1, 4, 9, 11, 4), + id="5D.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference_qat(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = AddTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AddTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + use_qat=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 6)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((5, 3, 4)), ModelInputSpec((1, 3, 1))], + id="2 inputs 3D.", + ), + pytest.param( + [ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D + 2D." + ), + pytest.param( + [ModelInputSpec((69, 73)), ModelInputSpec((1, 73))], + id="2 inputs 2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__broadcast(self, input_spec, mocker): + model = AddTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AddTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 1)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((1, 3, 4)), ModelInputSpec((5, 3, 1))], + id="2 inputs 3D.", + ), + pytest.param( + [ModelInputSpec((6, 4)), ModelInputSpec((6, 6, 1))], + id="2 inputs 2D + 3D.", + ), + ], + ) + def test__broadcast_unsupported(self, input_spec): + # Broadcast where at least one of the inputs is not equal to output is not supported + model = AddTensorModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `add.Tensor` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [AddTensor]) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param( + (1, 4, 5, 5), id="4D, product of dims is not a multiple of 8." + ), + ], + ) + def test__w_conv(self, x_input_shape, mocker): + model = AddTensorConvModule() + + n, c, h, w = x_input_shape + y_input_spec = ModelInputSpec((n, 8, h, w)) + x_input_spec = ModelInputSpec(x_input_shape) + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={AddTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, y_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 8, 5, 1))], + id="2 inputs 4D + 4D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 5, 67)), ModelInputSpec((1, 8, 5, 1))], + id="2 inputs 4D + 4D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__w_conv_broadcast(self, input_spec, mocker): + model = AddTensorConvModule() + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={AddTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 5))], + id="2 inputs 4D + 2D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 4, 10)), ModelInputSpec((1, 4, 1))], + id="2 inputs 4D + 3D.", + ), + ], + ) + def test__w_conv_unsupported(self, input_spec): + model = AddTensorConvModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `add.Tensor` was NOT delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [AddTensor]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py index 2c73ccd8092..193b7ecf9ab 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py @@ -6,6 +6,7 @@ import numpy as np import pytest import torch + from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) @@ -29,13 +30,8 @@ ToNHWCPreprocess, ) from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier -from executorch.backends.nxp.tests.model_output_comparator import ( - NumericalStatsOutputComparator, -) from executorch.backends.nxp.tests.models import AvgPool2dConvModule, AvgPool2dModule - from executorch.backends.nxp.tests.nsys_testing import lower_run_compare - from executorch.backends.nxp.tests.ops_aliases import ( AvgPool2D, ExecutorchDelegateCall, @@ -45,6 +41,7 @@ Unsqueeze, ViewCopy, ) + from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -320,7 +317,6 @@ def test__basic_nsys_inference(self, mocker): def test__basic_nsys_inference_qat(self, mocker): input_shape = (2, 9, 6, 15) model = AvgPool2dModule(False, 0) - comparator = NumericalStatsOutputComparator() graph_verifier = DetailedGraphVerifier( mocker, expected_delegated_ops={AvgPool2D: 1}, expected_non_delegated_ops={} ) @@ -329,7 +325,6 @@ def test__basic_nsys_inference_qat(self, mocker): model, input_shape, graph_verifier, - output_comparator=comparator, use_new_flow_neutron_c=True, use_qat=True, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py index 583dc2bfd04..9062d5efbfc 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( @@ -17,9 +18,6 @@ ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier -from executorch.backends.nxp.tests.model_output_comparator import ( - NumericalStatsOutputComparator, -) from executorch.backends.nxp.tests.nsys_testing import lower_run_compare from executorch.backends.nxp.tests.ops_aliases import ( ExecutorchDelegateCall, @@ -32,7 +30,6 @@ ViewCopy, ) from executorch.backends.nxp.tests.use_qat import * # noqa F403 -import pytest class MaxPool1DModule(torch.nn.Module): @@ -286,7 +283,6 @@ def test__basic_nsys_inference(self, mocker): def test__basic_nsys_inference_qat(self, mocker): input_shape = (2, 11, 7, 16) # The old flow limited the batch size to 1. model = MaxPool2dModule() - comparator = NumericalStatsOutputComparator() graph_verifier = DetailedGraphVerifier( mocker, expected_delegated_ops={MaxPool2DWithIndices: 1, GetItem: 1}, @@ -297,7 +293,6 @@ def test__basic_nsys_inference_qat(self, mocker): model, input_shape, graph_verifier, - output_comparator=comparator, use_new_flow_neutron_c=True, use_qat=True, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py index 927af47bbf5..90113f484ad 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py @@ -21,9 +21,6 @@ ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier -from executorch.backends.nxp.tests.model_output_comparator import ( - NumericalStatsOutputComparator, -) from executorch.backends.nxp.tests.models import ( MulTensorConvModule, MulTensorModule, @@ -256,7 +253,6 @@ def test__basic_nsys_inference(self, x_input_shape, mocker): def test__basic_nsys_inference_qat(self, x_input_shape, mocker): x_input_spec = ModelInputSpec(x_input_shape) model = MulTensorModule() - comparator = NumericalStatsOutputComparator() graph_verifier = DetailedGraphVerifier( mocker, expected_delegated_ops={MulTensor: 1}, expected_non_delegated_ops={} ) @@ -265,7 +261,6 @@ def test__basic_nsys_inference_qat(self, x_input_shape, mocker): model, [x_input_spec, x_input_spec], graph_verifier, - output_comparator=comparator, use_new_flow_neutron_c=True, use_qat=True, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py index 9ce3e93f39b..2734e89bc5d 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py @@ -1,7 +1,8 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import numpy as np import pytest import torch @@ -9,18 +10,29 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator +from executorch.backends.nxp.tests.executorch_pipeline import ( + ModelInputSpec, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier from executorch.backends.nxp.tests.models import ( SubTensorConvModule, SubTensorModule, SubTensorOneInputModule, ) -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + Convolution, + ExecutorchDelegateCall, + SubTensor, +) from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -63,7 +75,7 @@ def test_sub_tensor_quant_conversion(mocker, input_shape, use_qat): input_data = {0: input_data_1, 1: input_data_2} nodes = list(exported_program.graph.nodes) - assert nodes[4].target == exir_ops.edge.aten.sub.Tensor + assert nodes[4].target == SubTensor convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data @@ -96,7 +108,7 @@ def test_sub_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) nodes = list(exported_program.graph.nodes) - assert nodes[2].target == exir_ops.edge.aten.sub.Tensor + assert nodes[2].target == SubTensor convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data @@ -141,7 +153,7 @@ def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape, use_qat): input_data = {0: input_data_1, 1: input_data_2} nodes = list(exported_program.graph.nodes) - assert nodes[15].target == exir_ops.edge.aten.sub.Tensor + assert nodes[15].target == SubTensor convert_run_compare( exported_program, @@ -176,6 +188,236 @@ def test_sub_tensor_broadcasting_unsupported_quant_conversion( nodes = list(edge_program.graph.nodes) # Broadcast is not supported, node is not converted - assert ( - nodes[6].target == exir_ops.edge.aten.sub.Tensor - ) # Sub Tensor is not delegated. + assert nodes[6].target == SubTensor # Sub Tensor is not delegated. + + +class TestSubTensorNewNeutronFlow: + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param( + (6, 82), + id="2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 68, 7), + id="3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (2, 4, 3, 15), + id="4D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 4, 9, 11, 4), + id="5D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = SubTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={SubTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((2, 4, 3, 15), id="4D."), + pytest.param( + (1, 4, 7), + id="3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 4, 9, 11, 4), + id="5D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference_qat(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = SubTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={SubTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + use_qat=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 6)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D + 2D." + ), + pytest.param( + [ModelInputSpec((5, 3, 4)), ModelInputSpec((1, 3, 1))], + id="2 inputs 3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + [ModelInputSpec((69, 73)), ModelInputSpec((1, 73))], + id="2 inputs 2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__broadcast(self, input_spec, mocker): + model = SubTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={SubTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 1)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((1, 3, 4)), ModelInputSpec((5, 3, 1))], + id="2 inputs 3D.", + ), + pytest.param( + [ModelInputSpec((6, 4)), ModelInputSpec((6, 6, 1))], + id="2 inputs 2D+3D.", + ), + ], + ) + def test__broadcast_unsupported(self, input_spec): + # Broadcast where at least one of the inputs is not equal to output is not supported + model = SubTensorModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `sub.Tensor` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [SubTensor]) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param( + (1, 4, 5, 5), id="4D, product of dims is not a multiple of 8." + ), + ], + ) + def test__w_conv(self, x_input_shape, mocker): + model = SubTensorConvModule() + + n, c, h, w = x_input_shape + y_input_spec = ModelInputSpec((n, 8, h, w)) + x_input_spec = ModelInputSpec(x_input_shape) + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={SubTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, y_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 7, 1)), ModelInputSpec((1, 8, 1, 1))], + id="2 inputs 4D + 4D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 8, 5, 1))], + id="2 inputs 4D + 4D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__w_conv_broadcast(self, input_spec, mocker): + model = SubTensorConvModule() + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={SubTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 5))], + id="2 inputs 4D + 2D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 4, 10)), ModelInputSpec((1, 4, 1))], + id="2 inputs 4D + 3D.", + ), + ], + ) + def test__w_conv_unsupported(self, input_spec): + model = SubTensorConvModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `sub.Tensor` was NOT delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [SubTensor]) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 045dcfaba40..1292c4cf17d 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -656,9 +656,9 @@ def __init__(self): super().__init__() self.conv = Conv2dModule(padding=1, stride=1) - def forward(self, x): + def forward(self, x, y): x = self.conv(x) - return x + x + return x + y class AddTensorOneInputModule(torch.nn.Module): diff --git a/backends/nxp/tests/ops_aliases.py b/backends/nxp/tests/ops_aliases.py index ec58072658d..7f855dd63af 100644 --- a/backends/nxp/tests/ops_aliases.py +++ b/backends/nxp/tests/ops_aliases.py @@ -13,6 +13,7 @@ Abs = exir_ops.edge.aten.abs.default AdaptiveAvgPool2D = exir_ops.edge.aten._adaptive_avg_pool2d.default +AddTensor = exir_ops.edge.aten.add.Tensor AvgPool2D = exir_ops.edge.aten.avg_pool2d.default Bmm = exir_ops.edge.aten.bmm.default ConstantPadND = exir_ops.edge.aten.constant_pad_nd.default @@ -36,6 +37,7 @@ Squeeze = exir_ops.edge.aten.squeeze.default SqueezeDim = exir_ops.edge.aten.squeeze.dim SqueezeDims = exir_ops.edge.aten.squeeze.dims +SubTensor = exir_ops.edge.aten.sub.Tensor Unsqueeze = exir_ops.edge.aten.unsqueeze.default UpsampleBilinear2D = exir_ops.edge.aten.upsample_bilinear2d.vec UpsampleNearest2D = exir_ops.edge.aten.upsample_nearest2d.vec From 41c2ff133266f8f5835186403bf3d689f70ad26d Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 26 May 2026 13:29:07 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- backends/cortex_m/passes/BUCK | 1 + backends/cortex_m/passes/convert_to_cortex_m_pass.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backends/cortex_m/passes/BUCK b/backends/cortex_m/passes/BUCK index 4e49c8cd319..f1b7b9a201d 100644 --- a/backends/cortex_m/passes/BUCK +++ b/backends/cortex_m/passes/BUCK @@ -36,6 +36,7 @@ fbcode_target(_kind = runtime.python_library, "decompose_hardswish_pass.py", "decompose_mean_pass.py", "quantized_clamp_activation_pass.py", + "scratch_buffer_sizes.py", ], deps=[ "//caffe2:torch", diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index e61ddaf63bc..5704645caf8 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -12,7 +12,7 @@ import torch.fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor -from executorch.backends.cortex_m.passes import CortexMPass +from executorch.backends.cortex_m.passes.cortex_m_pass import CortexMPass from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot from executorch.backends.cortex_m.passes.scratch_buffer_sizes import ( required_cmsis_nn_buffer_sizes,