tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,270 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Union
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+
22
+ # To import torch.ops.quantized_decomposed related operator
23
+ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
24
+ from torch.export import ExportedProgram
25
+
26
+ from tico.utils import logging
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import (
29
+ trace_const_diff_on_pass,
30
+ trace_graph_diff_on_pass,
31
+ )
32
+ from tico.utils.validate_args_kwargs import FakeQuantizePerTensorTQParamArgs
33
+
34
+
35
+ def get_quant_type(min: int, max: int) -> torch.dtype:
36
+ if min == 0 and max == 255:
37
+ return torch.uint8
38
+ if min == -32768 and max == 32767:
39
+ return torch.int16
40
+ if min == -32767 and max == 32767:
41
+ return torch.int16
42
+
43
+ raise RuntimeError("Not supported min/max values")
44
+
45
+
46
+ def get_constant_from_tensor(
47
+ node: Union[torch.fx.Node, float], ep: ExportedProgram
48
+ ) -> Union[torch.fx.Node, float]:
49
+ """
50
+ There are some nodes that can do constant folding.
51
+ Case 1. With constant tensors
52
+ Case 2. With `torch.ones.` or `torch.zeros`
53
+
54
+ Please refer to the below `DecomposeFakeQuantizeTensorQParams` docs for the detailed explanations.
55
+ """
56
+ if isinstance(node, float):
57
+ return node
58
+ assert isinstance(node.target, torch._ops.OpOverload)
59
+ if node.target.__name__ == "mul.Tensor":
60
+ assert len(node.args) == 2
61
+ x = get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
62
+ y = get_constant_from_tensor(node.args[1], ep) # type: ignore[arg-type]
63
+ return x * y # type: ignore[operator]
64
+ if node.target.__name__ == "zeros.default":
65
+ assert len(node.args) == 1
66
+ assert node.args[0] == [1]
67
+ return 0
68
+ if node.target.__name__ == "ones.default":
69
+ assert len(node.args) == 1
70
+ assert node.args[0] == [1]
71
+ return 1
72
+ if node.target.__name__ == "view.default":
73
+ assert len(node.args) == 2
74
+ tensor, shape = node.args
75
+ assert shape == [-1]
76
+ return get_constant_from_tensor(tensor, ep) # type: ignore[arg-type]
77
+ if node.target.__name__ == "_to_copy.default":
78
+ assert len(node.args) == 1
79
+ return get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
80
+ if node.target.__name__ == "lift_fresh_copy.default":
81
+ assert len(node.args) == 1
82
+ assert isinstance(node.args[0], torch.fx.Node)
83
+ lifted_tensor: torch.fx.Node = node.args[0]
84
+ lifted_tensor_constants = ep.graph_signature.inputs_to_lifted_tensor_constants
85
+ assert lifted_tensor.name in lifted_tensor_constants
86
+ tensor_name = lifted_tensor_constants[lifted_tensor.name]
87
+ value = ep.constants[tensor_name].cpu().detach().numpy()
88
+ return value
89
+ if node.target.__name__ in ["detach.default", "detach_.default"]:
90
+ assert len(node.args) == 1
91
+ return get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
92
+
93
+ raise RuntimeError(f"Not supported node {node.target.__name__}")
94
+
95
+
96
+ @trace_const_diff_on_pass
97
+ @trace_graph_diff_on_pass
98
+ class DecomposeFakeQuantizeTensorQParams(PassBase):
99
+ """
100
+ Decompose fake quantize with tensor QParams operator to quant/dequant operators.
101
+ Otherwise, it can't be converted to the edge IR because fake quantize operator is not Aten Canonical.
102
+
103
+ As of now, we don't support the (de)quantize op that has scale/zp whose dtypes are tensors. They should be scalars.
104
+ But, fake quantize with tensor QParams can be decomposed only when those tensors can be removed by constant foldings.
105
+
106
+ We consider below cases for now.
107
+
108
+ [CASE 1] With constant tensors
109
+
110
+ s = torch.tensor(0.1)
111
+ zp = torch.tensor(0)
112
+ fq_enabled = torch.tensor(True)
113
+ x = torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
114
+ x, s, zp, fq_enabled, 0, 255
115
+ )
116
+
117
+ [Before pass]
118
+
119
+ def forward(self, c_lifted_tensor_0, c_lifted_tensor_1, c_lifted_tensor_2, x):
120
+ lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
121
+ lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1); c_lifted_tensor_1 = None
122
+ lift_fresh_copy_2 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_2); c_lifted_tensor_2 = None
123
+ _fake_quantize_per_tensor_affine_cachemask_tensor_qparams = torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default(x, lift_fresh_copy, lift_fresh_copy_1, lift_fresh_copy_2, quant_min, quant_max); x = lift_fresh_copy = lift_fresh_copy_1 = lift_fresh_copy_2 = None
124
+ getitem = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams[0]
125
+ getitem_1 = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams[1]; _fake_quantize_per_tensor_affine_cachemask_tensor_qparams = None
126
+ return (getitem, getitem_1)
127
+
128
+ [After pass]
129
+
130
+ def forward(self, c_lifted_tensor_0, c_lifted_tensor_1, c_lifted_tensor_2, x):
131
+ lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
132
+ lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1); c_lifted_tensor_1 = None
133
+ quantize_per_tensor_tensor = torch.ops.quantized_decomposed.quantize_per_tensor.tensor(x, lift_fresh_copy, lift_fresh_copy_1, quant_min, quant_max, dtype = ${torch.dtype}); x = None
134
+ dequantize_per_tensor_tensor = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor(quantize_per_tensor_tensor, lift_fresh_copy, lift_fresh_copy_1, quant_min, quant_max, dtype = ${torch.dtype}); quantize_per_tensor_tensor = lift_fresh_copy = lift_fresh_copy_1 = None
135
+ return (dequantize_per_tensor_tensor,)
136
+
137
+ `s` and `zp` are tensors but they can be removed after constant foldings. When they are transformed to fx graph, they are
138
+ lifted as a placeholder and become an argument of the `aten.lift_fresh_copy`.
139
+
140
+
141
+ [CASE 2] With `torch.ones` or `torch.zeros`
142
+
143
+ n_bits=16
144
+ scale=torch.ones([1])
145
+ Qp = 2**(n_bits-1)-1
146
+ scale=scale*(1/Qp)
147
+ z = torch.fake_quantize_per_tensor_affine(x, scale, torch.zeros([1]).int().view(-1), -Qp, Qp)
148
+
149
+ `torch.ones([1])` or `torch.zeros([1])` is just number 1 or 0 but it is transformed to aten IR node, which prevents it from
150
+ being pre-calculated to the number.
151
+
152
+ For example, `n_bits * 1` would be just number 16 when the transformation, but `n_bits * torch.ones([1])`
153
+ would be `aten.Mul(16, aten.full)`, which is the reason why `torch.fake_quantize_per_tensor_affine` is trasnformed to
154
+ `aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams` whose scale/zp argument types are tensors rather than scalars.
155
+
156
+ So, if we manually compute such things like `n_bits * torch.ones([1])`, we can decompose fake quantize with qparam tensors.
157
+
158
+ [Before pass]
159
+
160
+ def forward(self, x):
161
+ ones = torch.ops.aten.ones.default([1], device = device(type='cpu'), pin_memory = False)
162
+ mul = torch.ops.aten.mul.Tensor(ones, 3.051850947599719e-05); ones = None
163
+ zeros = torch.ops.aten.zeros.default([1], device = device(type='cpu'), pin_memory = False)
164
+ _to_copy = torch.ops.aten._to_copy.default(zeros, dtype = torch.int32); zeros = None
165
+ view = torch.ops.aten.view.default(_to_copy, [-1]); _to_copy = None
166
+ ones_1 = torch.ops.aten.ones.default([1], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
167
+ _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default = torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default(x, mul, view, ones_1, -32767, 32767); x = mul = view = ones_1 = None
168
+ getitem = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default[0]; _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default = None
169
+ return (getitem,)
170
+
171
+ [After pass]
172
+ def forward(self, x: "f32[4]"):
173
+ quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(x, 3.051850947599719e-05, 0, -32767, 32767, dtype = torch.int16); x = None
174
+ dequantize_per_tensor_default: "f32[4]" = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, 3.051850947599719e-05, 0, -32767, 32767, dtype = torch.int16); quantize_per_tensor_default = None
175
+ return (dequantize_per_tensor_default,)
176
+ """
177
+
178
+ def __init__(self):
179
+ super().__init__()
180
+
181
+ def call(self, exported_program: ExportedProgram) -> PassResult:
182
+ modified = False
183
+
184
+ gm = exported_program.graph_module
185
+ qd = torch.ops.quantized_decomposed # type: ignore[return]
186
+ for node in gm.graph.nodes:
187
+ if (
188
+ node.op == "call_function"
189
+ and node.target
190
+ == torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default
191
+ ):
192
+ # tensor, scale, zero_p, fake_quant_enabled, quant_min, quant_max
193
+ # TODO Support `fake_quant_enabled`
194
+ assert len(node.args) == 6
195
+ tensor, s, zp, _, quant_min, quant_max = node.args
196
+ # Get constant tensors
197
+ ep = exported_program
198
+ s_value = get_constant_from_tensor(s, ep)
199
+ zp_value = get_constant_from_tensor(zp, ep)
200
+ # This op has one user: `getitem` for the output.
201
+ # TODO Investigate why the op is generated like this.
202
+ # node.users = {getitem: None}
203
+ get_item, *mask = node.users.keys()
204
+ # assert len(mask) == 0, "Not supported yet."
205
+ quant_kwargs = {
206
+ **node.kwargs,
207
+ **{"dtype": get_quant_type(quant_min, quant_max)},
208
+ }
209
+ with gm.graph.inserting_before(node):
210
+ quant = gm.graph.call_function(
211
+ qd.quantize_per_tensor.default,
212
+ args=(tensor, s_value, zp_value, quant_min, quant_max),
213
+ kwargs=quant_kwargs,
214
+ )
215
+ dequant = gm.graph.call_function(
216
+ qd.dequantize_per_tensor.default,
217
+ args=(quant, *quant.args[1:]),
218
+ kwargs=quant.kwargs,
219
+ )
220
+ # Not set meta for propagating replacing get_item's meta.
221
+ get_item.replace_all_uses_with(dequant, propagate_meta=True)
222
+ # If `mask` can be graph output, which prevents `eliminate_dead_code()` from eliminating `mask`.
223
+ # So, let's remove `mask` from the output.args first.
224
+ # mask_user(output).args == (dequantize_per_tensor.tensor, mask)
225
+ if mask:
226
+ len(mask) == 1
227
+ mask_user = list(mask[0].users.keys())[0]
228
+ assert len(mask_user.args) == 1
229
+ mask_user.args = ((mask_user.args[0][0],),)
230
+ modified = True
231
+ if (
232
+ node.op == "call_function"
233
+ and node.target
234
+ == torch.ops.aten.fake_quantize_per_tensor_affine.tensor_qparams
235
+ ):
236
+ fq_args = FakeQuantizePerTensorTQParamArgs(*node.args, **node.kwargs)
237
+ tensor = fq_args.input
238
+ s = fq_args.scale
239
+ zp = fq_args.zero_point
240
+ quant_min = fq_args.quant_min
241
+ quant_max = fq_args.quant_max
242
+
243
+ # Get constant tensors
244
+ ep = exported_program
245
+ s_value = get_constant_from_tensor(s, ep)
246
+ zp_value = get_constant_from_tensor(zp, ep)
247
+ quant_kwargs = {
248
+ **node.kwargs,
249
+ **{"dtype": get_quant_type(quant_min, quant_max)},
250
+ }
251
+ with gm.graph.inserting_before(node):
252
+ quant = gm.graph.call_function(
253
+ qd.quantize_per_tensor.default,
254
+ args=(tensor, s_value, zp_value, quant_min, quant_max),
255
+ kwargs=quant_kwargs,
256
+ )
257
+ dequant = gm.graph.call_function(
258
+ qd.dequantize_per_tensor.default,
259
+ args=(quant, *quant.args[1:]),
260
+ kwargs=quant.kwargs,
261
+ )
262
+ # Not set meta for propagating replacing get_item's meta.
263
+ node.replace_all_uses_with(dequant, propagate_meta=True)
264
+ modified = True
265
+
266
+ gm.graph.eliminate_dead_code()
267
+ gm.graph.lint()
268
+ gm.recompile()
269
+
270
+ return PassResult(modified)
@@ -0,0 +1,258 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import operator
17
+ from typing import TYPE_CHECKING
18
+
19
+ if TYPE_CHECKING:
20
+ import torch.fx
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.circle_mapping import extract_shape
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import NativeGroupNormArgs, NativeLayerNormArgs
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class DecomposeGroupNorm(PassBase):
33
+ """
34
+ This pass decomposes Group normalization operators.
35
+
36
+ LayerNorm is group=1 Group normalization.
37
+
38
+ [LayerNorm, GroupNorm]
39
+
40
+ Two normalzations result in same nodes but have different normalization shapes.
41
+
42
+ [before]
43
+
44
+ input (tensor, normalized_shape, weight, bias, eps)
45
+ |
46
+ NativeLayerNorm or GroupNorm
47
+ |
48
+ output
49
+
50
+ [after]
51
+
52
+ input
53
+ (tensor)
54
+ |
55
+ reshape
56
+ |
57
+ +------------+
58
+ | |
59
+ mean |
60
+ | |
61
+ reshape |
62
+ | |
63
+ + --->sub<---+
64
+ |
65
+ +-------+
66
+ | |
67
+ pow |
68
+ input | |
69
+ (eps) mean |
70
+ | | |
71
+ +----->add<----+ |
72
+ | | input
73
+ rsqrt | (weight)
74
+ | | | input
75
+ reshape | reshape (bias)
76
+ | | | |
77
+ +----->mul<----+ expand reshape
78
+ | | |
79
+ +----->mul<-----+ expand
80
+ | |
81
+ +------->add<-------+
82
+ |
83
+ reshape
84
+ |
85
+ output
86
+ """
87
+
88
+ def __init__(self):
89
+ super().__init__()
90
+
91
+ def call(self, exported_program: ExportedProgram) -> PassResult:
92
+ logger = logging.getLogger(__name__)
93
+
94
+ gm = exported_program.graph_module
95
+ graph: torch.fx.Graph = gm.graph
96
+ modified = False
97
+
98
+ for node in graph.nodes:
99
+ if node.op != "call_function":
100
+ continue
101
+
102
+ if node.target not in [
103
+ torch.ops.aten.native_layer_norm.default,
104
+ torch.ops.aten.native_group_norm.default,
105
+ ]:
106
+ continue
107
+
108
+ if node.target == torch.ops.aten.native_layer_norm.default:
109
+ ln_args = NativeLayerNormArgs(*node.args, **node.kwargs)
110
+ x = ln_args.input
111
+ normalized_shape = ln_args.normalized_shape
112
+ weight = ln_args.weight
113
+ bias = ln_args.bias
114
+ eps = ln_args.eps
115
+
116
+ if weight:
117
+ weight_shape = extract_shape(weight)
118
+ assert list(weight_shape) == normalized_shape
119
+ if bias:
120
+ bias_shape = extract_shape(bias)
121
+ assert list(bias_shape) == normalized_shape
122
+
123
+ x_val = x.meta.get("val")
124
+ assert isinstance(x_val, torch.Tensor)
125
+ x_shape = list(x_val.size())
126
+ x_dim = len(x_shape)
127
+ normalized_dim = len(normalized_shape)
128
+ assert x_dim >= normalized_dim
129
+ idx_normalize_start = x_dim - normalized_dim
130
+
131
+ norm_size = math.prod(normalized_shape)
132
+ layer_size = math.prod(x_shape[:idx_normalize_start])
133
+ elif node.target == torch.ops.aten.native_group_norm.default:
134
+ gn_args = NativeGroupNormArgs(*node.args, **node.kwargs)
135
+ x = gn_args.input
136
+ weight = gn_args.weight
137
+ bias = gn_args.bias
138
+ N = gn_args.N
139
+ C = gn_args.C
140
+ HW = gn_args.HxW
141
+ group = gn_args.group
142
+ eps = gn_args.eps
143
+
144
+ x_shape = list(extract_shape(x))
145
+ assert len(x_shape) == 4 or len(x_shape) == 3
146
+ assert x_shape[0] == N
147
+ assert x_shape[1] == C
148
+
149
+ assert C % group == 0
150
+ norm_size = int((C / group) * HW)
151
+ layer_size = N * group
152
+ else:
153
+ assert False, "Unreachable"
154
+
155
+ pack_shape = [layer_size, norm_size]
156
+
157
+ with gm.graph.inserting_before(node):
158
+ layer = graph.call_function(
159
+ # Sometimes, `x` has a stride for NHWC, which can't be reshaped with `aten.view.default`.
160
+ # TODO Find out how to process such case properly.
161
+ torch.ops.aten.reshape.default,
162
+ (x, pack_shape),
163
+ )
164
+ layer_mean = graph.call_function(
165
+ torch.ops.aten.mean.dim,
166
+ (layer, [-1]),
167
+ )
168
+ layer_mean_reshape = graph.call_function(
169
+ torch.ops.aten.view.default,
170
+ (layer_mean, [layer_size, 1]),
171
+ )
172
+ layer_deviation = graph.call_function(
173
+ torch.ops.aten.sub.Tensor,
174
+ (layer, layer_mean_reshape),
175
+ )
176
+ layer_sqr_diff = graph.call_function(
177
+ torch.ops.aten.pow.Tensor_Scalar,
178
+ (layer_deviation, 2),
179
+ )
180
+ var = graph.call_function(
181
+ torch.ops.aten.mean.dim,
182
+ (layer_sqr_diff, [-1]),
183
+ )
184
+ var_eps = graph.call_function(
185
+ torch.ops.aten.add.Tensor,
186
+ (var, eps),
187
+ )
188
+ rstd = graph.call_function(
189
+ torch.ops.aten.rsqrt.default,
190
+ (var_eps,),
191
+ )
192
+ rstd_reshape = graph.call_function(
193
+ torch.ops.aten.view.default,
194
+ (rstd, [layer_size, 1]),
195
+ )
196
+ layer_norm = graph.call_function(
197
+ torch.ops.aten.mul.Tensor,
198
+ (layer_deviation, rstd_reshape),
199
+ )
200
+ layer_norm = graph.call_function(
201
+ torch.ops.aten.view.default,
202
+ (layer_norm, x_shape),
203
+ )
204
+
205
+ # weight
206
+ if weight:
207
+ if node.target == torch.ops.aten.native_group_norm.default:
208
+ weight_shape = extract_shape(weight)
209
+ assert weight_shape[0] == C
210
+ reshape_size = [1] * len(x_shape)
211
+ reshape_size[1] = C
212
+ weight = graph.call_function(
213
+ torch.ops.aten.view.default,
214
+ (weight, reshape_size),
215
+ )
216
+ layer_norm = graph.call_function(
217
+ torch.ops.aten.mul.Tensor,
218
+ (layer_norm, weight),
219
+ )
220
+
221
+ # bias
222
+ if bias:
223
+ if node.target == torch.ops.aten.native_group_norm.default:
224
+ bias_shape = extract_shape(bias)
225
+ assert bias_shape[0] == C
226
+ reshape_size = [1] * len(x_shape)
227
+ reshape_size[1] = C
228
+ bias = graph.call_function(
229
+ torch.ops.aten.view.default,
230
+ (bias, reshape_size),
231
+ )
232
+ layer_norm = graph.call_function(
233
+ torch.ops.aten.add.Tensor,
234
+ (layer_norm, bias),
235
+ )
236
+
237
+ # Reset last node's meta for propagating replacing node's meta.
238
+ layer_norm.meta = {}
239
+
240
+ # NOTE Why select user `getitem` here?
241
+ # `native_layer_norm` and `native_group_norm` requires `getitem`
242
+ # to select the first output and discard the rest unused outputs.
243
+ # To replace those operators, it's necessary to replace the corresponding
244
+ # `getitem` node as well.
245
+ get_item = next(iter(node.users))
246
+ assert (
247
+ get_item.target == operator.getitem
248
+ ), "First user of native_group/layer_norm should be getitem"
249
+
250
+ get_item.replace_all_uses_with(layer_norm, propagate_meta=True)
251
+
252
+ modified = True
253
+
254
+ gm.graph.eliminate_dead_code()
255
+ gm.graph.lint()
256
+ gm.recompile()
257
+
258
+ return PassResult(modified)