tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import LinearArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class LinearVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.linear.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
43
+ )
44
+ args = LinearArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
45
+ input = args.input
46
+ weight = args.weight
47
+ bias = args.bias
48
+
49
+ inputs = [input, weight, bias]
50
+
51
+ outputs = [node]
52
+
53
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
54
+
55
+ # Op-specific option
56
+ operator.builtinOptionsType = (
57
+ circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
58
+ )
59
+ option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
60
+ option.fusedActivationFunction = (
61
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
62
+ )
63
+ option.weightsFormat = (
64
+ circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
65
+ )
66
+ option.keepNumDims = True
67
+ option.asymmetricQuantizeInputs = False
68
+ operator.builtinOptions = option
69
+
70
+ return operator
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import LogArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class LogVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.log.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.LOG, self._op_codes
43
+ )
44
+
45
+ args = LogArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ return operator
@@ -0,0 +1,86 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import (
25
+ extract_circle_dtype,
26
+ extract_shape,
27
+ extract_torch_dtype,
28
+ )
29
+ from tico.serialize.operators.hashable_opcode import OpCode
30
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
31
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
32
+ from tico.utils.validate_args_kwargs import Log1pArgs
33
+
34
+
35
+ @register_node_visitor
36
+ class Log1pVisitor(NodeVisitor):
37
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.log1p.default]
38
+
39
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
40
+ super().__init__(op_codes, graph)
41
+
42
+ def define_add_node(self, inputs: List, outputs: List) -> circle.Operator.OperatorT:
43
+ op_index = get_op_index(
44
+ circle.BuiltinOperator.BuiltinOperator.ADD, self._op_codes
45
+ )
46
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
47
+
48
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.AddOptions
49
+ option = circle.AddOptions.AddOptionsT()
50
+ option.fusedActivationFunction = (
51
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
52
+ )
53
+ option.potScaleInt16 = False
54
+ operator.builtinOptions = option
55
+
56
+ return operator
57
+
58
+ def define_node(
59
+ self,
60
+ node: torch.fx.Node,
61
+ ) -> circle.Operator.OperatorT:
62
+ args = Log1pArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
63
+ input = args.input
64
+
65
+ input_shape = list(extract_shape(input))
66
+ dst_dtype_circle = extract_circle_dtype(input)
67
+ add_tensor: circle.Tensor.TensorT = self.graph.add_tensor_from_scratch(
68
+ prefix=f"{input.name}_add",
69
+ shape=input_shape,
70
+ dtype=dst_dtype_circle,
71
+ source_node=node,
72
+ )
73
+ const_one = torch.tensor([1]).to(extract_torch_dtype(input))
74
+
75
+ add_node = self.define_add_node([input, const_one], [add_tensor])
76
+ self.graph.add_operator(add_node)
77
+
78
+ inputs = [add_tensor]
79
+ outputs = [node]
80
+
81
+ op_index = get_op_index(
82
+ circle.BuiltinOperator.BuiltinOperator.LOG, self._op_codes
83
+ )
84
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
85
+
86
+ return operator
@@ -0,0 +1,63 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import LogicalAndArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class LogicalAndVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.logical_and.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.LOGICAL_AND,
43
+ self._op_codes,
44
+ )
45
+
46
+ args = LogicalAndArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
47
+ input = args.input
48
+ other = args.other
49
+
50
+ inputs = [input, other]
51
+ outputs = [node]
52
+
53
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
54
+
55
+ # Op-specific option
56
+ operator.builtinOptionsType = (
57
+ circle.BuiltinOptions.BuiltinOptions.LogicalAndOptions
58
+ )
59
+ option = circle.LogicalAndOptions.LogicalAndOptionsT()
60
+
61
+ operator.builtinOptions = option
62
+
63
+ return operator
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import LogicalNotArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class LogicalNotVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.logical_not.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.LOGICAL_NOT,
43
+ self._op_codes,
44
+ )
45
+
46
+ args = LogicalNotArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
47
+ input = args.input
48
+
49
+ inputs = [input]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ # Op-specific option
55
+ operator.builtinOptionsType = (
56
+ circle.BuiltinOptions.BuiltinOptions.LogicalNotOptions
57
+ )
58
+ option = circle.LogicalNotOptions.LogicalNotOptionsT()
59
+
60
+ operator.builtinOptions = option
61
+
62
+ return operator
@@ -0,0 +1,61 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import LtArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class LtVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.lt.Tensor]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.LESS,
43
+ self._op_codes,
44
+ )
45
+
46
+ args = LtArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
47
+ input = args.input
48
+ other = args.other
49
+
50
+ inputs = [input, other]
51
+ outputs = [node]
52
+
53
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
54
+
55
+ # Op-specific option
56
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.LessOptions
57
+ option = circle.LessOptions.LessOptionsT()
58
+
59
+ operator.builtinOptions = option
60
+
61
+ return operator
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import MaxDimArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class MaxDimVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.max.dim]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ args = MaxDimArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
43
+ input = args.input
44
+ dim = args.dim
45
+ keepdim = args.keepdim
46
+
47
+ # Only support modules that return the first output.
48
+ assert len(node.users) == 1
49
+ assert list(node.users.keys())[0].args[1] == 0 # node.users: {getitem: None}
50
+
51
+ op_index = get_op_index(
52
+ circle.BuiltinOperator.BuiltinOperator.REDUCE_MAX, self._op_codes
53
+ )
54
+
55
+ dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
56
+ inputs = [input, dim_i32]
57
+ outputs = [list(node.users.keys())[0]]
58
+
59
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
60
+
61
+ # Op-specific option
62
+ operator.builtinOptionsType = (
63
+ circle.BuiltinOptions.BuiltinOptions.ReducerOptions
64
+ )
65
+ option = circle.ReducerOptions.ReducerOptionsT()
66
+ option.keepDims = keepdim
67
+
68
+ operator.builtinOptions = option
69
+
70
+ return operator
@@ -0,0 +1,155 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import IntEnum
16
+ from typing import Dict, List, Optional, TYPE_CHECKING
17
+
18
+ if TYPE_CHECKING:
19
+ import torch._ops
20
+ import torch.fx
21
+ import torch
22
+ from circle_schema import circle
23
+
24
+ from tico.serialize.circle_graph import CircleSubgraph
25
+ from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
26
+ from tico.serialize.operators.hashable_opcode import OpCode
27
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
28
+ from tico.serialize.operators.utils import (
29
+ create_builtin_operator,
30
+ get_integer_dtype_min,
31
+ get_op_index,
32
+ )
33
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam
34
+ from tico.utils.validate_args_kwargs import MaxPool2dWithIndicesArgs
35
+
36
+
37
+ class PaddingType(IntEnum):
38
+ SAME = 0
39
+ VALID = 1
40
+
41
+
42
+ @register_node_visitor
43
+ class MaxPool2DWithIndicesVisitor(NodeVisitor):
44
+ target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.maxpool2d]
45
+
46
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
47
+ super().__init__(op_codes, graph)
48
+
49
+ def define_padV2_node(
50
+ self, inputs: List, outputs: List
51
+ ) -> circle.Operator.OperatorT:
52
+ def set_padv2_option(operator: circle.Operator.OperatorT):
53
+ operator.builtinOptionsType = (
54
+ circle.BuiltinOptions.BuiltinOptions.PadV2Options
55
+ )
56
+ option = circle.PadV2Options.PadV2OptionsT()
57
+ operator.builtinOptions = option
58
+
59
+ pad_op_index = get_op_index(
60
+ circle.BuiltinOperator.BuiltinOperator.PADV2, self._op_codes
61
+ )
62
+ operator = create_builtin_operator(self.graph, pad_op_index, inputs, outputs)
63
+ set_padv2_option(operator)
64
+ return operator
65
+
66
+ def define_node(
67
+ self,
68
+ node: torch.fx.Node,
69
+ ) -> circle.Operator.OperatorT:
70
+ # max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
71
+ args = MaxPool2dWithIndicesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
72
+ input = args.input
73
+ kernel_size = args.kernel_size
74
+ stride = args.stride
75
+ assert stride
76
+ padding = args.padding
77
+
78
+ maxpool_input: torch.fx.Node | circle.Tensor.TensorT = input
79
+
80
+ def define_padding_node():
81
+ assert isinstance(padding, list), type(padding)
82
+ padding_vec = torch.tensor(
83
+ [
84
+ [0, 0],
85
+ [padding[0], padding[0]],
86
+ [padding[1], padding[1]],
87
+ [0, 0],
88
+ ],
89
+ dtype=torch.int32,
90
+ )
91
+ input_shape = list(extract_shape(input))
92
+ input_dtype: int = extract_circle_dtype(input)
93
+ padded_input_shape = [
94
+ input_shape[0],
95
+ input_shape[1],
96
+ input_shape[2],
97
+ input_shape[3],
98
+ ]
99
+ padded_input_shape[1] += padding[0] * 2
100
+ padded_input_shape[2] += padding[1] * 2
101
+ input_qparam: Optional[QuantParam] = (
102
+ input.meta[QPARAM_KEY] if QPARAM_KEY in input.meta else None
103
+ )
104
+ # create padded input tensor
105
+ padded_input_tensor = self.graph.add_tensor_from_scratch(
106
+ prefix=f"{input.name}_pad_output",
107
+ shape=padded_input_shape,
108
+ dtype=input_dtype,
109
+ qparam=input_qparam,
110
+ source_node=node,
111
+ )
112
+ if input_qparam is not None:
113
+ padding_value = get_integer_dtype_min(input_qparam.dtype)
114
+ else:
115
+ padding_value = float("-inf")
116
+ pad_operator = self.define_padV2_node(
117
+ [input, padding_vec, padding_value], [padded_input_tensor]
118
+ )
119
+ self.graph.add_operator(pad_operator)
120
+ return padded_input_tensor
121
+
122
+ padding_type = PaddingType.VALID
123
+ if padding is not None:
124
+ if extract_shape(input) == extract_shape(node):
125
+ padding_type = PaddingType.SAME
126
+ else:
127
+ padding_type = PaddingType.VALID
128
+ if padding[0] != 0 or padding[1] != 0:
129
+ maxpool_input = define_padding_node()
130
+
131
+ inputs = [maxpool_input]
132
+ outputs = [node]
133
+
134
+ op_index = get_op_index(
135
+ circle.BuiltinOperator.BuiltinOperator.MAX_POOL_2D,
136
+ self._op_codes,
137
+ )
138
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
139
+
140
+ # Op-specific option
141
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.Pool2DOptions
142
+ option = circle.Pool2DOptions.Pool2DOptionsT()
143
+
144
+ option.padding = int(padding_type)
145
+ option.strideH = stride[0]
146
+ option.strideW = stride[1]
147
+ option.filterHeight = kernel_size[0]
148
+ option.filterWidth = kernel_size[1]
149
+ option.fusedActivationFunction = (
150
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
151
+ )
152
+
153
+ operator.builtinOptions = option
154
+
155
+ return operator
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import MaximumArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class MaximumVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.maximum.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
40
+ args = MaximumArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
41
+ input = args.input
42
+ other = args.other
43
+
44
+ op_index = get_op_index(
45
+ circle.BuiltinOperator.BuiltinOperator.MAXIMUM, self._op_codes
46
+ )
47
+
48
+ inputs = [input, other]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ return operator