tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
21
+ from torch.export import ExportedProgram
22
+
23
+ from tico.serialize.circle_mapping import extract_shape
24
+
25
+ from tico.utils import logging
26
+ from tico.utils.graph import create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.validate_args_kwargs import BmmArgs, MatmulArgs
30
+
31
+
32
+ class Converter: # type: ignore[empty-body]
33
+ def __init__(self):
34
+ super().__init__()
35
+
36
+ def match(self, exported_program, node) -> bool: # type: ignore[empty-body]
37
+ return False
38
+
39
+ def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body]
40
+ pass
41
+
42
+
43
+ class MatmulToLinearConverter(Converter):
44
+ def __init__(self):
45
+ super().__init__()
46
+
47
+ def convert(self, exported_program, node) -> torch.fx.Node:
48
+ graph_module = exported_program.graph_module
49
+ graph = graph_module.graph
50
+
51
+ mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
52
+
53
+ lhs = mm_args.input
54
+ rhs = mm_args.other
55
+
56
+ with graph.inserting_before(node):
57
+ transpose_node = create_node(
58
+ graph,
59
+ torch.ops.aten.permute.default,
60
+ args=(rhs, [1, 0]),
61
+ )
62
+ linear_node = create_node(
63
+ graph,
64
+ torch.ops.aten.linear.default,
65
+ args=(lhs, transpose_node),
66
+ )
67
+ node.replace_all_uses_with(linear_node, propagate_meta=True)
68
+
69
+ return linear_node
70
+
71
+
72
+ class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
73
+ def __init__(self):
74
+ super().__init__()
75
+
76
+ def match(self, exported_program, node) -> bool:
77
+ if not node.target == torch.ops.aten.mm.default:
78
+ return False
79
+
80
+ mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
81
+
82
+ rhs = mm_args.other
83
+ if isinstance(rhs, torch.fx.Node):
84
+ if is_lifted_tensor_constant(exported_program, rhs):
85
+ return True
86
+ elif is_param(exported_program, rhs):
87
+ return True
88
+ elif is_buffer(exported_program, rhs):
89
+ return True
90
+ else:
91
+ return False
92
+ return False
93
+
94
+ def convert(self, exported_program, node) -> torch.fx.Node:
95
+ return super().convert(exported_program, node)
96
+
97
+
98
+ class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
99
+ def __init__(self):
100
+ super().__init__()
101
+
102
+ def match(self, exported_program, node) -> bool:
103
+ if not node.target == torch.ops.aten.mm.default:
104
+ return False
105
+
106
+ mm_args = MatmulArgs(*node.args, **node.kwargs)
107
+ lhs = mm_args.input
108
+ if isinstance(lhs, torch.fx.Node):
109
+ if is_lifted_tensor_constant(exported_program, lhs):
110
+ return True
111
+ elif is_param(exported_program, lhs):
112
+ return True
113
+ elif is_buffer(exported_program, lhs):
114
+ return True
115
+ return False
116
+
117
+ def convert(self, exported_program, node) -> torch.fx.Node:
118
+ return super().convert(exported_program, node)
119
+
120
+
121
+ class SingleBatchLhsConstBmmToLinearConverter(Converter):
122
+ """
123
+ Convert `single-batched & lhs-const BatchMatMul` to `linear` operation.
124
+
125
+ [1] exchange lhs and rhs
126
+ [2] transpose rhs
127
+ [3] transpose output
128
+
129
+ **Before**
130
+
131
+ lhs[1,a,b](const) rhs[1,b,c]
132
+ | |
133
+ | |
134
+ ---------bmm---------
135
+ |
136
+ output[1,a,c]
137
+
138
+
139
+ **After**
140
+
141
+ rhs[1,b,c]
142
+ |
143
+ tr lhs'[a,b](const-folded)
144
+ |[1,c,b] |
145
+ | |
146
+ ---------fc--------
147
+ |[1,c,a]
148
+ tr
149
+ |
150
+ output[1,a,c]
151
+
152
+ """
153
+
154
+ def __init__(self):
155
+ super().__init__()
156
+
157
+ def match(self, exported_program, node) -> bool:
158
+ if not node.target == torch.ops.aten.bmm.default:
159
+ return False
160
+
161
+ bmm_args = BmmArgs(*node.args, **node.kwargs)
162
+ lhs = bmm_args.input
163
+ rhs = bmm_args.mat2
164
+
165
+ # [1] Single-batch
166
+ lhs_shape = extract_shape(lhs)
167
+ rhs_shape = extract_shape(rhs)
168
+
169
+ assert len(lhs_shape) == len(
170
+ rhs_shape
171
+ ), f"Bmm input's ranks must be the same but got {lhs_shape} and {rhs_shape}"
172
+
173
+ if not (lhs_shape[0] == rhs_shape[0] == 1):
174
+ return False
175
+
176
+ # [2] Lhs is constant
177
+ if not isinstance(lhs, torch.fx.Node):
178
+ return False
179
+ if not (
180
+ is_lifted_tensor_constant(exported_program, lhs)
181
+ or is_param(exported_program, lhs)
182
+ or is_buffer(exported_program, lhs)
183
+ ):
184
+ return False
185
+
186
+ return True
187
+
188
+ def convert(self, exported_program, node) -> torch.fx.Node:
189
+ graph_module = exported_program.graph_module
190
+ graph = graph_module.graph
191
+
192
+ bmm_args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
193
+
194
+ lhs = bmm_args.input # const
195
+ rhs = bmm_args.mat2 # non-const
196
+ lhs_shape = extract_shape(lhs)
197
+ rhs_shape = extract_shape(rhs)
198
+ assert rhs_shape[0] == 1
199
+ assert lhs_shape[0] == 1
200
+
201
+ with graph.inserting_before(node):
202
+ rhs_tr = create_node(
203
+ graph,
204
+ torch.ops.aten.permute.default,
205
+ args=(rhs, [0, 2, 1]),
206
+ )
207
+ lhs_reshape = create_node(
208
+ graph,
209
+ torch.ops.aten.view.default,
210
+ args=(lhs, list(lhs_shape[1:])),
211
+ )
212
+
213
+ linear_node = create_node(
214
+ graph,
215
+ torch.ops.aten.linear.default,
216
+ args=(rhs_tr, lhs_reshape),
217
+ )
218
+
219
+ tr_linear_node = create_node(
220
+ graph,
221
+ torch.ops.aten.permute.default,
222
+ args=(linear_node, [0, 2, 1]),
223
+ )
224
+
225
+ node.replace_all_uses_with(tr_linear_node, propagate_meta=False)
226
+
227
+ return tr_linear_node
228
+
229
+
230
+ @trace_graph_diff_on_pass
231
+ class ConvertMatmulToLinear(PassBase):
232
+ """
233
+ This pass converts matmul(partially includes single-batch bmm) to linear selectively
234
+
235
+ How to select between `matmul` and `linear`?
236
+
237
+ * Linear has better quantization accuracy (NPU backend)
238
+ Due to ONE compiler's quantization policy;
239
+ FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input.
240
+ BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs.
241
+
242
+ * Matmul to Linear requires Transpose, which may harm latency
243
+ When RHS is constant, addtional transpose can be folded.
244
+
245
+ [RHS non-const case]
246
+ Constant folding cannot be performed.
247
+
248
+ lhs rhs (non-const)
249
+ | |
250
+ | transpose
251
+ | |
252
+ -- linear --
253
+ |
254
+ out
255
+
256
+ [RHS const case]
257
+ Constant folding can be performed to
258
+
259
+ lhs rhs (const) lh rhs (folded const)
260
+ | | | |
261
+ | transpose | |
262
+ | | | |
263
+ -- linear -- --> -- linear --
264
+ | |
265
+ out out
266
+
267
+
268
+ enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False.
269
+ enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True.
270
+ """
271
+
272
+ def __init__(
273
+ self,
274
+ enable_lhs_const: Optional[bool] = False,
275
+ enable_rhs_const: Optional[bool] = True,
276
+ enable_single_batch_lhs_const_bmm: Optional[bool] = False,
277
+ ):
278
+ super().__init__()
279
+ self.converters: List[Converter] = []
280
+ if enable_lhs_const:
281
+ self.converters.append(LhsConstMatmulToLinearConverter())
282
+ if enable_rhs_const:
283
+ self.converters.append(RhsConstMatmulToLinearConverter())
284
+ if enable_single_batch_lhs_const_bmm:
285
+ self.converters.append(SingleBatchLhsConstBmmToLinearConverter())
286
+
287
+ def call(self, exported_program: ExportedProgram) -> PassResult:
288
+ logger = logging.getLogger(__name__)
289
+
290
+ graph_module = exported_program.graph_module
291
+ graph = graph_module.graph
292
+ modified = False
293
+ for node in graph.nodes:
294
+ if not node.op == "call_function":
295
+ continue
296
+
297
+ for converter in self.converters:
298
+ if not converter.match(exported_program, node):
299
+ continue
300
+
301
+ new_node = converter.convert(exported_program, node)
302
+ modified = True
303
+ logger.debug(
304
+ f"{node.name} is replaced with {new_node.name} operator (permute + linear)"
305
+ )
306
+ continue
307
+
308
+ graph.eliminate_dead_code()
309
+ graph.lint()
310
+ graph_module.recompile()
311
+
312
+ return PassResult(modified)
@@ -172,7 +172,7 @@ class ConvertToReLU6(PassBase):
172
172
  converter.convert(exported_program, node)
