tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,43 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from circle_schema import circle
16
+
17
+
18
+ class OpCode(circle.OperatorCode.OperatorCodeT):
19
+ """
20
+ Wrapper class for operator code in circle schema
21
+ This implements __eq__ and __hash__ for use with dict()
22
+ """
23
+
24
+ def __init__(self):
25
+ super().__init__()
26
+
27
+ def __eq__(self, other):
28
+ if self.version != other.version:
29
+ return False
30
+
31
+ if self.builtinCode == circle.BuiltinOperator.BuiltinOperator.CUSTOM:
32
+ return self.customCode == other.customCode
33
+
34
+ return self.builtinCode == other.builtinCode
35
+
36
+ def __hash__(self):
37
+ val = (
38
+ self.deprecatedBuiltinCode,
39
+ self.customCode,
40
+ self.version,
41
+ self.builtinCode,
42
+ )
43
+ return hash(val)
@@ -0,0 +1,80 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, Type, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from circle_schema import circle
21
+
22
+ from tico.serialize.circle_graph import CircleSubgraph
23
+ from tico.serialize.operators.hashable_opcode import OpCode
24
+
25
+
26
+ class NodeVisitor:
27
+ """
28
+ Node visitor for lowering edge IR to circle
29
+ """
30
+
31
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
32
+ # For setting opcode index in circle model
33
+ # This is updated during serialization
34
+ self._op_codes = op_codes
35
+ self.graph = graph
36
+
37
+ # Define circle model operator
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.node.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ raise NotImplementedError("NodeVisitor must be extended.")
43
+
44
+
45
+ # container for all node visitors
46
+ _node_visitor_dict: Dict[torch._ops.OpOverload, Type[NodeVisitor]] = {}
47
+
48
+
49
+ # Decorator for each visitor
50
+ def register_node_visitor(visitor):
51
+ for target in visitor.target:
52
+ _node_visitor_dict[target] = visitor
53
+ return visitor
54
+
55
+
56
+ def get_node_visitor(target: torch._ops.OpOverload) -> Type[NodeVisitor]:
57
+ """
58
+ Get a single node visitor (for unittest purpose)
59
+ """
60
+ _visitor = _node_visitor_dict.get(target, None)
61
+
62
+ if not _visitor:
63
+ raise LookupError(f"NodeVisitor for {target} is not registered")
64
+
65
+ return _visitor
66
+
67
+
68
+ # Get all node visitors
69
+ def get_node_visitors(
70
+ op_codes: Dict[OpCode, int], graph: CircleSubgraph
71
+ ) -> Dict[torch._ops.OpOverload, NodeVisitor]:
72
+ node_visitors = {}
73
+ for target, visitor in _node_visitor_dict.items():
74
+ node_visitors[target] = visitor(op_codes, graph)
75
+
76
+ return node_visitors
77
+
78
+
79
+ def get_support_targets():
80
+ return _node_visitor_dict.keys()
@@ -0,0 +1,69 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import AddTensorArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class AddVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.add.Tensor,
34
+ torch.ops.aten.add.Scalar,
35
+ ]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(
41
+ self,
42
+ node: torch.fx.Node,
43
+ ) -> circle.Operator.OperatorT:
44
+ args = AddTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
45
+ input = args.input
46
+ other = args.other
47
+
48
+ inputs = [input, other]
49
+ outputs = [node]
50
+
51
+ op_index = get_op_index(
52
+ circle.BuiltinOperator.BuiltinOperator.ADD, self._op_codes
53
+ )
54
+
55
+ inputs = [input, other]
56
+ outputs = [node]
57
+
58
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
59
+
60
+ # Op-specific option
61
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.AddOptions
62
+ option = circle.AddOptions.AddOptionsT()
63
+ option.fusedActivationFunction = (
64
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
65
+ )
66
+ option.potScaleInt16 = False
67
+ operator.builtinOptions = option
68
+
69
+ return operator
@@ -0,0 +1,64 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.operators.hashable_opcode import OpCode
24
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
25
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
26
+ from tico.utils.validate_args_kwargs import AliasCopyArgs
27
+
28
+
29
+ @register_node_visitor
30
+ class AliasCopyVisitor(NodeVisitor):
31
+ target: List[torch._ops.OpOverload] = [
32
+ torch.ops.aten.alias.default,
33
+ torch.ops.aten.alias_copy.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = AliasCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ input = args.input
45
+
46
+ op_index = get_op_index(
47
+ circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
48
+ )
49
+
50
+ permute = torch.IntTensor(list(range(len(input.meta["val"].shape))))
51
+
52
+ inputs = [input, permute]
53
+ outputs = [node]
54
+
55
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
+
57
+ # Op-specific option
58
+ operator.builtinOptionsType = (
59
+ circle.BuiltinOptions.BuiltinOptions.TransposeOptions
60
+ )
61
+ option = circle.TransposeOptions.TransposeOptionsT()
62
+ operator.builtinOptions = option
63
+
64
+ return operator
@@ -0,0 +1,142 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from circle_schema import circle
21
+
22
+ from tico.serialize.circle_graph import CircleSubgraph
23
+ from tico.serialize.circle_mapping import (
24
+ circle_legalize_dtype_to,
25
+ extract_circle_dtype,
26
+ extract_shape,
27
+ extract_torch_dtype,
28
+ )
29
+ from tico.serialize.operators.hashable_opcode import OpCode
30
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
31
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
32
+ from tico.utils.validate_args_kwargs import AnyArgs
33
+
34
+
35
+ @register_node_visitor
36
+ class AnyVisitor(NodeVisitor):
37
+ """
38
+ Let's take NotEqual0 -> ReduceMax workaround for float, int
39
+ [RESTRICTION]
40
+ 1. ReduceAny is not supported (luci-interpreter)
41
+ [CASE: BOOL]
42
+ (Bool tensors don't need 'Not Equal 0' at the first step.)
43
+ bool[d0..dN] --- Reduce Max ---> bool[]
44
+ [CASE: FLOAT, INT]
45
+ int/float[d0..dN] --- Not Equal 0 ---> bool[d0,...dN]
46
+ --- Reduce Max ---> bool[]
47
+ * [d0..dN] means a tensor with any shape
48
+ * [] means Scalar
49
+ """
50
+
51
+ target: List[torch._ops.OpOverload] = [
52
+ torch.ops.aten.any.default,
53
+ torch.ops.aten.any.dim,
54
+ torch.ops.aten.any.dims,
55
+ ]
56
+
57
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
58
+ super().__init__(op_codes, graph)
59
+
60
+ def define_max_node(
61
+ self, inputs: List, outputs: List, keepdims: bool
62
+ ) -> circle.Operator.OperatorT:
63
+ op_index = get_op_index(
64
+ circle.BuiltinOperator.BuiltinOperator.REDUCE_MAX, self._op_codes
65
+ )
66
+
67
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
68
+
69
+ operator.builtinOptionsType = (
70
+ circle.BuiltinOptions.BuiltinOptions.ReducerOptions
71
+ )
72
+ option = circle.ReducerOptions.ReducerOptionsT()
73
+ option.keepDims = keepdims
74
+
75
+ operator.builtinOptions = option
76
+
77
+ return operator
78
+
79
+ def define_ne_node(self, inputs: List, outputs: List) -> circle.Operator.OperatorT:
80
+ op_index = get_op_index(
81
+ circle.BuiltinOperator.BuiltinOperator.NOT_EQUAL, self._op_codes
82
+ )
83
+
84
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
85
+
86
+ operator.builtinOptionsType = (
87
+ circle.BuiltinOptions.BuiltinOptions.NotEqualOptions
88
+ )
89
+ option = circle.NotEqualOptions.NotEqualOptionsT()
90
+ operator.builtinOptions = option
91
+ return operator
92
+
93
+ def define_node(
94
+ self,
95
+ node: torch.fx.Node,
96
+ ) -> circle.Operator.OperatorT:
97
+ args = AnyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
98
+ input = args.input
99
+ dim = args.dim
100
+ keepdim = args.keepdim
101
+
102
+ input_shape = list(extract_shape(input))
103
+ output_shape = list(extract_shape(node))
104
+
105
+ if dim is None:
106
+ dims = tuple(i for i in range(0, len(input_shape)))
107
+ dim_i32 = tuple(
108
+ circle_legalize_dtype_to(dim, dtype=torch.int32) for dim in dims
109
+ )
110
+ if isinstance(dim, int):
111
+ dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
112
+ if isinstance(dim, tuple):
113
+ dim_i32 = tuple(circle_legalize_dtype_to(d, dtype=torch.int32) for d in dim)
114
+
115
+ inputs = [input, dim_i32]
116
+ outputs = [node]
117
+
118
+ dtype_torch = extract_torch_dtype(input)
119
+ input_tensor: torch.fx.node.Node | circle.Tensor.TensorT = input
120
+
121
+ if dtype_torch in [torch.int32, torch.int64, torch.float32, torch.float64]:
122
+ dst_dtype_circle = circle.TensorType.TensorType.BOOL
123
+ dst_dtype_torch = torch.bool
124
+ ne_tensor: circle.Tensor.TensorT = self.graph.add_tensor_from_scratch(
125
+ prefix=f"{input.name}_ne", shape=input_shape, dtype=dst_dtype_circle
126
+ )
127
+ ne_node = self.define_ne_node(
128
+ [input_tensor, torch.Tensor([0]).to(dtype_torch)], [ne_tensor]
129
+ )
130
+ self.graph.add_operator(ne_node)
131
+
132
+ dtype_torch = dst_dtype_torch
133
+ input_tensor = ne_tensor
134
+ inputs = [ne_tensor, dim_i32]
135
+
136
+ inputs = [input_tensor, dim_i32]
137
+
138
+ reduce_node: circle.Operator.OperatorT = self.define_max_node(
139
+ inputs, outputs, keepdim
140
+ )
141
+
142
+ return reduce_node
@@ -0,0 +1,61 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.utils.validate_args_kwargs import ArangeStartStepArgs
27
+
28
+
29
+ @register_node_visitor
30
+ class ArangeStartStepVisitor(NodeVisitor):
31
+ """
32
+ Fuse arange_start_step to const_tensor
33
+ """
34
+
35
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.arange.start_step]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(
41
+ self,
42
+ node: torch.fx.Node,
43
+ ) -> circle.Operator.OperatorT:
44
+ args = ArangeStartStepArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
45
+ start = args.start
46
+ end = args.end
47
+ step = args.step
48
+ delta = 1
49
+
50
+ if step is not None:
51
+ delta = step[0] # type: ignore[index]
52
+ # assert False, "This pass must not be in use."
53
+
54
+ arange_dtype: torch.dtype = torch.float32
55
+ if isinstance(start, int) and isinstance(end, int):
56
+ arange_dtype = torch.int64
57
+
58
+ output_data = torch.arange(start=start, end=end, step=delta, dtype=arange_dtype)
59
+ self.graph.update_tensor_buffer(output_data, node.name)
60
+
61
+ return None # type: ignore[return-value]
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import ArgMaxArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class ArgMaxVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.argmax.default]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ args = ArgMaxArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
43
+ tensor = args.tensor
44
+ dim = args.dim
45
+
46
+ op_index = get_op_index(
47
+ circle.BuiltinOperator.BuiltinOperator.ARG_MAX, self._op_codes
48
+ )
49
+
50
+ dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
51
+ inputs = [tensor, dim_i32]
52
+ outputs = [node]
53
+
54
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
55
+
56
+ # Op-specific option
57
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ArgMaxOptions
58
+ option = circle.ArgMaxOptions.ArgMaxOptionsT()
59
+ option.outputType = circle.TensorType.TensorType.INT64
60
+ operator.builtinOptions = option
61
+
62
+ return operator
@@ -0,0 +1,112 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.define import define_pad_node
29
+ from tico.utils.validate_args_kwargs import AvgPool2dArgs
30
+
31
+
32
+ @register_node_visitor
33
+ class AvgPool2DVisitor(NodeVisitor):
34
+ target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.avgpool2d]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ input = args.input
45
+ kernel_size = args.kernel_size
46
+ stride = args.stride
47
+ padding = args.padding
48
+
49
+ avgpool_input: torch.fx.Node | circle.Tensor.TensorT = input
50
+
51
+ def define_padding_node():
52
+ assert isinstance(padding, list), type(padding)
53
+ padding_vec = torch.tensor(
54
+ [
55
+ [0, 0],
56
+ [padding[0], padding[0]],
57
+ [padding[1], padding[1]],
58
+ [0, 0],
59
+ ],
60
+ dtype=torch.int32,
61
+ )
62
+ input_shape = list(extract_shape(input))
63
+ input_dtype: int = extract_circle_dtype(input)
64
+ padded_input_shape = [
65
+ input_shape[0],
66
+ input_shape[1],
67
+ input_shape[2],
68
+ input_shape[3],
69
+ ]
70
+ padded_input_shape[1] += padding[0] * 2
71
+ padded_input_shape[2] += padding[1] * 2
72
+ # create padded input tensor
73
+ padded_input_tensor = self.graph.add_tensor_from_scratch(
74
+ prefix=f"{input.name}_pad_output",
75
+ shape=padded_input_shape,
76
+ dtype=input_dtype,
77
+ )
78
+ pad_operator = define_pad_node(
79
+ self.graph, self._op_codes, [input, padding_vec], [padded_input_tensor]
80
+ )
81
+ self.graph.add_operator(pad_operator)
82
+ return padded_input_tensor
83
+
84
+ if padding is not None:
85
+ avgpool_input = define_padding_node()
86
+
87
+ inputs = [avgpool_input]
88
+ outputs = [node]
89
+
90
+ op_index = get_op_index(
91
+ circle.BuiltinOperator.BuiltinOperator.AVERAGE_POOL_2D,
92
+ self._op_codes,
93
+ )
94
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
95
+
96
+ # Op-specific option
97
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.Pool2DOptions
98
+ option = circle.Pool2DOptions.Pool2DOptionsT()
99
+
100
+ SAME, VALID = 0, 1
101
+ option.padding = VALID
102
+ option.strideH = stride[0]
103
+ option.strideW = stride[1]
104
+ option.filterHeight = kernel_size[0]
105
+ option.filterWidth = kernel_size[1]
106
+ option.fusedActivationFunction = (
107
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
108
+ )
109
+
110
+ operator.builtinOptions = option
111
+
112
+ return operator
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import BmmArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class BatchMatmulVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.bmm.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
42
+ input = args.input
43
+ mat2 = args.mat2
44
+
45
+ op_index = get_op_index(
46
+ circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
47
+ )
48
+
49
+ inputs = [input, mat2]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ # Op-specific option
55
+ operator.builtinOptionsType = (
56
+ circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
57
+ )
58
+ option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
59
+ option.adjointLhs, option.adjointRhs = False, False
60
+ operator.builtinOptions = option
61
+
62
+ return operator