tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -28,6 +28,42 @@ from tico.utils.errors import InvalidArgumentError
|
|
|
28
28
|
from tico.utils.validate_args_kwargs import ConstantPadNdArgs
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
def convert_to_circle_padding(pad, input_shape_len):
|
|
32
|
+
MAX_RANK = 4
|
|
33
|
+
|
|
34
|
+
if not (1 <= input_shape_len <= MAX_RANK):
|
|
35
|
+
raise InvalidArgumentError(
|
|
36
|
+
f"Input rank must be between 1 and {MAX_RANK}, got {input_shape_len}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if len(pad) % 2 != 0 or len(pad) < 2 or len(pad) > 8:
|
|
40
|
+
raise InvalidArgumentError(
|
|
41
|
+
f"Pad length must be an even number between 2 and 8, got {len(pad)}"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if len(pad) == 2:
|
|
45
|
+
padding = [[pad[0], pad[1]]]
|
|
46
|
+
elif len(pad) == 4:
|
|
47
|
+
padding = [[pad[2], pad[3]], [pad[0], pad[1]]]
|
|
48
|
+
elif len(pad) == 6:
|
|
49
|
+
padding = [[pad[4], pad[5]], [pad[2], pad[3]], [pad[0], pad[1]]]
|
|
50
|
+
elif len(pad) == 8:
|
|
51
|
+
padding = [
|
|
52
|
+
[pad[6], pad[7]],
|
|
53
|
+
[pad[4], pad[5]],
|
|
54
|
+
[pad[2], pad[3]],
|
|
55
|
+
[pad[0], pad[1]],
|
|
56
|
+
]
|
|
57
|
+
else:
|
|
58
|
+
assert False, "Cannot reach here"
|
|
59
|
+
|
|
60
|
+
# Fill [0, 0] padding for the rest of dimension
|
|
61
|
+
while len(padding) < input_shape_len:
|
|
62
|
+
padding.insert(0, [0, 0])
|
|
63
|
+
|
|
64
|
+
return padding
|
|
65
|
+
|
|
66
|
+
|
|
31
67
|
@register_node_visitor
|
|
32
68
|
class ConstantPadNdVisitor(NodeVisitor):
|
|
33
69
|
target: List[torch._ops.OpOverload] = [torch.ops.aten.constant_pad_nd.default]
|
|
@@ -45,19 +81,13 @@ class ConstantPadNdVisitor(NodeVisitor):
|
|
|
45
81
|
val = args.value
|
|
46
82
|
|
|
47
83
|
if val != 0:
|
|
48
|
-
raise InvalidArgumentError("Only support 0 value padding.")
|
|
84
|
+
raise InvalidArgumentError(f"Only support 0 value padding. pad:{pad}")
|
|
49
85
|
|
|
50
86
|
input_shape_len = len(extract_shape(input_))
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
padding_size = [[0, 0], [0, 0]] + padding_size
|
|
56
|
-
else:
|
|
57
|
-
raise InvalidArgumentError("Only support 3D/4D inputs.")
|
|
58
|
-
|
|
59
|
-
paddings = torch.tensor(padding_size, dtype=torch.int32)
|
|
60
|
-
inputs = [input_, paddings]
|
|
87
|
+
|
|
88
|
+
padding = convert_to_circle_padding(pad, input_shape_len)
|
|
89
|
+
|
|
90
|
+
inputs = [input_, torch.tensor(padding, dtype=torch.int32)]
|
|
61
91
|
outputs = [node]
|
|
62
92
|
|
|
63
93
|
op_index = get_op_index(
|
|
@@ -20,7 +20,11 @@ if TYPE_CHECKING:
|
|
|
20
20
|
import torch
|
|
21
21
|
from circle_schema import circle
|
|
22
22
|
|
|
23
|
-
from tico.serialize.circle_mapping import
|
|
23
|
+
from tico.serialize.circle_mapping import (
|
|
24
|
+
extract_circle_dtype,
|
|
25
|
+
extract_shape,
|
|
26
|
+
to_circle_shape,
|
|
27
|
+
)
|
|
24
28
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
29
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
30
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
@@ -111,9 +115,9 @@ class Conv2dVisitor(NodeVisitor):
|
|
|
111
115
|
|
|
112
116
|
assert groups == 1, "Only support group 1 conv2d"
|
|
113
117
|
|
|
114
|
-
input_shape =
|
|
115
|
-
output_shape =
|
|
116
|
-
weight_shape =
|
|
118
|
+
input_shape = extract_shape(input_)
|
|
119
|
+
output_shape = extract_shape(node)
|
|
120
|
+
weight_shape = extract_shape(weight)
|
|
117
121
|
assert len(input_shape) == 4, len(input_shape)
|
|
118
122
|
assert len(output_shape) == 4, len(output_shape)
|
|
119
123
|
assert len(weight_shape) == 4, len(weight_shape)
|
|
@@ -132,17 +136,21 @@ class Conv2dVisitor(NodeVisitor):
|
|
|
132
136
|
],
|
|
133
137
|
dtype=torch.int32,
|
|
134
138
|
)
|
|
135
|
-
pad_output_shape = [
|
|
139
|
+
pad_output_shape: List[int | torch.SymInt] = [
|
|
136
140
|
input_shape[0],
|
|
137
141
|
input_shape[1] + pad_h * 2,
|
|
138
142
|
input_shape[2] + pad_w * 2,
|
|
139
143
|
input_shape[3],
|
|
140
144
|
]
|
|
145
|
+
pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
|
|
146
|
+
pad_output_shape
|
|
147
|
+
)
|
|
141
148
|
# create padded output tensor
|
|
142
149
|
input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
|
|
143
150
|
pad_output = self.graph.add_tensor_from_scratch(
|
|
144
151
|
prefix=f"{node.name}_input_pad_output",
|
|
145
|
-
shape=
|
|
152
|
+
shape=pad_output_cshape,
|
|
153
|
+
shape_signature=pad_output_cshape_signature,
|
|
146
154
|
dtype=extract_circle_dtype(input_),
|
|
147
155
|
qparam=input_qparam,
|
|
148
156
|
source_node=node,
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Dict, List, TYPE_CHECKING, Union
|
|
15
|
+
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
import torch._ops
|
|
@@ -52,7 +52,15 @@ class CopyVisitor(NodeVisitor):
|
|
|
52
52
|
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
|
53
53
|
super().__init__(op_codes, graph)
|
|
54
54
|
|
|
55
|
-
def check_to_do_broadcast(
|
|
55
|
+
def check_to_do_broadcast(
|
|
56
|
+
self,
|
|
57
|
+
dst: List[int],
|
|
58
|
+
dst_sig: Optional[List[int]],
|
|
59
|
+
src: List[int],
|
|
60
|
+
src_sig: Optional[List[int]],
|
|
61
|
+
) -> bool:
|
|
62
|
+
assert dst_sig is None
|
|
63
|
+
assert src_sig is None
|
|
56
64
|
return dst != src
|
|
57
65
|
|
|
58
66
|
def define_broadcast_to_node(
|
|
@@ -102,6 +110,12 @@ class CopyVisitor(NodeVisitor):
|
|
|
102
110
|
# To connect 'dst' to Reshape node in the graph, 'dst' must be converted to Shape op.
|
|
103
111
|
dst_tensor: circle.Tensor.TensorT = self.graph.get_tensor(dst)
|
|
104
112
|
dst_shape: List[int] = dst_tensor.shape
|
|
113
|
+
dst_shape_signature: Optional[List[int]] = dst_tensor.shapeSignature
|
|
114
|
+
|
|
115
|
+
if dst_shape_signature is not None:
|
|
116
|
+
# TODO: support dynamic shape
|
|
117
|
+
raise NotYetSupportedError("Dynamic shape is not supported yet.")
|
|
118
|
+
|
|
105
119
|
dst_shape_tensor = torch.as_tensor(dst_shape, dtype=torch.int32)
|
|
106
120
|
|
|
107
121
|
dst_shape_shape = [len(dst_shape)]
|
|
@@ -110,6 +124,7 @@ class CopyVisitor(NodeVisitor):
|
|
|
110
124
|
shape_output = self.graph.add_tensor_from_scratch(
|
|
111
125
|
prefix=f"{dst_name}_shape_output",
|
|
112
126
|
shape=dst_shape_shape,
|
|
127
|
+
shape_signature=None,
|
|
113
128
|
dtype=circle.TensorType.TensorType.INT32,
|
|
114
129
|
source_node=node,
|
|
115
130
|
)
|
|
@@ -119,9 +134,16 @@ class CopyVisitor(NodeVisitor):
|
|
|
119
134
|
|
|
120
135
|
src_tensor: circle.Tensor.TensorT = self.graph.get_tensor(src)
|
|
121
136
|
src_shape: List[int] = src_tensor.shape
|
|
137
|
+
src_shape_signature: Optional[List[int]] = src_tensor.shapeSignature
|
|
138
|
+
|
|
139
|
+
if src_shape_signature is not None:
|
|
140
|
+
# TODO: support dynamic shape
|
|
141
|
+
raise NotYetSupportedError("Dynamic shape is not supported yet.")
|
|
122
142
|
|
|
123
143
|
# The src tensor must be broadcastable with the dst tensor.
|
|
124
|
-
do_broadcast = self.check_to_do_broadcast(
|
|
144
|
+
do_broadcast = self.check_to_do_broadcast(
|
|
145
|
+
dst_shape, dst_shape_signature, src_shape, src_shape_signature
|
|
146
|
+
)
|
|
125
147
|
if do_broadcast:
|
|
126
148
|
# create braodcastTo output tensor
|
|
127
149
|
src_name: str = src.name
|
|
@@ -131,6 +153,7 @@ class CopyVisitor(NodeVisitor):
|
|
|
131
153
|
self.graph.add_tensor_from_scratch(
|
|
132
154
|
prefix=f"{src_name}_broadcast_to_output",
|
|
133
155
|
shape=dst_shape,
|
|
156
|
+
shape_signature=dst_shape_signature,
|
|
134
157
|
dtype=src_type,
|
|
135
158
|
source_node=node,
|
|
136
159
|
)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Dict, List, TYPE_CHECKING
|
|
15
|
+
from typing import Dict, List, Optional, TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
import torch._ops
|
|
@@ -57,6 +57,7 @@ class CumsumVisitor(NodeVisitor):
|
|
|
57
57
|
if input_dtype == torch.int32:
|
|
58
58
|
input_tensor: circle.Tensor.TensorT = self.graph.get_tensor(input)
|
|
59
59
|
input_shape: List[int] = input_tensor.shape
|
|
60
|
+
input_shape_signature: Optional[List[int]] = input_tensor.shapeSignature
|
|
60
61
|
cast_op_index = get_op_index(
|
|
61
62
|
circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes
|
|
62
63
|
)
|
|
@@ -66,6 +67,7 @@ class CumsumVisitor(NodeVisitor):
|
|
|
66
67
|
prefix=cast_name,
|
|
67
68
|
dtype=cast_dtype,
|
|
68
69
|
shape=input_shape,
|
|
70
|
+
shape_signature=input_shape_signature,
|
|
69
71
|
source_node=node,
|
|
70
72
|
)
|
|
71
73
|
cast_operator = create_builtin_operator(
|
|
@@ -20,7 +20,11 @@ if TYPE_CHECKING:
|
|
|
20
20
|
import torch
|
|
21
21
|
from circle_schema import circle
|
|
22
22
|
|
|
23
|
-
from tico.serialize.circle_mapping import
|
|
23
|
+
from tico.serialize.circle_mapping import (
|
|
24
|
+
extract_circle_dtype,
|
|
25
|
+
extract_shape,
|
|
26
|
+
to_circle_shape,
|
|
27
|
+
)
|
|
24
28
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
29
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
30
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
@@ -115,12 +119,13 @@ class DepthwiseConv2dVisitor(NodeVisitor):
|
|
|
115
119
|
dilation = args.dilation
|
|
116
120
|
groups = args.groups
|
|
117
121
|
|
|
118
|
-
input_shape =
|
|
119
|
-
output_shape =
|
|
120
|
-
weight_shape =
|
|
122
|
+
input_shape = extract_shape(input_) # OHWI
|
|
123
|
+
output_shape = extract_shape(node) # OHWI
|
|
124
|
+
weight_shape = extract_shape(weight) # 1HWO
|
|
121
125
|
assert len(input_shape) == 4, len(input_shape)
|
|
122
126
|
assert len(output_shape) == 4, len(output_shape)
|
|
123
|
-
assert len(weight_shape) == 4
|
|
127
|
+
assert len(weight_shape) == 4, len(weight_shape)
|
|
128
|
+
|
|
124
129
|
assert weight_shape[0] == 1
|
|
125
130
|
assert weight_shape[3] == output_shape[3]
|
|
126
131
|
assert input_shape[3] == groups
|
|
@@ -145,17 +150,22 @@ class DepthwiseConv2dVisitor(NodeVisitor):
|
|
|
145
150
|
],
|
|
146
151
|
dtype=torch.int32,
|
|
147
152
|
)
|
|
148
|
-
pad_output_shape = [
|
|
153
|
+
pad_output_shape: List[int | torch.SymInt] = [
|
|
149
154
|
input_shape[0],
|
|
150
155
|
input_shape[1] + pad_h * 2,
|
|
151
156
|
input_shape[2] + pad_w * 2,
|
|
152
157
|
input_shape[3],
|
|
153
158
|
]
|
|
159
|
+
|
|
160
|
+
pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
|
|
161
|
+
pad_output_shape
|
|
162
|
+
)
|
|
154
163
|
# create padded output tensor
|
|
155
164
|
input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
|
|
156
165
|
pad_output = self.graph.add_tensor_from_scratch(
|
|
157
166
|
prefix=f"{node.name}_input_pad_output",
|
|
158
|
-
shape=
|
|
167
|
+
shape=pad_output_cshape,
|
|
168
|
+
shape_signature=pad_output_cshape_signature,
|
|
159
169
|
dtype=extract_circle_dtype(input_),
|
|
160
170
|
qparam=input_qparam,
|
|
161
171
|
source_node=node,
|
|
@@ -21,10 +21,8 @@ import torch
|
|
|
21
21
|
from circle_schema import circle
|
|
22
22
|
|
|
23
23
|
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
|
-
from tico.serialize.circle_mapping import to_circle_dtype
|
|
25
24
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
26
25
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
27
|
-
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
28
26
|
from tico.utils.validate_args_kwargs import FullLikeArgs
|
|
29
27
|
|
|
30
28
|
|
|
@@ -49,7 +49,14 @@ class IndexSelectVisitor(NodeVisitor):
|
|
|
49
49
|
self._op_codes,
|
|
50
50
|
)
|
|
51
51
|
|
|
52
|
+
# TODO: Revise this to be simple
|
|
52
53
|
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
|
54
|
+
assert (
|
|
55
|
+
dim_i32.dim() == 0 or len(dim_i32) == 1
|
|
56
|
+
), f"dim should be scalar: {dim_i32}"
|
|
57
|
+
dim_i32_item = dim_i32.item()
|
|
58
|
+
assert isinstance(dim_i32_item, int)
|
|
59
|
+
|
|
53
60
|
inputs = [input, index]
|
|
54
61
|
outputs = [node]
|
|
55
62
|
|
|
@@ -57,7 +64,7 @@ class IndexSelectVisitor(NodeVisitor):
|
|
|
57
64
|
|
|
58
65
|
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.GatherOptions
|
|
59
66
|
option = circle.GatherOptions.GatherOptionsT()
|
|
60
|
-
option.axis =
|
|
67
|
+
option.axis = dim_i32_item
|
|
61
68
|
|
|
62
69
|
operator.builtinOptions = option
|
|
63
70
|
|
|
@@ -73,12 +73,6 @@ class InstanceNormVisitor(NodeVisitor):
|
|
|
73
73
|
eps = args.eps
|
|
74
74
|
|
|
75
75
|
# Ignore training-related args
|
|
76
|
-
running_mean = args.running_mean
|
|
77
|
-
running_var = args.running_var
|
|
78
|
-
use_input_stats = args.use_input_stats
|
|
79
|
-
momentum = args.momentum
|
|
80
|
-
cudnn_enabled = args.cudnn_enabled
|
|
81
|
-
|
|
82
76
|
input_shape = list(extract_shape(input))
|
|
83
77
|
assert len(input_shape) == 4, len(input_shape)
|
|
84
78
|
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch._ops
|
|
19
|
+
import torch.fx
|
|
20
|
+
import torch
|
|
21
|
+
from circle_schema import circle
|
|
22
|
+
|
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
27
|
+
from tico.utils.validate_args_kwargs import LeArgs
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_node_visitor
|
|
31
|
+
class LeVisitor(NodeVisitor):
|
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
|
33
|
+
torch.ops.aten.le.Scalar,
|
|
34
|
+
torch.ops.aten.le.Tensor,
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
|
38
|
+
super().__init__(op_codes, graph)
|
|
39
|
+
|
|
40
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
|
41
|
+
args = LeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
42
|
+
input = args.input
|
|
43
|
+
other = args.other
|
|
44
|
+
|
|
45
|
+
op_index = get_op_index(
|
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.LESS_EQUAL, self._op_codes
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
inputs = [input, other]
|
|
50
|
+
outputs = [node]
|
|
51
|
+
|
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
53
|
+
|
|
54
|
+
return operator
|
|
@@ -23,7 +23,7 @@ from circle_schema import circle
|
|
|
23
23
|
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
24
|
from tico.serialize.circle_mapping import (
|
|
25
25
|
extract_circle_dtype,
|
|
26
|
-
|
|
26
|
+
extract_circle_shape,
|
|
27
27
|
extract_torch_dtype,
|
|
28
28
|
)
|
|
29
29
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
@@ -62,11 +62,12 @@ class Log1pVisitor(NodeVisitor):
|
|
|
62
62
|
args = Log1pArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
63
63
|
input = args.input
|
|
64
64
|
|
|
65
|
-
input_shape =
|
|
65
|
+
input_shape, input_shape_signature = extract_circle_shape(input)
|
|
66
66
|
dst_dtype_circle = extract_circle_dtype(input)
|
|
67
67
|
add_tensor: circle.Tensor.TensorT = self.graph.add_tensor_from_scratch(
|
|
68
68
|
prefix=f"{input.name}_add",
|
|
69
69
|
shape=input_shape,
|
|
70
|
+
shape_signature=input_shape_signature,
|
|
70
71
|
dtype=dst_dtype_circle,
|
|
71
72
|
source_node=node,
|
|
72
73
|
)
|
|
@@ -22,7 +22,11 @@ import torch
|
|
|
22
22
|
from circle_schema import circle
|
|
23
23
|
|
|
24
24
|
from tico.serialize.circle_graph import CircleSubgraph
|
|
25
|
-
from tico.serialize.circle_mapping import
|
|
25
|
+
from tico.serialize.circle_mapping import (
|
|
26
|
+
extract_circle_dtype,
|
|
27
|
+
extract_shape,
|
|
28
|
+
to_circle_shape,
|
|
29
|
+
)
|
|
26
30
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
27
31
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
28
32
|
from tico.serialize.operators.utils import (
|
|
@@ -88,8 +92,15 @@ class MaxPool2DWithIndicesVisitor(NodeVisitor):
|
|
|
88
92
|
],
|
|
89
93
|
dtype=torch.int32,
|
|
90
94
|
)
|
|
91
|
-
|
|
95
|
+
|
|
96
|
+
input_shape = extract_shape(input)
|
|
92
97
|
input_dtype: int = extract_circle_dtype(input)
|
|
98
|
+
|
|
99
|
+
input_qparam: Optional[QuantParam] = (
|
|
100
|
+
input.meta[QPARAM_KEY] if QPARAM_KEY in input.meta else None
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# create padded input tensor
|
|
93
104
|
padded_input_shape = [
|
|
94
105
|
input_shape[0],
|
|
95
106
|
input_shape[1],
|
|
@@ -98,17 +109,16 @@ class MaxPool2DWithIndicesVisitor(NodeVisitor):
|
|
|
98
109
|
]
|
|
99
110
|
padded_input_shape[1] += padding[0] * 2
|
|
100
111
|
padded_input_shape[2] += padding[1] * 2
|
|
101
|
-
|
|
102
|
-
input.meta[QPARAM_KEY] if QPARAM_KEY in input.meta else None
|
|
103
|
-
)
|
|
104
|
-
# create padded input tensor
|
|
112
|
+
padded_cshape, padded_cshape_signature = to_circle_shape(padded_input_shape)
|
|
105
113
|
padded_input_tensor = self.graph.add_tensor_from_scratch(
|
|
106
114
|
prefix=f"{input.name}_pad_output",
|
|
107
|
-
shape=
|
|
115
|
+
shape=padded_cshape,
|
|
116
|
+
shape_signature=padded_cshape_signature,
|
|
108
117
|
dtype=input_dtype,
|
|
109
118
|
qparam=input_qparam,
|
|
110
119
|
source_node=node,
|
|
111
120
|
)
|
|
121
|
+
|
|
112
122
|
if input_qparam is not None:
|
|
113
123
|
padding_value = get_integer_dtype_min(input_qparam.dtype)
|
|
114
124
|
else:
|
|
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
|
|
20
20
|
import torch
|
|
21
21
|
from circle_schema import circle
|
|
22
22
|
|
|
23
|
-
from tico.serialize.circle_graph import CircleSubgraph
|
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
24
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
25
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
26
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
@@ -28,9 +28,9 @@ from tico.utils.validate_args_kwargs import MatmulArgs
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@register_node_visitor
|
|
31
|
-
class
|
|
31
|
+
class MatmulVisitor(NodeVisitor):
|
|
32
32
|
"""
|
|
33
|
-
Convert matmul to
|
|
33
|
+
Convert matmul to Circle BatchMatMul
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default]
|
|
@@ -38,130 +38,7 @@ class MatmulDefaultVisitor(NodeVisitor):
|
|
|
38
38
|
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
|
39
39
|
super().__init__(op_codes, graph)
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
|
43
|
-
def set_bmm_option(operator):
|
|
44
|
-
operator.builtinOptionsType = (
|
|
45
|
-
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
|
|
46
|
-
)
|
|
47
|
-
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
|
|
48
|
-
option.adjointLhs, option.adjointRhs = False, False
|
|
49
|
-
option.asymmetricQuantizeInputs = False
|
|
50
|
-
operator.builtinOptions = option
|
|
51
|
-
|
|
52
|
-
op_index = get_op_index(
|
|
53
|
-
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
|
|
54
|
-
)
|
|
55
|
-
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
56
|
-
set_bmm_option(operator)
|
|
57
|
-
|
|
58
|
-
return operator
|
|
59
|
-
|
|
60
|
-
def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
|
61
|
-
def set_transpose_option(operator):
|
|
62
|
-
operator.builtinOptionsType = (
|
|
63
|
-
circle.BuiltinOptions.BuiltinOptions.TransposeOptions
|
|
64
|
-
)
|
|
65
|
-
option = circle.TransposeOptions.TransposeOptionsT()
|
|
66
|
-
operator.builtinOptions = option
|
|
67
|
-
|
|
68
|
-
transpose_op_index = get_op_index(
|
|
69
|
-
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
|
|
70
|
-
)
|
|
71
|
-
operator = create_builtin_operator(
|
|
72
|
-
self.graph, transpose_op_index, inputs, outputs
|
|
73
|
-
)
|
|
74
|
-
set_transpose_option(operator)
|
|
75
|
-
return operator
|
|
76
|
-
|
|
77
|
-
def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
|
78
|
-
def set_fc_option(operator):
|
|
79
|
-
operator.builtinOptionsType = (
|
|
80
|
-
circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
|
|
81
|
-
)
|
|
82
|
-
option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
|
|
83
|
-
|
|
84
|
-
option.fusedActivationFunction = (
|
|
85
|
-
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
|
86
|
-
)
|
|
87
|
-
option.weightsFormat = (
|
|
88
|
-
circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
|
|
89
|
-
)
|
|
90
|
-
option.keepNumDims = False
|
|
91
|
-
option.asymmetricQuantizeInputs = False
|
|
92
|
-
option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
|
|
93
|
-
|
|
94
|
-
operator.builtinOptions = option
|
|
95
|
-
|
|
96
|
-
fc_op_index = get_op_index(
|
|
97
|
-
circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
|
|
98
|
-
)
|
|
99
|
-
operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
|
|
100
|
-
set_fc_option(operator)
|
|
101
|
-
return operator
|
|
102
|
-
|
|
103
|
-
"""
|
|
104
|
-
Define FullyConnnected with Tranpose operator.
|
|
105
|
-
Note that those sets of operators are equivalent.
|
|
106
|
-
(1) Matmul
|
|
107
|
-
matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W')
|
|
108
|
-
|
|
109
|
-
(2) Transpose + FullyConneccted
|
|
110
|
-
transpose( rhs[K, W'] ) -> trs_output[W', K]
|
|
111
|
-
fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W')
|
|
112
|
-
"""
|
|
113
|
-
|
|
114
|
-
def define_fc_with_transpose(
|
|
115
|
-
self, node, inputs, outputs
|
|
116
|
-
) -> circle.Operator.OperatorT:
|
|
117
|
-
lhs, rhs = inputs
|
|
118
|
-
|
|
119
|
-
# get transpose shape
|
|
120
|
-
rhs_tid: int = self.graph.get_tid_registered(rhs)
|
|
121
|
-
rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
|
|
122
|
-
rhs_name: str = rhs.name
|
|
123
|
-
rhs_type: int = rhs_tensor.type
|
|
124
|
-
rhs_shape: List[int] = rhs_tensor.shape
|
|
125
|
-
assert len(rhs_shape) == 2, len(rhs_shape)
|
|
126
|
-
rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]]
|
|
127
|
-
|
|
128
|
-
# create transpose output tensor
|
|
129
|
-
trs_output = self.graph.add_tensor_from_scratch(
|
|
130
|
-
prefix=f"{rhs_name}_transposed_output",
|
|
131
|
-
shape=rhs_shape_transpose,
|
|
132
|
-
dtype=rhs_type,
|
|
133
|
-
source_node=node,
|
|
134
|
-
)
|
|
135
|
-
trs_perm = self.graph.add_const_tensor(data=[1, 0], source_node=node)
|
|
136
|
-
trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output])
|
|
137
|
-
self.graph.add_operator(trs_operator)
|
|
138
|
-
|
|
139
|
-
# define fc node
|
|
140
|
-
fc_input = lhs
|
|
141
|
-
fc_weight = trs_output
|
|
142
|
-
fc_shape = [fc_weight.shape[0]]
|
|
143
|
-
fc_bias = self.graph.add_const_tensor(
|
|
144
|
-
data=[0.0] * fc_shape[0], source_node=node
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs)
|
|
148
|
-
|
|
149
|
-
return operator
|
|
150
|
-
|
|
151
|
-
def define_node(
|
|
152
|
-
self, node: torch.fx.Node, prior_latency=True
|
|
153
|
-
) -> circle.Operator.OperatorT:
|
|
154
|
-
"""
|
|
155
|
-
NOTE: Possibility of accuracy-latency trade-off
|
|
156
|
-
From ONE compiler's perspective:
|
|
157
|
-
- BMM uses per-tensor quantization for both rhs and lhs.
|
|
158
|
-
- FC uses per-channel quantization for weight and per-tensor for input.
|
|
159
|
-
Thus, FC is better in terms of accuracy.
|
|
160
|
-
FC necessarily involves an additional transpose operation to be identical with mm.
|
|
161
|
-
If transposed operand is const, it can be optimized by constant folding.
|
|
162
|
-
Thus, convert FC only if tranpose can be folded.
|
|
163
|
-
TODO set prior_latency outside
|
|
164
|
-
"""
|
|
41
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
|
165
42
|
args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
166
43
|
input = args.input
|
|
167
44
|
other = args.other
|
|
@@ -169,9 +46,16 @@ class MatmulDefaultVisitor(NodeVisitor):
|
|
|
169
46
|
inputs = [input, other]
|
|
170
47
|
outputs = [node]
|
|
171
48
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
49
|
+
op_index = get_op_index(
|
|
50
|
+
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
|
|
51
|
+
)
|
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
53
|
+
operator.builtinOptionsType = (
|
|
54
|
+
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
|
|
55
|
+
)
|
|
56
|
+
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
|
|
57
|
+
option.adjointLhs, option.adjointRhs = False, False
|
|
58
|
+
option.asymmetricQuantizeInputs = False
|
|
59
|
+
operator.builtinOptions = option
|
|
176
60
|
|
|
177
61
|
return operator
|
|
@@ -66,10 +66,7 @@ class MulTensorVisitor(BaseMulVisitor):
|
|
|
66
66
|
self,
|
|
67
67
|
node: torch.fx.Node,
|
|
68
68
|
) -> circle.Operator.OperatorT:
|
|
69
|
-
|
|
70
|
-
input = args.input
|
|
71
|
-
other = args.other
|
|
72
|
-
|
|
69
|
+
_ = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
73
70
|
operator = super().define_node(
|
|
74
71
|
node,
|
|
75
72
|
)
|
|
@@ -88,10 +85,7 @@ class MulScalarVisitor(BaseMulVisitor):
|
|
|
88
85
|
self,
|
|
89
86
|
node: torch.fx.Node,
|
|
90
87
|
) -> circle.Operator.OperatorT:
|
|
91
|
-
|
|
92
|
-
input = args.input
|
|
93
|
-
other = args.other
|
|
94
|
-
|
|
88
|
+
_ = MulScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
95
89
|
operator = super().define_node(
|
|
96
90
|
node,
|
|
97
91
|
)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Dict, List, TYPE_CHECKING
|
|
15
|
+
from typing import Dict, List, Optional, TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
import torch._ops
|
|
@@ -36,6 +36,7 @@ class BasePowVisitor(NodeVisitor):
|
|
|
36
36
|
assert isinstance(node, torch.fx.Node), type(node)
|
|
37
37
|
node_tensor: circle.Tensor.TensorT = self.graph.get_tensor(node)
|
|
38
38
|
node_shape: List[int] = node_tensor.shape
|
|
39
|
+
node_shape_signature: Optional[List[int]] = node_tensor.shapeSignature
|
|
39
40
|
op_index = get_op_index(
|
|
40
41
|
circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes
|
|
41
42
|
)
|
|
@@ -45,6 +46,7 @@ class BasePowVisitor(NodeVisitor):
|
|
|
45
46
|
prefix=cast_name,
|
|
46
47
|
dtype=cast_dtype,
|
|
47
48
|
shape=node_shape,
|
|
49
|
+
shape_signature=node_shape_signature,
|
|
48
50
|
source_node=node,
|
|
49
51
|
)
|
|
50
52
|
cast_operator = create_builtin_operator(
|