173
173
  modified = True
174
174
  logger.debug(f"{node.name} is replaced with ReLU6 operator")
175
- break
175
+ continue
176
176
 
177
177
  graph.eliminate_dead_code()
178
178
  graph.lint()
@@ -20,7 +20,6 @@ import torch
20
20
  from torch.export import ExportedProgram
21
21
 
22
22
  from tico.serialize.circle_mapping import extract_shape
23
- from tico.utils import logging
24
23
  from tico.utils.graph import add_placeholder, create_node
25
24
  from tico.utils.passes import PassBase, PassResult
26
25
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
@@ -59,8 +58,6 @@ class DecomposeAddmm(PassBase):
59
58
  super().__init__()
60
59
 
61
60
  def call(self, exported_program: ExportedProgram) -> PassResult:
62
- logger = logging.getLogger(__name__)
63
-
64
61
  gm = exported_program.graph_module
65
62
  graph: torch.fx.Graph = gm.graph
66
63
  modified = False
@@ -96,9 +96,9 @@ class DecomposeBatchNorm(PassBase):
96
96
  eps = args.eps
97
97
 
98
98
  if not running_mean:
99
- raise NotYetSupportedError(f"running_mean=None is not supported yet")
99
+ raise NotYetSupportedError("running_mean=None is not supported yet")
100
100
  if not running_var:
