tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,150 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from circle_schema import circle
21
+
22
+ from tico.serialize.circle_graph import CircleSubgraph
23
+ from tico.serialize.circle_mapping import (
24
+ circle_legalize_dtype_to,
25
+ extract_circle_dtype,
26
+ extract_shape,
27
+ extract_torch_dtype,
28
+ )
29
+ from tico.serialize.operators.hashable_opcode import OpCode
30
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
31
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
32
+ from tico.utils.validate_args_kwargs import AnyArgs
33
+
34
+
35
+ @register_node_visitor
36
+ class AnyVisitor(NodeVisitor):
37
+ """
38
+ Let's take NotEqual0 -> ReduceMax workaround for float, int
39
+ [RESTRICTION]
40
+ 1. ReduceAny is not supported (luci-interpreter)
41
+ [CASE: BOOL]
42
+ (Bool tensors don't need 'Not Equal 0' at the first step.)
43
+ bool[d0..dN] --- Reduce Max ---> bool[]
44
+ [CASE: FLOAT, INT]
45
+ int/float[d0..dN] --- Not Equal 0 ---> bool[d0,...dN]
46
+ --- Reduce Max ---> bool[]
47
+ * [d0..dN] means a tensor with any shape
48
+ * [] means Scalar
49
+ """
50
+
51
+ target: List[torch._ops.OpOverload] = [
52
+ torch.ops.aten.any.default,
53
+ torch.ops.aten.any.dim,
54
+ torch.ops.aten.any.dims,
55
+ ]
56
+
57
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
58
+ super().__init__(op_codes, graph)
59
+
60
+ def define_max_node(
61
+ self, inputs: List, outputs: List, keepdims: bool
62
+ ) -> circle.Operator.OperatorT:
63
+ op_index = get_op_index(
64
+ circle.BuiltinOperator.BuiltinOperator.REDUCE_MAX, self._op_codes
65
+ )
66
+
67
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
68
+
69
+ operator.builtinOptionsType = (
70
+ circle.BuiltinOptions.BuiltinOptions.ReducerOptions
71
+ )
72
+ option = circle.ReducerOptions.ReducerOptionsT()
73
+ option.keepDims = keepdims
74
+
75
+ operator.builtinOptions = option
76
+
77
+ return operator
78
+
79
+ def define_ne_node(self, inputs: List, outputs: List) -> circle.Operator.OperatorT:
80
+ op_index = get_op_index(
81
+ circle.BuiltinOperator.BuiltinOperator.NOT_EQUAL, self._op_codes
82
+ )
83
+
84
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
85
+
86
+ operator.builtinOptionsType = (
87
+ circle.BuiltinOptions.BuiltinOptions.NotEqualOptions
88
+ )
89
+ option = circle.NotEqualOptions.NotEqualOptionsT()
90
+ operator.builtinOptions = option
91
+ return operator
92
+
93
+ def define_node(
94
+ self,
95
+ node: torch.fx.Node,
96
+ ) -> circle.Operator.OperatorT:
97
+ args = AnyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
98
+ input = args.input
99
+ dim = args.dim
100
+ keepdim = args.keepdim
101
+
102
+ input_shape = list(extract_shape(input))
103
+ output_shape = list(extract_shape(node))
104
+
105
+ dim_i32 = None
106
+ if dim is None:
107
+ dims = tuple(i for i in range(0, len(input_shape)))
108
+ dim_i32 = tuple(
109
+ circle_legalize_dtype_to(dim, dtype=torch.int32) for dim in dims
110
+ )
111
+ if isinstance(dim, int):
112
+ dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
113
+ if isinstance(dim, tuple):
114
+ dim_i32 = tuple(circle_legalize_dtype_to(d, dtype=torch.int32) for d in dim)
115
+ assert dim_i32 is not None
116
+
117
+ inputs = [
118
+ input,
119
+ dim_i32,
120
+ ] # type: ignore[list-item]
121
+ outputs = [node]
122
+
123
+ dtype_torch = extract_torch_dtype(input)
124
+ input_tensor: torch.fx.node.Node | circle.Tensor.TensorT = input
125
+
126
+ if dtype_torch in [torch.int32, torch.int64, torch.float32, torch.float64]:
127
+ dst_dtype_circle = circle.TensorType.TensorType.BOOL
128
+ dst_dtype_torch = torch.bool
129
+ ne_tensor: circle.Tensor.TensorT = self.graph.add_tensor_from_scratch(
130
+ prefix=f"{input.name}_ne",
131
+ shape=input_shape,
132
+ dtype=dst_dtype_circle,
133
+ source_node=input,
134
+ )
135
+ ne_node = self.define_ne_node(
136
+ [input_tensor, torch.Tensor([0]).to(dtype_torch)], [ne_tensor]
137
+ )
138
+ self.graph.add_operator(ne_node)
139
+
140
+ dtype_torch = dst_dtype_torch
141
+ input_tensor = ne_tensor
142
+ inputs = [ne_tensor, dim_i32]
143
+
144
+ inputs = [input_tensor, dim_i32]
145
+
146
+ reduce_node: circle.Operator.OperatorT = self.define_max_node(
147
+ inputs, outputs, keepdim
148
+ )
149
+
150
+ return reduce_node
@@ -0,0 +1,61 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.utils.validate_args_kwargs import ArangeStartStepArgs
27
+
28
+
29
+ @register_node_visitor
30
+ class ArangeStartStepVisitor(NodeVisitor):
31
+ """
32
+ Fuse arange_start_step to const_tensor
33
+ """
34
+
35
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.arange.start_step]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(
41
+ self,
42
+ node: torch.fx.Node,
43
+ ) -> circle.Operator.OperatorT:
44
+ args = ArangeStartStepArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
45
+ start = args.start
46
+ end = args.end
47
+ step = args.step
48
+ delta = 1
49
+
50
+ if step is not None:
51
+ delta = step[0] # type: ignore[index]
52
+ # assert False, "This pass must not be in use."
53
+
54
+ arange_dtype: torch.dtype = torch.float32
55
+ if isinstance(start, int) and isinstance(end, int):
56
+ arange_dtype = torch.int64
57
+
58
+ output_data = torch.arange(start=start, end=end, step=delta, dtype=arange_dtype)
59
+ self.graph.update_tensor_buffer(output_data, node.name)
60
+
61
+ return None # type: ignore[return-value]
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import ArgMaxArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class ArgMaxVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.argmax.default]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ args = ArgMaxArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
43
+ tensor = args.tensor
44
+ dim = args.dim
45
+
46
+ op_index = get_op_index(
47
+ circle.BuiltinOperator.BuiltinOperator.ARG_MAX, self._op_codes
48
+ )
49
+
50
+ dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
51
+ inputs = [tensor, dim_i32]
52
+ outputs = [node]
53
+
54
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
55
+
56
+ # Op-specific option
57
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ArgMaxOptions
58
+ option = circle.ArgMaxOptions.ArgMaxOptionsT()
59
+ option.outputType = circle.TensorType.TensorType.INT64
60
+ operator.builtinOptions = option
61
+
62
+ return operator
@@ -0,0 +1,192 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Dict, List, TYPE_CHECKING
17
+
18
+ if TYPE_CHECKING:
19
+ import torch._ops
20
+ import torch.fx
21
+ import torch
22
+ from circle_schema import circle
23
+
24
+ from tico.serialize.circle_graph import CircleSubgraph
25
+ from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
26
+ from tico.serialize.operators.hashable_opcode import OpCode
27
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
28
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
29
+ from tico.utils.define import define_pad_node
30
+ from tico.utils.errors import NotYetSupportedError
31
+ from tico.utils.validate_args_kwargs import AvgPool2dArgs
32
+
33
+
34
+ @register_node_visitor
35
+ class AvgPool2DVisitor(NodeVisitor):
36
+ """
37
+ This class defines how to serialize AvgPool2D operation into Circle IR.
38
+
39
+ Torch | Circle
40
+
41
+ count_include_pad: True/False | (count_include_pad): Always False
42
+ padding: number (could be valid, same, or etc) | padding: "valid"/"same"
43
+
44
+ * Circle's avgpool2d has no option for count_include_pad, so we always set it as False.
45
+ """
46
+
47
+ target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.avgpool2d]
48
+
49
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
50
+ super().__init__(op_codes, graph)
51
+
52
+ def has_padding(self, args: AvgPool2dArgs) -> bool:
53
+ padding = args.padding
54
+ if padding[0] == 0 and padding[1] == 0:
55
+ return False
56
+ else:
57
+ return True
58
+
59
+ def has_same_padding(self, args: AvgPool2dArgs) -> bool:
60
+ input_shape = list(extract_shape(args.input))
61
+ kernel_size = args.kernel_size
62
+ stride = args.stride
63
+ assert stride
64
+ padding = args.padding
65
+ # TODO Update this function when supporting ceil_mode = True
66
+ assert args.ceil_mode is False
67
+ output_height = math.floor(
68
+ (input_shape[1] + padding[0] * 2 - kernel_size[0]) / stride[0] + 1
69
+ )
70
+ output_width = math.floor(
71
+ (input_shape[2] + padding[1] * 2 - kernel_size[1]) / stride[1] + 1
72
+ )
73
+
74
+ return input_shape[1] == output_height and input_shape[2] == output_width
75
+
76
+ def define_avgpool_node(self, inputs, outputs, padding, stride, kernel_size):
77
+ op_index = get_op_index(
78
+ circle.BuiltinOperator.BuiltinOperator.AVERAGE_POOL_2D,
79
+ self._op_codes,
80
+ )
81
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
82
+
83
+ # Op-specific option
84
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.Pool2DOptions
85
+ option = circle.Pool2DOptions.Pool2DOptionsT()
86
+
87
+ assert padding in {"SAME": 0, "VALID": 1}
88
+
89
+ option.padding = {"SAME": 0, "VALID": 1}[padding]
90
+ option.strideH = stride[0]
91
+ option.strideW = stride[1]
92
+ option.filterHeight = kernel_size[0]
93
+ option.filterWidth = kernel_size[1]
94
+ option.fusedActivationFunction = (
95
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
96
+ )
97
+
98
+ operator.builtinOptions = option
99
+ return operator
100
+
101
+ def define_node(
102
+ self,
103
+ node: torch.fx.Node,
104
+ ) -> circle.Operator.OperatorT:
105
+ """
106
+ PSEUDO CODE
107
+
108
+ if count_include_pad == True:
109
+ (Circle cannot represent count_include_pad=True in AvgPool2D. Therefore we manually add zero padding node.)
110
+ DEFINE zero padding node
111
+ DEFINE avgpool node with no padding (valid)
112
+ if count_include_pad == False:
113
+ (Lucky! Circle can represent count_include_pad=False)
114
+ DEFINE avgpool node with same/valid padding.
115
+
116
+ (However, it cannot represent all paddings. So, if the padding is not same or valid, we throw an error.)
117
+ if the paddding is neither same nor valid:
118
+ THROW an error.
119
+ """
120
+ args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
121
+ input = args.input
122
+ kernel_size = args.kernel_size
123
+ stride = args.stride
124
+ padding = args.padding
125
+ count_include_pad = args.count_include_pad
126
+
127
+ avgpool_input: torch.fx.Node | circle.Tensor.TensorT = input
128
+
129
+ def define_padding_node():
130
+ assert isinstance(padding, list), type(padding)
131
+ padding_vec = torch.tensor(
132
+ [
133
+ [0, 0],
134
+ [padding[0], padding[0]],
135
+ [padding[1], padding[1]],
136
+ [0, 0],
137
+ ],
138
+ dtype=torch.int32,
139
+ )
140
+ input_shape = list(extract_shape(input))
141
+ input_dtype: int = extract_circle_dtype(input)
142
+ padded_input_shape = [
143
+ input_shape[0],
144
+ input_shape[1],
145
+ input_shape[2],
146
+ input_shape[3],
147
+ ]
148
+ padded_input_shape[1] += padding[0] * 2
149
+ padded_input_shape[2] += padding[1] * 2
150
+ # create padded input tensor
151
+ padded_input_tensor = self.graph.add_tensor_from_scratch(
152
+ prefix=f"{input.name}_pad_output",
153
+ shape=padded_input_shape,
154
+ dtype=input_dtype,
155
+ source_node=node,
156
+ )
157
+ pad_operator = define_pad_node(
158
+ self.graph, self._op_codes, [input, padding_vec], [padded_input_tensor]
159
+ )
160
+ self.graph.add_operator(pad_operator)
161
+ return padded_input_tensor
162
+
163
+ if count_include_pad is True:
164
+ # Add padding before avgpool2d
165
+ # Circle's avgpool2d does not support count_include_pad=True, so we need to add padding manually
166
+ if self.has_padding(args):
167
+ avgpool_input = define_padding_node()
168
+
169
+ result = self.define_avgpool_node(
170
+ [avgpool_input], [node], "VALID", stride, kernel_size
171
+ )
172
+ elif count_include_pad is False:
173
+ if not self.has_padding(args): # valid padding
174
+ result = self.define_avgpool_node(
175
+ [avgpool_input], [node], "VALID", stride, kernel_size
176
+ )
177
+ elif self.has_same_padding(args):
178
+ result = self.define_avgpool_node(
179
+ [avgpool_input], [node], "SAME", stride, kernel_size
180
+ )
181
+ else:
182
+ # CASE: count_include_pad is False and not VALID/SAME padding
183
+ #
184
+ # Implement this when it's needed.
185
+ # If needed, may it help: the idea of ratio masking in https://github.com/Samsung/TICO/pull/119
186
+ raise NotYetSupportedError(
187
+ f"Padding({padding}) with count_include_pad({count_include_pad}) is not supported yet."
188
+ )
189
+ else:
190
+ raise RuntimeError("Cannot reach here")
191
+
192
+ return result
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import BmmArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class BatchMatmulVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.bmm.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
42
+ input = args.input
43
+ mat2 = args.mat2
44
+
45
+ op_index = get_op_index(
46
+ circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
47
+ )
48
+
49
+ inputs = [input, mat2]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ # Op-specific option
55
+ operator.builtinOptionsType = (
56
+ circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
57
+ )
58
+ option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
59
+ option.adjointLhs, option.adjointRhs = False, False
60
+ operator.builtinOptions = option
61
+
62
+ return operator
@@ -0,0 +1,66 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.operators.hashable_opcode import OpCode
24
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
25
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
26
+ from tico.utils.validate_args_kwargs import CatArgs
27
+
28
+
29
+ @register_node_visitor
30
+ class CatVisitor(NodeVisitor):
31
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.cat.default]
32
+
33
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
34
+ super().__init__(op_codes, graph)
35
+
36
+ def define_node(
37
+ self,
38
+ node: torch.fx.Node,
39
+ ) -> circle.Operator.OperatorT:
40
+ args = CatArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
41
+ tensors = args.tensors
42
+ dim = args.dim
43
+
44
+ op_index = get_op_index(
45
+ circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes
46
+ )
47
+
48
+ inputs = tensors
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ # Op-specific option
54
+ operator.builtinOptionsType = (
55
+ circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions
56
+ )
57
+ option = circle.ConcatenationOptions.ConcatenationOptionsT()
58
+
59
+ option.axis = dim
60
+
61
+ option.fusedActivationFunction = (
62
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
63
+ )
64
+ operator.builtinOptions = option
65
+
66
+ return operator