tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import MinimumArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class MinimumVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.minimum.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
40
+ args = MinimumArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
41
+ input = args.input
42
+ other = args.other
43
+
44
+ op_index = get_op_index(
45
+ circle.BuiltinOperator.BuiltinOperator.MINIMUM, self._op_codes
46
+ )
47
+
48
+ inputs = [input, other]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ return operator
@@ -0,0 +1,174 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph, is_const
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import MatmulArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class MatmulDefaultVisitor(NodeVisitor):
32
+ """
33
+ Convert matmul to equavalent BatchMatMul or FullyConnected with Transpose.
34
+ """
35
+
36
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default]
37
+
38
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
39
+ super().__init__(op_codes, graph)
40
+
41
+ # NOTE: Matmul is equivalent to Batch MatMul (batch=1)
42
+ def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT:
43
+ def set_bmm_option(operator):
44
+ operator.builtinOptionsType = (
45
+ circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
46
+ )
47
+ option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
48
+ option.adjointLhs, option.adjointRhs = False, False
49
+ option.asymmetricQuantizeInputs = False
50
+ operator.builtinOptions = option
51
+
52
+ op_index = get_op_index(
53
+ circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
54
+ )
55
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
+ set_bmm_option(operator)
57
+
58
+ return operator
59
+
60
+ def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT:
61
+ def set_transpose_option(operator):
62
+ operator.builtinOptionsType = (
63
+ circle.BuiltinOptions.BuiltinOptions.TransposeOptions
64
+ )
65
+ option = circle.TransposeOptions.TransposeOptionsT()
66
+ operator.builtinOptions = option
67
+
68
+ transpose_op_index = get_op_index(
69
+ circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
70
+ )
71
+ operator = create_builtin_operator(
72
+ self.graph, transpose_op_index, inputs, outputs
73
+ )
74
+ set_transpose_option(operator)
75
+ return operator
76
+
77
+ def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
78
+ def set_fc_option(operator):
79
+ operator.builtinOptionsType = (
80
+ circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
81
+ )
82
+ option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
83
+
84
+ option.fusedActivationFunction = (
85
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
86
+ )
87
+ option.weightsFormat = (
88
+ circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
89
+ )
90
+ option.keepNumDims = False
91
+ option.asymmetricQuantizeInputs = False
92
+ option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
93
+
94
+ operator.builtinOptions = option
95
+
96
+ fc_op_index = get_op_index(
97
+ circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
98
+ )
99
+ operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
100
+ set_fc_option(operator)
101
+ return operator
102
+
103
+ """
104
+ Define FullyConnnected with Tranpose operator.
105
+ Note that those sets of operators are equivalent.
106
+ (1) Matmul
107
+ matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W')
108
+
109
+ (2) Transpose + FullyConneccted
110
+ transpose( rhs[K, W'] ) -> trs_output[W', K]
111
+ fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W')
112
+ """
113
+
114
+ def define_fc_with_transpose(self, inputs, outputs) -> circle.Operator.OperatorT:
115
+ lhs, rhs = inputs
116
+
117
+ # get transpose shape
118
+ rhs_tid: int = self.graph.get_tid_registered(rhs)
119
+ rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
120
+ rhs_name: str = rhs.name
121
+ rhs_type: int = rhs_tensor.type
122
+ rhs_shape: List[int] = rhs_tensor.shape
123
+ assert len(rhs_shape) == 2, len(rhs_shape)
124
+ rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]]
125
+
126
+ # create transpose output tensor
127
+ trs_output = self.graph.add_tensor_from_scratch(
128
+ prefix=f"{rhs_name}_transposed_output",
129
+ shape=rhs_shape_transpose,
130
+ dtype=rhs_type,
131
+ )
132
+ trs_perm = self.graph.add_const_tensor(data=[1, 0])
133
+ trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output])
134
+ self.graph.add_operator(trs_operator)
135
+
136
+ # define fc node
137
+ fc_input = lhs
138
+ fc_weight = trs_output
139
+ fc_shape = [fc_weight.shape[0]]
140
+ fc_bias = self.graph.add_const_tensor(
141
+ data=[0.0] * fc_shape[0],
142
+ )
143
+
144
+ operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs)
145
+
146
+ return operator
147
+
148
+ def define_node(
149
+ self, node: torch.fx.Node, prior_latency=True
150
+ ) -> circle.Operator.OperatorT:
151
+ """
152
+ NOTE: Possibility of accuracy-latency trade-off
153
+ From ONE compiler's perspective:
154
+ - BMM uses per-tensor quantization for both rhs and lhs.
155
+ - FC uses per-channel quantization for weight and per-tensor for input.
156
+ Thus, FC is better in terms of accuracy.
157
+ FC necessarily involves an additional transpose operation to be identical with mm.
158
+ If transposed operand is const, it can be optimized by constant folding.
159
+ Thus, convert FC only if tranpose can be folded.
160
+ TODO set prior_latency outside
161
+ """
162
+ args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
163
+ input = args.input
164
+ other = args.other
165
+
166
+ inputs = [input, other]
167
+ outputs = [node]
168
+
169
+ if not is_const(other) and prior_latency:
170
+ operator = self.define_bmm_node(inputs, outputs)
171
+ else:
172
+ operator = self.define_fc_with_transpose(inputs, outputs)
173
+
174
+ return operator
@@ -0,0 +1,99 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import MulScalarArgs, MulTensorArgs
28
+
29
+
30
+ class BaseMulVisitor(NodeVisitor):
31
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
32
+ super().__init__(op_codes, graph)
33
+
34
+ def define_node(
35
+ self,
36
+ node: torch.fx.node.Node,
37
+ ) -> circle.Operator.OperatorT:
38
+ op_index = get_op_index(
39
+ circle.BuiltinOperator.BuiltinOperator.MUL, self._op_codes
40
+ )
41
+
42
+ inputs = list(node.args)
43
+ outputs = [node]
44
+
45
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
46
+
47
+ # Op-specific option
48
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.MulOptions
49
+ option = circle.MulOptions.MulOptionsT()
50
+ option.fusedActivationFunction = (
51
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
52
+ )
53
+ operator.builtinOptions = option
54
+
55
+ return operator
56
+
57
+
58
+ @register_node_visitor
59
+ class MulTensorVisitor(BaseMulVisitor):
60
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.mul.Tensor]
61
+
62
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
63
+ super().__init__(op_codes, graph)
64
+
65
+ def define_node(
66
+ self,
67
+ node: torch.fx.Node,
68
+ ) -> circle.Operator.OperatorT:
69
+ args = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
70
+ input = args.input
71
+ other = args.other
72
+
73
+ operator = super().define_node(
74
+ node,
75
+ )
76
+
77
+ return operator
78
+
79
+
80
+ @register_node_visitor
81
+ class MulScalarVisitor(BaseMulVisitor):
82
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.mul.Scalar]
83
+
84
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
85
+ super().__init__(op_codes, graph)
86
+
87
+ def define_node(
88
+ self,
89
+ node: torch.fx.Node,
90
+ ) -> circle.Operator.OperatorT:
91
+ args = MulScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
92
+ input = args.input
93
+ other = args.other
94
+
95
+ operator = super().define_node(
96
+ node,
97
+ )
98
+
99
+ return operator
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import NeTensorArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class NeVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.ne.Scalar,
34
+ torch.ops.aten.ne.Tensor,
35
+ ]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.NOT_EQUAL, self._op_codes
43
+ )
44
+
45
+ args = NeTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+ other = args.other
48
+
49
+ inputs = [input, other]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ return operator
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import NegArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class NegVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.neg.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.NEG,
43
+ self._op_codes,
44
+ )
45
+ args = NegArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ # Op-specific option
54
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.NegOptions
55
+ option = circle.NegOptions.NegOptionsT()
56
+
57
+ operator.builtinOptions = option
58
+
59
+ return operator
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import PermuteArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class PermuteVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.permute.default]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ op_index = get_op_index(
43
+ circle.BuiltinOperator.BuiltinOperator.TRANSPOSE,
44
+ self._op_codes,
45
+ )
46
+
47
+ args = PermuteArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
48
+ input = args.input
49
+ dims = args.dims
50
+
51
+ dims_i32 = circle_legalize_dtype_to(dims, dtype=torch.int32)
52
+ inputs = [input, dims_i32]
53
+ outputs = [node]
54
+
55
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
+
57
+ # Op-specific option
58
+ operator.builtinOptionsType = (
59
+ circle.BuiltinOptions.BuiltinOptions.TransposeOptions
60
+ )
61
+ option = circle.TransposeOptions.TransposeOptionsT()
62
+
63
+ operator.builtinOptions = option
64
+
65
+ return operator
@@ -0,0 +1,138 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import extract_torch_dtype, to_circle_dtype
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import PowTensorScalarArgs, PowTensorTensorArgs
29
+
30
+
31
+ class BasePowVisitor(NodeVisitor):
32
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
33
+ super().__init__(op_codes, graph)
34
+
35
+ def cast_to_float(self, node: torch.fx.Node) -> circle.Tensor.TensorT:
36
+ assert isinstance(node, torch.fx.Node), type(node)
37
+ node_tensor: circle.Tensor.TensorT = self.graph.get_tensor(node)
38
+ node_shape: List[int] = node_tensor.shape
39
+ op_index = get_op_index(
40
+ circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes
41
+ )
42
+ cast_name = f"{node.name}_cast"
43
+ cast_dtype = circle.TensorType.TensorType.FLOAT32
44
+ cast_tensor = self.graph.add_tensor_from_scratch(
45
+ prefix=cast_name, dtype=cast_dtype, shape=node_shape
46
+ )
47
+ cast_operator = create_builtin_operator(
48
+ self.graph, op_index, [node], [cast_tensor]
49
+ )
50
+ cast_operator.builtinOptionsType = (
51
+ circle.BuiltinOptions.BuiltinOptions.CastOptions
52
+ )
53
+ option = circle.CastOptions.CastOptionsT()
54
+ node_dtype = extract_torch_dtype(node)
55
+ option.inDataType = to_circle_dtype(node_dtype)
56
+ option.outDataType = cast_dtype
57
+ cast_operator.builtinOptions = option
58
+ self.graph.add_operator(cast_operator)
59
+
60
+ return cast_tensor
61
+
62
+ def define_pow_node(self, inputs: List, outputs: List) -> circle.Operator.OperatorT:
63
+ op_index = get_op_index(
64
+ circle.BuiltinOperator.BuiltinOperator.POW, self._op_codes
65
+ )
66
+
67
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
68
+
69
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.PowOptions
70
+ # Pow opearation does not have any options.
71
+ option = circle.PowOptions.PowOptionsT()
72
+
73
+ operator.builtinOptions = option
74
+
75
+ return operator
76
+
77
+
78
+ # TODO Support `aten::pow.Scalar` (base=scalar, exponenent=tensor)
79
+ # ExecuTorch currently does not support it as of now (2024/02/13).
80
+
81
+
82
+ @register_node_visitor
83
+ class PowTensorScalarVisitor(BasePowVisitor):
84
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.pow.Tensor_Scalar]
85
+
86
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
87
+ super().__init__(op_codes, graph)
88
+
89
+ def define_node(
90
+ self,
91
+ node: torch.fx.Node,
92
+ ) -> circle.Operator.OperatorT:
93
+
94
+ args = PowTensorScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
95
+ input = args.input
96
+ exponent = args.exponent
97
+
98
+ lhs_dtype = extract_torch_dtype(input)
99
+ # Circle supports only same dtype between lhs and rhs.
100
+ if lhs_dtype == torch.float32 and isinstance(exponent, int):
101
+ exponent = float(exponent)
102
+ if lhs_dtype == torch.int32 or lhs_dtype == torch.int64:
103
+ if isinstance(exponent, float):
104
+ input = self.cast_to_float(input) # type: ignore[assignment]
105
+
106
+ operator = self.define_pow_node([input, exponent], [node])
107
+
108
+ return operator
109
+
110
+
111
+ @register_node_visitor
112
+ class PowTensorTensorVisitor(BasePowVisitor):
113
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.pow.Tensor_Tensor]
114
+
115
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
116
+ super().__init__(op_codes, graph)
117
+
118
+ def define_node(
119
+ self,
120
+ node: torch.fx.Node,
121
+ ) -> circle.Operator.OperatorT:
122
+
123
+ args = PowTensorTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
124
+ input = args.input
125
+ exponent = args.exponent # type: ignore[arg-type]
126
+
127
+ lhs_dtype = extract_torch_dtype(input)
128
+ rhs_dtype = extract_torch_dtype(exponent) # type: ignore[arg-type]
129
+ # Circle supports only same dtype between lhs and rhs.
130
+ if lhs_dtype == torch.float32 and rhs_dtype == torch.int:
131
+ exponent = self.cast_to_float(exponent) # type: ignore[arg-type, assignment]
132
+ if lhs_dtype == torch.int32 or lhs_dtype == torch.int64:
133
+ if rhs_dtype == torch.float32:
134
+ input = self.cast_to_float(input) # type: ignore[assignment]
135
+
136
+ operator = self.define_pow_node([input, exponent], [node])
137
+
138
+ return operator