101
- raise NotYetSupportedError(f"running_var=None is not supported yet")
101
+ raise NotYetSupportedError("running_var=None is not supported yet")
102
102
 
103
103
  """
104
104
  Only support the cases generated from torch.nn.BatchNorm2d module,
@@ -19,10 +19,8 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
 
21
21
  # To import torch.ops.quantized_decomposed related operator
22
- from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
23
22
  from torch.export import ExportedProgram
24
23
 
25
- from tico.utils import logging
26
24
  from tico.utils.graph import create_node
27
25
  from tico.utils.passes import PassBase, PassResult
28
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
@@ -66,7 +64,6 @@ class DecomposeFakeQuantize(PassBase):
66
64
  super().__init__()
67
65
 
68
66
  def call(self, exported_program: ExportedProgram) -> PassResult:
69
- logger = logging.getLogger(__name__)
70
67
  modified = False
71
68
 
72
69
  gm = exported_program.graph_module
@@ -26,10 +26,8 @@ from torch._export.utils import (
26
26
  )
27
27
 
28
28
  # To import torch.ops.quantized_decomposed related operator
29
- from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
30
29
  from torch.export import ExportedProgram
31
30
 
32
- from tico.utils import logging
33
31
  from tico.utils.graph import create_node
34
32
  from tico.utils.passes import PassBase, PassResult
35
33
  from tico.utils.trace_decorators import (
@@ -246,10 +244,11 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
246
244
  # So, let's remove `mask` from the output.args first.
247
245
  # mask_user(output).args == (dequantize_per_tensor.tensor, mask)
248
246
  if mask:
249
- len(mask) == 1
250
- mask_user = list(mask[0].users.keys())[0]
251
- assert len(mask_user.args) == 1
252
- mask_user.args = ((mask_user.args[0][0],),)
247
+ assert len(mask) == 1
248
+ if len(mask[0].users) > 0:
249
+ mask_user = list(mask[0].users.keys())[0]
250
+ assert len(mask_user.args) == 1
251
+ mask_user.args = ((mask_user.args[0][0],),)
253
252
  modified = True
254
253
  if (
255
254
  node.target
@@ -22,7 +22,6 @@ import torch
22
22
  from torch.export import ExportedProgram
23
23
 
24
24
  from tico.serialize.circle_mapping import extract_shape
25
- from tico.utils import logging
26
25
  from tico.utils.graph import create_node
27
26
  from tico.utils.passes import PassBase, PassResult
28
27
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
@@ -126,8 +125,6 @@ class DecomposeGroupNorm(PassBase):
126
125
  )
127
126
 
128
127
  def call(self, exported_program: ExportedProgram) -> PassResult:
129
- logger = logging.getLogger(__name__)
130
-
131
128
  gm = exported_program.graph_module
132
129
  graph: torch.fx.Graph = gm.graph
133
130
  modified = False
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
20
20
  import torch
21
21
  from torch.export import ExportedProgram
22
22
 
23
- from tico.serialize.circle_graph import extract_shape
23
+ from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
25
  from tico.utils.errors import NotYetSupportedError
26
26
  from tico.utils.graph import create_node
@@ -206,7 +206,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
206
206
 
207
207
  args = ConvTranspose2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
208
208
  input = args.input
209
- padding = args.padding
210
209
  groups = args.groups
211
210
  dilation = args.dilation
212
211
 
@@ -288,13 +287,12 @@ class LegalizePreDefinedLayoutOperators(PassBase):
288
287
  input = args.input
289
288
  weight = args.weight
290
289
  bias = args.bias
291
- eps = args.eps
292
290
 
293
291
  running_mean = args.running_mean
294
292
  running_var = args.running_var
295
293
  use_input_stats = args.use_input_stats
296
294
 
297
- if not (use_input_stats == True):
295
+ if not use_input_stats:
298
296
  raise NotYetSupportedError("Only support use_input_stats is True.")
299
297
  if not isinstance(running_mean, NoneType):
300
298
  raise NotYetSupportedError("Only support running_mean=None")
@@ -350,10 +348,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
350
348
  # max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
351
349
  args = MaxPool2dWithIndicesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
352
350
  input_ = args.input
353
- kernel_size = args.kernel_size
354
- stride = args.stride
355
- padding = args.padding
356
- dilation = args.dilation
357
351
  ceil_mode = args.ceil_mode
358
352
  if ceil_mode:
359
353
  raise NotYetSupportedError("Only support non-ceil model.")
@@ -402,9 +396,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
402
396
  # avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
403
397
  args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
404
398
  input_ = args.input
405
- kernel_size = args.kernel_size
406
- stride = args.stride
407
- padding = args.padding
408
399
  ceil_mode = args.ceil_mode
409
400
  if ceil_mode:
410
401
  raise NotYetSupportedError("Only support non-ceil model.")
@@ -67,7 +67,7 @@ class LowerToResizeNearestNeighbor(PassBase):
67
67
  return None
68
68
  # indices = [None, None, H index, W index]
69
69
  N, C, H, W = indices
70
- if N != None or C != None:
70
+ if N is not None or C is not None:
71
71
  return None
72
72
  if not isinstance(H, torch.fx.Node):
73
73
  return None
@@ -28,7 +28,7 @@ from torch._export.utils import (
28
28
  from torch.export import ExportedProgram
29
29
 
30
30
  from tico.passes import ops
31
- from tico.serialize.circle_graph import extract_shape
31
+ from tico.serialize.circle_mapping import extract_shape
32
32
  from tico.utils import logging
33
33
  from tico.utils.graph import create_node, is_single_value_tensor
34
34
  from tico.utils.passes import PassBase, PassResult
@@ -51,7 +51,7 @@ class MergeConsecutiveCat(PassBase):
51
51
  if not prev_cat.op == "call_function":
52
52
  continue
53
53
 
54
- if not prev_cat.target in ops.aten.cat:
54
+ if prev_cat.target not in ops.aten.cat:
55
55
  continue
56
56
 
57
57
  prev_args = CatArgs(*prev_cat.args, **prev_cat.kwargs) # type: ignore[arg-type]
tico/passes/ops.py CHANGED
@@ -69,10 +69,10 @@ class AtenOps:
69
69
  torch.ops.aten.unsqueeze_copy.default,
70
70
  ]
71
71
  self.view = [
72
- torch.ops.aten.view,
73
72
  torch.ops.aten.view.default,
74
73
  torch.ops.aten.view_copy.default,
75
74
  ]
75
+ self._to_copy = [torch.ops.aten._to_copy.default]
76
76
 
77
77
 
78
78
  aten = AtenOps()
@@ -21,7 +21,9 @@ from tico.utils.utils import is_target_node
21
21
 
22
22
 
23
23
  assert_node_targets = [
24
+ torch.ops.aten._assert_scalar.default,
24
25
  torch.ops.aten._assert_tensor_metadata.default,
26
+ torch.ops.aten.sym_constrain_range_for_size.default, # Related to symbolic shape validation
25
27
  ]
26
28
 
27
29
 
@@ -29,7 +31,7 @@ assert_node_targets = [
29
31
  class RemoveRedundantAssertionNodes(PassBase):
30
32
  """
