tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,95 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import (
|
25
|
+
circle_legalize_dtype_to,
|
26
|
+
extract_torch_dtype,
|
27
|
+
to_circle_dtype,
|
28
|
+
)
|
29
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
30
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
31
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
32
|
+
from tico.utils.validate_args_kwargs import CumsumArgs
|
33
|
+
|
34
|
+
|
35
|
+
@register_node_visitor
|
36
|
+
class CumsumVisitor(NodeVisitor):
|
37
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.cumsum.default]
|
38
|
+
|
39
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
40
|
+
super().__init__(op_codes, graph)
|
41
|
+
|
42
|
+
def define_node(
|
43
|
+
self,
|
44
|
+
node: torch.fx.Node,
|
45
|
+
) -> circle.Operator.OperatorT:
|
46
|
+
args = CumsumArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
47
|
+
input = args.input
|
48
|
+
dim = args.dim
|
49
|
+
|
50
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
51
|
+
|
52
|
+
casted_input: torch.fx.Node | circle.Tensor.TensorT = input
|
53
|
+
# torch.cumsum doesn't follow input dtype when input dtype is int32.
|
54
|
+
# Since circle-interpreter needs a model to have same dtype between input and output,
|
55
|
+
# let's cast the input to torch.int64.
|
56
|
+
input_dtype = extract_torch_dtype(input)
|
57
|
+
if input_dtype == torch.int32:
|
58
|
+
input_tensor: circle.Tensor.TensorT = self.graph.get_tensor(input)
|
59
|
+
input_shape: List[int] = input_tensor.shape
|
60
|
+
cast_op_index = get_op_index(
|
61
|
+
circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes
|
62
|
+
)
|
63
|
+
cast_name = f"{input.name}_cast"
|
64
|
+
cast_dtype = circle.TensorType.TensorType.INT64
|
65
|
+
cast_tensor = self.graph.add_tensor_from_scratch(
|
66
|
+
prefix=cast_name,
|
67
|
+
dtype=cast_dtype,
|
68
|
+
shape=input_shape,
|
69
|
+
source_node=node,
|
70
|
+
)
|
71
|
+
cast_operator = create_builtin_operator(
|
72
|
+
self.graph, cast_op_index, [input], [cast_tensor]
|
73
|
+
)
|
74
|
+
cast_operator.builtinOptionsType = (
|
75
|
+
circle.BuiltinOptions.BuiltinOptions.CastOptions
|
76
|
+
)
|
77
|
+
cast_option = circle.CastOptions.CastOptionsT()
|
78
|
+
cast_option.inDataType = to_circle_dtype(input_dtype)
|
79
|
+
cast_option.outDataType = cast_dtype
|
80
|
+
cast_operator.builtinOptions = cast_option
|
81
|
+
self.graph.add_operator(cast_operator)
|
82
|
+
casted_input = cast_tensor
|
83
|
+
|
84
|
+
inputs = [casted_input, dim_i32]
|
85
|
+
outputs = [node]
|
86
|
+
|
87
|
+
op_index = get_op_index(
|
88
|
+
circle.BuiltinOperator.BuiltinOperator.CUMSUM, self._op_codes
|
89
|
+
)
|
90
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
91
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.CumsumOptions
|
92
|
+
option = circle.CumsumOptions.CumsumOptionsT()
|
93
|
+
operator.builtinOptions = option
|
94
|
+
|
95
|
+
return operator
|
@@ -0,0 +1,199 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.define import define_pad_node
|
28
|
+
from tico.utils.padding import is_same_padding, is_valid_padding, SAME, VALID
|
29
|
+
from tico.utils.validate_args_kwargs import Conv2DArgs
|
30
|
+
|
31
|
+
|
32
|
+
@register_node_visitor
|
33
|
+
class DepthwiseConv2dVisitor(NodeVisitor):
|
34
|
+
"""
|
35
|
+
NOTE
|
36
|
+
- The padding of DepthwiseCircleConv2D has only padding type('VALID', 'SAME') in circle, but the padding of nn.Conv2d has padding type(('valid', 'same')), padding value(int)
|
37
|
+
and padding value(tuple->[pad_h, pad_w]).
|
38
|
+
ref: https://tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d
|
39
|
+
|
40
|
+
[1] With valid/same padding: DepthwiseCircleConv2D (only)
|
41
|
+
|
42
|
+
[ATEN IR]
|
43
|
+
Input[NHWC] ---- circle_cumstom.depthwise_conv2d[NHWC] ---- OUTPUT[NHWC]
|
44
|
+
Weight[NHWC] ---/
|
45
|
+
Bias ----------/
|
46
|
+
|
47
|
+
[CIRCLE IR]
|
48
|
+
Input[NHWC] ---- DepthwiseCircleConv2D[NHWC] ---- OUTPUT[NHWC]
|
49
|
+
Weight[NHWC] ---/
|
50
|
+
Bias ----------/
|
51
|
+
|
52
|
+
[2] With additional padding: CirclePad + DepthwiseCircleConv2D
|
53
|
+
|
54
|
+
[ATEN IR]
|
55
|
+
Input[NHWC] ---- circle_cumstom.depthwise_conv2d[NHWC] ---- OUTPUT[NHWC]
|
56
|
+
Weight[NHWC] ---/
|
57
|
+
Bias ----------/
|
58
|
+
|
59
|
+
[CIRCLE IR]
|
60
|
+
Input[NHWC] ---- CirclePad[NHWC] ---- DepthwiseCircleConv2D[NHWC] ---- OUTPUT[NHWC]
|
61
|
+
Weight[NHWC] ------/
|
62
|
+
Bias -------------/
|
63
|
+
"""
|
64
|
+
|
65
|
+
target: List[torch._ops.OpOverload] = [
|
66
|
+
torch.ops.circle_custom.depthwise_conv2d,
|
67
|
+
torch.ops.circle_custom.depthwise_conv2d.padding,
|
68
|
+
]
|
69
|
+
|
70
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
71
|
+
super().__init__(op_codes, graph)
|
72
|
+
|
73
|
+
def define_dconv_node(
|
74
|
+
self,
|
75
|
+
padding: int,
|
76
|
+
stride: List[int],
|
77
|
+
dilation: List[int],
|
78
|
+
depthMultiplier: int,
|
79
|
+
inputs: List,
|
80
|
+
outputs: List,
|
81
|
+
) -> circle.Operator.OperatorT:
|
82
|
+
def set_conv2d_option(operator, stride, dilation):
|
83
|
+
operator.builtinOptionsType = (
|
84
|
+
circle.BuiltinOptions.BuiltinOptions.DepthwiseConv2DOptions
|
85
|
+
)
|
86
|
+
option = circle.DepthwiseConv2DOptions.DepthwiseConv2DOptionsT()
|
87
|
+
|
88
|
+
option.padding = padding
|
89
|
+
option.strideH = stride[0]
|
90
|
+
option.strideW = stride[1]
|
91
|
+
option.depthMultiplier = depthMultiplier
|
92
|
+
option.dilationHFactor = dilation[0]
|
93
|
+
option.dilationWFactor = dilation[1]
|
94
|
+
option.fusedActivationFunction = (
|
95
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
96
|
+
)
|
97
|
+
operator.builtinOptions = option
|
98
|
+
|
99
|
+
conv2d_op_index = get_op_index(
|
100
|
+
circle.BuiltinOperator.BuiltinOperator.DEPTHWISE_CONV_2D, self._op_codes
|
101
|
+
)
|
102
|
+
operator = create_builtin_operator(self.graph, conv2d_op_index, inputs, outputs)
|
103
|
+
set_conv2d_option(operator, stride, dilation)
|
104
|
+
return operator
|
105
|
+
|
106
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
107
|
+
# Let's get Conv2dArgs because torch Conv2D with group == input_channel maps to CircleDepthwiseConv2D
|
108
|
+
args = Conv2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
109
|
+
input_ = args.input
|
110
|
+
weight = args.weight
|
111
|
+
bias = args.bias
|
112
|
+
stride = args.stride
|
113
|
+
padding = args.padding
|
114
|
+
dilation = args.dilation
|
115
|
+
groups = args.groups
|
116
|
+
|
117
|
+
input_dtype: int = extract_circle_dtype(input_)
|
118
|
+
input_shape = list(extract_shape(input_)) # OHWI
|
119
|
+
assert len(input_shape) == 4, len(input_shape)
|
120
|
+
|
121
|
+
output_shape = list(extract_shape(node)) # OHWI
|
122
|
+
assert len(output_shape) == 4, len(output_shape)
|
123
|
+
|
124
|
+
weight_shape = list(extract_shape(weight)) # 1HWO
|
125
|
+
assert (
|
126
|
+
weight_shape[3] % groups == 0
|
127
|
+
), "Depthwise convolution requires output channel to be divisible by groups"
|
128
|
+
|
129
|
+
assert weight_shape[0] == 1
|
130
|
+
assert weight_shape[3] == output_shape[3]
|
131
|
+
assert input_shape[3] == groups
|
132
|
+
|
133
|
+
depthMultiplier = weight_shape[3] // input_shape[3]
|
134
|
+
assert weight_shape[3] % input_shape[3] == 0, "depthMultiplier must be integer"
|
135
|
+
|
136
|
+
conv_input: torch.fx.node.Node | circle.Tensor.TensorT = input_
|
137
|
+
|
138
|
+
if is_valid_padding(padding):
|
139
|
+
dconv2d_padding_type = VALID
|
140
|
+
elif is_same_padding(padding, input_shape, output_shape):
|
141
|
+
dconv2d_padding_type = SAME
|
142
|
+
else:
|
143
|
+
assert isinstance(padding, list) and len(padding) == 2
|
144
|
+
|
145
|
+
dconv2d_padding_type = VALID
|
146
|
+
|
147
|
+
# Padding is not valid or same, so we use valid padding and add padding operator before conv2d operator.
|
148
|
+
# when data_format is "NHWC", padding should be [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
|
149
|
+
paddings = torch.tensor(
|
150
|
+
[
|
151
|
+
[0, 0],
|
152
|
+
[padding[0], padding[0]],
|
153
|
+
[padding[1], padding[1]],
|
154
|
+
[0, 0],
|
155
|
+
],
|
156
|
+
dtype=torch.int32,
|
157
|
+
)
|
158
|
+
pad_output_shape = [
|
159
|
+
input_shape[0],
|
160
|
+
input_shape[1],
|
161
|
+
input_shape[2],
|
162
|
+
input_shape[3],
|
163
|
+
]
|
164
|
+
# Add (pad_top+pad_bottom) to pad_output_shape_h
|
165
|
+
pad_output_shape[1] += padding[0] * 2
|
166
|
+
# Add (pad_left+pad_Right) to pad_output_shape_w
|
167
|
+
pad_output_shape[2] += padding[1] * 2
|
168
|
+
# create padded output tensor
|
169
|
+
|
170
|
+
pad_output = self.graph.add_tensor_from_scratch(
|
171
|
+
prefix=f"{node.name}_input_pad_output",
|
172
|
+
shape=pad_output_shape,
|
173
|
+
dtype=input_dtype,
|
174
|
+
source_node=node,
|
175
|
+
)
|
176
|
+
# CirclePad
|
177
|
+
pad_operator = define_pad_node(
|
178
|
+
self.graph, self._op_codes, [input_, paddings], [pad_output]
|
179
|
+
)
|
180
|
+
self.graph.add_operator(pad_operator)
|
181
|
+
conv_input = pad_output
|
182
|
+
|
183
|
+
if bias is None:
|
184
|
+
# luci-interpreter can't run no bias conv. Let's add zero vector for bias.
|
185
|
+
assert len(weight_shape) == 4
|
186
|
+
out_channel = weight_shape[3]
|
187
|
+
bias = [0.0] * out_channel # type: ignore[assignment]
|
188
|
+
|
189
|
+
# DConv2D
|
190
|
+
dconv2d_operator = self.define_dconv_node(
|
191
|
+
dconv2d_padding_type,
|
192
|
+
stride,
|
193
|
+
dilation,
|
194
|
+
depthMultiplier,
|
195
|
+
[conv_input, weight, bias],
|
196
|
+
[node],
|
197
|
+
)
|
198
|
+
|
199
|
+
return dconv2d_operator
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import numpy as np
|
21
|
+
import torch
|
22
|
+
from circle_schema import circle
|
23
|
+
|
24
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import DequantizePerChannelArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class DequantizePerChannelDefaultVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [
|
34
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(
|
41
|
+
self,
|
42
|
+
node: torch.fx.Node,
|
43
|
+
) -> circle.Operator.OperatorT:
|
44
|
+
args = DequantizePerChannelArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
input = args.input
|
46
|
+
scales = args.scales
|
47
|
+
zero_points = args.zero_points
|
48
|
+
axis = args.axis
|
49
|
+
quant_min = args.quant_min
|
50
|
+
quant_max = args.quant_max
|
51
|
+
|
52
|
+
output_tensor: circle.Tensor.TensorT = self.graph.get_tensor(node)
|
53
|
+
assert not output_tensor.quantization
|
54
|
+
quant_param = circle.QuantizationParameters.QuantizationParametersT()
|
55
|
+
quant_param.min = [quant_min]
|
56
|
+
quant_param.max = [quant_max]
|
57
|
+
|
58
|
+
# Retrieve scale
|
59
|
+
scale_buf = bytes(self.graph.get_buffer(scales).data)
|
60
|
+
quant_param.scale = np.frombuffer(scale_buf, dtype=np.float32).tolist() # type: ignore[assignment]
|
61
|
+
# Retrieve zp
|
62
|
+
zp_buf = bytes(self.graph.get_buffer(zero_points).data)
|
63
|
+
quant_param.zeroPoint = np.frombuffer(zp_buf, dtype=np.int32).tolist() # type: ignore[assignment]
|
64
|
+
quant_param.quantizedDimension = axis
|
65
|
+
output_tensor.quantization = quant_param
|
66
|
+
|
67
|
+
inputs = [input]
|
68
|
+
outputs = [node]
|
69
|
+
|
70
|
+
op_index = get_op_index(
|
71
|
+
circle.BuiltinOperator.BuiltinOperator.DEQUANTIZE, self._op_codes
|
72
|
+
)
|
73
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
74
|
+
|
75
|
+
# Op-specific option
|
76
|
+
operator.builtinOptionsType = (
|
77
|
+
circle.BuiltinOptions.BuiltinOptions.DequantizeOptions
|
78
|
+
)
|
79
|
+
option = circle.DequantizeOptions.DequantizeOptionsT()
|
80
|
+
operator.builtinOptions = option
|
81
|
+
|
82
|
+
return operator
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import DequantizePerTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class DequantizePerTensorDefaultVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
args = DequantizePerTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
input = args.input
|
45
|
+
input_tensor: circle.Tensor.TensorT = self.graph.get_tensor(input)
|
46
|
+
|
47
|
+
assert input_tensor.quantization
|
48
|
+
|
49
|
+
inputs = [input]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
op_index = get_op_index(
|
53
|
+
circle.BuiltinOperator.BuiltinOperator.DEQUANTIZE, self._op_codes
|
54
|
+
)
|
55
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
56
|
+
|
57
|
+
# Op-specific option
|
58
|
+
operator.builtinOptionsType = (
|
59
|
+
circle.BuiltinOptions.BuiltinOptions.DequantizeOptions
|
60
|
+
)
|
61
|
+
option = circle.DequantizeOptions.DequantizeOptionsT()
|
62
|
+
operator.builtinOptions = option
|
63
|
+
|
64
|
+
return operator
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import DivTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class DivVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.div.Tensor]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
args = DivTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
42
|
+
input = args.input
|
43
|
+
other = args.other
|
44
|
+
|
45
|
+
op_index = get_op_index(
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.DIV, self._op_codes
|
47
|
+
)
|
48
|
+
|
49
|
+
inputs = [input, other]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
# Op-specific option
|
55
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.DivOptions
|
56
|
+
option = circle.DivOptions.DivOptionsT()
|
57
|
+
option.fusedActivationFunction = (
|
58
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
59
|
+
)
|
60
|
+
operator.builtinOptions = option
|
61
|
+
|
62
|
+
return operator
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
24
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
25
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
26
|
+
from tico.utils.validate_args_kwargs import EmbeddingArgs
|
27
|
+
|
28
|
+
|
29
|
+
@register_node_visitor
|
30
|
+
class EmbeddingVisitor(NodeVisitor):
|
31
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.embedding.default]
|
32
|
+
|
33
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
34
|
+
super().__init__(op_codes, graph)
|
35
|
+
|
36
|
+
def define_node(
|
37
|
+
self,
|
38
|
+
node: torch.fx.Node,
|
39
|
+
) -> circle.Operator.OperatorT:
|
40
|
+
op_index = get_op_index(
|
41
|
+
circle.BuiltinOperator.BuiltinOperator.GATHER, self._op_codes
|
42
|
+
)
|
43
|
+
args = EmbeddingArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
# do not need to handle optional parameters, since all options are related to Embedding Table values.
|
45
|
+
embedding_table = args.weight
|
46
|
+
index_tensor = args.indices
|
47
|
+
|
48
|
+
inputs = [embedding_table, index_tensor]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
# Op-specific option
|
54
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.GatherOptions
|
55
|
+
option = circle.GatherOptions.GatherOptionsT()
|
56
|
+
option.axis = 0
|
57
|
+
operator.builtinOptions = option
|
58
|
+
|
59
|
+
# Op-specific option does not exists for Embedding.
|
60
|
+
return operator
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import EqArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class EqVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.eq.Scalar,
|
34
|
+
torch.ops.aten.eq.Tensor,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(
|
41
|
+
self,
|
42
|
+
node: torch.fx.Node,
|
43
|
+
) -> circle.Operator.OperatorT:
|
44
|
+
args = EqArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
input = args.input
|
46
|
+
other = args.other
|
47
|
+
|
48
|
+
op_index = get_op_index(
|
49
|
+
circle.BuiltinOperator.BuiltinOperator.EQUAL,
|
50
|
+
self._op_codes,
|
51
|
+
)
|
52
|
+
|
53
|
+
inputs = [input, other]
|
54
|
+
outputs = [node]
|
55
|
+
|
56
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
57
|
+
|
58
|
+
# Op-specific option
|
59
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.EqualOptions
|
60
|
+
option = circle.EqualOptions.EqualOptionsT()
|
61
|
+
|
62
|
+
operator.builtinOptions = option
|
63
|
+
|
64
|
+
return operator
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import ExpArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class ExpVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.exp.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
args = ExpArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
42
|
+
input = args.input
|
43
|
+
|
44
|
+
op_index = get_op_index(
|
45
|
+
circle.BuiltinOperator.BuiltinOperator.EXP,
|
46
|
+
self._op_codes,
|
47
|
+
)
|
48
|
+
|
49
|
+
inputs = [input]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
# Op-specific option
|
55
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ExpOptions
|
56
|
+
option = circle.ExpOptions.ExpOptionsT()
|
57
|
+
|
58
|
+
operator.builtinOptions = option
|
59
|
+
|
60
|
+
return operator
|