tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,141 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import extract_torch_dtype, to_circle_dtype
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import PowTensorScalarArgs, PowTensorTensorArgs
29
+
30
+
31
+ class BasePowVisitor(NodeVisitor):
32
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
33
+ super().__init__(op_codes, graph)
34
+
35
+ def cast_to_float(self, node: torch.fx.Node) -> circle.Tensor.TensorT:
36
+ assert isinstance(node, torch.fx.Node), type(node)
37
+ node_tensor: circle.Tensor.TensorT = self.graph.get_tensor(node)
38
+ node_shape: List[int] = node_tensor.shape
39
+ op_index = get_op_index(
40
+ circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes
41
+ )
42
+ cast_name = f"{node.name}_cast"
43
+ cast_dtype = circle.TensorType.TensorType.FLOAT32
44
+ cast_tensor = self.graph.add_tensor_from_scratch(
45
+ prefix=cast_name,
46
+ dtype=cast_dtype,
47
+ shape=node_shape,
48
+ source_node=node,
49
+ )
50
+ cast_operator = create_builtin_operator(
51
+ self.graph, op_index, [node], [cast_tensor]
52
+ )
53
+ cast_operator.builtinOptionsType = (
54
+ circle.BuiltinOptions.BuiltinOptions.CastOptions
55
+ )
56
+ option = circle.CastOptions.CastOptionsT()
57
+ node_dtype = extract_torch_dtype(node)
58
+ option.inDataType = to_circle_dtype(node_dtype)
59
+ option.outDataType = cast_dtype
60
+ cast_operator.builtinOptions = option
61
+ self.graph.add_operator(cast_operator)
62
+
63
+ return cast_tensor
64
+
65
+ def define_pow_node(self, inputs: List, outputs: List) -> circle.Operator.OperatorT:
66
+ op_index = get_op_index(
67
+ circle.BuiltinOperator.BuiltinOperator.POW, self._op_codes
68
+ )
69
+
70
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
71
+
72
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.PowOptions
73
+ # Pow opearation does not have any options.
74
+ option = circle.PowOptions.PowOptionsT()
75
+
76
+ operator.builtinOptions = option
77
+
78
+ return operator
79
+
80
+
81
+ # TODO Support `aten::pow.Scalar` (base=scalar, exponenent=tensor)
82
+ # ExecuTorch currently does not support it as of now (2024/02/13).
83
+
84
+
85
+ @register_node_visitor
86
+ class PowTensorScalarVisitor(BasePowVisitor):
87
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.pow.Tensor_Scalar]
88
+
89
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
90
+ super().__init__(op_codes, graph)
91
+
92
+ def define_node(
93
+ self,
94
+ node: torch.fx.Node,
95
+ ) -> circle.Operator.OperatorT:
96
+
97
+ args = PowTensorScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
98
+ input = args.input
99
+ exponent = args.exponent
100
+
101
+ lhs_dtype = extract_torch_dtype(input)
102
+ # Circle supports only same dtype between lhs and rhs.
103
+ if lhs_dtype == torch.float32 and isinstance(exponent, int):
104
+ exponent = float(exponent)
105
+ if lhs_dtype == torch.int32 or lhs_dtype == torch.int64:
106
+ if isinstance(exponent, float):
107
+ input = self.cast_to_float(input) # type: ignore[assignment]
108
+
109
+ operator = self.define_pow_node([input, exponent], [node])
110
+
111
+ return operator
112
+
113
+
114
+ @register_node_visitor
115
+ class PowTensorTensorVisitor(BasePowVisitor):
116
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.pow.Tensor_Tensor]
117
+
118
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
119
+ super().__init__(op_codes, graph)
120
+
121
+ def define_node(
122
+ self,
123
+ node: torch.fx.Node,
124
+ ) -> circle.Operator.OperatorT:
125
+
126
+ args = PowTensorTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
127
+ input = args.input
128
+ exponent = args.exponent # type: ignore[arg-type]
129
+
130
+ lhs_dtype = extract_torch_dtype(input)
131
+ rhs_dtype = extract_torch_dtype(exponent) # type: ignore[arg-type]
132
+ # Circle supports only same dtype between lhs and rhs.
133
+ if lhs_dtype == torch.float32 and rhs_dtype == torch.int:
134
+ exponent = self.cast_to_float(exponent) # type: ignore[arg-type, assignment]
135
+ if lhs_dtype == torch.int32 or lhs_dtype == torch.int64:
136
+ if rhs_dtype == torch.float32:
137
+ input = self.cast_to_float(input) # type: ignore[assignment]
138
+
139
+ operator = self.define_pow_node([input, exponent], [node])
140
+
141
+ return operator
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import PReLUArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class PReLUVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.prelu.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.PRELU, self._op_codes
43
+ )
44
+
45
+ args = PReLUArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+ weight = args.weight
48
+
49
+ inputs = [input, weight]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ return operator
@@ -0,0 +1,79 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import QuantizePerTensorArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class QuantizePerTensorDefaultVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.quantized_decomposed.quantize_per_tensor.default
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = QuantizePerTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ tensor = args.tensor
45
+ scale = args.scale
46
+ zero_p = args.zero_p
47
+ quant_min = args.quant_min
48
+ quant_max = args.quant_max
49
+
50
+ output_tensor: circle.Tensor.TensorT = self.graph.get_tensor(node)
51
+ assert output_tensor.quantization is not None
52
+
53
+ # Tensor should have qparam when it's exported
54
+ # The qparam must match with the arguments of this Op
55
+ assert output_tensor.quantization.scale[0] == scale
56
+ assert output_tensor.quantization.zeroPoint[0] == zero_p
57
+
58
+ if output_tensor.type == circle.TensorType.TensorType.UINT8:
59
+ assert quant_min == 0 and quant_max == 255
60
+ elif output_tensor.type == circle.TensorType.TensorType.INT16:
61
+ # Some frameworks use -32767 as quant_min of int16
62
+ assert quant_min in (-32768, -32767) and quant_max == 32767
63
+
64
+ inputs = [tensor]
65
+ outputs = [node]
66
+
67
+ op_index = get_op_index(
68
+ circle.BuiltinOperator.BuiltinOperator.QUANTIZE, self._op_codes
69
+ )
70
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
71
+
72
+ # Op-specific option
73
+ operator.builtinOptionsType = (
74
+ circle.BuiltinOptions.BuiltinOptions.QuantizeOptions
75
+ )
76
+ option = circle.QuantizeOptions.QuantizeOptionsT()
77
+ operator.builtinOptions = option
78
+
79
+ return operator
@@ -0,0 +1,64 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import extract_torch_dtype
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import ReciprocalArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class ReciprocalVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.reciprocal.default]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ op_index = get_op_index(
43
+ circle.BuiltinOperator.BuiltinOperator.DIV, self._op_codes
44
+ )
45
+
46
+ args = ReciprocalArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
47
+ input = args.input
48
+
49
+ input_tensor = torch.tensor(1, dtype=extract_torch_dtype(input))
50
+ x = self.graph.add_const_tensor(input_tensor, source_node=node)
51
+ inputs = [x, input]
52
+ outputs = [node]
53
+
54
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
55
+
56
+ # Op-specific option
57
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.DivOptions
58
+ option = circle.DivOptions.DivOptionsT()
59
+ option.fusedActivationFunction = (
60
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
61
+ )
62
+ operator.builtinOptions = option
63
+
64
+ return operator
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import ReluArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class ReluVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.relu.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.RELU, self._op_codes
43
+ )
44
+
45
+ args = ReluArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ return operator
@@ -0,0 +1,52 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import Relu6Args
28
+
29
+
30
+ @register_node_visitor
31
+ class Relu6Visitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.relu6.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
40
+ args = Relu6Args(*node.args, **node.kwargs) # type: ignore[arg-type]
41
+ input = args.input
42
+
43
+ op_index = get_op_index(
44
+ circle.BuiltinOperator.BuiltinOperator.RELU6, self._op_codes
45
+ )
46
+
47
+ inputs = [input]
48
+ outputs = [node]
49
+
50
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
51
+
52
+ return operator
@@ -0,0 +1,100 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
29
+ from tico.utils.validate_args_kwargs import RepeatArgs
30
+
31
+
32
+ @register_node_visitor
33
+ class RepeatVisitor(NodeVisitor):
34
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.repeat.default]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = RepeatArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ input = args.input
45
+ repeats = args.repeats
46
+
47
+ for r in repeats:
48
+ if r == 0:
49
+ # TODO: Support r == 0 case
50
+ raise NotYetSupportedError("Only support positive repeat value")
51
+ elif r < 0:
52
+ raise InvalidArgumentError("Only support positive repeat value")
53
+
54
+ tensor_shape = extract_shape(input)
55
+ assert len(tensor_shape) <= len(repeats)
56
+ if len(tensor_shape) != len(repeats):
57
+ # TODO Support len(tensor_shape) < len(repeats)
58
+ raise NotYetSupportedError(
59
+ "Length of both input tensor and repeats vector should be same."
60
+ )
61
+ repeat_dim_cnt = len(repeats) - repeats.count(1)
62
+ tensor_dtype = extract_circle_dtype(input)
63
+ op_index = get_op_index(
64
+ circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes
65
+ )
66
+ concat_input: torch.fx.Node | circle.Tensor.TensorT = input
67
+ concat_output: torch.fx.node.Node | circle.Tensor.TensorT = node
68
+ for idx, r in enumerate(repeats):
69
+ # concat along idx dimension
70
+ if r > 1:
71
+ # Except last created concat, a tensor should be created.
72
+ if repeat_dim_cnt > 1:
73
+ repeated_shape = list(tensor_shape)
74
+ repeated_shape[idx] = repeated_shape[idx] * r
75
+ concat_output = self.graph.add_tensor_from_scratch(
76
+ prefix=f"{node.name}_concat_{idx}",
77
+ shape=repeated_shape,
78
+ dtype=tensor_dtype,
79
+ source_node=node,
80
+ )
81
+ inputs = [concat_input] * r
82
+ if repeat_dim_cnt == 1:
83
+ outputs: List[torch.fx.node.Node | circle.Tensor.TensorT] = [node]
84
+ else:
85
+ outputs = [concat_output]
86
+ operator = create_builtin_operator(
87
+ self.graph, op_index, inputs, outputs
88
+ )
89
+ operator.builtinOptionsType = (
90
+ circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions
91
+ )
92
+ option = circle.ConcatenationOptions.ConcatenationOptionsT()
93
+ option.axis = idx
94
+ operator.builtinOptions = option
95
+ if repeat_dim_cnt > 1:
96
+ self.graph.add_operator(operator)
97
+ concat_input = concat_output
98
+ repeat_dim_cnt -= 1
99
+
100
+ return operator
@@ -0,0 +1,73 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph, is_const
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.errors import NotYetSupportedError
29
+ from tico.utils.validate_args_kwargs import ReshapeArgs
30
+
31
+
32
+ @register_node_visitor
33
+ class ReshapeVisitor(NodeVisitor):
34
+ target: List[torch._ops.OpOverload] = [
35
+ torch.ops.aten.reshape.default,
36
+ ]
37
+
38
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
39
+ super().__init__(op_codes, graph)
40
+
41
+ def define_node(
42
+ self,
43
+ node: torch.fx.Node,
44
+ ) -> circle.Operator.OperatorT:
45
+ op_index = get_op_index(
46
+ circle.BuiltinOperator.BuiltinOperator.RESHAPE,
47
+ self._op_codes,
48
+ )
49
+ args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
50
+ input = args.input
51
+ size = args.shape
52
+
53
+ if isinstance(size, int):
54
+ raise NotYetSupportedError("scalar size conversion is not supported yet.")
55
+
56
+ assert is_const(size), type(size)
57
+
58
+ size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
59
+ inputs = [input, size_i32]
60
+ outputs = [node]
61
+
62
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
63
+
64
+ # Op-specific option
65
+ operator.builtinOptionsType = (
66
+ circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
67
+ )
68
+ option = circle.ReshapeOptions.ReshapeOptionsT()
69
+ option.newShape = size_i32
70
+
71
+ operator.builtinOptions = option
72
+
73
+ return operator
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import ResizeNearestNeighborArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class ResizeNearestNeighborVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.circle_custom.resize_nearest_neighbor
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ # TODO Support generic algorithm
44
+ args = ResizeNearestNeighborArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
45
+ input = args.input
46
+ size = args.size
47
+
48
+ size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
49
+ inputs = [input, size_i32]
50
+ outputs = [node]
51
+
52
+ op_index = get_op_index(
53
+ circle.BuiltinOperator.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR,
54
+ self._op_codes,
55
+ )
56
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
57
+
58
+ # Op-specific option
59
+ operator.builtinOptionsType = (
60
+ circle.BuiltinOptions.BuiltinOptions.ResizeNearestNeighborOptions
61
+ )
62
+ option = circle.ResizeNearestNeighborOptions.ResizeNearestNeighborOptionsT()
63
+ # TODO Consider these options
64
+ # If True, the centers of the 4 corner pixels of the input and output tensors are aligned, preserving the values at the corner pixels.
65
+ option.alignCorners = False
66
+ # If True, the pixel centers are assumed to be at (0.5, 0.5). If this parameter is True, then align_corners parameter must be False.
67
+ option.halfPixelCenters = False
68
+ operator.builtinOptions = option
69
+
70
+ return operator