tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,84 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Union
16
+
17
+ import torch
18
+ from torch.export import ExportedProgram
19
+
20
+ from tico.passes import ops
21
+ from tico.serialize.circle_mapping import extract_torch_dtype
22
+ from tico.utils import logging
23
+ from tico.utils.passes import PassBase, PassResult
24
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.validate_args_kwargs import ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs
26
+
27
+
28
+ @trace_graph_diff_on_pass
29
+ class RemoveRedundantToCopy(PassBase):
30
+ """
31
+ This pass removes redundant `aten._to_copy` operators.
32
+ """
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+
37
+ def call(self, exported_program: ExportedProgram) -> PassResult:
38
+ logger = logging.getLogger(__name__)
39
+
40
+ graph_module = exported_program.graph_module
41
+ graph = graph_module.graph
42
+ modified = False
43
+ for node in graph.nodes:
44
+ if not node.op == "call_function":
45
+ continue
46
+
47
+ if not node.target in ops.aten.to_copy:
48
+ continue
49
+
50
+ args: Union[ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs]
51
+ if node.target == torch.ops.aten._to_copy.default:
52
+ args = ToCopyArgs(*node.args, **node.kwargs)
53
+ elif node.target == torch.ops.aten.to.dtype:
54
+ args = ToDtypeArgs(*node.args, **node.kwargs)
55
+ elif node.target == torch.ops.aten.to.dtype_layout:
56
+ args = ToDtypeLayoutArgs(*node.args, **node.kwargs)
57
+ else:
58
+ raise NotImplementedError(
59
+ f"Unsupported to_copy operator: {node.target}"
60
+ )
61
+
62
+ input_ = args.input
63
+ # https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout
64
+ # layout is two types: torch.strided(dense Tensors), torch.sparse_coo(sparse COO Tensors)
65
+ if hasattr(args, "layout") and args.layout is not None:
66
+ if args.layout != input_.meta["val"].layout:
67
+ continue
68
+
69
+ if hasattr(args, "dtype") and args.dtype is not None:
70
+ target_dtype = args.dtype
71
+ input_dtype = extract_torch_dtype(input_)
72
+ if input_dtype != target_dtype:
73
+ continue
74
+
75
+ node.replace_all_uses_with(input_, propagate_meta=False)
76
+
77
+ modified = True
78
+ logger.debug(f"{node.name} is replaced with {input_.name}")
79
+
80
+ graph.eliminate_dead_code()
81
+ graph.lint()
82
+ graph_module.recompile()
83
+
84
+ return PassResult(modified)
@@ -0,0 +1,113 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch.export import ExportedProgram
17
+
18
+ from tico.utils import logging
19
+ from tico.utils.passes import PassBase, PassResult
20
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
21
+
22
+
23
+ @trace_graph_diff_on_pass
24
+ class RestoreLinear(PassBase):
25
+ """
26
+ Linear Op is decomposed to multiple Ops in core aten
27
+ This pass restores linear Ops. For example,
28
+
29
+ Before)
30
+
31
+ bias input weight input weight
32
+ | | | | |
33
+ | | permute_copy | permute_copy
34
+ | V | | |
35
+ +----> addmm <---+ | V
36
+ +------> mm
37
+
38
+ After)
39
+
40
+ input weight bias input weight
41
+ | | | | |
42
+ | | | | |
43
+ | V | | V
44
+ +---> linear <---+ +----> linear
45
+ """
46
+
47
+ def __init__(self):
48
+ super().__init__()
49
+
50
+ def call(self, exported_program: ExportedProgram) -> PassResult:
51
+ logger = logging.getLogger(__name__)
52
+
53
+ graph_module = exported_program.graph_module
54
+ graph = graph_module.graph
55
+ modified = False
56
+ for node in graph.nodes:
57
+ if not node.op == "call_function":
58
+ continue
59
+
60
+ if node.target == torch.ops.aten.addmm.default:
61
+ assert len(node.args) == 3
62
+ bias, input, permute = node.args
63
+ if permute.target not in [
64
+ torch.ops.aten.permute.default,
65
+ torch.ops.aten.t.default,
66
+ ]:
67
+ continue
68
+
69
+ if permute.target == torch.ops.aten.permute_copy.default:
70
+ dims = permute.args[1]
71
+ if dims != [1, 0]:
72
+ continue
73
+ weight = permute.args[0]
74
+
75
+ addmm_args = (input, weight, bias)
76
+ with graph.inserting_after(node):
77
+ linear_node = graph.call_function(
78
+ torch.ops.aten.linear.default, args=addmm_args
79
+ )
80
+ node.replace_all_uses_with(linear_node, propagate_meta=True)
81
+ graph.erase_node(node)
82
+
83
+ elif node.target == torch.ops.aten.mm.default:
84
+ assert len(node.args) == 2
85
+ input, permute = node.args
86
+ if permute.target not in [
87
+ torch.ops.aten.permute.default,
88
+ torch.ops.aten.t.default,
89
+ ]:
90
+ continue
91
+
92
+ if permute.target == torch.ops.aten.permute_copy.default:
93
+ dims = permute.args[1]
94
+ if dims != [1, 0]:
95
+ continue
96
+ weight = permute.args[0]
97
+
98
+ mm_args = (input, weight)
99
+ with graph.inserting_after(node):
100
+ linear_node = graph.call_function(
101
+ torch.ops.aten.linear.default, args=mm_args
102
+ )
103
+ node.replace_all_uses_with(linear_node, propagate_meta=True)
104
+
105
+ else:
106
+ continue
107
+
108
+ modified = True
109
+ logger.debug(f"{node.name} is replaced with linear")
110
+
111
+ graph.eliminate_dead_code()
112
+ graph_module.recompile()
113
+ return PassResult(modified)
@@ -0,0 +1,143 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+
21
+ from torch._export.utils import (
22
+ get_buffer,
23
+ get_lifted_tensor_constant,
24
+ get_param,
25
+ is_buffer,
26
+ is_lifted_tensor_constant,
27
+ is_param,
28
+ )
29
+ from torch.export import ExportedProgram
30
+
31
+ from tico.passes import ops
32
+ from tico.serialize.circle_graph import extract_shape
33
+ from tico.utils import logging
34
+ from tico.utils.graph import add_placeholder, is_single_value_tensor
35
+ from tico.utils.passes import PassBase, PassResult
36
+ from tico.utils.trace_decorators import trace_const_diff_on_pass
37
+ from tico.utils.validate_args_kwargs import IndexSelectArgs
38
+
39
+
40
+ @trace_const_diff_on_pass
41
+ class SegmentIndexSelectConst(PassBase):
42
+ """
43
+ Let's segment index_select with multiple const indices to index_select operators with one index.
44
+ WHY?
45
+ Gather(index, index_select, select, embedding, ...) operation with const indices can be lowered to slice by LowerToSlice pass.
46
+ For that, we need to split 'a index_select operator with multiple indice' to 'multiple index_select operators with one index'.
47
+ Note that NPU is not fully compatible with gather operation.
48
+
49
+ [before]
50
+ input
51
+ |
52
+ index_select.default, len(index) > 1
53
+ |
54
+ output
55
+
56
+ [after]
57
+
58
+ input
59
+ |
60
+ -------------------------------------------------
61
+ | |
62
+ index_select.default, len(index) == 1 , ... , index_select.default, len(index) == 1
63
+ | |
64
+ -------------------------------------------------
65
+ |
66
+ torch.concat (input=[index_select0, index_select1, ...], axis = dim)
67
+ |
68
+ output
69
+ """
70
+
71
+ def __init__(self):
72
+ super().__init__()
73
+
74
+ def call(self, exported_program: ExportedProgram) -> PassResult:
75
+ logger = logging.getLogger(__name__)
76
+
77
+ graph_module = exported_program.graph_module
78
+ graph = graph_module.graph
79
+ modified = False
80
+ for node in graph.nodes:
81
+ if not node.op == "call_function":
82
+ continue
83
+
84
+ if not node.target in ops.aten.index_select:
85
+ continue
86
+
87
+ args = IndexSelectArgs(*node.args, **node.kwargs)
88
+ input = args.input
89
+ dim = args.dim
90
+ index = args.index
91
+
92
+ if isinstance(index, torch.fx.Node):
93
+ if is_lifted_tensor_constant(exported_program, index):
94
+ index = get_lifted_tensor_constant(exported_program, index) # type: ignore[assignment]
95
+ elif is_param(exported_program, index):
96
+ index = get_param(exported_program, index) # type: ignore[assignment]
97
+ elif is_buffer(exported_program, index):
98
+ index = get_buffer(exported_program, index) # type: ignore[assignment]
99
+ else:
100
+ continue
101
+
102
+ if not isinstance(index, torch.Tensor):
103
+ continue
104
+
105
+ if is_single_value_tensor(index):
106
+ continue
107
+
108
+ if len(index) < 2:
109
+ continue
110
+
111
+ input_shape = extract_shape(input)
112
+ if dim < 0:
113
+ dim = dim % len(input_shape)
114
+
115
+ index_select_node_list = []
116
+ for i in index:
117
+ index_node = add_placeholder(
118
+ exported_program, torch.tensor([i]), prefix="segm_index"
119
+ )
120
+ with graph.inserting_before(node):
121
+ index_select_node = graph.call_function(
122
+ torch.ops.aten.index_select.default,
123
+ args=(input, dim, index_node),
124
+ )
125
+ index_select_node_list.append(index_select_node)
126
+
127
+ with graph.inserting_before(node):
128
+ concat_node = graph.call_function(
129
+ torch.ops.aten.cat.default, args=(index_select_node_list, dim)
130
+ )
131
+
132
+ node.replace_all_uses_with(concat_node, propagate_meta=False)
133
+
134
+ modified = True
135
+ logger.debug(
136
+ f"{node.name} is replaced with {concat_node.name} and {index_select_node_list}"
137
+ )
138
+
139
+ graph.eliminate_dead_code()
140
+ graph.lint()
141
+ graph_module.recompile()
142
+
143
+ return PassResult(modified)
tico/pt2_to_circle.py ADDED
@@ -0,0 +1,101 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+
18
+ import torch
19
+ import yaml
20
+
21
+ from tico.config import CompileConfigBase, get_default_config
22
+
23
+ from tico.utils.convert import convert_exported_module_to_circle
24
+
25
+
26
+ def convert(
27
+ input: str,
28
+ output: str,
29
+ verbose: bool = False,
30
+ config: CompileConfigBase = get_default_config(),
31
+ ):
32
+ # TODO Check input and output
33
+
34
+ if verbose:
35
+ os.environ["TICO_LOG"] = "4"
36
+
37
+ exported_program = torch.export.load(input)
38
+ circle_program = convert_exported_module_to_circle(exported_program, config=config)
39
+ circle_binary = circle_program
40
+ with open(output, "wb") as f:
41
+ f.write(circle_binary)
42
+
43
+
44
+ def main():
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument(
47
+ "-i",
48
+ "--input",
49
+ required=True,
50
+ help="provide a path to .pt2 model.",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "-o",
55
+ "--output",
56
+ required=True,
57
+ help="provide a path to .circle model.",
58
+ )
59
+
60
+ parser.add_argument(
61
+ "-v",
62
+ "--verbose",
63
+ action="store_true",
64
+ help="print logs.",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "-c",
69
+ "--config",
70
+ required=False,
71
+ help="provide a path to config file.",
72
+ )
73
+
74
+ args = parser.parse_args()
75
+
76
+ if args.config:
77
+ with open(args.config) as f:
78
+ config_dict = yaml.safe_load(f)
79
+
80
+ version = config_dict.get("version", None)
81
+ latest_version = "1.0"
82
+
83
+ if version is None:
84
+ raise ValueError(
85
+ f"'version' field must be provided in the config file. (lastest: {latest_version})"
86
+ )
87
+
88
+ if version == "1.0":
89
+ from tico.config.v1 import CompileConfigV1
90
+
91
+ config = CompileConfigV1.from_dict(config_dict)
92
+ else:
93
+ raise ValueError(
94
+ f"Unsupported version '{version}'. (lastest: {latest_version})"
95
+ )
96
+
97
+ convert(args.input, args.output, args.verbose, config)
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE