tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,307 @@
1
+ # Portions of this file are adapted from code originally authored by
2
+ # Meta Platforms, Inc. and affiliates, licensed under the BSD-style
3
+ # license found in the LICENSE file in the root directory of their source tree.
4
+
5
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ # https://github.com/pytorch/executorch/blob/61ddee5/exir/passes/constant_prop_pass.py
20
+
21
+ from collections import OrderedDict
22
+ from typing import List, Mapping, Optional, TYPE_CHECKING
23
+
24
+ if TYPE_CHECKING:
25
+ import torch.fx
26
+ import torch
27
+ from torch._export.utils import (
28
+ get_buffer,
29
+ get_lifted_tensor_constant,
30
+ get_param,
31
+ is_buffer,
32
+ is_lifted_tensor_constant,
33
+ is_param,
34
+ )
35
+ from torch.export import ExportedProgram
36
+ from torch.export.exported_program import InputKind, InputSpec
37
+ from torch.utils import _pytree as pytree
38
+
39
+ from tico.serialize.circle_graph import _PRIMITIVE_TYPES
40
+ from tico.utils import logging
41
+ from tico.utils.graph import create_input_spec, generate_fqn, get_first_user_input
42
+ from tico.utils.passes import PassBase, PassResult
43
+ from tico.utils.trace_decorators import (
44
+ trace_const_diff_on_pass,
45
+ trace_graph_diff_on_pass,
46
+ )
47
+ from tico.utils.utils import get_fake_mode
48
+
49
+
50
+ def get_constant_placeholder_to_tensor_dict(
51
+ exported_program: ExportedProgram,
52
+ ) -> OrderedDict[torch.fx.Node, torch.Tensor]:
53
+ """
54
+ Returns a dictionary of constant placeholder node to constant tensor.
55
+ """
56
+ const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
57
+ graph_module = exported_program.graph_module
58
+ graph: torch.fx.Graph = graph_module.graph
59
+ for node in graph.nodes:
60
+ if node.op != "placeholder":
61
+ continue
62
+ tensor: Optional[torch.Tensor] = None
63
+ if is_param(exported_program, node):
64
+ tensor = get_param(exported_program, node)
65
+ elif is_buffer(exported_program, node):
66
+ tensor = get_buffer(exported_program, node)
67
+ elif is_lifted_tensor_constant(exported_program, node):
68
+ tensor = get_lifted_tensor_constant(exported_program, node)
69
+
70
+ if tensor is not None:
71
+ assert node not in const_node_to_tensor
72
+ const_node_to_tensor[node] = tensor
73
+
74
+ return const_node_to_tensor
75
+
76
+
77
+ def has_constant_data(arg, const_node_to_tensor=None) -> bool:
78
+ """
79
+ Check if `arg` has constant data.
80
+
81
+ Assume that `const_node_to_tensor` is retrived from exported program.
82
+ When a node is a placeholder, only method to check if it is constant is to check the exported program.
83
+ """
84
+ if isinstance(arg, (tuple, list)):
85
+ return all(has_constant_data(a, const_node_to_tensor) for a in arg)
86
+ elif isinstance(arg, dict):
87
+ return all(has_constant_data(a, const_node_to_tensor) for a in arg.values())
88
+ elif isinstance(
89
+ arg,
90
+ _PRIMITIVE_TYPES,
91
+ ):
92
+ return True
93
+ elif not isinstance(arg, torch.fx.Node):
94
+ return False
95
+ elif const_node_to_tensor is not None and arg in const_node_to_tensor:
96
+ return True
97
+
98
+ return False
99
+
100
+
101
+ def get_data(
102
+ arg,
103
+ exported_program: ExportedProgram,
104
+ const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
105
+ ):
106
+ if isinstance(arg, (tuple, list)):
107
+ return (get_data(x, exported_program, const_node_to_tensor) for x in arg)
108
+ elif isinstance(arg, _PRIMITIVE_TYPES):
109
+ return arg
110
+ elif arg in const_node_to_tensor:
111
+ return const_node_to_tensor[arg]
112
+ return None
113
+
114
+
115
+ def propagate_constants(
116
+ exported_program: ExportedProgram,
117
+ ) -> OrderedDict[torch.fx.Node, torch.Tensor]:
118
+ """
119
+ Propagates constants and returns a dictionary of node to constant tensors of the graph.
120
+ """
121
+ const_node_to_tensor = get_constant_placeholder_to_tensor_dict(exported_program)
122
+
123
+ graph_module = exported_program.graph_module
124
+ graph: torch.fx.Graph = graph_module.graph
125
+ for node in graph.nodes:
126
+ if node.op != "call_function":
127
+ continue
128
+ if node.target in [
129
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
130
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
131
+ ]:
132
+ continue
133
+ if not has_constant_data(
134
+ [node.args, node.kwargs],
135
+ const_node_to_tensor,
136
+ ):
137
+ continue
138
+
139
+ args_data, kwargs_data = pytree.tree_map(
140
+ lambda x: get_data(x, exported_program, const_node_to_tensor),
141
+ (node.args, node.kwargs),
142
+ )
143
+
144
+ # propagate constant because all of its args are constant tensors.
145
+ with torch.no_grad():
146
+ prop_constant_tensor = node.target(*args_data, **kwargs_data)
147
+ const_node_to_tensor[node] = prop_constant_tensor
148
+
149
+ return const_node_to_tensor
150
+
151
+
152
+ def erase_constant_node(
153
+ exported_program: ExportedProgram,
154
+ node: torch.fx.Node,
155
+ ) -> None:
156
+ """
157
+ Remove corresponding tensor from param/constants dict.
158
+
159
+ Q) Isn't it necessary to remove a node from `inputs_to_parameters`, `inputs_to_lifted_tensor_constants`
160
+ and `inputs_to_buffers` as well? Why do they just call `get`?
161
+ A) They internally uses `exported_program.graph_signature.input_specs` and the `input_specs` are updated
162
+ at the end of the const_prop_pass.
163
+ """
164
+ signature = exported_program.graph_signature
165
+ if name := signature.inputs_to_parameters.get(node.name, None):
166
+ exported_program.state_dict.pop(name, None)
167
+ elif name := signature.inputs_to_lifted_tensor_constants.get(node.name, None):
168
+ exported_program.constants.pop(name, None)
169
+ elif name := signature.inputs_to_buffers.get(node.name, None):
170
+ exported_program.constants.pop(name, None)
171
+ exported_program.state_dict.pop(name, None)
172
+
173
+ # Remove from graph.
174
+ exported_program.graph.erase_node(node)
175
+
176
+
177
+ def create_constant_placeholder(
178
+ const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
179
+ exported_program: ExportedProgram,
180
+ ) -> List[torch.fx.Node]:
181
+ """
182
+ This function creates constant placeholder nodes according to the given constant nodes (`const_node_to_tensor`) and replace it with the original node.
183
+ """
184
+ placeholders = []
185
+
186
+ fake_mode = get_fake_mode(exported_program)
187
+ first_user_input = get_first_user_input(exported_program)
188
+ if not first_user_input:
189
+ # Placeholder nodes must be the first N nodes in the nodes list of a graph.
190
+ # Therefore, insert the newly created placeholders at the start of the node list.
191
+ assert exported_program.graph.nodes
192
+ first_node = list(exported_program.graph.nodes)[0]
193
+ first_user_input = first_node
194
+
195
+ # Iterate over nodes in reverse order to insert created placeholder before the `first_user_input`.
196
+ for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
197
+ if all(x in const_node_to_tensor for x in node.users):
198
+ # All users of this constant node are also constant, so we don't need to create a new constant node.
199
+ erase_constant_node(exported_program, node)
200
+ continue
201
+
202
+ if node.op == "placeholder":
203
+ continue
204
+
205
+ # Add `prop_constant_tensor` to program.state_dict.
206
+ prop_constant_tensor_fqn = generate_fqn(
207
+ "_prop_tensor_constant", exported_program
208
+ )
209
+
210
+ # Insert a new placeholder node for the propagated constant tensor.
211
+ with exported_program.graph.inserting_before(first_user_input):
212
+ const_placeholder_node = exported_program.graph.placeholder(
213
+ prop_constant_tensor_fqn
214
+ )
215
+
216
+ # The key here should be same with "target" arg of InputSpec when creating input specs.
217
+ exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
218
+
219
+ # Replace the original node with the new constant node.
220
+ node.replace_all_uses_with(const_placeholder_node, propagate_meta=True)
221
+ exported_program.graph.erase_node(node)
222
+
223
+ # Update the meta data of the new placeholder node.
224
+ const_placeholder_node.meta["val"] = fake_mode.from_tensor(
225
+ prop_constant_tensor, static_shapes=True
226
+ )
227
+ const_placeholder_node.meta["val"].constant = prop_constant_tensor
228
+
229
+ placeholders.append(const_placeholder_node)
230
+
231
+ return placeholders
232
+
233
+
234
+ def create_input_specs(
235
+ placeholders: List[torch.fx.Node],
236
+ ) -> dict[str, InputSpec]:
237
+ name_to_spec: dict[str, InputSpec] = {}
238
+
239
+ # https://pytorch.org/docs/stable/export.ir_spec.html#placeholder
240
+ # %name = placeholder[target = name](args = ())
241
+ for node in placeholders:
242
+ name_to_spec[node.name] = create_input_spec(node, InputKind.CONSTANT_TENSOR)
243
+
244
+ return name_to_spec
245
+
246
+
247
+ @trace_graph_diff_on_pass
248
+ @trace_const_diff_on_pass
249
+ class ConstPropPass(PassBase):
250
+ """
251
+ Performs constant folding and constant propagation.
252
+
253
+ NOTE The exported program gurantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs.
254
+ It means that the pass need to update input specs after folding the constant nodes.
255
+ # ref: https://pytorch.org/docs/stable/export.html#torch.export.ExportGraphSignature
256
+
257
+ [WHAT IT DOES]
258
+ [1] Propagate the constants.
259
+ [2] Get propagated data from constant nodes.
260
+ [3] Create the constant placeholder nodes according to the propagated data.
261
+ [4] Create input specs according to the created placeholders.
262
+ [5] Update the input specs.
263
+ """
264
+
265
+ def __init__(self) -> None:
266
+ super().__init__()
267
+
268
+ def call(self, exported_program: ExportedProgram) -> PassResult:
269
+ logger = logging.getLogger(__name__)
270
+
271
+ graph_module = exported_program.graph_module
272
+ graph: torch.fx.Graph = graph_module.graph
273
+
274
+ # [1], [2]
275
+ const_node_to_tensor: OrderedDict[
276
+ torch.fx.Node, torch.Tensor
277
+ ] = propagate_constants(exported_program)
278
+ # [3]
279
+ placeholders = create_constant_placeholder(
280
+ const_node_to_tensor, exported_program
281
+ )
282
+ # [4]
283
+ new_name_to_spec = create_input_specs(placeholders)
284
+
285
+ # [5]
286
+ # Get existing input specs.
287
+ existing_name_to_spec = {
288
+ s.arg.name: s for s in exported_program.graph_signature.input_specs
289
+ }
290
+ # Add the new constants to existing input specs dict.
291
+ existing_name_to_spec.update(new_name_to_spec)
292
+ # Generate new input spec.
293
+ new_input_specs = []
294
+ for node in exported_program.graph.nodes:
295
+ if node.op != "placeholder":
296
+ continue
297
+ assert node.name in existing_name_to_spec, node.name
298
+ new_input_specs.append(existing_name_to_spec[node.name])
299
+ exported_program.graph_signature.input_specs = new_input_specs
300
+
301
+ graph.eliminate_dead_code()
302
+ graph_module.recompile()
303
+
304
+ logger.debug(f"Constant nodes are propagated")
305
+ # Constant folding can be done with only one time run. Let's set `modified` to False.
306
+ modified = False
307
+ return PassResult(modified)
@@ -0,0 +1,160 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.circle_graph import extract_shape
23
+ from tico.utils import logging
24
+ from tico.utils.errors import NotYetSupportedError
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
29
+ from tico.utils.validate_args_kwargs import Conv1DArgs
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class ConvertConv1dToConv2d(PassBase):
34
+ """
35
+ This pass converts `torch.ops.aten.conv1d` to `torch.ops.aten.conv2d`
36
+ because Circle does not support `conv1d`.
37
+
38
+ [before]
39
+
40
+ input weight
41
+ (tensor,dim=3) (tensor,dim=3)
42
+ | |
43
+ conv1d<----------------+
44
+ |
45
+ output
46
+ (tensor,dim=3)
47
+
48
+ [after]
49
+
50
+ input weight
51
+ (tensor,dim=3) (tensor,dim=3)
52
+ | |
53
+ unsqueeze unsqueeze
54
+ (dim=4) (dim=4)
55
+ | |
56
+ conv2d<--------------+
57
+ |
58
+ squeeze
59
+ (dim=3)
60
+ |
61
+ output
62
+ (tensor,dim=3)
63
+ """
64
+
65
+ def __init__(self):
66
+ super().__init__()
67
+
68
+ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
69
+ logger = logging.getLogger(__name__)
70
+ modified = False
71
+
72
+ graph_module = exported_program.graph_module
73
+ graph = graph_module.graph
74
+
75
+ # conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor
76
+ # conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor
77
+ args = Conv1DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
78
+ input = args.input
79
+ weight = args.weight
80
+ bias = args.bias
81
+ stride = args.stride
82
+ padding = args.padding
83
+ dilation = args.dilation
84
+ groups = args.groups
85
+
86
+ input_shape = extract_shape(input)
87
+ if not (len(input_shape) == 3):
88
+ raise NotYetSupportedError(
89
+ f"Only support 3D input tensor: node's input shape: {input_shape}"
90
+ )
91
+
92
+ with graph.inserting_after(input):
93
+ input_unsqueeze = create_node(
94
+ graph,
95
+ torch.ops.aten.unsqueeze.default,
96
+ args=(input, 3),
97
+ origin=input,
98
+ )
99
+
100
+ with graph.inserting_after(weight):
101
+ weight_unsqueeze = create_node(
102
+ graph,
103
+ torch.ops.aten.unsqueeze.default,
104
+ args=(weight, 3),
105
+ origin=weight,
106
+ )
107
+
108
+ with graph.inserting_before(node):
109
+ if isinstance(padding, list):
110
+ conv2d_op = torch.ops.aten.conv2d.default
111
+ elif isinstance(padding, str):
112
+ conv2d_op = torch.ops.aten.conv2d.padding
113
+ else:
114
+ raise RuntimeError("Invalid input")
115
+
116
+ conv2d = create_node(
117
+ graph,
118
+ conv2d_op,
119
+ args=(
120
+ input_unsqueeze,
121
+ weight_unsqueeze,
122
+ bias,
123
+ [*stride, 1],
124
+ [*padding, 0] if isinstance(padding, list) else padding,
125
+ [*dilation, 1],
126
+ groups,
127
+ ),
128
+ kwargs=node.kwargs,
129
+ origin=node,
130
+ )
131
+
132
+ conv_out_squeeze = create_node(
133
+ graph,
134
+ torch.ops.aten.squeeze.dims,
135
+ args=(conv2d, [3]),
136
+ )
137
+
138
+ node.replace_all_uses_with(conv_out_squeeze, propagate_meta=True)
139
+
140
+ logger.debug(f"{node.name} is replaced with {conv2d.name}")
141
+ modified = True
142
+ return modified
143
+
144
+ def call(self, exported_program: ExportedProgram) -> PassResult:
145
+ target_conv_op = [torch.ops.aten.conv1d.default, torch.ops.aten.conv1d.padding]
146
+
147
+ graph_module = exported_program.graph_module
148
+ graph = graph_module.graph
149
+ modified = False
150
+ for node in graph.nodes:
151
+ if not is_target_node(node, target_conv_op):
152
+ continue
153
+
154
+ modified |= self.convert(exported_program, node)
155
+
156
+ graph.eliminate_dead_code()
157
+ graph.lint()
158
+ graph_module.recompile()
159
+
160
+ return PassResult(modified)
@@ -0,0 +1,85 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import SqueezeArgs, UnSqueezeArgs, ViewArgs
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class ConvertLayoutOpToReshape(PassBase):
33
+ """
34
+ This pass converts layout transformation Op to reshape if possible.
35
+ This is helpful for further optimization.
36
+ """
37
+
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ def call(self, exported_program: ExportedProgram) -> PassResult:
42
+ logger = logging.getLogger(__name__)
43
+
44
+ graph_module = exported_program.graph_module
45
+ graph = graph_module.graph
46
+ modified = False
47
+
48
+ def convert(node, input):
49
+ out_shape = list(extract_shape(node))
50
+
51
+ with graph.inserting_after(node):
52
+ reshape_node = create_node(
53
+ graph,
54
+ torch.ops.aten.reshape.default,
55
+ args=(input, out_shape),
56
+ )
57
+ node.replace_all_uses_with(reshape_node, propagate_meta=True)
58
+
59
+ logger.debug(f"{node.name} is replaced with {reshape_node.name}")
60
+
61
+ for node in graph.nodes:
62
+ if not node.op == "call_function":
63
+ continue
64
+
65
+ if node.target in ops.aten.view:
66
+ view_args = ViewArgs(*node.args, **node.kwargs)
67
+ convert(node, view_args.input)
68
+ modified = True
69
+ continue
70
+ elif node.target in ops.aten.unsqueeze:
71
+ unsqueeze_args = UnSqueezeArgs(*node.args, **node.kwargs)
72
+ convert(node, unsqueeze_args.input)
73
+ modified = True
74
+ continue
75
+ elif node.target in ops.aten.squeeze:
76
+ squeeze_args = SqueezeArgs(*node.args, **node.kwargs)
77
+ convert(node, squeeze_args.input)
78
+ modified = True
79
+ continue
80
+
81
+ graph.eliminate_dead_code()
82
+ graph.lint()
83
+ graph_module.recompile()
84
+
85
+ return PassResult(modified)
@@ -0,0 +1,89 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.utils import logging
23
+ from tico.utils.graph import create_node
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
27
+ from tico.utils.validate_args_kwargs import RepeatArgs
28
+
29
+
30
+ @trace_graph_diff_on_pass
31
+ class ConvertRepeatToExpandCopy(PassBase):
32
+ """
33
+ aten.repeat.default is converted to aten.expand_copy.default.
34
+ Why? There isn't CircleNode mapped to repeat.
35
+ so, We convert it using existing aten.expand_copy.default.
36
+ """
37
+
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ def call(self, exported_program: ExportedProgram) -> PassResult:
42
+ logger = logging.getLogger(__name__)
43
+
44
+ graph_module = exported_program.graph_module
45
+ graph = graph_module.graph
46
+ modified = False
47
+ for node in graph.nodes:
48
+ if not is_target_node(node, torch.ops.aten.repeat.default):
49
+ continue
50
+
51
+ reshape_args = RepeatArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
52
+ tensor, repeats = reshape_args.input, reshape_args.repeats
53
+
54
+ tensor_shape: List[int] = [int(dim) for dim in tensor.meta["val"].shape]
55
+
56
+ # Check if it is possible to convert to aten.expand_copy.default
57
+ cannot_converted = False
58
+ extending_idx = len(repeats) - len(tensor_shape)
59
+ for idx, dim in enumerate(tensor_shape):
60
+ if not (dim == 1 or repeats[extending_idx + idx] == 1):
61
+ cannot_converted = True
62
+ if cannot_converted:
63
+ continue
64
+
65
+ size = []
66
+ for idx, repeats_dim in enumerate(repeats):
67
+ if idx < extending_idx:
68
+ size.append(repeats_dim)
69
+ else:
70
+ size.append(repeats_dim * tensor_shape[idx - extending_idx])
71
+
72
+ expand_copy_args = (tensor, size)
73
+
74
+ with graph.inserting_after(node):
75
+ expand_copy_node = create_node(
76
+ graph,
77
+ torch.ops.aten.expand_copy.default,
78
+ args=expand_copy_args,
79
+ )
80
+ node.replace_all_uses_with(expand_copy_node, propagate_meta=True)
81
+
82
+ modified = True
83
+ logger.debug(f"{node.name} is replaced with expand_copy operator")
84
+
85
+ graph.eliminate_dead_code()
86
+ graph.lint()
87
+ graph_module.recompile()
88
+
89
+ return PassResult(modified)