tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,209 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
26
+ from tico.utils.graph import add_placeholder, create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.utils import is_target_node
30
+ from tico.utils.validate_args_kwargs import Conv2DArgs
31
+
32
+
33
+ @trace_graph_diff_on_pass
34
+ class DecomposeGroupedConv2d(PassBase):
35
+ """
36
+ This pass decomposes grouped Conv2d operator as multiple Conv2d operator whose groups=1.
37
+
38
+ Grouped Conv2d denotes a Conv2d operator whose `groups` argument is not equal to input channels nor 1.
39
+
40
+ [before]
41
+
42
+ input weight bias
43
+ | | |
44
+ +-----------+-----------+
45
+ |
46
+ Conv2d (groups != IN_CHANNEL && groups != 1)
47
+ |
48
+ output
49
+
50
+ [after]
51
+
52
+ The below `slice` operators slice the input tensor, weight and bias along the channel axis by the number of `groups`.
53
+ In addition, the numbered input, weight and bias denotes sliced input tensor, weight and bias respectively.
54
+
55
+ input
56
+ | weight
57
+ slice | bias
58
+ | slice |
59
+ | | slice
60
+ | | |
61
+ +---------------------------+---------------------------+
62
+ | | | | |
63
+ | +---------------------------+---------------------------+
64
+ | | | | | | |
65
+ | | +---------------------------+---------------------------+
66
+ | | | | | | | | |
67
+ input_1 | | ... | | input_N | |
68
+ | weight_1 | | ... | | weight_N |
69
+ | | bias_1 | | ... | | bias_N
70
+ +---------+---------+ +---------+---------+ +---------+---------+
71
+ | | |
72
+ Conv2d_1 ... Conv2d_N
73
+ | | |
74
+ +---------------------------+---------------------------+
75
+ |
76
+ concat
77
+ |
78
+ output
79
+ """
80
+
81
+ def __init__(self):
82
+ super().__init__()
83
+
84
+ def call(self, exported_program: ExportedProgram) -> PassResult:
85
+ logger = logging.getLogger(__name__)
86
+
87
+ gm = exported_program.graph_module
88
+ graph: torch.fx.Graph = gm.graph
89
+ modified = False
90
+
91
+ for node in graph.nodes:
92
+ if not is_target_node(node, ops.aten.conv2d):
93
+ continue
94
+
95
+ args = Conv2DArgs(*node.args)
96
+ input_ = args.input
97
+ weight = args.weight
98
+ bias = args.bias
99
+ stride = args.stride
100
+ padding = args.padding
101
+ dilation = args.dilation
102
+ groups = args.groups
103
+
104
+ input_shape = extract_shape(input_)
105
+ if not len(input_shape) == 4:
106
+ raise NotYetSupportedError(
107
+ f"Only support 4D input tensor: node's input shape: {input_shape}"
108
+ )
109
+
110
+ in_channels = input_shape[1]
111
+ if groups == 1 or groups == in_channels:
112
+ continue
113
+ assert (
114
+ in_channels % groups == 0
115
+ ), f"in_channels should be divisible by groups: in_channels: {in_channels}, groups: {groups}"
116
+
117
+ output_shape = extract_shape(node)
118
+ assert len(output_shape) == 4, len(output_shape)
119
+
120
+ out_channels = output_shape[1]
121
+ assert (
122
+ out_channels % groups == 0
123
+ ), f"out_channels should be divisible by groups: out_channels: {out_channels}, groups: {groups}"
124
+
125
+ weight_shape = extract_shape(weight)
126
+ assert len(weight_shape) == 4, len(weight_shape)
127
+ assert (
128
+ weight_shape[0] == out_channels
129
+ ), f"weight shape[0]: {weight_shape[0]}, out channels: {out_channels}"
130
+ assert (
131
+ weight_shape[1] == in_channels // groups
132
+ ), f"weight shape[1]: {weight_shape[1]}, in channels: {in_channels}"
133
+
134
+ if bias is not None:
135
+ bias_shape = extract_shape(bias)
136
+ assert (
137
+ bias_shape[0] == out_channels
138
+ ), f"bias shape[0]: {bias_shape[0]}, out channels: {out_channels}"
139
+ else: # Make dummy bias tensor
140
+ bias = add_placeholder(
141
+ exported_program, torch.zeros(out_channels), "bias"
142
+ )
143
+
144
+ group_size = in_channels // groups
145
+ out_group_size = out_channels // groups
146
+
147
+ with gm.graph.inserting_before(node):
148
+ conv2d_op = None
149
+ if isinstance(padding, list) and all(
150
+ isinstance(element, int) for element in padding
151
+ ):
152
+ conv2d_op = torch.ops.aten.conv2d.default
153
+ elif isinstance(padding, str):
154
+ conv2d_op = torch.ops.aten.conv2d.padding
155
+ else:
156
+ raise InvalidArgumentError(
157
+ f"Unsupported padding type: {padding}"
158
+ ) # Unreachable to here
159
+
160
+ conv2d_tensors = []
161
+ for i in range(groups):
162
+ sliced_input = create_node(
163
+ graph,
164
+ torch.ops.aten.slice.Tensor,
165
+ (input_, 1, group_size * i, group_size * (i + 1), 1),
166
+ origin=node,
167
+ )
168
+ sliced_weight = create_node(
169
+ graph,
170
+ torch.ops.aten.slice.Tensor,
171
+ (weight, 0, out_group_size * i, out_group_size * (i + 1), 1),
172
+ origin=node,
173
+ )
174
+ sliced_bias = create_node(
175
+ graph,
176
+ torch.ops.aten.slice.Tensor,
177
+ (bias, 0, out_group_size * i, out_group_size * (i + 1), 1),
178
+ origin=node,
179
+ )
180
+ conv2d_tensor = create_node(
181
+ graph,
182
+ conv2d_op,
183
+ (
184
+ sliced_input,
185
+ sliced_weight,
186
+ sliced_bias,
187
+ stride,
188
+ padding,
189
+ dilation,
190
+ 1,
191
+ ),
192
+ origin=node,
193
+ )
194
+ conv2d_tensors.append(conv2d_tensor)
195
+
196
+ concat_output = create_node(
197
+ graph, torch.ops.aten.cat.default, (conv2d_tensors, 1)
198
+ )
199
+
200
+ node.replace_all_uses_with(concat_output, propagate_meta=True)
201
+
202
+ modified = True
203
+ logger.debug(
204
+ f"{node.name} is replaced with groups of conv2d: The number of groups: {groups}, groups size: {group_size}"
205
+ )
206
+
207
+ graph.eliminate_dead_code()
208
+ gm.recompile()
209
+ return PassResult(modified)
@@ -0,0 +1,169 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, TYPE_CHECKING
17
+
18
+ if TYPE_CHECKING:
19
+ import torch.fx
20
+ import torch
21
+ from torch.export import ExportedProgram
22
+
23
+ from tico.serialize.circle_mapping import extract_shape
24
+
25
+ from tico.utils import logging
26
+ from tico.utils.graph import create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.utils import enforce_type, is_target_node
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class DecomposeSliceScatter(PassBase):
34
+ """
35
+ Let's decompose slice_scatter.default to cat.
36
+
37
+ slice_scatter with step=1 embeds src tensor to input tensor
38
+ We can replace it with (1) slicing input tensors and (2) concatenating all tensors
39
+
40
+ [1] When step = 1,
41
+
42
+ (1) Split input to input_0 and input_1 (either of them can be zero-size)
43
+ (2) Concatenate input_0, src, input_1
44
+
45
+ Before)
46
+
47
+ input src
48
+ | |
49
+ | |
50
+ | |
51
+ +--> slice_scatter <---+
52
+
53
+ After)
54
+
55
+ input
56
+ |-------------------------
57
+ | |
58
+ | |
59
+ | |
60
+ slice_copy slice_copy
61
+ | |
62
+ | |
63
+ | |
64
+ slice_0* src slice_1*
65
+ | | |
66
+ | | |
67
+ | | |
68
+ +---------> cat <---------+
69
+
70
+ *Either of slice_0 or slice_1 could be empty. Then it's ignored.
71
+
72
+ [2] When step > 1, not supported yet. (TBD)
73
+ """
74
+
75
+ def __init__(self):
76
+ super().__init__()
77
+
78
+ def call(self, exported_program: ExportedProgram) -> PassResult:
79
+ logger = logging.getLogger(__name__)
80
+
81
+ graph_module = exported_program.graph_module
82
+ graph: torch.fx.Graph = graph_module.graph
83
+ modified = False
84
+
85
+ for node in graph.nodes:
86
+ if not is_target_node(node, torch.ops.aten.slice_scatter.default):
87
+ continue
88
+
89
+ @enforce_type
90
+ @dataclass
91
+ class Args:
92
+ """
93
+ input (Tensor) the input tensor.
94
+ src (Tensor) The tensor to embed into input
95
+ dim (int) the dimension to insert the slice into
96
+ start (Optional[int]) the start index of where to insert the slice
97
+ end (Optional[int]) the end index of where to insert the slice
98
+ step (int) the how many elements to skip in
99
+ """
100
+
101
+ input: torch.fx.Node
102
+ src: torch.fx.Node
103
+ dim: int = 0
104
+ start: Optional[int] = None
105
+ end: Optional[int] = None
106
+ step: int = 1
107
+
108
+ args = Args(*node.args, **node.kwargs) # type: ignore[arg-type]
109
+
110
+ input = args.input
111
+ src = args.src
112
+ dim = args.dim
113
+ s = args.start
114
+ e = args.end
115
+ step = args.step
116
+
117
+ # TODO Support step > 1 cases
118
+ if step > 1:
119
+ raise RuntimeError(
120
+ f"slice_scatter with step > 1 is not yet supported. Node: {node}"
121
+ )
122
+
123
+ start: int = 0 if s is None else s
124
+ end: int = (
125
+ extract_shape(src)[dim]
126
+ if e is None
127
+ else min(extract_shape(src)[dim], e)
128
+ )
129
+
130
+ with graph.inserting_before(node):
131
+ slices = []
132
+
133
+ if 0 < start:
134
+ slice_0 = create_node(
135
+ graph,
136
+ torch.ops.aten.slice_copy.Tensor,
137
+ args=(input, dim, 0, start, 1),
138
+ origin=node,
139
+ )
140
+ slices.append(slice_0)
141
+
142
+ slices.append(src)
143
+
144
+ if start + end < extract_shape(input)[dim]:
145
+ slice_1 = create_node(
146
+ graph,
147
+ torch.ops.aten.slice_copy.Tensor,
148
+ args=(
149
+ input,
150
+ dim,
151
+ start + end,
152
+ extract_shape(input)[dim],
153
+ 1,
154
+ ),
155
+ origin=node,
156
+ )
157
+ slices.append(slice_1)
158
+
159
+ concat = create_node(
160
+ graph, torch.ops.aten.cat.default, args=(slices, dim)
161
+ )
162
+ node.replace_all_uses_with(concat, propagate_meta=True)
163
+
164
+ modified = True
165
+ logger.debug(f"{node.name} is replaced with slice_copy + concat")
166
+
167
+ graph.eliminate_dead_code()
168
+ graph_module.recompile()
169
+ return PassResult(modified)
@@ -0,0 +1,122 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+ from torch.utils import _pytree as pytree
22
+
23
+ from tico.utils import logging
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
27
+
28
+
29
+ def _extract_to_output(node: torch.fx.Node, graph: torch.fx.Graph) -> bool:
30
+ """
31
+ This extracts dtype kwargs to node's output direction
32
+
33
+ So, op(..., dtype = X) is converted to op(...).to(X)
34
+
35
+ Return true if modified
36
+
37
+ NOTE
38
+
39
+ [1] This function always returns true. Return value is introduced for extension
40
+ [2] This conversion is not safe for some Ops whose inputs should also be casted to X (ex: Mean).
41
+
42
+ """
43
+ logger = logging.getLogger(__name__)
44
+
45
+ node_kwargs = node.kwargs
46
+ # Remove "dtype" from node's kwargs
47
+ new_kwargs = {}
48
+ for k, v in node_kwargs.items():
49
+ if k == "dtype":
50
+ continue
51
+ new_kwargs[k] = v
52
+ node.kwargs = new_kwargs
53
+ # Create new val for node
54
+ # `node.target()` needs only `Tensor` for its arguments. Therefore, let's retrieve `FakeTensor` if it is `torch.fx.Node`.
55
+ args, kwargs = pytree.tree_map_only(
56
+ torch.fx.Node, lambda x: x.meta["val"], (node.args, node.kwargs)
57
+ )
58
+ new_val = node.target(*args, **kwargs) # type: ignore[operator]
59
+ # Set args, kwargs of `to_copy`
60
+ to_args = (node,)
61
+ to_kwargs = {"dtype": node_kwargs["dtype"]}
62
+ with graph.inserting_after(node):
63
+ to_copy = graph.call_function(torch.ops.aten._to_copy.default, (), {})
64
+ node.replace_all_uses_with(to_copy, propagate_meta=True)
65
+ # Q) Why lazy-update args, kwargs of the `to_copy`?
66
+ # A) `replace_all_uses_with` replace all the uses of `node`. If `to_copy` args is set to
67
+ # (node, ) before `replace_all_uses_with`, the function would even replace the args of
68
+ # `to_copy` with `to_copy`.
69
+ to_copy.args = to_args
70
+ to_copy.kwargs = to_kwargs
71
+ # Update meta["val"] to change dtype
72
+ node.meta["val"] = new_val
73
+
74
+ logger.debug(f"{node.name}'s dtype kwargs is extracted into {to_copy.name}")
75
+
76
+ return True
77
+
78
+
79
+ @trace_graph_diff_on_pass
80
+ class ExtractDtypeKwargsPass(PassBase):
81
+ """
82
+ This pass extracts "dtype" keyword argument from nodes.
83
+
84
+ Sometimes, torch api receives "dtype" keyword argument.
85
+
86
+ E.g. x_bool = torch.full_like(x, 0, dtype=torch.bool)
87
+
88
+ But, this argument makes circle build logic complicated because many operators has
89
+ same type with their inputs'.
90
+
91
+ So, this pass changes `op(dtype)` to `op + to(dtype)`.
92
+
93
+ NOTE
94
+
95
+ [1] There are some ops that are natural to have "dtype" kwargs. The pass is not applied to those ops.
96
+ [2] If node.kwargs["dtype"] is redundant `op(dtype).dtype == op().dtype`, the pass is not applied.
97
+
98
+ """
99
+
100
+ def __init__(self):
101
+ super().__init__()
102
+ # List of Ops whose "dtype" kwargs is extracted
103
+ self.target_ops = dict()
104
+ self.target_ops[torch.ops.aten.full_like.default] = _extract_to_output
105
+
106
+ def call(self, exported_program: ExportedProgram) -> PassResult:
107
+ graph_module = exported_program.graph_module
108
+ graph: torch.fx.Graph = graph_module.graph
109
+ modified = False
110
+ for node in graph.nodes:
111
+ if not is_target_node(node, list(self.target_ops.keys())):
112
+ continue
113
+ if "dtype" not in node.kwargs:
114
+ continue
115
+
116
+ modified |= self.target_ops[node.target](node, graph)
117
+
118
+ graph.eliminate_dead_code()
119
+ graph.lint()
120
+ graph_module.recompile()
121
+
122
+ return PassResult(modified)
@@ -0,0 +1,57 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torch.export import ExportedProgram
16
+
17
+ from tico.utils import logging
18
+ from tico.utils.passes import PassBase, PassResult
19
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
20
+ from tico.utils.utils import set_new_meta_val
21
+
22
+
23
+ @trace_graph_diff_on_pass
24
+ class FillMetaVal(PassBase):
25
+ """
26
+ Let's set new meta['val'] for nodes which don't have meta['val']
27
+ """
28
+
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ def call(self, exported_program: ExportedProgram) -> PassResult:
33
+ logger = logging.getLogger(__name__)
34
+
35
+ graph_module = exported_program.graph_module
36
+ graph = graph_module.graph
37
+ modified = False
38
+ # To make sure graph is topologically sorted
39
+ graph.lint()
40
+ for node in graph.nodes:
41
+ if not node.op == "call_function":
42
+ continue
43
+
44
+ if hasattr(node, "meta") and "val" in node.meta:
45
+ continue
46
+
47
+ set_new_meta_val(node)
48
+
49
+ modified = True
50
+
51
+ logger.debug(f"{node.name} has new meta values.")
52
+
53
+ graph.eliminate_dead_code()
54
+ graph.lint()
55
+ graph_module.recompile()
56
+
57
+ return PassResult(modified)
@@ -0,0 +1,112 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Sequence
16
+
17
+ import torch
18
+ from torch.export import ExportedProgram
19
+
20
+ from tico.passes import ops
21
+ from tico.serialize.circle_mapping import extract_shape
22
+ from tico.utils import logging
23
+ from tico.utils.graph import create_node
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
27
+ from tico.utils.validate_args_kwargs import PermuteArgs, ReshapeArgs
28
+
29
+
30
+ def _is_leading_unsqueeze(target: Sequence[int], permuted: Sequence[int]) -> bool:
31
+ """
32
+ True if `target` == [1]*k + permuted, k>=1.
33
+ """
34
+ k = len(target) - len(permuted)
35
+ return (
36
+ k > 0 and all(d == 1 for d in target[:k]) and list(target[k:]) == list(permuted)
37
+ )
38
+
39
+
40
+ @trace_graph_diff_on_pass
41
+ class FuseLeadingUnsqueezeReshape(PassBase):
42
+ """
43
+ Fuse reshape → permute → reshape where the second reshape only
44
+ prepends one-sized dims (unsqueeze) to the permuted tensor.
45
+
46
+ [BEFORE]
47
+ x - aten.reshape(s1) - aten.permute(p) - aten.reshape([1]*k + p(s1))
48
+ [AFTER]
49
+ x - aten.reshape([1]*k + s1) - aten.permute(list(range(k)) + [d+k for d in p])
50
+ """
51
+
52
+ def call(self, ep: ExportedProgram) -> PassResult:
53
+ logger = logging.getLogger(__name__)
54
+
55
+ gm = ep.graph_module
56
+ graph = gm.graph
57
+ modified = False
58
+ for reshape_back in graph.nodes:
59
+ if not is_target_node(reshape_back, ops.aten.reshape):
60
+ continue
61
+ reshape_back_args = ReshapeArgs(*reshape_back.args, **reshape_back.kwargs) # type: ignore[arg-type]
62
+ permute = reshape_back_args.input
63
+
64
+ if not is_target_node(permute, ops.aten.permute):
65
+ continue
66
+ permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
67
+ reshape_front, permute_dims = permute_args.input, permute_args.dims
68
+
69
+ if not is_target_node(reshape_front, ops.aten.reshape):
70
+ continue
71
+ reshape_front_args = ReshapeArgs(*reshape_front.args, **reshape_front.kwargs) # type: ignore[arg-type]
72
+ reshape_front_input, reshape_front_size = (
73
+ reshape_front_args.input,
74
+ reshape_front_args.shape,
75
+ )
76
+
77
+ # ---- condition: only leading unsqueeze ------------------------
78
+ back_shape = extract_shape(reshape_back)
79
+ permute_shape = extract_shape(permute)
80
+
81
+ if not _is_leading_unsqueeze(back_shape, permute_shape):
82
+ continue
83
+
84
+ # ---- create new reshape & new permute -------------------------
85
+ k = len(back_shape) - len(permute_shape)
86
+ with graph.inserting_before(permute):
87
+ new_shape = [1] * k + list(reshape_front_size)
88
+ r_new = create_node(
89
+ graph,
90
+ torch.ops.aten.reshape.default,
91
+ args=(reshape_front_input, new_shape),
92
+ origin=reshape_back,
93
+ )
94
+ new_p_dims = list(range(k)) + [
95
+ d + k for d in permute_dims
96
+ ] # shift by k
97
+ p_new = create_node(
98
+ graph,
99
+ torch.ops.aten.permute.default,
100
+ args=(r_new, new_p_dims),
101
+ )
102
+
103
+ reshape_back.replace_all_uses_with(p_new, propagate_meta=True)
104
+ modified = True
105
+ logger.debug(f"{reshape_back.name} is fused to {r_new.name}")
106
+
107
+ if modified:
108
+ graph.eliminate_dead_code()
109
+ graph.lint()
110
+ gm.recompile()
111
+
112
+ return PassResult(modified)