tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import PReLUArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class PReLUVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.prelu.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.PRELU, self._op_codes
43
+ )
44
+
45
+ args = PReLUArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+ weight = args.weight
48
+
49
+ inputs = [input, weight]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ return operator
@@ -0,0 +1,79 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import QuantizePerTensorArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class QuantizePerTensorDefaultVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.quantized_decomposed.quantize_per_tensor.default
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = QuantizePerTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ tensor = args.tensor
45
+ scale = args.scale
46
+ zero_p = args.zero_p
47
+ quant_min = args.quant_min
48
+ quant_max = args.quant_max
49
+
50
+ output_tensor: circle.Tensor.TensorT = self.graph.get_tensor(node)
51
+ assert output_tensor.quantization is not None
52
+
53
+ # Tensor should have qparam when it's exported
54
+ # The qparam must match with the arguments of this Op
55
+ assert output_tensor.quantization.scale[0] == scale
56
+ assert output_tensor.quantization.zeroPoint[0] == zero_p
57
+
58
+ if output_tensor.type == circle.TensorType.TensorType.UINT8:
59
+ assert quant_min == 0 and quant_max == 255
60
+ elif output_tensor.type == circle.TensorType.TensorType.INT16:
61
+ # Some frameworks use -32767 as quant_min of int16
62
+ assert quant_min in (-32768, -32767) and quant_max == 32767
63
+
64
+ inputs = [tensor]
65
+ outputs = [node]
66
+
67
+ op_index = get_op_index(
68
+ circle.BuiltinOperator.BuiltinOperator.QUANTIZE, self._op_codes
69
+ )
70
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
71
+
72
+ # Op-specific option
73
+ operator.builtinOptionsType = (
74
+ circle.BuiltinOptions.BuiltinOptions.QuantizeOptions
75
+ )
76
+ option = circle.QuantizeOptions.QuantizeOptionsT()
77
+ operator.builtinOptions = option
78
+
79
+ return operator
@@ -0,0 +1,64 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import extract_torch_dtype
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import ReciprocalArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class ReciprocalVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.reciprocal.default]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ op_index = get_op_index(
43
+ circle.BuiltinOperator.BuiltinOperator.DIV, self._op_codes
44
+ )
45
+
46
+ args = ReciprocalArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
47
+ input = args.input
48
+
49
+ input_tensor = torch.tensor(1, dtype=extract_torch_dtype(input))
50
+ x = self.graph.add_const_tensor(input_tensor)
51
+ inputs = [x, input]
52
+ outputs = [node]
53
+
54
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
55
+
56
+ # Op-specific option
57
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.DivOptions
58
+ option = circle.DivOptions.DivOptionsT()
59
+ option.fusedActivationFunction = (
60
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
61
+ )
62
+ operator.builtinOptions = option
63
+
64
+ return operator
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import ReluArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class ReluVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.relu.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.RELU, self._op_codes
43
+ )
44
+
45
+ args = ReluArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ return operator
@@ -0,0 +1,52 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import Relu6Args
28
+
29
+
30
+ @register_node_visitor
31
+ class Relu6Visitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.relu6.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
40
+ args = Relu6Args(*node.args, **node.kwargs) # type: ignore[arg-type]
41
+ input = args.input
42
+
43
+ op_index = get_op_index(
44
+ circle.BuiltinOperator.BuiltinOperator.RELU6, self._op_codes
45
+ )
46
+
47
+ inputs = [input]
48
+ outputs = [node]
49
+
50
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
51
+
52
+ return operator
@@ -0,0 +1,99 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
29
+ from tico.utils.validate_args_kwargs import RepeatArgs
30
+
31
+
32
+ @register_node_visitor
33
+ class RepeatVisitor(NodeVisitor):
34
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.repeat.default]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = RepeatArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ input = args.input
45
+ repeats = args.repeats
46
+
47
+ for r in repeats:
48
+ if r == 0:
49
+ # TODO: Support r == 0 case
50
+ raise NotYetSupportedError("Only support positive repeat value")
51
+ elif r < 0:
52
+ raise InvalidArgumentError("Only support positive repeat value")
53
+
54
+ tensor_shape = extract_shape(input)
55
+ assert len(tensor_shape) <= len(repeats)
56
+ if len(tensor_shape) != len(repeats):
57
+ # TODO Support len(tensor_shape) < len(repeats)
58
+ raise NotYetSupportedError(
59
+ "Length of both input tensor and repeats vector should be same."
60
+ )
61
+ repeat_dim_cnt = len(repeats) - repeats.count(1)
62
+ tensor_dtype = extract_circle_dtype(input)
63
+ op_index = get_op_index(
64
+ circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes
65
+ )
66
+ concat_input: torch.fx.Node | circle.Tensor.TensorT = input
67
+ concat_output: torch.fx.node.Node | circle.Tensor.TensorT = node
68
+ for idx, r in enumerate(repeats):
69
+ # concat along idx dimension
70
+ if r > 1:
71
+ # Except last created concat, a tensor should be created.
72
+ if repeat_dim_cnt > 1:
73
+ repeated_shape = list(tensor_shape)
74
+ repeated_shape[idx] = repeated_shape[idx] * r
75
+ concat_output = self.graph.add_tensor_from_scratch(
76
+ prefix=f"{node.name}_concat_{idx}",
77
+ shape=repeated_shape,
78
+ dtype=tensor_dtype,
79
+ )
80
+ inputs = [concat_input] * r
81
+ if repeat_dim_cnt == 1:
82
+ outputs: List[torch.fx.node.Node | circle.Tensor.TensorT] = [node]
83
+ else:
84
+ outputs = [concat_output]
85
+ operator = create_builtin_operator(
86
+ self.graph, op_index, inputs, outputs
87
+ )
88
+ operator.builtinOptionsType = (
89
+ circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions
90
+ )
91
+ option = circle.ConcatenationOptions.ConcatenationOptionsT()
92
+ option.axis = idx
93
+ operator.builtinOptions = option
94
+ if repeat_dim_cnt > 1:
95
+ self.graph.add_operator(operator)
96
+ concat_input = concat_output
97
+ repeat_dim_cnt -= 1
98
+
99
+ return operator
@@ -0,0 +1,73 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph, is_const
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.errors import NotYetSupportedError
29
+ from tico.utils.validate_args_kwargs import ReshapeArgs
30
+
31
+
32
+ @register_node_visitor
33
+ class ReshapeVisitor(NodeVisitor):
34
+ target: List[torch._ops.OpOverload] = [
35
+ torch.ops.aten.reshape.default,
36
+ ]
37
+
38
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
39
+ super().__init__(op_codes, graph)
40
+
41
+ def define_node(
42
+ self,
43
+ node: torch.fx.Node,
44
+ ) -> circle.Operator.OperatorT:
45
+ op_index = get_op_index(
46
+ circle.BuiltinOperator.BuiltinOperator.RESHAPE,
47
+ self._op_codes,
48
+ )
49
+ args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
50
+ input = args.input
51
+ size = args.size
52
+
53
+ if isinstance(size, int):
54
+ raise NotYetSupportedError("scalar size conversion is not supported yet.")
55
+
56
+ assert is_const(size), type(size)
57
+
58
+ size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
59
+ inputs = [input, size_i32]
60
+ outputs = [node]
61
+
62
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
63
+
64
+ # Op-specific option
65
+ operator.builtinOptionsType = (
66
+ circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
67
+ )
68
+ option = circle.ReshapeOptions.ReshapeOptionsT()
69
+ option.newShape = size_i32
70
+
71
+ operator.builtinOptions = option
72
+
73
+ return operator
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import ResizeNearestNeighborArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class ResizeNearestNeighborVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.circle_custom.resize_nearest_neighbor
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ # TODO Support generic algorithm
44
+ args = ResizeNearestNeighborArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
45
+ input = args.input
46
+ size = args.size
47
+
48
+ size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
49
+ inputs = [input, size_i32]
50
+ outputs = [node]
51
+
52
+ op_index = get_op_index(
53
+ circle.BuiltinOperator.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR,
54
+ self._op_codes,
55
+ )
56
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
57
+
58
+ # Op-specific option
59
+ operator.builtinOptionsType = (
60
+ circle.BuiltinOptions.BuiltinOptions.ResizeNearestNeighborOptions
61
+ )
62
+ option = circle.ResizeNearestNeighborOptions.ResizeNearestNeighborOptionsT()
63
+ # TODO Consider these options
64
+ # If True, the centers of the 4 corner pixels of the input and output tensors are aligned, preserving the values at the corner pixels.
65
+ option.alignCorners = False
66
+ # If True, the pixel centers are assumed to be at (0.5, 0.5). If this parameter is True, then align_corners parameter must be False.
67
+ option.halfPixelCenters = False
68
+ operator.builtinOptions = option
69
+
70
+ return operator
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import RsqrtArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class RsqrtVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.rsqrt.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.RSQRT, self._op_codes
43
+ )
44
+
45
+ args = RsqrtArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ return operator
@@ -0,0 +1,51 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import extract_torch_dtype
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.utils.validate_args_kwargs import ScalarTensorArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class ScalarTensorVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.scalar_tensor.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ # assert False, "This pass must not be in use."
42
+
43
+ args = ScalarTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ scalar = args.scalar
45
+
46
+ # Set dtype as node dtype because `scalar_tensor` results in float even the input is int.
47
+ output_data = torch.scalar_tensor(scalar, dtype=extract_torch_dtype(node))
48
+
49
+ self.graph.update_tensor_buffer(output_data, node.name)
50
+
51
+ return None # type: ignore[return-value]