tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,202 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
26
+ from tico.utils.graph import add_placeholder
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.validate_args_kwargs import Conv2DArgs
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class DecomposeGroupedConv2d(PassBase):
34
+ """
35
+ This pass decomposes grouped Conv2d operator as multiple Conv2d operator whose groups=1.
36
+
37
+ Grouped Conv2d denotes a Conv2d operator whose `groups` argument is not equal to input channels nor 1.
38
+
39
+ [before]
40
+
41
+ input weight bias
42
+ | | |
43
+ +-----------+-----------+
44
+ |
45
+ Conv2d (groups != IN_CHANNEL && groups != 1)
46
+ |
47
+ output
48
+
49
+ [after]
50
+
51
+ The below `slice` operators slice the input tensor, weight and bias along the channel axis by the number of `groups`.
52
+ In addition, the numbered input, weight and bias denotes sliced input tensor, weight and bias respectively.
53
+
54
+ input
55
+ | weight
56
+ slice | bias
57
+ | slice |
58
+ | | slice
59
+ | | |
60
+ +---------------------------+---------------------------+
61
+ | | | | |
62
+ | +---------------------------+---------------------------+
63
+ | | | | | | |
64
+ | | +---------------------------+---------------------------+
65
+ | | | | | | | | |
66
+ input_1 | | ... | | input_N | |
67
+ | weight_1 | | ... | | weight_N |
68
+ | | bias_1 | | ... | | bias_N
69
+ +---------+---------+ +---------+---------+ +---------+---------+
70
+ | | |
71
+ Conv2d_1 ... Conv2d_N
72
+ | | |
73
+ +---------------------------+---------------------------+
74
+ |
75
+ concat
76
+ |
77
+ output
78
+ """
79
+
80
+ def __init__(self):
81
+ super().__init__()
82
+
83
+ def call(self, exported_program: ExportedProgram) -> PassResult:
84
+ logger = logging.getLogger(__name__)
85
+
86
+ gm = exported_program.graph_module
87
+ graph: torch.fx.Graph = gm.graph
88
+ modified = False
89
+
90
+ for node in graph.nodes:
91
+ if node.op != "call_function":
92
+ continue
93
+ if not node.target in ops.aten.conv2d:
94
+ continue
95
+
96
+ args = Conv2DArgs(*node.args)
97
+ input_ = args.input
98
+ weight = args.weight
99
+ bias = args.bias
100
+ stride = args.stride
101
+ padding = args.padding
102
+ dilation = args.dilation
103
+ groups = args.groups
104
+
105
+ input_shape = extract_shape(input_)
106
+ if not len(input_shape) == 4:
107
+ raise NotYetSupportedError(
108
+ f"Only support 4D input tensor: node's input shape: {input_shape}"
109
+ )
110
+
111
+ in_channels = input_shape[1]
112
+ if groups == 1 or groups == in_channels:
113
+ continue
114
+ assert (
115
+ in_channels % groups == 0
116
+ ), f"in_channels should be divisible by groups: in_channels: {in_channels}, groups: {groups}"
117
+
118
+ output_shape = extract_shape(node)
119
+ assert len(output_shape) == 4, len(output_shape)
120
+
121
+ out_channels = output_shape[1]
122
+ assert (
123
+ out_channels % groups == 0
124
+ ), f"out_channels should be divisible by groups: out_channels: {out_channels}, groups: {groups}"
125
+
126
+ weight_shape = extract_shape(weight)
127
+ assert len(weight_shape) == 4, len(weight_shape)
128
+ assert (
129
+ weight_shape[0] == out_channels
130
+ ), f"weight shape[0]: {weight_shape[0]}, out channels: {out_channels}"
131
+ assert (
132
+ weight_shape[1] == in_channels // groups
133
+ ), f"weight shape[1]: {weight_shape[1]}, in channels: {in_channels}"
134
+
135
+ if bias is not None:
136
+ bias_shape = extract_shape(bias)
137
+ assert (
138
+ bias_shape[0] == out_channels
139
+ ), f"bias shape[0]: {bias_shape[0]}, out channels: {out_channels}"
140
+ else: # Make dummy bias tensor
141
+ bias = add_placeholder(
142
+ exported_program, torch.zeros(out_channels), "bias"
143
+ )
144
+
145
+ group_size = in_channels // groups
146
+ out_group_size = out_channels // groups
147
+
148
+ with gm.graph.inserting_before(node):
149
+ conv2d_op = None
150
+ if isinstance(padding, list) and all(
151
+ isinstance(element, int) for element in padding
152
+ ):
153
+ conv2d_op = torch.ops.aten.conv2d.default
154
+ elif isinstance(padding, str):
155
+ conv2d_op = torch.ops.aten.conv2d.padding
156
+ else:
157
+ raise InvalidArgumentError(
158
+ f"Unsupported padding type: {padding}"
159
+ ) # Unreachable to here
160
+
161
+ conv2d_tensors = []
162
+ for i in range(groups):
163
+ sliced_input = graph.call_function(
164
+ torch.ops.aten.slice.Tensor,
165
+ (input_, 1, group_size * i, group_size * (i + 1), 1),
166
+ )
167
+ sliced_weight = graph.call_function(
168
+ torch.ops.aten.slice.Tensor,
169
+ (weight, 0, out_group_size * i, out_group_size * (i + 1), 1),
170
+ )
171
+ sliced_bias = graph.call_function(
172
+ torch.ops.aten.slice.Tensor,
173
+ (bias, 0, out_group_size * i, out_group_size * (i + 1), 1),
174
+ )
175
+ conv2d_tensor = graph.call_function(
176
+ conv2d_op,
177
+ (
178
+ sliced_input,
179
+ sliced_weight,
180
+ sliced_bias,
181
+ stride,
182
+ padding,
183
+ dilation,
184
+ 1,
185
+ ),
186
+ )
187
+ conv2d_tensors.append(conv2d_tensor)
188
+
189
+ concat_output = graph.call_function(
190
+ torch.ops.aten.cat.default, (conv2d_tensors, 1)
191
+ )
192
+
193
+ node.replace_all_uses_with(concat_output, propagate_meta=True)
194
+
195
+ modified = True
196
+ logger.debug(
197
+ f"{node.name} is replaced with groups of conv2d: The number of groups: {groups}, groups size: {group_size}"
198
+ )
199
+
200
+ graph.eliminate_dead_code()
201
+ gm.recompile()
202
+ return PassResult(modified)
@@ -0,0 +1,167 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, TYPE_CHECKING
17
+
18
+ if TYPE_CHECKING:
19
+ import torch.fx
20
+ import torch
21
+ from torch.export import ExportedProgram
22
+
23
+ from tico.serialize.circle_mapping import extract_shape
24
+
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import enforce_type
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class DecomposeSliceScatter(PassBase):
33
+ """
34
+ Let's decompose slice_scatter.default to cat.
35
+
36
+ slice_scatter with step=1 embeds src tensor to input tensor
37
+ We can replace it with (1) slicing input tensors and (2) concatenating all tensors
38
+
39
+ [1] When step = 1,
40
+
41
+ (1) Split input to input_0 and input_1 (either of them can be zero-size)
42
+ (2) Concatenate input_0, src, input_1
43
+
44
+ Before)
45
+
46
+ input src
47
+ | |
48
+ | |
49
+ | |
50
+ +--> slice_scatter <---+
51
+
52
+ After)
53
+
54
+ input
55
+ |-------------------------
56
+ | |
57
+ | |
58
+ | |
59
+ slice_copy slice_copy
60
+ | |
61
+ | |
62
+ | |
63
+ slice_0* src slice_1*
64
+ | | |
65
+ | | |
66
+ | | |
67
+ +---------> cat <---------+
68
+
69
+ *Either of slice_0 or slice_1 could be empty. Then it's ignored.
70
+
71
+ [2] When step > 1, not supported yet. (TBD)
72
+ """
73
+
74
+ def __init__(self):
75
+ super().__init__()
76
+
77
+ def call(self, exported_program: ExportedProgram) -> PassResult:
78
+ logger = logging.getLogger(__name__)
79
+
80
+ graph_module = exported_program.graph_module
81
+ graph: torch.fx.Graph = graph_module.graph
82
+ modified = False
83
+
84
+ for node in graph.nodes:
85
+ if node.op != "call_function":
86
+ continue
87
+ if node.target != torch.ops.aten.slice_scatter.default:
88
+ continue
89
+
90
+ @enforce_type
91
+ @dataclass
92
+ class Args:
93
+ """
94
+ input (Tensor) the input tensor.
95
+ src (Tensor) The tensor to embed into input
96
+ dim (int) the dimension to insert the slice into
97
+ start (Optional[int]) the start index of where to insert the slice
98
+ end (Optional[int]) the end index of where to insert the slice
99
+ step (int) the how many elements to skip in
100
+ """
101
+
102
+ input: torch.fx.Node
103
+ src: torch.fx.Node
104
+ dim: int = 0
105
+ start: Optional[int] = None
106
+ end: Optional[int] = None
107
+ step: int = 1
108
+
109
+ args = Args(*node.args, **node.kwargs) # type: ignore[arg-type]
110
+
111
+ input = args.input
112
+ src = args.src
113
+ dim = args.dim
114
+ s = args.start
115
+ e = args.end
116
+ step = args.step
117
+
118
+ # TODO Support step > 1 cases
119
+ if step > 1:
120
+ raise RuntimeError(
121
+ f"slice_scatter with step > 1 is not yet supported. Node: {node}"
122
+ )
123
+
124
+ start: int = 0 if s is None else s
125
+ end: int = (
126
+ extract_shape(src)[dim]
127
+ if e is None
128
+ else min(extract_shape(src)[dim], e)
129
+ )
130
+
131
+ with graph.inserting_before(node):
132
+ slices = []
133
+
134
+ if 0 < start:
135
+ slice_0 = graph.call_function(
136
+ torch.ops.aten.slice_copy.Tensor,
137
+ args=(input, dim, 0, start, 1),
138
+ )
139
+ slices.append(slice_0)
140
+
141
+ slices.append(src)
142
+
143
+ if start + end < extract_shape(input)[dim]:
144
+ slice_1 = graph.call_function(
145
+ torch.ops.aten.slice_copy.Tensor,
146
+ args=(
147
+ input,
148
+ dim,
149
+ start + end,
150
+ extract_shape(input)[dim],
151
+ 1,
152
+ ),
153
+ )
154
+ slices.append(slice_1)
155
+
156
+ concat = graph.call_function(
157
+ torch.ops.aten.cat.default, args=(slices, dim)
158
+ )
159
+ # Not set meta for propagating replacing node's meta.
160
+ node.replace_all_uses_with(concat, propagate_meta=True)
161
+
162
+ modified = True
163
+ logger.debug(f"{node.name} is replaced with slice_copy + concat")
164
+
165
+ graph.eliminate_dead_code()
166
+ graph_module.recompile()
167
+ return PassResult(modified)
@@ -0,0 +1,121 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+ from torch.utils import _pytree as pytree
22
+
23
+ from tico.utils import logging
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+
27
+
28
+ def _extract_to_output(node: torch.fx.Node, graph: torch.fx.Graph) -> bool:
29
+ """
30
+ This extracts dtype kwargs to node's output direction
31
+
32
+ So, op(..., dtype = X) is converted to op(...).to(X)
33
+
34
+ Return true if modified
35
+
36
+ NOTE
37
+
38
+ [1] This function always returns true. Return value is introduced for extension
39
+ [2] This conversion is not safe for some Ops whose inputs should also be casted to X (ex: Mean).
40
+
41
+ """
42
+ logger = logging.getLogger(__name__)
43
+
44
+ node_kwargs = node.kwargs
45
+ # Remove "dtype" from node's kwargs
46
+ new_kwargs = {}
47
+ for k, v in node_kwargs.items():
48
+ if k == "dtype":
49
+ continue
50
+ new_kwargs[k] = v
51
+ node.kwargs = new_kwargs
52
+ # Create new val for node
53
+ # `node.target()` needs only `Tensor` for its arguments. Therefore, let's retrieve `FakeTensor` if it is `torch.fx.Node`.
54
+ args, kwargs = pytree.tree_map_only(
55
+ torch.fx.Node, lambda x: x.meta["val"], (node.args, node.kwargs)
56
+ )
57
+ new_val = node.target(*args, **kwargs) # type: ignore[operator]
58
+ # Set args, kwargs of `to_copy`
59
+ to_args = (node,)
60
+ to_kwargs = {"dtype": node_kwargs["dtype"]}
61
+ with graph.inserting_after(node):
62
+ to_copy = graph.call_function(torch.ops.aten._to_copy.default, (), {})
63
+ node.replace_all_uses_with(to_copy, propagate_meta=True)
64
+ # Q) Why lazy-update args, kwargs of the `to_copy`?
65
+ # A) `replace_all_uses_with` replace all the uses of `node`. If `to_copy` args is set to
66
+ # (node, ) before `replace_all_uses_with`, the function would even replace the args of
67
+ # `to_copy` with `to_copy`.
68
+ to_copy.args = to_args
69
+ to_copy.kwargs = to_kwargs
70
+ # Update meta["val"] to change dtype
71
+ node.meta["val"] = new_val
72
+
73
+ logger.debug(f"{node.name}'s dtype kwargs is extracted into {to_copy.name}")
74
+
75
+ return True
76
+
77
+
78
+ @trace_graph_diff_on_pass
79
+ class ExtractDtypeKwargsPass(PassBase):
80
+ """
81
+ This pass extracts "dtype" keyword argument from nodes.
82
+
83
+ Sometimes, torch api receives "dtype" keyword argument.
84
+
85
+ E.g. x_bool = torch.full_like(x, 0, dtype=torch.bool)
86
+
87
+ But, this argument makes circle build logic complicated because many operators has
88
+ same type with their inputs'.
89
+
90
+ So, this pass changes `op(dtype)` to `op + to(dtype)`.
91
+
92
+ NOTE
93
+
94
+ [1] There are some ops that are natural to have "dtype" kwargs. The pass is not applied to those ops.
95
+ [2] If node.kwargs["dtype"] is redundant `op(dtype).dtype == op().dtype`, the pass is not applied.
96
+
97
+ """
98
+
99
+ def __init__(self):
100
+ super().__init__()
101
+ # List of Ops whose "dtype" kwargs is extracted
102
+ self.target_ops = dict()
103
+ self.target_ops[torch.ops.aten.full_like.default] = _extract_to_output
104
+
105
+ def call(self, exported_program: ExportedProgram) -> PassResult:
106
+ graph_module = exported_program.graph_module
107
+ graph: torch.fx.Graph = graph_module.graph
108
+ modified = False
109
+ for node in graph.nodes:
110
+ if not node.op == "call_function" or node.target not in self.target_ops:
111
+ continue
112
+ if "dtype" not in node.kwargs:
113
+ continue
114
+
115
+ modified |= self.target_ops[node.target](node, graph)
116
+
117
+ graph.eliminate_dead_code()
118
+ graph.lint()
119
+ graph_module.recompile()
120
+
121
+ return PassResult(modified)
@@ -0,0 +1,57 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torch.export import ExportedProgram
16
+
17
+ from tico.utils import logging
18
+ from tico.utils.passes import PassBase, PassResult
19
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
20
+ from tico.utils.utils import set_new_meta_val
21
+
22
+
23
+ @trace_graph_diff_on_pass
24
+ class FillMetaVal(PassBase):
25
+ """
26
+ Let's set new meta['val'] for nodes which don't have meta['val']
27
+ """
28
+
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ def call(self, exported_program: ExportedProgram) -> PassResult:
33
+ logger = logging.getLogger(__name__)
34
+
35
+ graph_module = exported_program.graph_module
36
+ graph = graph_module.graph
37
+ modified = False
38
+ # To make sure graph is topologically sorted
39
+ graph.lint()
40
+ for node in graph.nodes:
41
+ if not node.op == "call_function":
42
+ continue
43
+
44
+ if hasattr(node, "meta") and "val" in node.meta:
45
+ continue
46
+
47
+ set_new_meta_val(node)
48
+
49
+ modified = True
50
+
51
+ logger.debug(f"{node.name} has new meta values.")
52
+
53
+ graph.eliminate_dead_code()
54
+ graph.lint()
55
+ graph_module.recompile()
56
+
57
+ return PassResult(modified)
@@ -0,0 +1,102 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+ from torch.utils import _pytree as pytree
22
+
23
+ from tico.passes import ops
24
+ from tico.serialize.circle_mapping import extract_shape
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+
29
+
30
+ @trace_graph_diff_on_pass
31
+ class FuseRedundantReshapeToMean(PassBase):
32
+ """
33
+ This pass removes redundant `aten.reshape` operators that can be fused to `aten.mean` with `keep_dims`.
34
+
35
+ Shape(aten.reshape(aten.mean(input))) == Shape(aten.mean(input, keep_dims=True))
36
+ """
37
+
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ def call(self, exported_program: ExportedProgram) -> PassResult:
42
+ logger = logging.getLogger(__name__)
43
+
44
+ graph_module = exported_program.graph_module
45
+ graph = graph_module.graph
46
+ modified = False
47
+ for node in graph.nodes:
48
+ if not node.op == "call_function":
49
+ continue
50
+
51
+ if node.target != torch.ops.aten.mean.dim:
52
+ continue
53
+
54
+ # If mean is being used in other nodes, do not fuse it.
55
+ if len(node.users) != 1:
56
+ continue
57
+
58
+ user_node = next(iter(node.users))
59
+ if user_node.target not in ops.aten.reshape:
60
+ continue
61
+
62
+ mean_args, mean_kwargs = pytree.tree_map_only(
63
+ torch.fx.Node,
64
+ lambda n: n.meta["val"],
65
+ (node.args, node.kwargs),
66
+ )
67
+ # Signature of aten.mean.dim is as follows.
68
+ # mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
69
+ # `keepdim` in `node.kwargs` is moved to `node.args` in `run_decompositions`.
70
+ # `dtype` in `node.kwargs` is not moved
71
+ assert len(mean_args) == 3 or len(mean_args) == 2 # keepdim exists or not
72
+ assert len(mean_kwargs) <= 1 # dtype exists or not
73
+ fused_mean_args = mean_args
74
+ keep_dims = True
75
+ if len(mean_args) == 2:
76
+ fused_mean_args += (keep_dims,)
77
+
78
+ fused_val = node.target(*fused_mean_args, **mean_kwargs)
79
+
80
+ # Check if both shapes are same
81
+ # 1. Shape(aten.reshape(aten.mean))
82
+ # 2. Shape(aten.mean(keep_dims=True))
83
+ if fused_val.size() != extract_shape(user_node):
84
+ continue
85
+
86
+ # update args
87
+ if len(mean_args) == 2:
88
+ updated_args = node.args + (keep_dims,)
89
+ elif len(mean_args) == 3:
90
+ updated_args = node.args
91
+ node.args = updated_args
92
+ node.meta["val"] = fused_val
93
+ user_node.replace_all_uses_with(node, propagate_meta=False)
94
+
95
+ modified = True
96
+ logger.debug(f"{user_node.name} is replaced with {node.name}")
97
+
98
+ graph.eliminate_dead_code()
99
+ graph.lint()
100
+ graph_module.recompile()
101
+
102
+ return PassResult(modified)