tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,126 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.passes import ops
24
+
25
+ from tico.serialize.circle_graph import (
26
+ CircleSubgraph,
27
+ extract_circle_dtype,
28
+ extract_shape,
29
+ )
30
+ from tico.serialize.operators.hashable_opcode import OpCode
31
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
32
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
33
+ from tico.utils.validate_args_kwargs import ClampArgs
34
+
35
+
36
+ @register_node_visitor
37
+ class ClampVisitor(NodeVisitor):
38
+ target: List[torch._ops.OpOverload] = ops.aten.clamp
39
+
40
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
41
+ super().__init__(op_codes, graph)
42
+
43
+ def define_minimum_node(
44
+ self,
45
+ inputs: List[torch.fx.Node | circle.Tensor.TensorT | int | float],
46
+ outputs: List[torch.fx.Node | circle.Tensor.TensorT],
47
+ ) -> circle.Operator.OperatorT:
48
+
49
+ op_index = get_op_index(
50
+ circle.BuiltinOperator.BuiltinOperator.MINIMUM, self._op_codes
51
+ )
52
+
53
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
54
+
55
+ # Op-specific option
56
+ operator.builtinOptionsType = (
57
+ circle.BuiltinOptions.BuiltinOptions.MaximumMinimumOptions
58
+ )
59
+ option = circle.MaximumMinimumOptions.MaximumMinimumOptionsT()
60
+
61
+ operator.builtinOptions = option
62
+ return operator
63
+
64
+ def define_maximum_node(
65
+ self,
66
+ inputs: List[torch.fx.Node | circle.Tensor.TensorT | int | float],
67
+ outputs: List[torch.fx.Node | circle.Tensor.TensorT],
68
+ ) -> circle.Operator.OperatorT:
69
+
70
+ op_index = get_op_index(
71
+ circle.BuiltinOperator.BuiltinOperator.MAXIMUM, self._op_codes
72
+ )
73
+
74
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
75
+
76
+ # Op-specific option
77
+ operator.builtinOptionsType = (
78
+ circle.BuiltinOptions.BuiltinOptions.MaximumMinimumOptions
79
+ )
80
+ option = circle.MaximumMinimumOptions.MaximumMinimumOptionsT()
81
+
82
+ operator.builtinOptions = option
83
+
84
+ return operator
85
+
86
+ def define_node(
87
+ self,
88
+ node: torch.fx.Node,
89
+ ) -> circle.Operator.OperatorT:
90
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
91
+ input = args.input
92
+ min_val = args.min
93
+ max_val = args.max
94
+
95
+ if min_val is None and max_val is None:
96
+ raise ValueError("Both min and max cannot be None")
97
+
98
+ elif min_val is not None and max_val is None:
99
+ # min only
100
+ return self.define_maximum_node([input, min_val], [node])
101
+
102
+ elif min_val is None and max_val is not None:
103
+ # max only
104
+ return self.define_minimum_node([input, max_val], [node])
105
+
106
+ elif min_val is not None and max_val is not None:
107
+ input_shape = extract_shape(input)
108
+ input_dtype = extract_circle_dtype(input)
109
+ minimum_tensor = self.graph.add_tensor_from_scratch(
110
+ prefix=f"{input.name}_min",
111
+ dtype=input_dtype,
112
+ shape=list(input_shape),
113
+ source_node=node,
114
+ )
115
+ minimum_opertor = self.define_minimum_node(
116
+ [input, max_val], [minimum_tensor]
117
+ )
118
+ self.graph.add_operator(minimum_opertor)
119
+
120
+ maximum_operator = self.define_maximum_node(
121
+ [minimum_tensor, min_val], [node]
122
+ )
123
+ return maximum_operator
124
+
125
+ else:
126
+ raise RuntimeError("Cannot reach here")
@@ -0,0 +1,71 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.operators.hashable_opcode import OpCode
24
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
25
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
26
+ from tico.utils.validate_args_kwargs import CloneArgs
27
+
28
+
29
+ @register_node_visitor
30
+ class CloneVisitor(NodeVisitor):
31
+ """
32
+ Clone tensor
33
+ TODO: Support dim_order and memory_format
34
+ Tranpose may be required if 'memory_format' differs from input tensor's 'memory_format'
35
+ """
36
+
37
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.clone.default]
38
+
39
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
40
+ super().__init__(op_codes, graph)
41
+
42
+ def define_node(
43
+ self,
44
+ node: torch.fx.Node,
45
+ ) -> circle.Operator.OperatorT:
46
+ if "memory_format" in node.kwargs:
47
+ # TODO: Support dim_order and memory_format
48
+ pass
49
+
50
+ args = CloneArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
51
+ input = args.input
52
+
53
+ op_index = get_op_index(
54
+ circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
55
+ )
56
+
57
+ permute = torch.IntTensor(list(range(len(input.meta["val"].shape))))
58
+
59
+ inputs = [input, permute]
60
+ outputs = [node]
61
+
62
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
63
+
64
+ # Op-specific option
65
+ operator.builtinOptionsType = (
66
+ circle.BuiltinOptions.BuiltinOptions.TransposeOptions
67
+ )
68
+ option = circle.TransposeOptions.TransposeOptionsT()
69
+ operator.builtinOptions = option
70
+
71
+ return operator
@@ -0,0 +1,72 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.errors import InvalidArgumentError
28
+ from tico.utils.validate_args_kwargs import ConstantPadNdArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class ConstantPadNdVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.constant_pad_nd.default]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ args = ConstantPadNdArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
43
+ input_ = args.input
44
+ pad = args.pad
45
+ val = args.value
46
+
47
+ if val != 0:
48
+ raise InvalidArgumentError("Only support 0 value padding.")
49
+
50
+ input_shape_len = len(extract_shape(input_))
51
+ padding_size = [[pad[2], pad[3]], [pad[0], pad[1]]]
52
+ if input_shape_len == 3:
53
+ padding_size = [[0, 0]] + padding_size
54
+ elif input_shape_len == 4:
55
+ padding_size = [[0, 0], [0, 0]] + padding_size
56
+ else:
57
+ raise InvalidArgumentError("Only support 3D/4D inputs.")
58
+
59
+ paddings = torch.tensor(padding_size, dtype=torch.int32)
60
+ inputs = [input_, paddings]
61
+ outputs = [node]
62
+
63
+ op_index = get_op_index(
64
+ circle.BuiltinOperator.BuiltinOperator.PAD, self._op_codes
65
+ )
66
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
67
+
68
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.PadOptions
69
+ option = circle.PadOptions.PadOptionsT()
70
+ operator.builtinOptions = option
71
+
72
+ return operator
@@ -0,0 +1,186 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, Optional, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam
28
+ from tico.utils.define import define_pad_node
29
+ from tico.utils.padding import is_same_padding, is_valid_padding, SAME, VALID
30
+ from tico.utils.validate_args_kwargs import Conv2DArgs
31
+
32
+
33
+ @register_node_visitor
34
+ class Conv2dVisitor(NodeVisitor):
35
+ """
36
+ NOTE
37
+ - The padding of CircleConv2D has only padding type('VALID', 'SAME') in circle, but the padding of nn.Conv2d has padding type(('valid', 'same')), padding value(int)
38
+ and padding value(tuple->[pad_h, pad_w]).
39
+ ref: https://tensorflow.org/api_docs/python/tf/nn/conv2d
40
+
41
+ [1] With valid/same padding: CircleConv2D (only)
42
+
43
+ [ATEN IR]
44
+ Input[NHWC] ---- circle_cumstom.conv2d[NHWC] ---- OUTPUT[NHWC]
45
+ Weight[NHWC] ---/
46
+ Bias ----------/
47
+
48
+ [CIRCLE IR]
49
+ Input[NHWC] ---- CircleConv2D[NHWC] ---- OUTPUT[NHWC]
50
+ Weight[NHWC] ---/
51
+ Bias ----------/
52
+
53
+ [2] With additional padding: CirclePad + CircleConv2D
54
+
55
+ [ATEN IR]
56
+ Input[NHWC] ---- circle_cumstom.conv2d[NHWC] ---- OUTPUT[NHWC]
57
+ Weight[NHWC] ---/
58
+ Bias ----------/
59
+
60
+ [CIRCLE IR]
61
+ Input[NHWC] ---- CirclePad[NHWC] ---- CircleConv2D[NHWC] ---- OUTPUT[NHWC]
62
+ Weight[NHWC] ------/
63
+ Bias -------------/
64
+ """
65
+
66
+ target: List[torch._ops.OpOverload] = [
67
+ torch.ops.circle_custom.conv2d,
68
+ torch.ops.circle_custom.conv2d.padding,
69
+ ]
70
+
71
+ def define_conv2d_node(
72
+ self, padding: int, stride: List, dilation: List, inputs: List, outputs: List
73
+ ) -> circle.Operator.OperatorT:
74
+ def set_conv2d_option(operator, stride, dilation):
75
+ operator.builtinOptionsType = (
76
+ circle.BuiltinOptions.BuiltinOptions.Conv2DOptions
77
+ )
78
+ option = circle.Conv2DOptions.Conv2DOptionsT()
79
+ option.padding = padding
80
+ option.strideH = stride[0]
81
+ option.strideW = stride[1]
82
+ option.dilationHFactor = dilation[0]
83
+ option.dilationWFactor = dilation[1]
84
+ option.fusedActivationFunction = (
85
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
86
+ )
87
+ operator.builtinOptions = option
88
+
89
+ conv2d_op_index = get_op_index(
90
+ circle.BuiltinOperator.BuiltinOperator.CONV_2D, self._op_codes
91
+ )
92
+ operator = create_builtin_operator(self.graph, conv2d_op_index, inputs, outputs)
93
+ set_conv2d_option(operator, stride, dilation)
94
+ return operator
95
+
96
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
97
+ super().__init__(op_codes, graph)
98
+
99
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
100
+ # conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
101
+ # conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
102
+ args = Conv2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
103
+
104
+ input_ = args.input
105
+ weight = args.weight
106
+ bias = args.bias
107
+ stride = args.stride
108
+ padding = args.padding
109
+ dilation = args.dilation
110
+ groups = args.groups
111
+
112
+ assert groups == 1, "Only support group 1 conv2d"
113
+
114
+ input_dtype: int = extract_circle_dtype(input_)
115
+ input_shape = list(extract_shape(input_))
116
+ assert len(input_shape) == 4, len(input_shape)
117
+ output_shape = extract_shape(node)
118
+ assert len(output_shape) == 4, len(output_shape)
119
+
120
+ conv_input: torch.fx.node.Node | circle.Tensor.TensorT = input_
121
+ weight_shape = list(extract_shape(weight))
122
+
123
+ if is_valid_padding(padding):
124
+ conv2d_padding_type = VALID
125
+ elif is_same_padding(padding, input_shape, output_shape):
126
+ conv2d_padding_type = SAME
127
+ else:
128
+ assert isinstance(padding, list) and len(padding) == 2
129
+
130
+ conv2d_padding_type = VALID
131
+
132
+ # Padding is not valid or same, so we use valid padding and add padding operator before conv2d operator.
133
+ # when data_foramt is "NHWC", padding should be [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
134
+ paddings = torch.tensor(
135
+ [
136
+ [0, 0],
137
+ [padding[0], padding[0]],
138
+ [padding[1], padding[1]],
139
+ [0, 0],
140
+ ],
141
+ dtype=torch.int32,
142
+ )
143
+ pad_output_shape = [
144
+ input_shape[0],
145
+ input_shape[1],
146
+ input_shape[2],
147
+ input_shape[3],
148
+ ]
149
+ # Add (pad_top+pad_bottom) to pad_output_shape_h
150
+ pad_output_shape[1] += padding[0] * 2
151
+ # Add (pad_left+pad_Right) to pad_output_shape_w
152
+ pad_output_shape[2] += padding[1] * 2
153
+ # create padded output tensor
154
+ input_qparam: Optional[QuantParam] = (
155
+ input_.meta[QPARAM_KEY] if QPARAM_KEY in input_.meta else None
156
+ )
157
+ pad_output = self.graph.add_tensor_from_scratch(
158
+ prefix=f"{node.name}_input_pad_output",
159
+ shape=pad_output_shape,
160
+ dtype=input_dtype,
161
+ qparam=input_qparam,
162
+ source_node=node,
163
+ )
164
+ # CirclePad
165
+ pad_operator = define_pad_node(
166
+ self.graph, self._op_codes, [input_, paddings], [pad_output]
167
+ )
168
+ self.graph.add_operator(pad_operator)
169
+ conv_input = pad_output
170
+
171
+ if bias is None:
172
+ # luci-interpreter can't run no bias conv. Let's add zero vector for bias.
173
+ assert len(weight_shape) == 4
174
+ out_channel = weight_shape[0]
175
+ bias = [0.0] * out_channel # type: ignore[assignment]
176
+
177
+ # Conv2D
178
+ conv2d_operator = self.define_conv2d_node(
179
+ conv2d_padding_type, # 'SAME'(0) or 'VALID'(1)
180
+ stride,
181
+ dilation,
182
+ [conv_input, weight, bias],
183
+ [node],
184
+ )
185
+
186
+ return conv2d_operator
@@ -0,0 +1,164 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING, Union
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.errors import NotYetSupportedError
28
+ from tico.utils.validate_args_kwargs import CopyArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class CopyVisitor(NodeVisitor):
33
+ """
34
+ NOTE `torch.Tensor.copy_`'s behavior matches with `Reshape` of CIRCLE.
35
+ - because `torch.Tensor.copy_` is a in-place operator, so `dst` is converted to `Shape` of CIRCLE.
36
+ - after that, `dst` converted to `Shape` is connected to shape of `Reshape`.
37
+ - `src` is connected to tensor of `Reshape`.
38
+ - if `dst` is not converted to `Shape`.
39
+ [dst] [src]
40
+ |
41
+ [Reshape]
42
+ - if `dst` is converted to `Shape`.
43
+ [dst] [src]
44
+ | |
45
+ [Shape] |
46
+ \ /
47
+ [Reshape]
48
+ """
49
+
50
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.copy.default]
51
+
52
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
53
+ super().__init__(op_codes, graph)
54
+
55
+ def check_to_do_broadcast(self, dst: List[int], src: List[int]) -> bool:
56
+ return dst != src
57
+
58
+ def define_broadcast_to_node(
59
+ self,
60
+ inputs: List[Union[circle.Tensor.TensorT, torch.Tensor]],
61
+ outputs: List[circle.Tensor.TensorT],
62
+ ) -> circle.Operator.OperatorT:
63
+ op_index = get_op_index(
64
+ circle.BuiltinOperator.BuiltinOperator.BROADCAST_TO, self._op_codes
65
+ )
66
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
67
+ operator.builtinOptionsType = (
68
+ circle.BuiltinOptions.BuiltinOptions.BroadcastToOptions
69
+ )
70
+
71
+ option = circle.BroadcastToOptions.BroadcastToOptionsT()
72
+ operator.builtinOptions = option
73
+ return operator
74
+
75
+ def define_shape_node(
76
+ self, inputs: List[torch.fx.Node], outputs: List[circle.Tensor.TensorT]
77
+ ) -> circle.Operator.OperatorT:
78
+ op_index = get_op_index(
79
+ circle.BuiltinOperator.BuiltinOperator.SHAPE, self._op_codes
80
+ )
81
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
82
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ShapeOptions
83
+
84
+ option = circle.ShapeOptions.ShapeOptionsT()
85
+ option.outType = circle.TensorType.TensorType.INT32
86
+ operator.builtinOptions = option
87
+ return operator
88
+
89
+ def define_node(
90
+ self,
91
+ node: torch.fx.Node,
92
+ ) -> circle.Operator.OperatorT:
93
+ if len(node.args) == 3:
94
+ raise NotYetSupportedError("'non_blocking' is not supported yet.")
95
+
96
+ assert len(node.args) == 2, len(node.args)
97
+
98
+ args = CopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
99
+ dst = args.dst
100
+ src = args.src
101
+
102
+ # To connect 'dst' to Reshape node in the graph, 'dst' must be converted to Shape op.
103
+ dst_tensor: circle.Tensor.TensorT = self.graph.get_tensor(dst)
104
+ dst_shape: List[int] = dst_tensor.shape
105
+ dst_shape_tensor = torch.as_tensor(dst_shape, dtype=torch.int32)
106
+
107
+ dst_shape_shape = [len(dst_shape)]
108
+ dst_name: str = dst.name
109
+
110
+ shape_output = self.graph.add_tensor_from_scratch(
111
+ prefix=f"{dst_name}_shape_output",
112
+ shape=dst_shape_shape,
113
+ dtype=circle.TensorType.TensorType.INT32,
114
+ source_node=node,
115
+ )
116
+
117
+ shape_operator = self.define_shape_node([dst], [shape_output])
118
+ self.graph.add_operator(shape_operator)
119
+
120
+ src_tensor: circle.Tensor.TensorT = self.graph.get_tensor(src)
121
+ src_shape: List[int] = src_tensor.shape
122
+
123
+ # The src tensor must be broadcastable with the dst tensor.
124
+ do_broadcast = self.check_to_do_broadcast(dst_shape, src_shape)
125
+ if do_broadcast:
126
+ # create braodcastTo output tensor
127
+ src_name: str = src.name
128
+ src_type: int = src_tensor.type
129
+
130
+ broadcast_to_output: circle.Tensor.TensorT = (
131
+ self.graph.add_tensor_from_scratch(
132
+ prefix=f"{src_name}_broadcast_to_output",
133
+ shape=dst_shape,
134
+ dtype=src_type,
135
+ source_node=node,
136
+ )
137
+ )
138
+
139
+ broadcast_to_operator: circle.Operator.OperatorT = (
140
+ self.define_broadcast_to_node(
141
+ [src_tensor, dst_shape_tensor], [broadcast_to_output]
142
+ )
143
+ )
144
+ self.graph.add_operator(broadcast_to_operator)
145
+ inputs: List = [broadcast_to_output, shape_output]
146
+ else:
147
+ inputs = [src, shape_output]
148
+
149
+ outputs = [node]
150
+ op_index = get_op_index(
151
+ circle.BuiltinOperator.BuiltinOperator.RESHAPE, self._op_codes
152
+ )
153
+
154
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
155
+
156
+ # Op-specific option
157
+ operator.builtinOptionsType = (
158
+ circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
159
+ )
160
+ option = circle.ReshapeOptions.ReshapeOptionsT()
161
+ option.newShape = dst_shape
162
+
163
+ operator.builtinOptions = option
164
+ return operator
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import CosArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class CosVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.cos.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ args = CosArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
42
+ input = args.input
43
+
44
+ op_index = get_op_index(
45
+ circle.BuiltinOperator.BuiltinOperator.COS, self._op_codes
46
+ )
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ # Op-specific option
54
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.CosOptions
55
+ option = circle.CosOptions.CosOptionsT()
56
+
57
+ operator.builtinOptions = option
58
+
59
+ return operator