Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.prelu.default,
exir_ops.edge.aten.rand.default,
exir_ops.edge.aten.randn.default,
exir_ops.edge.aten.reflection_pad1d.default,
exir_ops.edge.aten.reflection_pad2d.default,
exir_ops.edge.aten.repeat.default,
Expand Down
5 changes: 3 additions & 2 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ Please help update following table if you are contributing new operators:
+ 🚫 = Deprecated, supported with other QNN Ops


| Operators | HTP - 98/119 Enabled |
| Operators | HTP - 99/120 Enabled |
|-----------|---------|
| Argmax | ✓ |
| Argmin | ✓ |
Expand Down Expand Up @@ -457,7 +457,8 @@ Please help update following table if you are contributing new operators:
| PoolMax2d | ✓ |
| Prelu | ✓ |
| Quantize | ✓ |
| Rand | ✓ |
| RandomUniformLike | ✓ |
| RandomNormalLike | ✓ |
| ReduceMax | ✓ |
| ReduceMean | ✓ |
| ReduceMin | ✓ |
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
op_prelu,
op_quantize,
op_rand,
op_randn,
op_relu,
op_repeat,
op_reshape,
Expand Down Expand Up @@ -194,6 +195,7 @@
op_prelu,
op_quantize,
op_rand,
op_randn,
op_relu,
op_repeat,
op_reshape,
Expand Down
79 changes: 79 additions & 0 deletions backends/qualcomm/builders/op_randn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpRandomNormalLike, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Randn(NodeVisitor):
target = ["aten.randn.default", "aten.randn_like.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
output_tensor = node.meta["val"]
output_shape = list(output_tensor.shape)

shape_data = np.array(output_shape, dtype=np.uint32)
shape_dims = [len(output_shape)]

shape_tensor_wrapper = PyQnnManager.TensorWrapper(
f"{node.name}_shape",
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
{},
len(shape_dims),
shape_dims,
[],
shape_data,
True,
)

output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

randn_op = PyQnnManager.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpRandomNormalLike.op_name,
)

randn_op.AddInputTensors([shape_tensor_wrapper])
randn_op.AddOutputTensors([output_tensor_wrapper])

randn_op.AddScalarParam(
OpRandomNormalLike.param_mean,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{QCOM_DATA: np.float32(0.0)},
)

randn_op.AddScalarParam(
OpRandomNormalLike.param_scale,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{QCOM_DATA: np.float32(1.0)},
)

return randn_op
7 changes: 7 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,13 @@ class OpQuantize:
op_name: str = "Quantize"


@dataclass(init=False, frozen=True)
class OpRandomNormalLike:
op_name: str = "RandomNormalLike"
param_mean: str = "mean"
param_scale: str = "scale"


@dataclass(init=False, frozen=True)
class OpRandomUniformLike:
op_name: str = "RandomUniformLike"
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/quantizer/annotators/htp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ class ColIm(GeneralOpDef):
torch.ops.aten.zeros_like.default,
torch.ops.aten.ones.default,
torch.ops.aten.ones_like.default,
torch.ops.aten.rand.default,
torch.ops.aten.randn.default,
],
qnn_op=None,
)
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/quantizer/annotators/lpai_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ class ColIm(GeneralOpDef):
torch.ops.aten.zeros_like.default,
torch.ops.aten.ones.default,
torch.ops.aten.ones_like.default,
torch.ops.aten.rand.default,
torch.ops.aten.randn.default,
],
qnn_op=None,
)
Expand Down
8 changes: 8 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,14 @@ def forward(self, x):
return torch.rand_like(x) + x


class Randn(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.randn_like(x) + x


class Reciprocal(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
43 changes: 22 additions & 21 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,24 +1809,6 @@ def test_qnn_backend_prelu(self):
index += 1
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_rand(self):
sample_inputs = [
(torch.randn(3, 4, 5),),
(torch.randn(2, 8),),
(
torch.randn(
10,
),
),
(torch.randn(1, 3, 32, 32),),
]
for i, sample_input in enumerate(sample_inputs):
with self.subTest(i=i):
module = Rand() # noqa: F405
self.lower_module_and_test_output(
module, sample_input, assert_output_equal=False
)

def test_qnn_backend_reciprocal(self):
module = Reciprocal() # noqa: F405
sample_input = (torch.randn([2, 2, 2, 2]),)
Expand Down Expand Up @@ -4555,6 +4537,7 @@ def test_qnn_backend_prelu(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_rand(self):
module = Rand() # noqa: F405
sample_inputs = [
(torch.randn(3, 4, 5),),
(torch.randn(2, 8),),
Expand All @@ -4567,10 +4550,28 @@ def test_qnn_backend_rand(self):
]
for i, sample_input in enumerate(sample_inputs):
with self.subTest(i=i):
module = Rand() # noqa: F405
module = self.get_qdq_module(module, sample_input)
qdq_module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(
qdq_module, sample_input, assert_output_equal=False
)

def test_qnn_backend_randn(self):
module = Randn() # noqa: F405
sample_inputs = [
(torch.randn(3, 4, 5),),
(torch.randn(2, 8),),
(
torch.randn(
10,
),
),
(torch.randn(1, 3, 32, 32),),
]
for i, sample_input in enumerate(sample_inputs):
with self.subTest(i=i):
qdq_module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(
module, sample_input, assert_output_equal=False
qdq_module, sample_input, assert_output_equal=False
)

def test_qnn_backend_reciprocal(self):
Expand Down
Loading