tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torch.export import ExportedProgram
16
+
17
+ from tico.passes import ops
18
+ from tico.serialize.circle_mapping import extract_shape
19
+ from tico.utils import logging
20
+ from tico.utils.passes import PassBase, PassResult
21
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
22
+ from tico.utils.utils import is_target_node
23
+ from tico.utils.validate_args_kwargs import SliceArgs
24
+
25
+
26
+ @trace_graph_diff_on_pass
27
+ class RemoveRedundantSlice(PassBase):
28
+ """
29
+ This pass removes redundant slice operators where shapes of input and output are same.
30
+ """
31
+
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ def call(self, exported_program: ExportedProgram) -> PassResult:
36
+ logger = logging.getLogger(__name__)
37
+
38
+ graph_module = exported_program.graph_module
39
+ graph = graph_module.graph
40
+ modified = False
41
+ for node in graph.nodes:
42
+ if not is_target_node(node, ops.aten.slice):
43
+ continue
44
+
45
+ args = SliceArgs(*node.args, **node.kwargs)
46
+
47
+ input_shape = extract_shape(args.input)
48
+ node_shape = extract_shape(node)
49
+
50
+ if input_shape != node_shape:
51
+ continue
52
+
53
+ node.replace_all_uses_with(args.input, propagate_meta=False)
54
+
55
+ modified = True
56
+ logger.debug(f"{node.name} is replaced with {args.input.name}")
57
+
58
+ graph.eliminate_dead_code()
59
+ graph.lint()
60
+ graph_module.recompile()
61
+
62
+ return PassResult(modified)
@@ -0,0 +1,86 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Union
16
+
17
+ import torch
18
+ from torch.export import ExportedProgram
19
+
20
+ from tico.passes import ops
21
+ from tico.serialize.circle_mapping import extract_torch_dtype
22
+ from tico.utils import logging
23
+ from tico.utils.passes import PassBase, PassResult
24
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.utils import is_target_node
26
+ from tico.utils.validate_args_kwargs import ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs
27
+
28
+
29
+ @trace_graph_diff_on_pass
30
+ class RemoveRedundantToCopy(PassBase):
31
+ """
32
+ This pass removes redundant `aten._to_copy` operators.
33
+ """
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+
38
+ def call(self, exported_program: ExportedProgram) -> PassResult:
39
+ logger = logging.getLogger(__name__)
40
+
41
+ graph_module = exported_program.graph_module
42
+ graph = graph_module.graph
43
+ modified = False
44
+ for node in graph.nodes:
45
+ if not is_target_node(node, ops.aten.to_copy):
46
+ continue
47
+
48
+ args: Union[ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs]
49
+ if node.target == torch.ops.aten._to_copy.default:
50
+ args = ToCopyArgs(*node.args, **node.kwargs)
51
+ elif node.target == torch.ops.aten.to.dtype:
52
+ args = ToDtypeArgs(*node.args, **node.kwargs)
53
+ elif node.target == torch.ops.aten.to.dtype_layout:
54
+ args = ToDtypeLayoutArgs(*node.args, **node.kwargs)
55
+ else:
56
+ raise NotImplementedError(
57
+ f"Unsupported to_copy operator: {node.target}"
58
+ )
59
+
60
+ input_ = args.input
61
+ # https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout
62
+ # layout is two types: torch.strided(dense Tensors), torch.sparse_coo(sparse COO Tensors)
63
+ if hasattr(args, "layout") and args.layout is not None:
64
+ if args.layout != input_.meta["val"].layout:
65
+ continue
66
+
67
+ if hasattr(args, "dtype") and args.dtype is not None:
68
+ target_dtype = args.dtype
69
+ input_dtype = extract_torch_dtype(input_)
70
+ if input_dtype != target_dtype:
71
+ continue
72
+
73
+ if hasattr(args, "memory_format") and args.memory_format is not None:
74
+ if args.memory_format != torch.contiguous_format:
75
+ continue
76
+
77
+ node.replace_all_uses_with(input_, propagate_meta=False)
78
+
79
+ modified = True
80
+ logger.debug(f"{node.name} is replaced with {input_.name}")
81
+
82
+ graph.eliminate_dead_code()
83
+ graph.lint()
84
+ graph_module.recompile()
85
+
86
+ return PassResult(modified)
@@ -0,0 +1,115 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch.export import ExportedProgram
17
+
18
+ from tico.utils import logging
19
+ from tico.utils.graph import create_node
20
+ from tico.utils.passes import PassBase, PassResult
21
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
22
+
23
+
24
+ @trace_graph_diff_on_pass
25
+ class RestoreLinear(PassBase):
26
+ """
27
+ Linear Op is decomposed to multiple Ops in core aten
28
+ This pass restores linear Ops. For example,
29
+
30
+ Before)
31
+
32
+ bias input weight input weight
33
+ | | | | |
34
+ | | permute_copy | permute_copy
35
+ | V | | |
36
+ +----> addmm <---+ | V
37
+ +------> mm
38
+
39
+ After)
40
+
41
+ input weight bias input weight
42
+ | | | | |
43
+ | | | | |
44
+ | V | | V
45
+ +---> linear <---+ +----> linear
46
+ """
47
+
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def call(self, exported_program: ExportedProgram) -> PassResult:
52
+ logger = logging.getLogger(__name__)
53
+
54
+ graph_module = exported_program.graph_module
55
+ graph = graph_module.graph
56
+ modified = False
57
+ for node in graph.nodes:
58
+ if not node.op == "call_function":
59
+ continue
60
+
61
+ if node.target == torch.ops.aten.addmm.default:
62
+ assert len(node.args) == 3
63
+ bias, input, permute = node.args
64
+ if permute.target not in [
65
+ torch.ops.aten.permute.default,
66
+ torch.ops.aten.t.default,
67
+ ]:
68
+ continue
69
+
70
+ if permute.target == torch.ops.aten.permute_copy.default:
71
+ dims = permute.args[1]
72
+ if dims != [1, 0]:
73
+ continue
74
+ weight = permute.args[0]
75
+
76
+ addmm_args = (input, weight, bias)
77
+ with graph.inserting_after(node):
78
+ linear_node = create_node(
79
+ graph,
80
+ torch.ops.aten.linear.default,
81
+ args=addmm_args,
82
+ )
83
+ node.replace_all_uses_with(linear_node, propagate_meta=True)
84
+
85
+ elif node.target == torch.ops.aten.mm.default:
86
+ assert len(node.args) == 2
87
+ input, permute = node.args
88
+ if permute.target not in [
89
+ torch.ops.aten.permute.default,
90
+ torch.ops.aten.t.default,
91
+ ]:
92
+ continue
93
+
94
+ if permute.target == torch.ops.aten.permute_copy.default:
95
+ dims = permute.args[1]
96
+ if dims != [1, 0]:
97
+ continue
98
+ weight = permute.args[0]
99
+
100
+ mm_args = (input, weight)
101
+ with graph.inserting_after(node):
102
+ linear_node = create_node(
103
+ graph, torch.ops.aten.linear.default, args=mm_args
104
+ )
105
+ node.replace_all_uses_with(linear_node, propagate_meta=True)
106
+
107
+ else:
108
+ continue
109
+
110
+ modified = True
111
+ logger.debug(f"{node.name} is replaced with linear")
112
+
113
+ graph.eliminate_dead_code()
114
+ graph_module.recompile()
115
+ return PassResult(modified)
@@ -0,0 +1,145 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+
21
+ from torch._export.utils import (
22
+ get_buffer,
23
+ get_lifted_tensor_constant,
24
+ get_param,
25
+ is_buffer,
26
+ is_lifted_tensor_constant,
27
+ is_param,
28
+ )
29
+ from torch.export import ExportedProgram
30
+
31
+ from tico.passes import ops
32
+ from tico.serialize.circle_graph import extract_shape
33
+ from tico.utils import logging
34
+ from tico.utils.graph import add_placeholder, create_node, is_single_value_tensor
35
+ from tico.utils.passes import PassBase, PassResult
36
+ from tico.utils.trace_decorators import trace_const_diff_on_pass
37
+ from tico.utils.utils import is_target_node
38
+ from tico.utils.validate_args_kwargs import IndexSelectArgs
39
+
40
+
41
+ @trace_const_diff_on_pass
42
+ class SegmentIndexSelectConst(PassBase):
43
+ """
44
+ Let's segment index_select with multiple const indices to index_select operators with one index.
45
+ WHY?
46
+ Gather(index, index_select, select, embedding, ...) operation with const indices can be lowered to slice by LowerToSlice pass.
47
+ For that, we need to split 'a index_select operator with multiple indice' to 'multiple index_select operators with one index'.
48
+ Note that NPU is not fully compatible with gather operation.
49
+
50
+ [before]
51
+ input
52
+ |
53
+ index_select.default, len(index) > 1
54
+ |
55
+ output
56
+
57
+ [after]
58
+
59
+ input
60
+ |
61
+ -------------------------------------------------
62
+ | |
63
+ index_select.default, len(index) == 1 , ... , index_select.default, len(index) == 1
64
+ | |
65
+ -------------------------------------------------
66
+ |
67
+ torch.concat (input=[index_select0, index_select1, ...], axis = dim)
68
+ |
69
+ output
70
+ """
71
+
72
+ def __init__(self):
73
+ super().__init__()
74
+
75
+ def call(self, exported_program: ExportedProgram) -> PassResult:
76
+ logger = logging.getLogger(__name__)
77
+
78
+ graph_module = exported_program.graph_module
79
+ graph = graph_module.graph
80
+ modified = False
81
+ for node in graph.nodes:
82
+ if not is_target_node(node, ops.aten.index_select):
83
+ continue
84
+
85
+ args = IndexSelectArgs(*node.args, **node.kwargs)
86
+ input = args.input
87
+ dim = args.dim
88
+ index = args.index
89
+
90
+ if isinstance(index, torch.fx.Node):
91
+ if is_lifted_tensor_constant(exported_program, index):
92
+ index = get_lifted_tensor_constant(exported_program, index) # type: ignore[assignment]
93
+ elif is_param(exported_program, index):
94
+ index = get_param(exported_program, index) # type: ignore[assignment]
95
+ elif is_buffer(exported_program, index):
96
+ index = get_buffer(exported_program, index) # type: ignore[assignment]
97
+ else:
98
+ continue
99
+
100
+ if not isinstance(index, torch.Tensor):
101
+ continue
102
+
103
+ if is_single_value_tensor(index):
104
+ continue
105
+
106
+ if len(index) < 2:
107
+ continue
108
+
109
+ input_shape = extract_shape(input)
110
+ if dim < 0:
111
+ dim = dim % len(input_shape)
112
+
113
+ index_select_node_list = []
114
+ for i in index:
115
+ index_node = add_placeholder(
116
+ exported_program, torch.tensor([i]), prefix="segm_index"
117
+ )
118
+ with graph.inserting_before(node):
119
+ index_select_node = create_node(
120
+ graph,
121
+ torch.ops.aten.index_select.default,
122
+ args=(input, dim, index_node),
123
+ origin=node,
124
+ )
125
+ index_select_node_list.append(index_select_node)
126
+
127
+ with graph.inserting_before(node):
128
+ concat_node = create_node(
129
+ graph,
130
+ torch.ops.aten.cat.default,
131
+ args=(index_select_node_list, dim),
132
+ )
133
+
134
+ node.replace_all_uses_with(concat_node, propagate_meta=True)
135
+
136
+ modified = True
137
+ logger.debug(
138
+ f"{node.name} is replaced with {concat_node.name} and {index_select_node_list}"
139
+ )
140
+
141
+ graph.eliminate_dead_code()
142
+ graph.lint()
143
+ graph_module.recompile()
144
+
145
+ return PassResult(modified)
tico/pt2_to_circle.py ADDED
@@ -0,0 +1,105 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import yaml
21
+
22
+ from tico.config import CompileConfigBase, get_default_config
23
+
24
+ from tico.utils.convert import convert_exported_module_to_circle
25
+
26
+
27
+ def convert(
28
+ input: str,
29
+ output: str,
30
+ verbose: bool = False,
31
+ config: Optional[CompileConfigBase] = None,
32
+ ):
33
+ if config is None:
34
+ config = get_default_config()
35
+ # TODO Check input and output
36
+
37
+ if verbose:
38
+ os.environ["TICO_LOG"] = "4"
39
+
40
+ exported_program = torch.export.load(input)
41
+ circle_program = convert_exported_module_to_circle(exported_program, config=config)
42
+ circle_binary = circle_program
43
+ with open(output, "wb") as f:
44
+ f.write(circle_binary)
45
+
46
+
47
+ def main():
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument(
50
+ "-i",
51
+ "--input",
52
+ required=True,
53
+ help="provide a path to .pt2 model.",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "-o",
58
+ "--output",
59
+ required=True,
60
+ help="provide a path to .circle model.",
61
+ )
62
+
63
+ parser.add_argument(
64
+ "-v",
65
+ "--verbose",
66
+ action="store_true",
67
+ help="print logs.",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "-c",
72
+ "--config",
73
+ required=False,
74
+ help="provide a path to config file.",
75
+ )
76
+
77
+ args = parser.parse_args()
78
+
79
+ config = None
80
+ if args.config:
81
+ with open(args.config) as f:
82
+ config_dict = yaml.safe_load(f)
83
+
84
+ version = config_dict.get("version", None)
85
+ latest_version = "1.0"
86
+
87
+ if version is None:
88
+ raise ValueError(
89
+ f"'version' field must be provided in the config file. (lastest: {latest_version})"
90
+ )
91
+
92
+ if version == "1.0":
93
+ from tico.config.v1 import CompileConfigV1
94
+
95
+ config = CompileConfigV1.from_dict(config_dict)
96
+ else:
97
+ raise ValueError(
98
+ f"Unsupported version '{version}'. (lastest: {latest_version})"
99
+ )
100
+
101
+ convert(args.input, args.output, args.verbose, config)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ main()
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE