tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,270 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING, Union
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
|
22
|
+
# To import torch.ops.quantized_decomposed related operator
|
23
|
+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
24
|
+
from torch.export import ExportedProgram
|
25
|
+
|
26
|
+
from tico.utils import logging
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
28
|
+
from tico.utils.trace_decorators import (
|
29
|
+
trace_const_diff_on_pass,
|
30
|
+
trace_graph_diff_on_pass,
|
31
|
+
)
|
32
|
+
from tico.utils.validate_args_kwargs import FakeQuantizePerTensorTQParamArgs
|
33
|
+
|
34
|
+
|
35
|
+
def get_quant_type(min: int, max: int) -> torch.dtype:
|
36
|
+
if min == 0 and max == 255:
|
37
|
+
return torch.uint8
|
38
|
+
if min == -32768 and max == 32767:
|
39
|
+
return torch.int16
|
40
|
+
if min == -32767 and max == 32767:
|
41
|
+
return torch.int16
|
42
|
+
|
43
|
+
raise RuntimeError("Not supported min/max values")
|
44
|
+
|
45
|
+
|
46
|
+
def get_constant_from_tensor(
|
47
|
+
node: Union[torch.fx.Node, float], ep: ExportedProgram
|
48
|
+
) -> Union[torch.fx.Node, float]:
|
49
|
+
"""
|
50
|
+
There are some nodes that can do constant folding.
|
51
|
+
Case 1. With constant tensors
|
52
|
+
Case 2. With `torch.ones.` or `torch.zeros`
|
53
|
+
|
54
|
+
Please refer to the below `DecomposeFakeQuantizeTensorQParams` docs for the detailed explanations.
|
55
|
+
"""
|
56
|
+
if isinstance(node, float):
|
57
|
+
return node
|
58
|
+
assert isinstance(node.target, torch._ops.OpOverload)
|
59
|
+
if node.target.__name__ == "mul.Tensor":
|
60
|
+
assert len(node.args) == 2
|
61
|
+
x = get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
|
62
|
+
y = get_constant_from_tensor(node.args[1], ep) # type: ignore[arg-type]
|
63
|
+
return x * y # type: ignore[operator]
|
64
|
+
if node.target.__name__ == "zeros.default":
|
65
|
+
assert len(node.args) == 1
|
66
|
+
assert node.args[0] == [1]
|
67
|
+
return 0
|
68
|
+
if node.target.__name__ == "ones.default":
|
69
|
+
assert len(node.args) == 1
|
70
|
+
assert node.args[0] == [1]
|
71
|
+
return 1
|
72
|
+
if node.target.__name__ == "view.default":
|
73
|
+
assert len(node.args) == 2
|
74
|
+
tensor, shape = node.args
|
75
|
+
assert shape == [-1]
|
76
|
+
return get_constant_from_tensor(tensor, ep) # type: ignore[arg-type]
|
77
|
+
if node.target.__name__ == "_to_copy.default":
|
78
|
+
assert len(node.args) == 1
|
79
|
+
return get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
|
80
|
+
if node.target.__name__ == "lift_fresh_copy.default":
|
81
|
+
assert len(node.args) == 1
|
82
|
+
assert isinstance(node.args[0], torch.fx.Node)
|
83
|
+
lifted_tensor: torch.fx.Node = node.args[0]
|
84
|
+
lifted_tensor_constants = ep.graph_signature.inputs_to_lifted_tensor_constants
|
85
|
+
assert lifted_tensor.name in lifted_tensor_constants
|
86
|
+
tensor_name = lifted_tensor_constants[lifted_tensor.name]
|
87
|
+
value = ep.constants[tensor_name].cpu().detach().numpy()
|
88
|
+
return value
|
89
|
+
if node.target.__name__ in ["detach.default", "detach_.default"]:
|
90
|
+
assert len(node.args) == 1
|
91
|
+
return get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
|
92
|
+
|
93
|
+
raise RuntimeError(f"Not supported node {node.target.__name__}")
|
94
|
+
|
95
|
+
|
96
|
+
@trace_const_diff_on_pass
|
97
|
+
@trace_graph_diff_on_pass
|
98
|
+
class DecomposeFakeQuantizeTensorQParams(PassBase):
|
99
|
+
"""
|
100
|
+
Decompose fake quantize with tensor QParams operator to quant/dequant operators.
|
101
|
+
Otherwise, it can't be converted to the edge IR because fake quantize operator is not Aten Canonical.
|
102
|
+
|
103
|
+
As of now, we don't support the (de)quantize op that has scale/zp whose dtypes are tensors. They should be scalars.
|
104
|
+
But, fake quantize with tensor QParams can be decomposed only when those tensors can be removed by constant foldings.
|
105
|
+
|
106
|
+
We consider below cases for now.
|
107
|
+
|
108
|
+
[CASE 1] With constant tensors
|
109
|
+
|
110
|
+
s = torch.tensor(0.1)
|
111
|
+
zp = torch.tensor(0)
|
112
|
+
fq_enabled = torch.tensor(True)
|
113
|
+
x = torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
|
114
|
+
x, s, zp, fq_enabled, 0, 255
|
115
|
+
)
|
116
|
+
|
117
|
+
[Before pass]
|
118
|
+
|
119
|
+
def forward(self, c_lifted_tensor_0, c_lifted_tensor_1, c_lifted_tensor_2, x):
|
120
|
+
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
121
|
+
lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1); c_lifted_tensor_1 = None
|
122
|
+
lift_fresh_copy_2 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_2); c_lifted_tensor_2 = None
|
123
|
+
_fake_quantize_per_tensor_affine_cachemask_tensor_qparams = torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default(x, lift_fresh_copy, lift_fresh_copy_1, lift_fresh_copy_2, quant_min, quant_max); x = lift_fresh_copy = lift_fresh_copy_1 = lift_fresh_copy_2 = None
|
124
|
+
getitem = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams[0]
|
125
|
+
getitem_1 = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams[1]; _fake_quantize_per_tensor_affine_cachemask_tensor_qparams = None
|
126
|
+
return (getitem, getitem_1)
|
127
|
+
|
128
|
+
[After pass]
|
129
|
+
|
130
|
+
def forward(self, c_lifted_tensor_0, c_lifted_tensor_1, c_lifted_tensor_2, x):
|
131
|
+
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
132
|
+
lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1); c_lifted_tensor_1 = None
|
133
|
+
quantize_per_tensor_tensor = torch.ops.quantized_decomposed.quantize_per_tensor.tensor(x, lift_fresh_copy, lift_fresh_copy_1, quant_min, quant_max, dtype = ${torch.dtype}); x = None
|
134
|
+
dequantize_per_tensor_tensor = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor(quantize_per_tensor_tensor, lift_fresh_copy, lift_fresh_copy_1, quant_min, quant_max, dtype = ${torch.dtype}); quantize_per_tensor_tensor = lift_fresh_copy = lift_fresh_copy_1 = None
|
135
|
+
return (dequantize_per_tensor_tensor,)
|
136
|
+
|
137
|
+
`s` and `zp` are tensors but they can be removed after constant foldings. When they are transformed to fx graph, they are
|
138
|
+
lifted as a placeholder and become an argument of the `aten.lift_fresh_copy`.
|
139
|
+
|
140
|
+
|
141
|
+
[CASE 2] With `torch.ones` or `torch.zeros`
|
142
|
+
|
143
|
+
n_bits=16
|
144
|
+
scale=torch.ones([1])
|
145
|
+
Qp = 2**(n_bits-1)-1
|
146
|
+
scale=scale*(1/Qp)
|
147
|
+
z = torch.fake_quantize_per_tensor_affine(x, scale, torch.zeros([1]).int().view(-1), -Qp, Qp)
|
148
|
+
|
149
|
+
`torch.ones([1])` or `torch.zeros([1])` is just number 1 or 0 but it is transformed to aten IR node, which prevents it from
|
150
|
+
being pre-calculated to the number.
|
151
|
+
|
152
|
+
For example, `n_bits * 1` would be just number 16 when the transformation, but `n_bits * torch.ones([1])`
|
153
|
+
would be `aten.Mul(16, aten.full)`, which is the reason why `torch.fake_quantize_per_tensor_affine` is trasnformed to
|
154
|
+
`aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams` whose scale/zp argument types are tensors rather than scalars.
|
155
|
+
|
156
|
+
So, if we manually compute such things like `n_bits * torch.ones([1])`, we can decompose fake quantize with qparam tensors.
|
157
|
+
|
158
|
+
[Before pass]
|
159
|
+
|
160
|
+
def forward(self, x):
|
161
|
+
ones = torch.ops.aten.ones.default([1], device = device(type='cpu'), pin_memory = False)
|
162
|
+
mul = torch.ops.aten.mul.Tensor(ones, 3.051850947599719e-05); ones = None
|
163
|
+
zeros = torch.ops.aten.zeros.default([1], device = device(type='cpu'), pin_memory = False)
|
164
|
+
_to_copy = torch.ops.aten._to_copy.default(zeros, dtype = torch.int32); zeros = None
|
165
|
+
view = torch.ops.aten.view.default(_to_copy, [-1]); _to_copy = None
|
166
|
+
ones_1 = torch.ops.aten.ones.default([1], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
|
167
|
+
_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default = torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default(x, mul, view, ones_1, -32767, 32767); x = mul = view = ones_1 = None
|
168
|
+
getitem = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default[0]; _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default = None
|
169
|
+
return (getitem,)
|
170
|
+
|
171
|
+
[After pass]
|
172
|
+
def forward(self, x: "f32[4]"):
|
173
|
+
quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(x, 3.051850947599719e-05, 0, -32767, 32767, dtype = torch.int16); x = None
|
174
|
+
dequantize_per_tensor_default: "f32[4]" = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, 3.051850947599719e-05, 0, -32767, 32767, dtype = torch.int16); quantize_per_tensor_default = None
|
175
|
+
return (dequantize_per_tensor_default,)
|
176
|
+
"""
|
177
|
+
|
178
|
+
def __init__(self):
|
179
|
+
super().__init__()
|
180
|
+
|
181
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
182
|
+
modified = False
|
183
|
+
|
184
|
+
gm = exported_program.graph_module
|
185
|
+
qd = torch.ops.quantized_decomposed # type: ignore[return]
|
186
|
+
for node in gm.graph.nodes:
|
187
|
+
if (
|
188
|
+
node.op == "call_function"
|
189
|
+
and node.target
|
190
|
+
== torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default
|
191
|
+
):
|
192
|
+
# tensor, scale, zero_p, fake_quant_enabled, quant_min, quant_max
|
193
|
+
# TODO Support `fake_quant_enabled`
|
194
|
+
assert len(node.args) == 6
|
195
|
+
tensor, s, zp, _, quant_min, quant_max = node.args
|
196
|
+
# Get constant tensors
|
197
|
+
ep = exported_program
|
198
|
+
s_value = get_constant_from_tensor(s, ep)
|
199
|
+
zp_value = get_constant_from_tensor(zp, ep)
|
200
|
+
# This op has one user: `getitem` for the output.
|
201
|
+
# TODO Investigate why the op is generated like this.
|
202
|
+
# node.users = {getitem: None}
|
203
|
+
get_item, *mask = node.users.keys()
|
204
|
+
# assert len(mask) == 0, "Not supported yet."
|
205
|
+
quant_kwargs = {
|
206
|
+
**node.kwargs,
|
207
|
+
**{"dtype": get_quant_type(quant_min, quant_max)},
|
208
|
+
}
|
209
|
+
with gm.graph.inserting_before(node):
|
210
|
+
quant = gm.graph.call_function(
|
211
|
+
qd.quantize_per_tensor.default,
|
212
|
+
args=(tensor, s_value, zp_value, quant_min, quant_max),
|
213
|
+
kwargs=quant_kwargs,
|
214
|
+
)
|
215
|
+
dequant = gm.graph.call_function(
|
216
|
+
qd.dequantize_per_tensor.default,
|
217
|
+
args=(quant, *quant.args[1:]),
|
218
|
+
kwargs=quant.kwargs,
|
219
|
+
)
|
220
|
+
# Not set meta for propagating replacing get_item's meta.
|
221
|
+
get_item.replace_all_uses_with(dequant, propagate_meta=True)
|
222
|
+
# If `mask` can be graph output, which prevents `eliminate_dead_code()` from eliminating `mask`.
|
223
|
+
# So, let's remove `mask` from the output.args first.
|
224
|
+
# mask_user(output).args == (dequantize_per_tensor.tensor, mask)
|
225
|
+
if mask:
|
226
|
+
len(mask) == 1
|
227
|
+
mask_user = list(mask[0].users.keys())[0]
|
228
|
+
assert len(mask_user.args) == 1
|
229
|
+
mask_user.args = ((mask_user.args[0][0],),)
|
230
|
+
modified = True
|
231
|
+
if (
|
232
|
+
node.op == "call_function"
|
233
|
+
and node.target
|
234
|
+
== torch.ops.aten.fake_quantize_per_tensor_affine.tensor_qparams
|
235
|
+
):
|
236
|
+
fq_args = FakeQuantizePerTensorTQParamArgs(*node.args, **node.kwargs)
|
237
|
+
tensor = fq_args.input
|
238
|
+
s = fq_args.scale
|
239
|
+
zp = fq_args.zero_point
|
240
|
+
quant_min = fq_args.quant_min
|
241
|
+
quant_max = fq_args.quant_max
|
242
|
+
|
243
|
+
# Get constant tensors
|
244
|
+
ep = exported_program
|
245
|
+
s_value = get_constant_from_tensor(s, ep)
|
246
|
+
zp_value = get_constant_from_tensor(zp, ep)
|
247
|
+
quant_kwargs = {
|
248
|
+
**node.kwargs,
|
249
|
+
**{"dtype": get_quant_type(quant_min, quant_max)},
|
250
|
+
}
|
251
|
+
with gm.graph.inserting_before(node):
|
252
|
+
quant = gm.graph.call_function(
|
253
|
+
qd.quantize_per_tensor.default,
|
254
|
+
args=(tensor, s_value, zp_value, quant_min, quant_max),
|
255
|
+
kwargs=quant_kwargs,
|
256
|
+
)
|
257
|
+
dequant = gm.graph.call_function(
|
258
|
+
qd.dequantize_per_tensor.default,
|
259
|
+
args=(quant, *quant.args[1:]),
|
260
|
+
kwargs=quant.kwargs,
|
261
|
+
)
|
262
|
+
# Not set meta for propagating replacing get_item's meta.
|
263
|
+
node.replace_all_uses_with(dequant, propagate_meta=True)
|
264
|
+
modified = True
|
265
|
+
|
266
|
+
gm.graph.eliminate_dead_code()
|
267
|
+
gm.graph.lint()
|
268
|
+
gm.recompile()
|
269
|
+
|
270
|
+
return PassResult(modified)
|
@@ -0,0 +1,258 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
import operator
|
17
|
+
from typing import TYPE_CHECKING
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
import torch.fx
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.circle_mapping import extract_shape
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.validate_args_kwargs import NativeGroupNormArgs, NativeLayerNormArgs
|
29
|
+
|
30
|
+
|
31
|
+
@trace_graph_diff_on_pass
|
32
|
+
class DecomposeGroupNorm(PassBase):
|
33
|
+
"""
|
34
|
+
This pass decomposes Group normalization operators.
|
35
|
+
|
36
|
+
LayerNorm is group=1 Group normalization.
|
37
|
+
|
38
|
+
[LayerNorm, GroupNorm]
|
39
|
+
|
40
|
+
Two normalzations result in same nodes but have different normalization shapes.
|
41
|
+
|
42
|
+
[before]
|
43
|
+
|
44
|
+
input (tensor, normalized_shape, weight, bias, eps)
|
45
|
+
|
|
46
|
+
NativeLayerNorm or GroupNorm
|
47
|
+
|
|
48
|
+
output
|
49
|
+
|
50
|
+
[after]
|
51
|
+
|
52
|
+
input
|
53
|
+
(tensor)
|
54
|
+
|
|
55
|
+
reshape
|
56
|
+
|
|
57
|
+
+------------+
|
58
|
+
| |
|
59
|
+
mean |
|
60
|
+
| |
|
61
|
+
reshape |
|
62
|
+
| |
|
63
|
+
+ --->sub<---+
|
64
|
+
|
|
65
|
+
+-------+
|
66
|
+
| |
|
67
|
+
pow |
|
68
|
+
input | |
|
69
|
+
(eps) mean |
|
70
|
+
| | |
|
71
|
+
+----->add<----+ |
|
72
|
+
| | input
|
73
|
+
rsqrt | (weight)
|
74
|
+
| | | input
|
75
|
+
reshape | reshape (bias)
|
76
|
+
| | | |
|
77
|
+
+----->mul<----+ expand reshape
|
78
|
+
| | |
|
79
|
+
+----->mul<-----+ expand
|
80
|
+
| |
|
81
|
+
+------->add<-------+
|
82
|
+
|
|
83
|
+
reshape
|
84
|
+
|
|
85
|
+
output
|
86
|
+
"""
|
87
|
+
|
88
|
+
def __init__(self):
|
89
|
+
super().__init__()
|
90
|
+
|
91
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
92
|
+
logger = logging.getLogger(__name__)
|
93
|
+
|
94
|
+
gm = exported_program.graph_module
|
95
|
+
graph: torch.fx.Graph = gm.graph
|
96
|
+
modified = False
|
97
|
+
|
98
|
+
for node in graph.nodes:
|
99
|
+
if node.op != "call_function":
|
100
|
+
continue
|
101
|
+
|
102
|
+
if node.target not in [
|
103
|
+
torch.ops.aten.native_layer_norm.default,
|
104
|
+
torch.ops.aten.native_group_norm.default,
|
105
|
+
]:
|
106
|
+
continue
|
107
|
+
|
108
|
+
if node.target == torch.ops.aten.native_layer_norm.default:
|
109
|
+
ln_args = NativeLayerNormArgs(*node.args, **node.kwargs)
|
110
|
+
x = ln_args.input
|
111
|
+
normalized_shape = ln_args.normalized_shape
|
112
|
+
weight = ln_args.weight
|
113
|
+
bias = ln_args.bias
|
114
|
+
eps = ln_args.eps
|
115
|
+
|
116
|
+
if weight:
|
117
|
+
weight_shape = extract_shape(weight)
|
118
|
+
assert list(weight_shape) == normalized_shape
|
119
|
+
if bias:
|
120
|
+
bias_shape = extract_shape(bias)
|
121
|
+
assert list(bias_shape) == normalized_shape
|
122
|
+
|
123
|
+
x_val = x.meta.get("val")
|
124
|
+
assert isinstance(x_val, torch.Tensor)
|
125
|
+
x_shape = list(x_val.size())
|
126
|
+
x_dim = len(x_shape)
|
127
|
+
normalized_dim = len(normalized_shape)
|
128
|
+
assert x_dim >= normalized_dim
|
129
|
+
idx_normalize_start = x_dim - normalized_dim
|
130
|
+
|
131
|
+
norm_size = math.prod(normalized_shape)
|
132
|
+
layer_size = math.prod(x_shape[:idx_normalize_start])
|
133
|
+
elif node.target == torch.ops.aten.native_group_norm.default:
|
134
|
+
gn_args = NativeGroupNormArgs(*node.args, **node.kwargs)
|
135
|
+
x = gn_args.input
|
136
|
+
weight = gn_args.weight
|
137
|
+
bias = gn_args.bias
|
138
|
+
N = gn_args.N
|
139
|
+
C = gn_args.C
|
140
|
+
HW = gn_args.HxW
|
141
|
+
group = gn_args.group
|
142
|
+
eps = gn_args.eps
|
143
|
+
|
144
|
+
x_shape = list(extract_shape(x))
|
145
|
+
assert len(x_shape) == 4 or len(x_shape) == 3
|
146
|
+
assert x_shape[0] == N
|
147
|
+
assert x_shape[1] == C
|
148
|
+
|
149
|
+
assert C % group == 0
|
150
|
+
norm_size = int((C / group) * HW)
|
151
|
+
layer_size = N * group
|
152
|
+
else:
|
153
|
+
assert False, "Unreachable"
|
154
|
+
|
155
|
+
pack_shape = [layer_size, norm_size]
|
156
|
+
|
157
|
+
with gm.graph.inserting_before(node):
|
158
|
+
layer = graph.call_function(
|
159
|
+
# Sometimes, `x` has a stride for NHWC, which can't be reshaped with `aten.view.default`.
|
160
|
+
# TODO Find out how to process such case properly.
|
161
|
+
torch.ops.aten.reshape.default,
|
162
|
+
(x, pack_shape),
|
163
|
+
)
|
164
|
+
layer_mean = graph.call_function(
|
165
|
+
torch.ops.aten.mean.dim,
|
166
|
+
(layer, [-1]),
|
167
|
+
)
|
168
|
+
layer_mean_reshape = graph.call_function(
|
169
|
+
torch.ops.aten.view.default,
|
170
|
+
(layer_mean, [layer_size, 1]),
|
171
|
+
)
|
172
|
+
layer_deviation = graph.call_function(
|
173
|
+
torch.ops.aten.sub.Tensor,
|
174
|
+
(layer, layer_mean_reshape),
|
175
|
+
)
|
176
|
+
layer_sqr_diff = graph.call_function(
|
177
|
+
torch.ops.aten.pow.Tensor_Scalar,
|
178
|
+
(layer_deviation, 2),
|
179
|
+
)
|
180
|
+
var = graph.call_function(
|
181
|
+
torch.ops.aten.mean.dim,
|
182
|
+
(layer_sqr_diff, [-1]),
|
183
|
+
)
|
184
|
+
var_eps = graph.call_function(
|
185
|
+
torch.ops.aten.add.Tensor,
|
186
|
+
(var, eps),
|
187
|
+
)
|
188
|
+
rstd = graph.call_function(
|
189
|
+
torch.ops.aten.rsqrt.default,
|
190
|
+
(var_eps,),
|
191
|
+
)
|
192
|
+
rstd_reshape = graph.call_function(
|
193
|
+
torch.ops.aten.view.default,
|
194
|
+
(rstd, [layer_size, 1]),
|
195
|
+
)
|
196
|
+
layer_norm = graph.call_function(
|
197
|
+
torch.ops.aten.mul.Tensor,
|
198
|
+
(layer_deviation, rstd_reshape),
|
199
|
+
)
|
200
|
+
layer_norm = graph.call_function(
|
201
|
+
torch.ops.aten.view.default,
|
202
|
+
(layer_norm, x_shape),
|
203
|
+
)
|
204
|
+
|
205
|
+
# weight
|
206
|
+
if weight:
|
207
|
+
if node.target == torch.ops.aten.native_group_norm.default:
|
208
|
+
weight_shape = extract_shape(weight)
|
209
|
+
assert weight_shape[0] == C
|
210
|
+
reshape_size = [1] * len(x_shape)
|
211
|
+
reshape_size[1] = C
|
212
|
+
weight = graph.call_function(
|
213
|
+
torch.ops.aten.view.default,
|
214
|
+
(weight, reshape_size),
|
215
|
+
)
|
216
|
+
layer_norm = graph.call_function(
|
217
|
+
torch.ops.aten.mul.Tensor,
|
218
|
+
(layer_norm, weight),
|
219
|
+
)
|
220
|
+
|
221
|
+
# bias
|
222
|
+
if bias:
|
223
|
+
if node.target == torch.ops.aten.native_group_norm.default:
|
224
|
+
bias_shape = extract_shape(bias)
|
225
|
+
assert bias_shape[0] == C
|
226
|
+
reshape_size = [1] * len(x_shape)
|
227
|
+
reshape_size[1] = C
|
228
|
+
bias = graph.call_function(
|
229
|
+
torch.ops.aten.view.default,
|
230
|
+
(bias, reshape_size),
|
231
|
+
)
|
232
|
+
layer_norm = graph.call_function(
|
233
|
+
torch.ops.aten.add.Tensor,
|
234
|
+
(layer_norm, bias),
|
235
|
+
)
|
236
|
+
|
237
|
+
# Reset last node's meta for propagating replacing node's meta.
|
238
|
+
layer_norm.meta = {}
|
239
|
+
|
240
|
+
# NOTE Why select user `getitem` here?
|
241
|
+
# `native_layer_norm` and `native_group_norm` requires `getitem`
|
242
|
+
# to select the first output and discard the rest unused outputs.
|
243
|
+
# To replace those operators, it's necessary to replace the corresponding
|
244
|
+
# `getitem` node as well.
|
245
|
+
get_item = next(iter(node.users))
|
246
|
+
assert (
|
247
|
+
get_item.target == operator.getitem
|
248
|
+
), "First user of native_group/layer_norm should be getitem"
|
249
|
+
|
250
|
+
get_item.replace_all_uses_with(layer_norm, propagate_meta=True)
|
251
|
+
|
252
|
+
modified = True
|
253
|
+
|
254
|
+
gm.graph.eliminate_dead_code()
|
255
|
+
gm.graph.lint()
|
256
|
+
gm.recompile()
|
257
|
+
|
258
|
+
return PassResult(modified)
|