31
33
  This removes redundant assertion nodes.
32
- - `aten.assert_tensor_meta.default`
34
+ When assertion node is erased, related comparison nodes are also removed by graph.eliminate_dead_code().
33
35
  """
34
36
 
35
37
  def __init__(self):
@@ -12,11 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- import torch.fx
19
- import torch
20
15
  from torch.export import ExportedProgram
21
16
 
22
17
  from tico.passes import ops
@@ -51,7 +46,9 @@ class RemoveRedundantExpand(PassBase):
51
46
  input, size = args.input, args.size
52
47
 
53
48
  input_shape = extract_shape(input)
54
- if list(input_shape) != size:
49
+ output_shape = extract_shape(node)
50
+
51
+ if input_shape != output_shape:
55
52
  continue
56
53
 
57
54
  node.replace_all_uses_with(input, propagate_meta=False)
@@ -90,7 +90,7 @@ class RemoveRedundantReshapePattern1(PassBase):
90
90
  if len(permute.users) != 1:
91
91
  continue
92
92
  permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
93
- permute_input, permute_dims = permute_args.input, permute_args.dims
93
+ permute_dims = permute_args.dims
94
94
  # (1xAxBxC) - `aten.permute` - (1xAxCxB)
95
95
  if permute_dims != [0, 1, 3, 2]:
96
96
  continue
@@ -172,7 +172,7 @@ class RemoveRedundantReshapePattern2(PassBase):
172
172
  if len(permute.users) != 1:
173
173
  continue
174
174
  permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
175
- permute_input, permute_dims = permute_args.input, permute_args.dims
175
+ permute_dims = permute_args.dims
176
176
  # (1xAxBxC) - `aten.permute` - (Bx1xAxC)
177
177
  if permute_dims != [2, 0, 1, 3]:
178
178
  continue
@@ -262,7 +262,7 @@ class RemoveRedundantReshapePattern3(PassBase):
262
262
  continue
263
263
 
264
264
  # add
265
- if not add.target in ops.aten.add:
265
+ if add.target not in ops.aten.add:
266
266
  continue
267
267
  add_args = AddTensorArgs(*add.args, **add.kwargs) # type: ignore[arg-type]
268
268
  reshape_2, reshape_3 = add_args.input, add_args.other
@@ -272,7 +272,7 @@ class RemoveRedundantReshapePattern3(PassBase):
272
272
  # reshape_2
273
273
  if not reshape_2.op == "call_function":
274
274
  continue
275
- if not reshape_2.target in ops.aten.reshape:
275
+ if reshape_2.target not in ops.aten.reshape:
276
276
  continue
277
277
  reshape_2_args = ReshapeArgs(*reshape_2.args, **reshape_2.kwargs) # type: ignore[arg-type]
278
278
  reshape_2_input = reshape_2_args.input
@@ -280,7 +280,7 @@ class RemoveRedundantReshapePattern3(PassBase):
280
280
  # reshape_3
281
281
  if not reshape_3.op == "call_function":
282
282
  continue
283
- if not reshape_3.target in ops.aten.reshape:
283
+ if reshape_3.target not in ops.aten.reshape:
284
284
  continue
285
285
  reshape_3_args = ReshapeArgs(*reshape_3.args, **reshape_3.kwargs) # type: ignore[arg-type]
286
286
  reshape_3_input = reshape_3_args.input
@@ -29,7 +29,7 @@ from torch._export.utils import (
29
29
  from torch.export import ExportedProgram
30
30
 
31
31
  from tico.passes import ops
32
- from tico.serialize.circle_graph import extract_shape
32
+ from tico.serialize.circle_mapping import extract_shape
33
33
  from tico.utils import logging
34
34
  from tico.utils.graph import add_placeholder, create_node, is_single_value_tensor
35
35
  from tico.utils.passes import PassBase, PassResult
@@ -0,0 +1,6 @@
1
+ from tico.quantization.public_interface import convert, prepare
2
+
3
+ __all__ = [
4
+ "convert",
5
+ "prepare",
6
+ ]
@@ -25,7 +25,7 @@ from typing import Optional
25
25
  import torch
26
26
  import torch.nn as nn
27
27
 
28
- from tico.experimental.quantization.algorithm.gptq.quant import quantize, Quantizer
28
+ from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
29
29
 
30
30
  torch.backends.cuda.matmul.allow_tf32 = False
31
31
  torch.backends.cudnn.allow_tf32 = False