tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,294 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING, Union
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from torch._export.utils import (
|
22
|
+
get_buffer,
|
23
|
+
get_lifted_tensor_constant,
|
24
|
+
is_buffer,
|
25
|
+
is_lifted_tensor_constant,
|
26
|
+
)
|
27
|
+
|
28
|
+
# To import torch.ops.quantized_decomposed related operator
|
29
|
+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
30
|
+
from torch.export import ExportedProgram
|
31
|
+
|
32
|
+
from tico.utils import logging
|
33
|
+
from tico.utils.graph import create_node
|
34
|
+
from tico.utils.passes import PassBase, PassResult
|
35
|
+
from tico.utils.trace_decorators import (
|
36
|
+
trace_const_diff_on_pass,
|
37
|
+
trace_graph_diff_on_pass,
|
38
|
+
)
|
39
|
+
from tico.utils.validate_args_kwargs import FakeQuantizePerTensorTQParamArgs
|
40
|
+
|
41
|
+
|
42
|
+
def get_quant_type(min: int, max: int) -> torch.dtype:
|
43
|
+
if min == 0 and max == 15:
|
44
|
+
# torch can't represent "uint4".
|
45
|
+
# Let's set torch.uint8 and infer dtype with quant_min/quant_max instead.
|
46
|
+
return torch.uint8
|
47
|
+
if min == 0 and max == 255:
|
48
|
+
return torch.uint8
|
49
|
+
if min == -32768 and max == 32767:
|
50
|
+
return torch.int16
|
51
|
+
if min == -32767 and max == 32767:
|
52
|
+
return torch.int16
|
53
|
+
|
54
|
+
raise RuntimeError("Not supported min/max values")
|
55
|
+
|
56
|
+
|
57
|
+
def get_constant_from_tensor(
|
58
|
+
node: Union[torch.fx.Node, float], ep: ExportedProgram
|
59
|
+
) -> Union[torch.fx.Node, float]:
|
60
|
+
"""
|
61
|
+
There are some nodes that can do constant folding.
|
62
|
+
Case 1. With constant tensors
|
63
|
+
Case 2. With `torch.ones.` or `torch.zeros`
|
64
|
+
|
65
|
+
Please refer to the below `DecomposeFakeQuantizeTensorQParams` docs for the detailed explanations.
|
66
|
+
"""
|
67
|
+
if isinstance(node, float):
|
68
|
+
return node
|
69
|
+
if is_buffer(ep, node):
|
70
|
+
buf = get_buffer(ep, node)
|
71
|
+
assert isinstance(buf, torch.Tensor)
|
72
|
+
return buf.item()
|
73
|
+
elif is_lifted_tensor_constant(ep, node):
|
74
|
+
lifted = get_lifted_tensor_constant(ep, node)
|
75
|
+
assert isinstance(lifted, torch.Tensor)
|
76
|
+
return lifted.item()
|
77
|
+
assert isinstance(node.target, torch._ops.OpOverload)
|
78
|
+
if node.target.__name__ == "mul.Tensor":
|
79
|
+
assert len(node.args) == 2
|
80
|
+
x = get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
|
81
|
+
y = get_constant_from_tensor(node.args[1], ep) # type: ignore[arg-type]
|
82
|
+
return x * y # type: ignore[operator]
|
83
|
+
if node.target.__name__ == "zeros.default":
|
84
|
+
assert len(node.args) == 1
|
85
|
+
assert node.args[0] == [1]
|
86
|
+
return 0
|
87
|
+
if node.target.__name__ == "ones.default":
|
88
|
+
assert len(node.args) == 1
|
89
|
+
assert node.args[0] == [1]
|
90
|
+
return 1
|
91
|
+
if node.target.__name__ == "view.default":
|
92
|
+
assert len(node.args) == 2
|
93
|
+
tensor, shape = node.args
|
94
|
+
assert shape == [-1]
|
95
|
+
return get_constant_from_tensor(tensor, ep) # type: ignore[arg-type]
|
96
|
+
if node.target.__name__ == "_to_copy.default":
|
97
|
+
assert len(node.args) == 1
|
98
|
+
return get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
|
99
|
+
if node.target.__name__ == "lift_fresh_copy.default":
|
100
|
+
assert len(node.args) == 1
|
101
|
+
assert isinstance(node.args[0], torch.fx.Node)
|
102
|
+
lifted_tensor: torch.fx.Node = node.args[0]
|
103
|
+
lifted_tensor_constants = ep.graph_signature.inputs_to_lifted_tensor_constants
|
104
|
+
assert lifted_tensor.name in lifted_tensor_constants
|
105
|
+
tensor_name = lifted_tensor_constants[lifted_tensor.name]
|
106
|
+
value = ep.constants[tensor_name].item()
|
107
|
+
return value
|
108
|
+
if node.target.__name__ in ["detach.default", "detach_.default"]:
|
109
|
+
assert len(node.args) == 1
|
110
|
+
return get_constant_from_tensor(node.args[0], ep) # type: ignore[arg-type]
|
111
|
+
|
112
|
+
raise RuntimeError(f"Not supported node {node.target.__name__}")
|
113
|
+
|
114
|
+
|
115
|
+
@trace_const_diff_on_pass
|
116
|
+
@trace_graph_diff_on_pass
|
117
|
+
class DecomposeFakeQuantizeTensorQParams(PassBase):
|
118
|
+
"""
|
119
|
+
Decompose fake quantize with tensor QParams operator to quant/dequant operators.
|
120
|
+
Otherwise, it can't be converted to the edge IR because fake quantize operator is not Aten Canonical.
|
121
|
+
|
122
|
+
As of now, we don't support the (de)quantize op that has scale/zp whose dtypes are tensors. They should be scalars.
|
123
|
+
But, fake quantize with tensor QParams can be decomposed only when those tensors can be removed by constant foldings.
|
124
|
+
|
125
|
+
We consider below cases for now.
|
126
|
+
|
127
|
+
[CASE 1] With constant tensors
|
128
|
+
|
129
|
+
s = torch.tensor(0.1)
|
130
|
+
zp = torch.tensor(0)
|
131
|
+
fq_enabled = torch.tensor(True)
|
132
|
+
x = torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
|
133
|
+
x, s, zp, fq_enabled, 0, 255
|
134
|
+
)
|
135
|
+
|
136
|
+
[Before pass]
|
137
|
+
|
138
|
+
def forward(self, c_lifted_tensor_0, c_lifted_tensor_1, c_lifted_tensor_2, x):
|
139
|
+
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
140
|
+
lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1); c_lifted_tensor_1 = None
|
141
|
+
lift_fresh_copy_2 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_2); c_lifted_tensor_2 = None
|
142
|
+
_fake_quantize_per_tensor_affine_cachemask_tensor_qparams = torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default(x, lift_fresh_copy, lift_fresh_copy_1, lift_fresh_copy_2, quant_min, quant_max); x = lift_fresh_copy = lift_fresh_copy_1 = lift_fresh_copy_2 = None
|
143
|
+
getitem = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams[0]
|
144
|
+
getitem_1 = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams[1]; _fake_quantize_per_tensor_affine_cachemask_tensor_qparams = None
|
145
|
+
return (getitem, getitem_1)
|
146
|
+
|
147
|
+
[After pass]
|
148
|
+
|
149
|
+
def forward(self, c_lifted_tensor_0, c_lifted_tensor_1, c_lifted_tensor_2, x):
|
150
|
+
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
|
151
|
+
lift_fresh_copy_1 = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1); c_lifted_tensor_1 = None
|
152
|
+
quantize_per_tensor_tensor = torch.ops.quantized_decomposed.quantize_per_tensor.tensor(x, lift_fresh_copy, lift_fresh_copy_1, quant_min, quant_max, dtype = ${torch.dtype}); x = None
|
153
|
+
dequantize_per_tensor_tensor = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor(quantize_per_tensor_tensor, lift_fresh_copy, lift_fresh_copy_1, quant_min, quant_max, dtype = ${torch.dtype}); quantize_per_tensor_tensor = lift_fresh_copy = lift_fresh_copy_1 = None
|
154
|
+
return (dequantize_per_tensor_tensor,)
|
155
|
+
|
156
|
+
`s` and `zp` are tensors but they can be removed after constant foldings. When they are transformed to fx graph, they are
|
157
|
+
lifted as a placeholder and become an argument of the `aten.lift_fresh_copy`.
|
158
|
+
|
159
|
+
|
160
|
+
[CASE 2] With `torch.ones` or `torch.zeros`
|
161
|
+
|
162
|
+
n_bits=16
|
163
|
+
scale=torch.ones([1])
|
164
|
+
Qp = 2**(n_bits-1)-1
|
165
|
+
scale=scale*(1/Qp)
|
166
|
+
z = torch.fake_quantize_per_tensor_affine(x, scale, torch.zeros([1]).int().view(-1), -Qp, Qp)
|
167
|
+
|
168
|
+
`torch.ones([1])` or `torch.zeros([1])` is just number 1 or 0 but it is transformed to aten IR node, which prevents it from
|
169
|
+
being pre-calculated to the number.
|
170
|
+
|
171
|
+
For example, `n_bits * 1` would be just number 16 when the transformation, but `n_bits * torch.ones([1])`
|
172
|
+
would be `aten.Mul(16, aten.full)`, which is the reason why `torch.fake_quantize_per_tensor_affine` is trasnformed to
|
173
|
+
`aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams` whose scale/zp argument types are tensors rather than scalars.
|
174
|
+
|
175
|
+
So, if we manually compute such things like `n_bits * torch.ones([1])`, we can decompose fake quantize with qparam tensors.
|
176
|
+
|
177
|
+
[Before pass]
|
178
|
+
|
179
|
+
def forward(self, x):
|
180
|
+
ones = torch.ops.aten.ones.default([1], device = device(type='cpu'), pin_memory = False)
|
181
|
+
mul = torch.ops.aten.mul.Tensor(ones, 3.051850947599719e-05); ones = None
|
182
|
+
zeros = torch.ops.aten.zeros.default([1], device = device(type='cpu'), pin_memory = False)
|
183
|
+
_to_copy = torch.ops.aten._to_copy.default(zeros, dtype = torch.int32); zeros = None
|
184
|
+
view = torch.ops.aten.view.default(_to_copy, [-1]); _to_copy = None
|
185
|
+
ones_1 = torch.ops.aten.ones.default([1], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
|
186
|
+
_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default = torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default(x, mul, view, ones_1, -32767, 32767); x = mul = view = ones_1 = None
|
187
|
+
getitem = _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default[0]; _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_default = None
|
188
|
+
return (getitem,)
|
189
|
+
|
190
|
+
[After pass]
|
191
|
+
def forward(self, x: "f32[4]"):
|
192
|
+
quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(x, 3.051850947599719e-05, 0, -32767, 32767, dtype = torch.int16); x = None
|
193
|
+
dequantize_per_tensor_default: "f32[4]" = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, 3.051850947599719e-05, 0, -32767, 32767, dtype = torch.int16); quantize_per_tensor_default = None
|
194
|
+
return (dequantize_per_tensor_default,)
|
195
|
+
"""
|
196
|
+
|
197
|
+
def __init__(self):
|
198
|
+
super().__init__()
|
199
|
+
|
200
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
201
|
+
modified = False
|
202
|
+
|
203
|
+
gm = exported_program.graph_module
|
204
|
+
g = gm.graph
|
205
|
+
qd = torch.ops.quantized_decomposed # type: ignore[return]
|
206
|
+
for node in gm.graph.nodes:
|
207
|
+
if node.op != "call_function":
|
208
|
+
continue
|
209
|
+
if (
|
210
|
+
node.target
|
211
|
+
== torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default
|
212
|
+
):
|
213
|
+
# tensor, scale, zero_p, fake_quant_enabled, quant_min, quant_max
|
214
|
+
# TODO Support `fake_quant_enabled`
|
215
|
+
assert len(node.args) == 6
|
216
|
+
tensor, s, zp, _, quant_min, quant_max = node.args
|
217
|
+
# Get constant tensors
|
218
|
+
ep = exported_program
|
219
|
+
s_value = get_constant_from_tensor(s, ep)
|
220
|
+
zp_value = get_constant_from_tensor(zp, ep)
|
221
|
+
# This op has one user: `getitem` for the output.
|
222
|
+
# TODO Investigate why the op is generated like this.
|
223
|
+
# node.users = {getitem: None}
|
224
|
+
get_item, *mask = node.users.keys()
|
225
|
+
# assert len(mask) == 0, "Not supported yet."
|
226
|
+
quant_kwargs = {
|
227
|
+
**node.kwargs,
|
228
|
+
**{"dtype": get_quant_type(quant_min, quant_max)},
|
229
|
+
}
|
230
|
+
with gm.graph.inserting_before(node):
|
231
|
+
quant = create_node(
|
232
|
+
g,
|
233
|
+
qd.quantize_per_tensor.default,
|
234
|
+
args=(tensor, s_value, zp_value, quant_min, quant_max),
|
235
|
+
kwargs=quant_kwargs,
|
236
|
+
origin=node,
|
237
|
+
)
|
238
|
+
dequant = create_node(
|
239
|
+
g,
|
240
|
+
qd.dequantize_per_tensor.default,
|
241
|
+
args=(quant, *quant.args[1:]),
|
242
|
+
kwargs=quant.kwargs,
|
243
|
+
)
|
244
|
+
get_item.replace_all_uses_with(dequant, propagate_meta=True)
|
245
|
+
# If `mask` can be graph output, which prevents `eliminate_dead_code()` from eliminating `mask`.
|
246
|
+
# So, let's remove `mask` from the output.args first.
|
247
|
+
# mask_user(output).args == (dequantize_per_tensor.tensor, mask)
|
248
|
+
if mask:
|
249
|
+
len(mask) == 1
|
250
|
+
mask_user = list(mask[0].users.keys())[0]
|
251
|
+
assert len(mask_user.args) == 1
|
252
|
+
mask_user.args = ((mask_user.args[0][0],),)
|
253
|
+
modified = True
|
254
|
+
if (
|
255
|
+
node.target
|
256
|
+
== torch.ops.aten.fake_quantize_per_tensor_affine.tensor_qparams
|
257
|
+
):
|
258
|
+
fq_args = FakeQuantizePerTensorTQParamArgs(*node.args, **node.kwargs)
|
259
|
+
tensor = fq_args.input
|
260
|
+
s = fq_args.scale
|
261
|
+
zp = fq_args.zero_point
|
262
|
+
quant_min = fq_args.quant_min
|
263
|
+
quant_max = fq_args.quant_max
|
264
|
+
|
265
|
+
# Get constant tensors
|
266
|
+
ep = exported_program
|
267
|
+
s_value = get_constant_from_tensor(s, ep)
|
268
|
+
zp_value = get_constant_from_tensor(zp, ep)
|
269
|
+
quant_kwargs = {
|
270
|
+
**node.kwargs,
|
271
|
+
**{"dtype": get_quant_type(quant_min, quant_max)},
|
272
|
+
}
|
273
|
+
with gm.graph.inserting_before(node):
|
274
|
+
quant = create_node(
|
275
|
+
g,
|
276
|
+
qd.quantize_per_tensor.default,
|
277
|
+
args=(tensor, s_value, zp_value, quant_min, quant_max),
|
278
|
+
kwargs=quant_kwargs,
|
279
|
+
origin=node,
|
280
|
+
)
|
281
|
+
dequant = create_node(
|
282
|
+
g,
|
283
|
+
qd.dequantize_per_tensor.default,
|
284
|
+
args=(quant, *quant.args[1:]),
|
285
|
+
kwargs=quant.kwargs,
|
286
|
+
)
|
287
|
+
node.replace_all_uses_with(dequant, propagate_meta=True)
|
288
|
+
modified = True
|
289
|
+
|
290
|
+
gm.graph.eliminate_dead_code()
|
291
|
+
gm.graph.lint()
|
292
|
+
gm.recompile()
|
293
|
+
|
294
|
+
return PassResult(modified)
|
@@ -0,0 +1,275 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
import operator
|
17
|
+
from typing import TYPE_CHECKING
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
import torch.fx
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.circle_mapping import extract_shape
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.graph import create_node
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
|
+
from tico.utils.utils import is_target_node
|
30
|
+
from tico.utils.validate_args_kwargs import NativeGroupNormArgs, NativeLayerNormArgs
|
31
|
+
|
32
|
+
|
33
|
+
@trace_graph_diff_on_pass
|
34
|
+
class DecomposeGroupNorm(PassBase):
|
35
|
+
"""
|
36
|
+
This pass decomposes Group normalization operators.
|
37
|
+
|
38
|
+
LayerNorm is group=1 Group normalization.
|
39
|
+
|
40
|
+
[LayerNorm, GroupNorm]
|
41
|
+
|
42
|
+
Two normalzations result in same nodes but have different normalization shapes.
|
43
|
+
|
44
|
+
[before]
|
45
|
+
|
46
|
+
input (tensor, normalized_shape, weight, bias, eps)
|
47
|
+
|
|
48
|
+
NativeLayerNorm or GroupNorm
|
49
|
+
|
|
50
|
+
output
|
51
|
+
|
52
|
+
[after]
|
53
|
+
|
54
|
+
input
|
55
|
+
(tensor)
|
56
|
+
|
|
57
|
+
reshape
|
58
|
+
|
|
59
|
+
+------------+
|
60
|
+
| |
|
61
|
+
mean |
|
62
|
+
| |
|
63
|
+
reshape |
|
64
|
+
| |
|
65
|
+
+ --->sub<---+
|
66
|
+
|
|
67
|
+
+-------+
|
68
|
+
| |
|
69
|
+
pow |
|
70
|
+
input | |
|
71
|
+
(eps) mean |
|
72
|
+
| | |
|
73
|
+
+----->add<----+ |
|
74
|
+
| | input
|
75
|
+
rsqrt | (weight)
|
76
|
+
| | | input
|
77
|
+
reshape | reshape (bias)
|
78
|
+
| | | |
|
79
|
+
+----->mul<----+ expand reshape
|
80
|
+
| | |
|
81
|
+
+----->mul<-----+ expand
|
82
|
+
| |
|
83
|
+
+------->add<-------+
|
84
|
+
|
|
85
|
+
reshape
|
86
|
+
|
|
87
|
+
output
|
88
|
+
"""
|
89
|
+
|
90
|
+
def __init__(self):
|
91
|
+
super().__init__()
|
92
|
+
|
93
|
+
def _insert_norm(self, graph, tensor, eps, origin):
|
94
|
+
"""
|
95
|
+
Insert (tensor - mean) / sqrt(var + eps)) into the graph
|
96
|
+
and return the normalized tensor node.
|
97
|
+
"""
|
98
|
+
mean = create_node(
|
99
|
+
graph,
|
100
|
+
torch.ops.aten.mean.dim,
|
101
|
+
(tensor, [-1]),
|
102
|
+
{"keepdim": True},
|
103
|
+
origin=origin,
|
104
|
+
)
|
105
|
+
deviation = create_node(
|
106
|
+
graph, torch.ops.aten.sub.Tensor, (tensor, mean), origin=origin
|
107
|
+
)
|
108
|
+
squared = create_node(
|
109
|
+
graph, torch.ops.aten.pow.Tensor_Scalar, (deviation, 2), origin=origin
|
110
|
+
)
|
111
|
+
var = create_node(
|
112
|
+
graph,
|
113
|
+
torch.ops.aten.mean.dim,
|
114
|
+
(squared, [-1]),
|
115
|
+
{"keepdim": True},
|
116
|
+
origin=origin,
|
117
|
+
)
|
118
|
+
inverse_std = create_node(
|
119
|
+
graph,
|
120
|
+
torch.ops.aten.rsqrt.default,
|
121
|
+
(create_node(graph, torch.ops.aten.add.Tensor, (var, eps), origin=origin),),
|
122
|
+
origin=origin,
|
123
|
+
)
|
124
|
+
return create_node(
|
125
|
+
graph, torch.ops.aten.mul.Tensor, (deviation, inverse_std), origin=origin
|
126
|
+
)
|
127
|
+
|
128
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
129
|
+
logger = logging.getLogger(__name__)
|
130
|
+
|
131
|
+
gm = exported_program.graph_module
|
132
|
+
graph: torch.fx.Graph = gm.graph
|
133
|
+
modified = False
|
134
|
+
|
135
|
+
for node in graph.nodes:
|
136
|
+
if not is_target_node(
|
137
|
+
node,
|
138
|
+
[
|
139
|
+
torch.ops.aten.native_layer_norm.default,
|
140
|
+
torch.ops.aten.native_group_norm.default,
|
141
|
+
],
|
142
|
+
):
|
143
|
+
continue
|
144
|
+
|
145
|
+
if node.target == torch.ops.aten.native_layer_norm.default:
|
146
|
+
ln_args = NativeLayerNormArgs(*node.args, **node.kwargs)
|
147
|
+
x = ln_args.input
|
148
|
+
normalized_shape = ln_args.normalized_shape
|
149
|
+
weight = ln_args.weight
|
150
|
+
bias = ln_args.bias
|
151
|
+
eps = ln_args.eps
|
152
|
+
|
153
|
+
if weight:
|
154
|
+
weight_shape = extract_shape(weight)
|
155
|
+
assert list(weight_shape) == normalized_shape
|
156
|
+
if bias:
|
157
|
+
bias_shape = extract_shape(bias)
|
158
|
+
assert list(bias_shape) == normalized_shape
|
159
|
+
|
160
|
+
x_val = x.meta.get("val")
|
161
|
+
assert isinstance(x_val, torch.Tensor)
|
162
|
+
x_shape = list(x_val.size())
|
163
|
+
x_dim = len(x_shape)
|
164
|
+
normalized_dim = len(normalized_shape)
|
165
|
+
assert x_dim >= normalized_dim
|
166
|
+
idx_normalize_start = x_dim - normalized_dim
|
167
|
+
|
168
|
+
norm_size = math.prod(normalized_shape)
|
169
|
+
layer_size = math.prod(x_shape[:idx_normalize_start])
|
170
|
+
elif node.target == torch.ops.aten.native_group_norm.default:
|
171
|
+
gn_args = NativeGroupNormArgs(*node.args, **node.kwargs)
|
172
|
+
x = gn_args.input
|
173
|
+
weight = gn_args.weight
|
174
|
+
bias = gn_args.bias
|
175
|
+
N = gn_args.N
|
176
|
+
C = gn_args.C
|
177
|
+
HW = gn_args.HxW
|
178
|
+
group = gn_args.group
|
179
|
+
eps = gn_args.eps
|
180
|
+
|
181
|
+
x_shape = list(extract_shape(x))
|
182
|
+
assert len(x_shape) == 4 or len(x_shape) == 3
|
183
|
+
assert x_shape[0] == N
|
184
|
+
assert x_shape[1] == C
|
185
|
+
|
186
|
+
assert C % group == 0
|
187
|
+
norm_size = int((C / group) * HW)
|
188
|
+
layer_size = N * group
|
189
|
+
else:
|
190
|
+
assert False, "Unreachable"
|
191
|
+
|
192
|
+
pack_shape = [layer_size, norm_size]
|
193
|
+
|
194
|
+
with gm.graph.inserting_before(node):
|
195
|
+
# Branch only on whether a reshape is needed; the normalization is shared.
|
196
|
+
if norm_size != x_shape[-1]:
|
197
|
+
# Pack groups so that the last dimension equals norm_size.
|
198
|
+
packed = create_node(
|
199
|
+
graph,
|
200
|
+
torch.ops.aten.reshape.default,
|
201
|
+
(x, pack_shape),
|
202
|
+
origin=node,
|
203
|
+
)
|
204
|
+
normed = self._insert_norm(graph, packed, eps, origin=node)
|
205
|
+
# Restore the original shape after normalization.
|
206
|
+
layer_norm = create_node(
|
207
|
+
graph,
|
208
|
+
torch.ops.aten.reshape.default,
|
209
|
+
(normed, x_shape),
|
210
|
+
origin=node,
|
211
|
+
)
|
212
|
+
else:
|
213
|
+
# The input already has norm_size in the last dimension.
|
214
|
+
layer_norm = self._insert_norm(graph, x, eps, origin=node)
|
215
|
+
|
216
|
+
# weight
|
217
|
+
if weight:
|
218
|
+
if node.target == torch.ops.aten.native_group_norm.default:
|
219
|
+
weight_shape = extract_shape(weight)
|
220
|
+
assert weight_shape[0] == C
|
221
|
+
reshape_size = [1] * len(x_shape)
|
222
|
+
reshape_size[1] = C
|
223
|
+
weight = create_node(
|
224
|
+
graph,
|
225
|
+
torch.ops.aten.view.default,
|
226
|
+
(weight, reshape_size),
|
227
|
+
origin=node,
|
228
|
+
)
|
229
|
+
layer_norm = create_node(
|
230
|
+
graph,
|
231
|
+
torch.ops.aten.mul.Tensor,
|
232
|
+
(layer_norm, weight),
|
233
|
+
origin=node,
|
234
|
+
)
|
235
|
+
|
236
|
+
# bias
|
237
|
+
if bias:
|
238
|
+
if node.target == torch.ops.aten.native_group_norm.default:
|
239
|
+
bias_shape = extract_shape(bias)
|
240
|
+
assert bias_shape[0] == C
|
241
|
+
reshape_size = [1] * len(x_shape)
|
242
|
+
reshape_size[1] = C
|
243
|
+
bias = create_node(
|
244
|
+
graph,
|
245
|
+
torch.ops.aten.view.default,
|
246
|
+
(bias, reshape_size),
|
247
|
+
origin=node,
|
248
|
+
)
|
249
|
+
layer_norm = create_node(
|
250
|
+
graph,
|
251
|
+
torch.ops.aten.add.Tensor,
|
252
|
+
(layer_norm, bias),
|
253
|
+
)
|
254
|
+
# Reset last node's meta for propagating replacing node's meta.
|
255
|
+
layer_norm.meta = {}
|
256
|
+
|
257
|
+
# NOTE Why select user `getitem` here?
|
258
|
+
# `native_layer_norm` and `native_group_norm` requires `getitem`
|
259
|
+
# to select the first output and discard the rest unused outputs.
|
260
|
+
# To replace those operators, it's necessary to replace the corresponding
|
261
|
+
# `getitem` node as well.
|
262
|
+
get_item = next(iter(node.users))
|
263
|
+
assert (
|
264
|
+
get_item.target == operator.getitem
|
265
|
+
), "First user of native_group/layer_norm should be getitem"
|
266
|
+
|
267
|
+
get_item.replace_all_uses_with(layer_norm, propagate_meta=True)
|
268
|
+
|
269
|
+
modified = True
|
270
|
+
|
271
|
+
gm.graph.eliminate_dead_code()
|
272
|
+
gm.graph.lint()
|
273
|
+
gm.recompile()
|
274
|
+
|
275
|
+
return PassResult(modified)
|