tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import SelectCopyIntArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class SelectCopyIntVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.select_copy.int,
34
+ torch.ops.aten.select.int,
35
+ ]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(
41
+ self,
42
+ node: torch.fx.Node,
43
+ ) -> circle.Operator.OperatorT:
44
+ args = SelectCopyIntArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
45
+ input = args.input
46
+ dim = args.dim
47
+ index = args.index
48
+
49
+ indices = torch.as_tensor(index, dtype=torch.int32)
50
+ inputs = [input, indices]
51
+ outputs = [node]
52
+
53
+ op_index = get_op_index(
54
+ circle.BuiltinOperator.BuiltinOperator.GATHER, self._op_codes
55
+ )
56
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
57
+
58
+ # Op-specific option
59
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.GatherOptions
60
+ option = circle.GatherOptions.GatherOptionsT()
61
+ option.axis = dim
62
+ # TODO option.batchDims
63
+ operator.builtinOptions = option
64
+
65
+ return operator
@@ -0,0 +1,56 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import SigmoidArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class SigmoidVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.sigmoid.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ # `Sigmoid` operation is implemented as a `Logistic` operation in circle.
43
+ # https://github.com/Samsung/ONE/blob/170382a/nnpackage/schema/circle_schema.fbs#L288
44
+ circle.BuiltinOperator.BuiltinOperator.LOGISTIC,
45
+ self._op_codes,
46
+ )
47
+
48
+ args = SigmoidArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
49
+ input = args.input
50
+
51
+ inputs = [input]
52
+ outputs = [node]
53
+
54
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
55
+
56
+ return operator
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import SinArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class SinVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.sin.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.SIN, self._op_codes
43
+ )
44
+
45
+ args = SinArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ return operator
@@ -0,0 +1,155 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ from typing import Dict, List, TYPE_CHECKING
17
+
18
+ if TYPE_CHECKING:
19
+ import torch._ops
20
+ import torch.fx
21
+ import torch
22
+ from circle_schema import circle
23
+
24
+ from tico.serialize.circle_graph import CircleSubgraph
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.errors import InvalidArgumentError
29
+ from tico.utils.validate_args_kwargs import SliceArgs
30
+
31
+
32
+ @register_node_visitor
33
+ class SliceCopyVisitor(NodeVisitor):
34
+ """
35
+ NOTE `torch.slice_copy`'s behavior matches with `strided slice` of CIRCLE, not `slice`.
36
+ """
37
+
38
+ target: List[torch._ops.OpOverload] = [
39
+ torch.ops.aten.slice.Tensor,
40
+ torch.ops.aten.slice_copy.Tensor,
41
+ ]
42
+
43
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
44
+ super().__init__(op_codes, graph)
45
+
46
+ def define_node(
47
+ self,
48
+ node: torch.fx.Node,
49
+ ) -> circle.Operator.OperatorT:
50
+ op_index = get_op_index(
51
+ circle.BuiltinOperator.BuiltinOperator.STRIDED_SLICE, self._op_codes
52
+ )
53
+
54
+ args = SliceArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
55
+ input = args.input
56
+ dim = args.dim
57
+ start = args.start
58
+ end = args.end
59
+ step = args.step
60
+
61
+ input_tensor: circle.Tensor.TensorT = self.graph.get_tensor(input)
62
+ input_shape: List[int] = input_tensor.shape
63
+
64
+ if start is None:
65
+ start = 0
66
+ if end is None:
67
+ end = input_shape[dim]
68
+ if step is None:
69
+ step = 1
70
+
71
+ assert dim is not None
72
+ assert (
73
+ -len(input_shape) <= dim < len(input_shape)
74
+ ), "Cannot reach here (Dimension Out of Range error must be thrown by torch)"
75
+
76
+ if dim < 0:
77
+ dim = dim % len(input_shape)
78
+
79
+ assert isinstance(start, int), type(start)
80
+ assert isinstance(end, int), type(end)
81
+ assert isinstance(step, int), type(step)
82
+
83
+ if start < -input_shape[dim]: # (-inf, -M)
84
+ """
85
+ WHY is 0?
86
+ start = -input_shape[dim] % input_shape[dim]
87
+ """
88
+ start = 0
89
+ elif -input_shape[dim] <= start < 0: # [-M, 0)
90
+ start %= input_shape[dim]
91
+ elif 0 <= start < input_shape[dim]: # [0, M)
92
+ start = start
93
+ elif input_shape[dim] <= start: # [M, +inf)
94
+ start = input_shape[dim]
95
+ else:
96
+ assert False, "Cannot reach here"
97
+
98
+ if end < -input_shape[dim]: # (-inf, -M)
99
+ """
100
+ WHY is 0?
101
+ end = -input_shape[dim] % input_shape[dim]
102
+ """
103
+ end = 0
104
+ elif -input_shape[dim] <= end < 0: # [-M, 0)
105
+ end %= input_shape[dim]
106
+ elif 0 <= end < input_shape[dim]: # [0, M)
107
+ end = end
108
+ elif input_shape[dim] <= end: # [M, +inf)
109
+ end = input_shape[dim]
110
+ else:
111
+ assert False, "Cannot reach here"
112
+
113
+ assert 0 <= dim and dim < len(input_shape), dim
114
+ assert 0 <= start and start < input_shape[dim], start
115
+ assert 0 <= end and end <= input_shape[dim], end
116
+ assert 0 < step, "Restriction of torch.slice_copy"
117
+
118
+ if end <= start:
119
+ """
120
+ CONSTRAINTS
121
+ In torch, 'end <= start' condition generates zero tensor with a peculiar shape - ex. tensor([], size=(5,0,5))
122
+ In circle, it's not accepted at all.
123
+ """
124
+ raise InvalidArgumentError(
125
+ f"end({end}) must be greater than start ({start})"
126
+ )
127
+
128
+ # Build new arguments
129
+ rank = len(input_shape)
130
+
131
+ begin_shape = [0] * rank
132
+ begin_shape[dim] = start
133
+ begin_shape_tensor = torch.as_tensor(begin_shape, dtype=torch.int32)
134
+
135
+ end_shape = copy.deepcopy(input_shape)
136
+ end_shape[dim] = end
137
+ end_shape_tensor = torch.as_tensor(end_shape, dtype=torch.int32)
138
+
139
+ stride_shape = [1] * rank
140
+ stride_shape[dim] = step
141
+ stride_shape_tensor = torch.as_tensor(stride_shape, dtype=torch.int32)
142
+
143
+ inputs = [input, begin_shape_tensor, end_shape_tensor, stride_shape_tensor]
144
+ outputs = [node]
145
+
146
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
147
+
148
+ operator.builtinOptionsType = (
149
+ circle.BuiltinOptions.BuiltinOptions.StridedSliceOptions
150
+ )
151
+
152
+ option = circle.StridedSliceOptions.StridedSliceOptionsT()
153
+
154
+ operator.builtinOptions = option
155
+ return operator
@@ -0,0 +1,100 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.operators.hashable_opcode import OpCode
24
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
25
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
26
+ from tico.utils.errors import NotYetSupportedError
27
+ from tico.utils.utils import HAS_TORCH_OVER_25
28
+ from tico.utils.validate_args_kwargs import SafeSoftmaxArgs, SoftmaxArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class SoftMaxVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = (
34
+ [
35
+ torch.ops.aten._softmax.default,
36
+ # NOTE: Let's treat _safe_softmax as normal _softmax as its usage is for training.
37
+ # In order for optimization during inference, it can be replaced to softmax.
38
+ # ref: https://github.com/pytorch/pytorch/pull/133882
39
+ torch.ops.aten._safe_softmax.default,
40
+ ]
41
+ if HAS_TORCH_OVER_25
42
+ else [
43
+ torch.ops.aten._softmax.default,
44
+ ]
45
+ )
46
+
47
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
48
+ super().__init__(op_codes, graph)
49
+
50
+ def define_softmax_node(self, inputs, outputs) -> circle.Operator.OperatorT:
51
+ op_index = get_op_index(
52
+ circle.BuiltinOperator.BuiltinOperator.SOFTMAX, self._op_codes
53
+ )
54
+
55
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
+ operator.builtinOptionsType = (
57
+ circle.BuiltinOptions.BuiltinOptions.SoftmaxOptions
58
+ )
59
+ option = circle.SoftmaxOptions.SoftmaxOptionsT()
60
+ option.beta = 1.0
61
+ operator.builtinOptions = option
62
+ return operator
63
+
64
+ def define_node(
65
+ self,
66
+ node: torch.fx.Node,
67
+ ) -> circle.Operator.OperatorT:
68
+ """
69
+ Note that Currently, Softmax operator is supported only when `dim` is last dimension and `half_to_float` is False.
70
+ """
71
+ if node.target == torch.ops.aten._softmax.default:
72
+ # aten._softmax
73
+ args = SoftmaxArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
74
+ half_to_float: bool = args.half_to_float
75
+ if half_to_float:
76
+ raise NotYetSupportedError(
77
+ "softmax with half to float conversion is not supported on circle."
78
+ )
79
+ elif node.target == torch.ops.aten._safe_softmax.default:
80
+ # aten._safe_softmax
81
+ args = SafeSoftmaxArgs(*node.args, **node.kwargs) # type: ignore[arg-type, assignment]
82
+
83
+ input: torch.fx.Node = args.input
84
+ dim: int = args.dim
85
+
86
+ input_tid: int = self.graph.get_tid_registered(input)
87
+ input_tensor: circle.Tensor.TensorT = self.graph.tensors[input_tid]
88
+ input_shape: List[int] = input_tensor.shape
89
+
90
+ if dim < 0:
91
+ dim = dim % len(input_shape)
92
+
93
+ if dim == len(input_shape) - 1:
94
+ inputs = [input]
95
+ outputs = [node]
96
+ operator = self.define_softmax_node(inputs, outputs)
97
+ else:
98
+ raise NotYetSupportedError("softmax only supports last dimension for now.")
99
+
100
+ return operator
@@ -0,0 +1,96 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING, Union
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+ from torch._subclasses.fake_tensor import FakeTensor
23
+
24
+ from tico.serialize.circle_graph import CircleSubgraph
25
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to, to_circle_dtype
26
+ from tico.serialize.operators.hashable_opcode import OpCode
27
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
28
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
29
+ from tico.utils.validate_args_kwargs import SplitWithSizesArgs
30
+
31
+
32
+ @register_node_visitor
33
+ class SplitWithSizesVisitor(NodeVisitor):
34
+ target: List[torch._ops.OpOverload] = [
35
+ torch.ops.aten.split_with_sizes.default,
36
+ ]
37
+
38
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
39
+ super().__init__(op_codes, graph)
40
+
41
+ def define_node(
42
+ self,
43
+ node: torch.fx.Node,
44
+ ) -> circle.Operator.OperatorT:
45
+ op_index = get_op_index(
46
+ circle.BuiltinOperator.BuiltinOperator.SPLIT_V, self._op_codes
47
+ )
48
+ args = SplitWithSizesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
49
+ input = args.input
50
+ split_sizes = args.split_sizes
51
+ axis = args.dim
52
+
53
+ split_sizes_i32 = [
54
+ circle_legalize_dtype_to(split_size, dtype=torch.int32)
55
+ for split_size in split_sizes
56
+ ]
57
+ axis_i32 = circle_legalize_dtype_to(axis, dtype=torch.int32)
58
+ inputs = [input, split_sizes_i32, axis_i32]
59
+
60
+ """
61
+ `split_with_sizes` has multiple output tensors and they are represented as `getitem`.
62
+ Therefore, unlike other ops, node itself doesn't become a circle tensor. Instead, each `getitem` will be
63
+ a circle tensor.
64
+ Further, torch module having `split_with_sizes` may somtimes return selected outputs. At that time, `getitem`
65
+ nodes are generated only for the ouptut selected. Since one-compiler assumes that `CircleSplitV` always has
66
+ all the outputs, let's add unused output tensors to compensate this restriction.
67
+ """
68
+ outputs: List[Union[circle.Tensor.TensorT, torch.fx.node.Node]] = []
69
+ sorted_users = sorted(node.users.keys(), key=lambda x: x.args[1]) # type: ignore[arg-type, return-value]
70
+ users_indices = list(usrnode.args[1] for usrnode in sorted_users)
71
+ user_it = iter(sorted_users)
72
+ for idx, _ in enumerate(split_sizes):
73
+ if idx in users_indices:
74
+ user_node = next(user_it)
75
+ outputs.append(user_node)
76
+ else:
77
+ # Let's add unused output tensor to satisfy circle split_v operator scheme
78
+ node_val = node.meta.get("val")
79
+ assert isinstance(node_val, list)
80
+ fake_tensor = node_val[idx]
81
+ assert isinstance(fake_tensor, FakeTensor)
82
+ shape = list(fake_tensor.size())
83
+ dtype = to_circle_dtype(fake_tensor.dtype)
84
+ tensor = self.graph.add_tensor_from_scratch(
85
+ f"{node.name}_unused_{idx}", shape, dtype
86
+ )
87
+ outputs.append(tensor)
88
+
89
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
90
+
91
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.SplitVOptions
92
+ option = circle.SplitVOptions.SplitVOptionsT()
93
+ option.numSplits = len(split_sizes)
94
+ operator.builtinOptions = option
95
+
96
+ return operator
@@ -0,0 +1,55 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import SqrtArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class SqrtVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.sqrt.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.SQRT, self._op_codes
43
+ )
44
+
45
+ args = SqrtArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ # `sqrt` does not have option
54
+
55
+ return operator
@@ -0,0 +1,73 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import SqueezeArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class SqueezeVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.squeeze.dims,
34
+ torch.ops.aten.squeeze_copy.dims,
35
+ ]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(
41
+ self,
42
+ node: torch.fx.Node,
43
+ ) -> circle.Operator.OperatorT:
44
+ op_index = get_op_index(
45
+ circle.BuiltinOperator.BuiltinOperator.SQUEEZE,
46
+ self._op_codes,
47
+ )
48
+
49
+ args = SqueezeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
50
+ input = args.input
51
+ dims = args.dims
52
+
53
+ inputs = [input]
54
+ outputs = [node]
55
+
56
+ squeeze_dims: List = []
57
+ shape = input.meta["val"].size()
58
+ if dims:
59
+ squeeze_dims += [axis for axis in dims if shape[axis] == 1]
60
+
61
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
62
+
63
+ # Op-specific option
64
+ operator.builtinOptionsType = (
65
+ circle.BuiltinOptions.BuiltinOptions.SqueezeOptions
66
+ )
67
+ option = circle.SqueezeOptions.SqueezeOptionsT()
68
+ if squeeze_dims:
69
+ option.squeezeDims = squeeze_dims
70
+
71
+ operator.builtinOptions = option
72
+
73
+ return operator