From 3f95804a60d4fdc4d396d26af3e5986ec3ebac60 Mon Sep 17 00:00:00 2001 From: Arik Horodniceanu Date: Wed, 6 May 2026 15:20:28 -0700 Subject: [PATCH] Qualcomm AI Engine Direct - Adding QNN backend support for randn core ATen op --- backends/qualcomm/_passes/layout_transform.py | 2 + backends/qualcomm/builders/README.md | 5 +- backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/op_randn.py | 79 +++++++++++++++++++ backends/qualcomm/builders/qnn_constants.py | 7 ++ .../quantizer/annotators/htp_rules.py | 2 + .../quantizer/annotators/lpai_rules.py | 2 + backends/qualcomm/tests/models.py | 8 ++ backends/qualcomm/tests/test_qnn_delegate.py | 43 +++++----- 9 files changed, 127 insertions(+), 23 deletions(-) create mode 100644 backends/qualcomm/builders/op_randn.py diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 9422051addd..5953360a13b 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -113,6 +113,8 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.neg.default, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.prelu.default, + exir_ops.edge.aten.rand.default, + exir_ops.edge.aten.randn.default, exir_ops.edge.aten.reflection_pad1d.default, exir_ops.edge.aten.reflection_pad2d.default, exir_ops.edge.aten.repeat.default, diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index d71aeb27be6..a0282377eaf 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -368,7 +368,7 @@ Please help update following table if you are contributing new operators: + 🚫 = Deprecated, supported with other QNN Ops -| Operators | HTP - 98/119 Enabled | +| Operators | HTP - 99/120 Enabled | |-----------|---------| | Argmax | ✓ | | Argmin | ✓ | @@ -457,7 +457,8 @@ Please help update following table if you are contributing new operators: | PoolMax2d | ✓ | | Prelu | ✓ | | Quantize | ✓ | -| Rand | ✓ | +| RandomUniformLike | ✓ | +| RandomNormalLike | ✓ | | ReduceMax | ✓ | | ReduceMean | ✓ | | ReduceMin | ✓ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index a897dfa53bd..8549e4f255f 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -81,6 +81,7 @@ op_prelu, op_quantize, op_rand, + op_randn, op_relu, op_repeat, op_reshape, @@ -194,6 +195,7 @@ op_prelu, op_quantize, op_rand, + op_randn, op_relu, op_repeat, op_reshape, diff --git a/backends/qualcomm/builders/op_randn.py b/backends/qualcomm/builders/op_randn.py new file mode 100644 index 00000000000..6160fc79609 --- /dev/null +++ b/backends/qualcomm/builders/op_randn.py @@ -0,0 +1,79 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpRandomNormalLike, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Randn(NodeVisitor): + target = ["aten.randn.default", "aten.randn_like.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + output_tensor = node.meta["val"] + output_shape = list(output_tensor.shape) + + shape_data = np.array(output_shape, dtype=np.uint32) + shape_dims = [len(output_shape)] + + shape_tensor_wrapper = PyQnnManager.TensorWrapper( + f"{node.name}_shape", + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + {}, + len(shape_dims), + shape_dims, + [], + shape_data, + True, + ) + + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + randn_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpRandomNormalLike.op_name, + ) + + randn_op.AddInputTensors([shape_tensor_wrapper]) + randn_op.AddOutputTensors([output_tensor_wrapper]) + + randn_op.AddScalarParam( + OpRandomNormalLike.param_mean, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: np.float32(0.0)}, + ) + + randn_op.AddScalarParam( + OpRandomNormalLike.param_scale, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: np.float32(1.0)}, + ) + + return randn_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index d7ec30fddc0..168601530c6 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -504,6 +504,13 @@ class OpQuantize: op_name: str = "Quantize" +@dataclass(init=False, frozen=True) +class OpRandomNormalLike: + op_name: str = "RandomNormalLike" + param_mean: str = "mean" + param_scale: str = "scale" + + @dataclass(init=False, frozen=True) class OpRandomUniformLike: op_name: str = "RandomUniformLike" diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index 342db1cb633..1dd0ab7f488 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -345,6 +345,8 @@ class ColIm(GeneralOpDef): torch.ops.aten.zeros_like.default, torch.ops.aten.ones.default, torch.ops.aten.ones_like.default, + torch.ops.aten.rand.default, + torch.ops.aten.randn.default, ], qnn_op=None, ) diff --git a/backends/qualcomm/quantizer/annotators/lpai_rules.py b/backends/qualcomm/quantizer/annotators/lpai_rules.py index 30a3cb1dc9d..3a3bffa174d 100644 --- a/backends/qualcomm/quantizer/annotators/lpai_rules.py +++ b/backends/qualcomm/quantizer/annotators/lpai_rules.py @@ -272,6 +272,8 @@ class ColIm(GeneralOpDef): torch.ops.aten.zeros_like.default, torch.ops.aten.ones.default, torch.ops.aten.ones_like.default, + torch.ops.aten.rand.default, + torch.ops.aten.randn.default, ], qnn_op=None, ) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 12d5e0902db..ffb1ccb26ef 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1929,6 +1929,14 @@ def forward(self, x): return torch.rand_like(x) + x +class Randn(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.randn_like(x) + x + + class Reciprocal(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 688dddf5c2a..2931538b7a0 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1809,24 +1809,6 @@ def test_qnn_backend_prelu(self): index += 1 self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_rand(self): - sample_inputs = [ - (torch.randn(3, 4, 5),), - (torch.randn(2, 8),), - ( - torch.randn( - 10, - ), - ), - (torch.randn(1, 3, 32, 32),), - ] - for i, sample_input in enumerate(sample_inputs): - with self.subTest(i=i): - module = Rand() # noqa: F405 - self.lower_module_and_test_output( - module, sample_input, assert_output_equal=False - ) - def test_qnn_backend_reciprocal(self): module = Reciprocal() # noqa: F405 sample_input = (torch.randn([2, 2, 2, 2]),) @@ -4555,6 +4537,7 @@ def test_qnn_backend_prelu(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_rand(self): + module = Rand() # noqa: F405 sample_inputs = [ (torch.randn(3, 4, 5),), (torch.randn(2, 8),), @@ -4567,10 +4550,28 @@ def test_qnn_backend_rand(self): ] for i, sample_input in enumerate(sample_inputs): with self.subTest(i=i): - module = Rand() # noqa: F405 - module = self.get_qdq_module(module, sample_input) + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output( + qdq_module, sample_input, assert_output_equal=False + ) + + def test_qnn_backend_randn(self): + module = Randn() # noqa: F405 + sample_inputs = [ + (torch.randn(3, 4, 5),), + (torch.randn(2, 8),), + ( + torch.randn( + 10, + ), + ), + (torch.randn(1, 3, 32, 32),), + ] + for i, sample_input in enumerate(sample_inputs): + with self.subTest(i=i): + qdq_module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output( - module, sample_input, assert_output_equal=False + qdq_module, sample_input, assert_output_equal=False ) def test_qnn_backend_reciprocal(self):