tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
tico/utils/graph.py ADDED
@@ -0,0 +1,200 @@
1
+ # Portions of this file are adapted from code originally authored by
2
+ # Meta Platforms, Inc. and affiliates, licensed under the BSD-style
3
+ # license found in the LICENSE file in the root directory of their source tree.
4
+
5
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ from typing import Optional, TYPE_CHECKING
20
+
21
+ if TYPE_CHECKING:
22
+ import torch.fx
23
+ import torch
24
+ from torch.export import ExportedProgram
25
+ from torch.export.exported_program import InputKind, InputSpec, TensorArgument
26
+
27
+ from tico.utils.utils import get_fake_mode
28
+
29
+
30
+ def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
31
+ assert node.op == "placeholder"
32
+
33
+ return node.name in ep.graph_signature.inputs_to_parameters
34
+
35
+
36
+ def is_torch_buffer(node: torch.fx.Node, ep: ExportedProgram):
37
+ assert node.op == "placeholder"
38
+
39
+ return node.name in ep.graph_signature.inputs_to_buffers
40
+
41
+
42
+ def get_torch_param_value(node: torch.fx.Node, ep: ExportedProgram):
43
+ assert isinstance(node, torch.fx.Node)
44
+ assert node.op == "placeholder"
45
+ assert (
46
+ node.name in ep.graph_signature.inputs_to_parameters
47
+ ), "Node {node.name} is not in the parameters" # FIX CALLER UNLESS
48
+
49
+ param_name = ep.graph_signature.inputs_to_parameters[node.name]
50
+ named_params = dict(ep.named_parameters())
51
+ assert param_name in named_params
52
+
53
+ return named_params[param_name].data
54
+
55
+
56
+ def get_torch_buffer_value(node: torch.fx.Node, ep: ExportedProgram):
57
+ assert isinstance(node, torch.fx.Node)
58
+ assert node.op == "placeholder"
59
+ assert (
60
+ node.name in ep.graph_signature.inputs_to_buffers
61
+ ), "Node {node.name} is not in the buffers" # FIX CALLER UNLESS
62
+
63
+ buf_name = ep.graph_signature.inputs_to_buffers[node.name]
64
+ named_buf = dict(ep.named_buffers())
65
+ assert buf_name in named_buf
66
+
67
+ return named_buf[buf_name]
68
+
69
+
70
+ def get_first_user_input(exported_program: ExportedProgram) -> Optional[torch.fx.Node]:
71
+ """Returns the first user input node in the graph."""
72
+ first_user_input: Optional[torch.fx.Node] = None
73
+ graph_module = exported_program.graph_module
74
+ graph: torch.fx.Graph = graph_module.graph
75
+ for node in graph.nodes:
76
+ if (
77
+ node.op == "placeholder"
78
+ and node.name in exported_program.graph_signature.user_inputs
79
+ ):
80
+ first_user_input = node
81
+ break
82
+
83
+ return first_user_input
84
+
85
+
86
+ def generate_fqn(prefix: str, exported_program: ExportedProgram):
87
+ """
88
+ Generate fully-qualized name for constants.
89
+
90
+ This function prevents `exported_program.constants` from having duplicate keys.
91
+ """
92
+ cnt = len(exported_program.constants)
93
+ while True:
94
+ if f"{prefix}{cnt}" in exported_program.constants:
95
+ cnt += 1
96
+ continue
97
+ break
98
+ return f"{prefix}{cnt}"
99
+
100
+
101
+ def create_input_spec(node, input_kind: InputKind):
102
+ """
103
+ @ref https://pytorch.org/docs/stable/export.ir_spec.html#placeholder
104
+ """
105
+ if input_kind == InputKind.CONSTANT_TENSOR:
106
+ return InputSpec(
107
+ kind=InputKind.CONSTANT_TENSOR,
108
+ arg=TensorArgument(name=node.name),
109
+ target=node.target, # type: ignore[arg-type]
110
+ persistent=True,
111
+ )
112
+ else:
113
+ raise NotImplementedError("NYI")
114
+
115
+
116
+ def validate_input_specs(exported_program):
117
+ name_to_spec_dict = {
118
+ s.arg.name: s for s in exported_program.graph_signature.input_specs
119
+ }
120
+
121
+ for node in exported_program.graph.nodes:
122
+ if node.op != "placeholder":
123
+ continue
124
+
125
+ if node.name not in name_to_spec_dict:
126
+ raise RuntimeError(
127
+ "Placeholder node {node.name} does not have corresponding input spec!"
128
+ )
129
+
130
+
131
+ def add_placeholder(
132
+ exported_program: ExportedProgram,
133
+ tensor: torch.Tensor,
134
+ prefix: str,
135
+ ) -> torch.fx.Node:
136
+ """
137
+ Add a placeholder to the graph and update the exported program.
138
+ """
139
+ fqn_name = generate_fqn(prefix, exported_program)
140
+
141
+ # Get fake mode before adding placeholder
142
+ fake_mode = get_fake_mode(exported_program)
143
+
144
+ first_user_input = get_first_user_input(exported_program)
145
+ if not first_user_input:
146
+ # Placeholder nodes must be the first N nodes in the nodes list of a graph.
147
+ # Therefore, insert the newly created placeholders at the start of the node list.
148
+ assert exported_program.graph.nodes
149
+ first_node = list(exported_program.graph.nodes)[0]
150
+ first_user_input = first_node
151
+
152
+ # Add a placeholder to the graph.
153
+ with exported_program.graph.inserting_before(first_user_input):
154
+ const_node = exported_program.graph.placeholder(fqn_name)
155
+
156
+ const_node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
157
+ const_node.meta["val"].constant = tensor
158
+
159
+ # Add a new constant to the exported program.
160
+ exported_program.constants[const_node.name] = tensor
161
+
162
+ # Use update (instead of append) if this assert is violated
163
+ assert const_node.name not in [
164
+ s.arg.name for s in exported_program.graph_signature.input_specs
165
+ ]
166
+
167
+ # Append the new input spec.
168
+ exported_program.graph_signature.input_specs.append(
169
+ create_input_spec(const_node, InputKind.CONSTANT_TENSOR)
170
+ )
171
+
172
+ # Get old input specs
173
+ name_to_spec_dict = {
174
+ s.arg.name: s for s in exported_program.graph_signature.input_specs
175
+ }
176
+
177
+ # Add the new constants to input specs dict.
178
+ name_to_spec_dict.update(
179
+ {const_node.name: create_input_spec(const_node, InputKind.CONSTANT_TENSOR)}
180
+ )
181
+
182
+ # Generate new input spec *in the same order of nodes*
183
+ # IMPORTANT Input specs and their placeholder nodes must have the same order.
184
+ new_input_specs = []
185
+ for node in exported_program.graph.nodes:
186
+ if node.op != "placeholder":
187
+ continue
188
+ new_input_specs.append(name_to_spec_dict[node.name])
189
+ exported_program.graph_signature.input_specs = new_input_specs
190
+
191
+ return const_node
192
+
193
+
194
+ def is_single_value_tensor(t: torch.Tensor):
195
+ if len(t.size()) == 0:
196
+ return True
197
+ if len(t.size()) == 1 and t.size()[0] == 1:
198
+ return True
199
+
200
+ return False
tico/utils/logging.py ADDED
@@ -0,0 +1,45 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+
18
+
19
+ def _loggerLevel():
20
+ TICO_LOG = os.environ.get("TICO_LOG")
21
+ if TICO_LOG == "1":
22
+ log_level = logging.FATAL
23
+ elif TICO_LOG == "2":
24
+ log_level = logging.WARNING
25
+ elif TICO_LOG == "3":
26
+ log_level = logging.INFO
27
+ elif TICO_LOG == "4":
28
+ log_level = logging.DEBUG
29
+ else:
30
+ log_level = logging.WARNING
31
+ return log_level
32
+
33
+
34
+ LOG_LEVEL = _loggerLevel()
35
+
36
+
37
+ def getLogger(name: str):
38
+ """
39
+ Get logger with setting log level according to the `TICO_LOG` environment variable.
40
+ """
41
+ logging.basicConfig()
42
+ logger = logging.getLogger(name)
43
+ logger.setLevel(LOG_LEVEL)
44
+
45
+ return logger
tico/utils/model.py ADDED
@@ -0,0 +1,37 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any
18
+
19
+ from tico.interpreter import infer
20
+
21
+
22
+ class CircleModel:
23
+ def __init__(self, circle_binary: bytes):
24
+ self.circle_binary = circle_binary
25
+
26
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
27
+ return infer.infer(self.circle_binary, *args, **kwargs)
28
+
29
+ @staticmethod
30
+ def load(circle_path: str) -> CircleModel:
31
+ with open(circle_path, "rb") as f:
32
+ buf = bytes(f.read())
33
+ return CircleModel(buf)
34
+
35
+ def save(self, circle_path: str) -> None:
36
+ with open(circle_path, "wb") as f:
37
+ f.write(self.circle_binary)
tico/utils/padding.py ADDED
@@ -0,0 +1,47 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from tico.utils.errors import InvalidArgumentError
18
+
19
+ SAME = 0
20
+ VALID = 1
21
+
22
+
23
+ def is_valid_padding(padding: str | list):
24
+ if isinstance(padding, str):
25
+ return padding == "valid"
26
+
27
+ if isinstance(padding, list):
28
+ assert len(padding) == 2, "Padding should be a list of length 2."
29
+ return padding == [0, 0]
30
+
31
+ raise InvalidArgumentError("Invalid padding.")
32
+
33
+
34
+ def is_same_padding(
35
+ padding: str | list, input_shape: list | torch.Size, output_shape: list | torch.Size
36
+ ):
37
+ if isinstance(padding, str):
38
+ return padding == "same"
39
+
40
+ if isinstance(padding, list):
41
+ assert len(padding) == 2, "Padding should be a list of length 2."
42
+
43
+ input_HW = input_shape[1:2] # N H W C
44
+ output_HW = output_shape[1:2] # N H W C
45
+ return input_HW == output_HW
46
+
47
+ raise InvalidArgumentError("Invalid padding.")
tico/utils/passes.py ADDED
@@ -0,0 +1,76 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import List
19
+
20
+ from torch.export import ExportedProgram
21
+
22
+
23
+ @dataclass
24
+ class PassResult:
25
+ modified: bool
26
+
27
+
28
+ class PassBase(ABC):
29
+ """
30
+ Base interface for passes.
31
+ """
32
+
33
+ @abstractmethod
34
+ def call(self, exported_program: ExportedProgram) -> PassResult:
35
+ pass
36
+
37
+
38
+ class PassStrategy(Enum):
39
+ # Run passes until there are no changes.
40
+ UNTIL_NO_CHANGE = (1,)
41
+ # Same as `UNTIL_NO_CHANGE` but it starts agian from the beginning.
42
+ RESTART = (2,)
43
+
44
+
45
+ class PassManager:
46
+ def __init__(
47
+ self,
48
+ passes: List[PassBase],
49
+ strategy: PassStrategy = PassStrategy.RESTART,
50
+ ):
51
+ self.passes: List[PassBase] = passes
52
+ self.strategy: PassStrategy = strategy
53
+
54
+ def run(self, exported_program: ExportedProgram):
55
+ MAXIMUM_STEP_COUNT = 1000
56
+ step = 0
57
+ while True:
58
+ modified = False
59
+ for _pass in self.passes:
60
+ # Automatically update the signatures of the input and output.
61
+ # https://github.com/pytorch/executorch/issues/4013#issuecomment-2187161844
62
+ with exported_program.graph_module._set_replace_hook(
63
+ exported_program.graph_signature.get_replace_hook()
64
+ ):
65
+ result = _pass.call(exported_program)
66
+ modified = modified or result.modified
67
+ if modified and self.strategy == PassStrategy.RESTART:
68
+ break
69
+
70
+ if not modified:
71
+ break
72
+ step += 1
73
+
74
+ assert (
75
+ step < MAXIMUM_STEP_COUNT
76
+ ), f"Loop iterated for {MAXIMUM_STEP_COUNT} times. Circular loop is suspected in {self.passes}"