tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,294 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Union
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from torch._export.utils import (
22
+ get_buffer,
23
+ get_lifted_tensor_constant,
24
+ is_buffer,
25
+ is_lifted_tensor_constant,
26
+ )
27
+
28
+ # To import torch.ops.quantized_decomposed related operator
29
+ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
30
+ from torch.export import ExportedProgram
31
+
32
+ from tico.utils import logging
33
+ from tico.utils.graph import create_node
34
+ from tico.utils.passes import PassBase, PassResult
35
+ from tico.utils.trace_decorators import (
36
+ trace_const_diff_on_pass,
37
+ trace_graph_diff_on_pass,
38
+ )
39
+ from tico.utils.validate_args_kwargs import FakeQuantizePerTensorTQParamArgs
40
+
41
+
42
+ def get_quant_type(min: int, max: int) -> torch.dtype:
43
+ if min == 0 and max == 15:
44
+ # torch can't represent "uint4".
45
+ # Let's set torch.uint8 and infer dtype with quant_min/quant_max instead.
46
+ return torch.uint8
47
+ if min == 0 and max == 255:
48
+ return torch.uint8
49
+ if min == -32768 and max == 32767:
50
+ return torch.int16
51
+ if min == -32767 and max == 32767:
52
+ return torch.int16
53
+
54
+ raise RuntimeError("Not supported min/max values")
55
+
56
+
57
+ def get_constant_from_tensor(
58
+ node: Union[torch.fx.Node, float], ep: ExportedProgram
59
+ ) -> Union[torch.fx.Node, float]:
60
+ """
61
+ There are some nodes that can do constant folding.
62
+ Case 1. With constant tensors
63
+ Case 2. With `torch.ones.` or `torch.zeros`
64
+
65
+ Please refer to the below `DecomposeFakeQuantizeTensorQParams` docs for the detailed explanations.
66
+ """
67
+ if isinstance(node, float):
68
+ return node
69
+ if is_buffer(ep, node):
70
+ buf = get_buffer(ep, node)
71
+ assert isinstance(buf, torch.Tensor)
72
+ return buf.item()
73
+ elif is_lifted_tensor_constant(ep, node):
74
+ lifted = get_lifted_tensor_constant(ep, node)
75
+ assert isinstance(lifted, torch.Tensor)
76
+ return lifted.item()
77
+ assert isinstance(node.target, torch._ops.OpOverload)
78
+ if node.target.__name__ == "mul.Tensor":
79
+ assert len(node.args) == 2
80
+ x = get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
81
+ y = get_constant_from_tensor(node.args[1], ep) # type: ignore[arg-type]
82
+ return x * y # type: ignore[operator]
83
+ if node.target.__name__ == "zeros.default":
84
+ assert len(node.args) == 1
85
+ assert node.args[0] == [1]
86
+ return 0
87
+ if node.target.__name__ == "ones.default":
88
+ assert len(node.args) == 1
89
+ assert node.args[0] == [1]
90
+ return 1
91
+ if node.target.__name__ == "view.default":
92
+ assert len(node.args) == 2
93
+ tensor, shape = node.args
94
+ assert shape == [-1]
95
+ return get_constant_from_tensor(tensor, ep) # type: ignore[arg-type]
96
+ if node.target.__name__ == "_to_copy.default":
97
+ assert len(node.args) == 1
98
+ return get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
99
+ if node.target.__name__ == "lift_fresh_copy.default":
100
+ assert len(node.args) == 1
101
+ assert isinstance(node.args[0], torch.fx.Node)
102
+ lifted_tensor: torch.fx.Node = node.args[0]
103
+ lifted_tensor_constants = ep.graph_signature.inputs_to_lifted_tensor_constants
104
+ assert lifted_tensor.name in lifted_tensor_constants
105
+ tensor_name = lifted_tensor_constants[lifted_tensor.name]
106
+ value = ep.constants[tensor_name].item()
107
+ return value
108
+ if node.target.__name__ in ["detach.default", "detach_.default"]:
109
+ assert len(node.args) == 1
110
+ return get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
111
+
112
+ raise RuntimeError(f"Not supported node {node.target.__name__}")
113
+
114
+
115
+ @trace_const_diff_on_pass
116
+ @trace_graph_diff_on_pass
117
+ class DecomposeFakeQuantizeTensorQParams(PassBase):
118
+ """
119
+ Decompose fake quantize with tensor QParams operator to quant/dequant operators.
120
+ Otherwise, it can't be converted to the edge IR because fake quantize operator is not Aten Canonical.
121
+
122
+ As of now, we don't support the (de)quantize op that has scale/zp whose dtypes are tensors. They should be scalars.
123
+ But, fake quantize with tensor QParams can be decomposed only when those tensors can be removed by constant foldings.
124
+
125
+ We consider below cases for now.
126
+
127
+ [CASE 1] With constant tensors
128
+
129
+ s = torch.tensor(0.1)
130
+ zp = torch.tensor(0)
131
+ fq_enabled = torch.tensor(True)
132
+ x = torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
133
+ x, s, zp, fq_enabled, 0, 255
134
+ )
135
+
136
+ [Before pass]
137
+
138
+ def forward(self, c_lifted_tensor_0, c_lifted_tensor_1, c_lifted_tensor_2, x):
139
+ lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
140
+ lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1); c_lifted_tensor_1 = None
141
+ lift_fresh_copy_2 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_2); c_lifted_tensor_2 = None
142
+ _fake_quantize_per_tensor_affine_cachemask_tensor_qparams = torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default(x, lift_fresh_copy, lift_fresh_copy_1, lift_fresh_copy_2, quant_min, quant_max); x = lift_fresh_copy = lift_fresh_copy_1 = lift_fresh_copy_2 = None
143
+ getitem = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams[0]
144
+ getitem_1 = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams[1]; _fake_quantize_per_tensor_affine_cachemask_tensor_qparams = None
145
+ return (getitem, getitem_1)
146
+
147
+ [After pass]
148
+
149
+ def forward(self, c_lifted_tensor_0, c_lifted_tensor_1, c_lifted_tensor_2, x):
150
+ lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
151
+ lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1); c_lifted_tensor_1 = None
152
+ quantize_per_tensor_tensor = torch.ops.quantized_decomposed.quantize_per_tensor.tensor(x, lift_fresh_copy, lift_fresh_copy_1, quant_min, quant_max, dtype = ${torch.dtype}); x = None
153
+ dequantize_per_tensor_tensor = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor(quantize_per_tensor_tensor, lift_fresh_copy, lift_fresh_copy_1, quant_min, quant_max, dtype = ${torch.dtype}); quantize_per_tensor_tensor = lift_fresh_copy = lift_fresh_copy_1 = None
154
+ return (dequantize_per_tensor_tensor,)
155
+
156
+ `s` and `zp` are tensors but they can be removed after constant foldings. When they are transformed to fx graph, they are
157
+ lifted as a placeholder and become an argument of the `aten.lift_fresh_copy`.
158
+
159
+
160
+ [CASE 2] With `torch.ones` or `torch.zeros`
161
+
162
+ n_bits=16
163
+ scale=torch.ones([1])
164
+ Qp = 2**(n_bits-1)-1
165
+ scale=scale*(1/Qp)
166
+ z = torch.fake_quantize_per_tensor_affine(x, scale, torch.zeros([1]).int().view(-1), -Qp, Qp)
167
+
168
+ `torch.ones([1])` or `torch.zeros([1])` is just number 1 or 0 but it is transformed to aten IR node, which prevents it from
169
+ being pre-calculated to the number.
170
+
171
+ For example, `n_bits * 1` would be just number 16 when the transformation, but `n_bits * torch.ones([1])`
172
+ would be `aten.Mul(16, aten.full)`, which is the reason why `torch.fake_quantize_per_tensor_affine` is trasnformed to
173
+ `aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams` whose scale/zp argument types are tensors rather than scalars.
174
+
175
+ So, if we manually compute such things like `n_bits * torch.ones([1])`, we can decompose fake quantize with qparam tensors.
176
+
177
+ [Before pass]
178
+
179
+ def forward(self, x):
180
+ ones = torch.ops.aten.ones.default([1], device = device(type='cpu'), pin_memory = False)
181
+ mul = torch.ops.aten.mul.Tensor(ones, 3.051850947599719e-05); ones = None
182
+ zeros = torch.ops.aten.zeros.default([1], device = device(type='cpu'), pin_memory = False)
183
+ _to_copy = torch.ops.aten._to_copy.default(zeros, dtype = torch.int32); zeros = None
184
+ view = torch.ops.aten.view.default(_to_copy, [-1]); _to_copy = None
185
+ ones_1 = torch.ops.aten.ones.default([1], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
186
+ _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default = torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default(x, mul, view, ones_1, -32767, 32767); x = mul = view = ones_1 = None
187
+ getitem = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default[0]; _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default = None
188
+ return (getitem,)
189
+
190
+ [After pass]
191
+ def forward(self, x: "f32[4]"):
192
+ quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(x, 3.051850947599719e-05, 0, -32767, 32767, dtype = torch.int16); x = None
193
+ dequantize_per_tensor_default: "f32[4]" = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, 3.051850947599719e-05, 0, -32767, 32767, dtype = torch.int16); quantize_per_tensor_default = None
194
+ return (dequantize_per_tensor_default,)
195
+ """
196
+
197
+ def __init__(self):
198
+ super().__init__()
199
+
200
+ def call(self, exported_program: ExportedProgram) -> PassResult:
201
+ modified = False
202
+
203
+ gm = exported_program.graph_module
204
+ g = gm.graph
205
+ qd = torch.ops.quantized_decomposed # type: ignore[return]
206
+ for node in gm.graph.nodes:
207
+ if node.op != "call_function":
208
+ continue
209
+ if (
210
+ node.target
211
+ == torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default
212
+ ):
213
+ # tensor, scale, zero_p, fake_quant_enabled, quant_min, quant_max
214
+ # TODO Support `fake_quant_enabled`
215
+ assert len(node.args) == 6
216
+ tensor, s, zp, _, quant_min, quant_max = node.args
217
+ # Get constant tensors
218
+ ep = exported_program
219
+ s_value = get_constant_from_tensor(s, ep)
220
+ zp_value = get_constant_from_tensor(zp, ep)
221
+ # This op has one user: `getitem` for the output.
222
+ # TODO Investigate why the op is generated like this.
223
+ # node.users = {getitem: None}
224
+ get_item, *mask = node.users.keys()
225
+ # assert len(mask) == 0, "Not supported yet."
226
+ quant_kwargs = {
227
+ **node.kwargs,
228
+ **{"dtype": get_quant_type(quant_min, quant_max)},
229
+ }
230
+ with gm.graph.inserting_before(node):
231
+ quant = create_node(
232
+ g,
233
+ qd.quantize_per_tensor.default,
234
+ args=(tensor, s_value, zp_value, quant_min, quant_max),
235
+ kwargs=quant_kwargs,
236
+ origin=node,
237
+ )
238
+ dequant = create_node(
239
+ g,
240
+ qd.dequantize_per_tensor.default,
241
+ args=(quant, *quant.args[1:]),
242
+ kwargs=quant.kwargs,
243
+ )
244
+ get_item.replace_all_uses_with(dequant, propagate_meta=True)
245
+ # If `mask` can be graph output, which prevents `eliminate_dead_code()` from eliminating `mask`.
246
+ # So, let's remove `mask` from the output.args first.
247
+ # mask_user(output).args == (dequantize_per_tensor.tensor, mask)
248
+ if mask:
249
+ len(mask) == 1
250
+ mask_user = list(mask[0].users.keys())[0]
251
+ assert len(mask_user.args) == 1
252
+ mask_user.args = ((mask_user.args[0][0],),)
253
+ modified = True
254
+ if (
255
+ node.target
256
+ == torch.ops.aten.fake_quantize_per_tensor_affine.tensor_qparams
257
+ ):
258
+ fq_args = FakeQuantizePerTensorTQParamArgs(*node.args, **node.kwargs)
259
+ tensor = fq_args.input
260
+ s = fq_args.scale
261
+ zp = fq_args.zero_point
262
+ quant_min = fq_args.quant_min
263
+ quant_max = fq_args.quant_max
264
+
265
+ # Get constant tensors
266
+ ep = exported_program
267
+ s_value = get_constant_from_tensor(s, ep)
268
+ zp_value = get_constant_from_tensor(zp, ep)
269
+ quant_kwargs = {
270
+ **node.kwargs,
271
+ **{"dtype": get_quant_type(quant_min, quant_max)},
272
+ }
273
+ with gm.graph.inserting_before(node):
274
+ quant = create_node(
275
+ g,
276
+ qd.quantize_per_tensor.default,
277
+ args=(tensor, s_value, zp_value, quant_min, quant_max),
278
+ kwargs=quant_kwargs,
279
+ origin=node,
280
+ )
281
+ dequant = create_node(
282
+ g,
283
+ qd.dequantize_per_tensor.default,
284
+ args=(quant, *quant.args[1:]),
285
+ kwargs=quant.kwargs,
286
+ )
287
+ node.replace_all_uses_with(dequant, propagate_meta=True)
288
+ modified = True
289
+
290
+ gm.graph.eliminate_dead_code()
291
+ gm.graph.lint()
292
+ gm.recompile()
293
+
294
+ return PassResult(modified)
@@ -0,0 +1,275 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import operator
17
+ from typing import TYPE_CHECKING
18
+
19
+ if TYPE_CHECKING:
20
+ import torch.fx
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.circle_mapping import extract_shape
25
+ from tico.utils import logging
26
+ from tico.utils.graph import create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.utils import is_target_node
30
+ from tico.utils.validate_args_kwargs import NativeGroupNormArgs, NativeLayerNormArgs
31
+
32
+
33
+ @trace_graph_diff_on_pass
34
+ class DecomposeGroupNorm(PassBase):
35
+ """
36
+ This pass decomposes Group normalization operators.
37
+
38
+ LayerNorm is group=1 Group normalization.
39
+
40
+ [LayerNorm, GroupNorm]
41
+
42
+ Two normalzations result in same nodes but have different normalization shapes.
43
+
44
+ [before]
45
+
46
+ input (tensor, normalized_shape, weight, bias, eps)
47
+ |
48
+ NativeLayerNorm or GroupNorm
49
+ |
50
+ output
51
+
52
+ [after]
53
+
54
+ input
55
+ (tensor)
56
+ |
57
+ reshape
58
+ |
59
+ +------------+
60
+ | |
61
+ mean |
62
+ | |
63
+ reshape |
64
+ | |
65
+ + --->sub<---+
66
+ |
67
+ +-------+
68
+ | |
69
+ pow |
70
+ input | |
71
+ (eps) mean |
72
+ | | |
73
+ +----->add<----+ |
74
+ | | input
75
+ rsqrt | (weight)
76
+ | | | input
77
+ reshape | reshape (bias)
78
+ | | | |
79
+ +----->mul<----+ expand reshape
80
+ | | |
81
+ +----->mul<-----+ expand
82
+ | |
83
+ +------->add<-------+
84
+ |
85
+ reshape
86
+ |
87
+ output
88
+ """
89
+
90
+ def __init__(self):
91
+ super().__init__()
92
+
93
+ def _insert_norm(self, graph, tensor, eps, origin):
94
+ """
95
+ Insert (tensor - mean) / sqrt(var + eps)) into the graph
96
+ and return the normalized tensor node.
97
+ """
98
+ mean = create_node(
99
+ graph,
100
+ torch.ops.aten.mean.dim,
101
+ (tensor, [-1]),
102
+ {"keepdim": True},
103
+ origin=origin,
104
+ )
105
+ deviation = create_node(
106
+ graph, torch.ops.aten.sub.Tensor, (tensor, mean), origin=origin
107
+ )
108
+ squared = create_node(
109
+ graph, torch.ops.aten.pow.Tensor_Scalar, (deviation, 2), origin=origin
110
+ )
111
+ var = create_node(
112
+ graph,
113
+ torch.ops.aten.mean.dim,
114
+ (squared, [-1]),
115
+ {"keepdim": True},
116
+ origin=origin,
117
+ )
118
+ inverse_std = create_node(
119
+ graph,
120
+ torch.ops.aten.rsqrt.default,
121
+ (create_node(graph, torch.ops.aten.add.Tensor, (var, eps), origin=origin),),
122
+ origin=origin,
123
+ )
124
+ return create_node(
125
+ graph, torch.ops.aten.mul.Tensor, (deviation, inverse_std), origin=origin
126
+ )
127
+
128
+ def call(self, exported_program: ExportedProgram) -> PassResult:
129
+ logger = logging.getLogger(__name__)
130
+
131
+ gm = exported_program.graph_module
132
+ graph: torch.fx.Graph = gm.graph
133
+ modified = False
134
+
135
+ for node in graph.nodes:
136
+ if not is_target_node(
137
+ node,
138
+ [
139
+ torch.ops.aten.native_layer_norm.default,
140
+ torch.ops.aten.native_group_norm.default,
141
+ ],
142
+ ):
143
+ continue
144
+
145
+ if node.target == torch.ops.aten.native_layer_norm.default:
146
+ ln_args = NativeLayerNormArgs(*node.args, **node.kwargs)
147
+ x = ln_args.input
148
+ normalized_shape = ln_args.normalized_shape
149
+ weight = ln_args.weight
150
+ bias = ln_args.bias
151
+ eps = ln_args.eps
152
+
153
+ if weight:
154
+ weight_shape = extract_shape(weight)
155
+ assert list(weight_shape) == normalized_shape
156
+ if bias:
157
+ bias_shape = extract_shape(bias)
158
+ assert list(bias_shape) == normalized_shape
159
+
160
+ x_val = x.meta.get("val")
161
+ assert isinstance(x_val, torch.Tensor)
162
+ x_shape = list(x_val.size())
163
+ x_dim = len(x_shape)
164
+ normalized_dim = len(normalized_shape)
165
+ assert x_dim >= normalized_dim
166
+ idx_normalize_start = x_dim - normalized_dim
167
+
168
+ norm_size = math.prod(normalized_shape)
169
+ layer_size = math.prod(x_shape[:idx_normalize_start])
170
+ elif node.target == torch.ops.aten.native_group_norm.default:
171
+ gn_args = NativeGroupNormArgs(*node.args, **node.kwargs)
172
+ x = gn_args.input
173
+ weight = gn_args.weight
174
+ bias = gn_args.bias
175
+ N = gn_args.N
176
+ C = gn_args.C
177
+ HW = gn_args.HxW
178
+ group = gn_args.group
179
+ eps = gn_args.eps
180
+
181
+ x_shape = list(extract_shape(x))
182
+ assert len(x_shape) == 4 or len(x_shape) == 3
183
+ assert x_shape[0] == N
184
+ assert x_shape[1] == C
185
+
186
+ assert C % group == 0
187
+ norm_size = int((C / group) * HW)
188
+ layer_size = N * group
189
+ else:
190
+ assert False, "Unreachable"
191
+
192
+ pack_shape = [layer_size, norm_size]
193
+
194
+ with gm.graph.inserting_before(node):
195
+ # Branch only on whether a reshape is needed; the normalization is shared.
196
+ if norm_size != x_shape[-1]:
197
+ # Pack groups so that the last dimension equals norm_size.
198
+ packed = create_node(
199
+ graph,
200
+ torch.ops.aten.reshape.default,
201
+ (x, pack_shape),
202
+ origin=node,
203
+ )
204
+ normed = self._insert_norm(graph, packed, eps, origin=node)
205
+ # Restore the original shape after normalization.
206
+ layer_norm = create_node(
207
+ graph,
208
+ torch.ops.aten.reshape.default,
209
+ (normed, x_shape),
210
+ origin=node,
211
+ )
212
+ else:
213
+ # The input already has norm_size in the last dimension.
214
+ layer_norm = self._insert_norm(graph, x, eps, origin=node)
215
+
216
+ # weight
217
+ if weight:
218
+ if node.target == torch.ops.aten.native_group_norm.default:
219
+ weight_shape = extract_shape(weight)
220
+ assert weight_shape[0] == C
221
+ reshape_size = [1] * len(x_shape)
222
+ reshape_size[1] = C
223
+ weight = create_node(
224
+ graph,
225
+ torch.ops.aten.view.default,
226
+ (weight, reshape_size),
227
+ origin=node,
228
+ )
229
+ layer_norm = create_node(
230
+ graph,
231
+ torch.ops.aten.mul.Tensor,
232
+ (layer_norm, weight),
233
+ origin=node,
234
+ )
235
+
236
+ # bias
237
+ if bias:
238
+ if node.target == torch.ops.aten.native_group_norm.default:
239
+ bias_shape = extract_shape(bias)
240
+ assert bias_shape[0] == C
241
+ reshape_size = [1] * len(x_shape)
242
+ reshape_size[1] = C
243
+ bias = create_node(
244
+ graph,
245
+ torch.ops.aten.view.default,
246
+ (bias, reshape_size),
247
+ origin=node,
248
+ )
249
+ layer_norm = create_node(
250
+ graph,
251
+ torch.ops.aten.add.Tensor,
252
+ (layer_norm, bias),
253
+ )
254
+ # Reset last node's meta for propagating replacing node's meta.
255
+ layer_norm.meta = {}
256
+
257
+ # NOTE Why select user `getitem` here?
258
+ # `native_layer_norm` and `native_group_norm` requires `getitem`
259
+ # to select the first output and discard the rest unused outputs.
260
+ # To replace those operators, it's necessary to replace the corresponding
261
+ # `getitem` node as well.
262
+ get_item = next(iter(node.users))
263
+ assert (
264
+ get_item.target == operator.getitem
265
+ ), "First user of native_group/layer_norm should be getitem"
266
+
267
+ get_item.replace_all_uses_with(layer_norm, propagate_meta=True)
268
+
269
+ modified = True
270
+
271
+ gm.graph.eliminate_dead_code()
272
+ gm.graph.lint()
273
+ gm.recompile()
274
+
275
+ return PassResult(modified)