tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import copy
|
|
20
|
+
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
from torch.export import ExportedProgram
|
|
26
|
+
|
|
27
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
|
|
28
|
+
from tico.utils import logging
|
|
29
|
+
from tico.utils.errors import NotYetSupportedError
|
|
30
|
+
from tico.utils.graph import create_node
|
|
31
|
+
from tico.utils.passes import PassBase, PassResult
|
|
32
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
33
|
+
from tico.utils.utils import quant_min_max, set_new_meta_val
|
|
34
|
+
from tico.utils.validate_args_kwargs import (
|
|
35
|
+
AddTensorArgs,
|
|
36
|
+
BmmArgs,
|
|
37
|
+
CatArgs,
|
|
38
|
+
LinearArgs,
|
|
39
|
+
MulTensorArgs,
|
|
40
|
+
PermuteArgs,
|
|
41
|
+
ReluArgs,
|
|
42
|
+
ReshapeArgs,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def qparam_dtype(node: torch.fx.Node) -> str:
|
|
47
|
+
assert QPARAM_KEY in node.meta
|
|
48
|
+
return node.meta[QPARAM_KEY].dtype
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# Convert i16 qparam to u8 qparam
|
|
52
|
+
# scale and zero_point are inferred from i16 qparam
|
|
53
|
+
def _i16_to_u8(qparam: QuantParam) -> QuantParam:
|
|
54
|
+
# Assume per-tensor quantization
|
|
55
|
+
assert qparam.scale is not None and len(qparam.scale) == 1
|
|
56
|
+
assert qparam.dtype == "int16"
|
|
57
|
+
|
|
58
|
+
s16_scale = qparam.scale[0]
|
|
59
|
+
max_ = s16_scale * 32767 # numeric_limits<int16>
|
|
60
|
+
min_ = -max_
|
|
61
|
+
|
|
62
|
+
u8_scale = (max_ - min_) / 255
|
|
63
|
+
u8_zerop = round(-min_ / u8_scale)
|
|
64
|
+
|
|
65
|
+
new_qparam = QuantParam()
|
|
66
|
+
new_qparam.scale = [u8_scale]
|
|
67
|
+
new_qparam.zero_point = [u8_zerop]
|
|
68
|
+
new_qparam.dtype = "uint8"
|
|
69
|
+
|
|
70
|
+
return new_qparam
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Convert u8 qparam to i16 qparam
|
|
74
|
+
# scale is inferred from u8 qparam
|
|
75
|
+
def _u8_to_i16(qparam: QuantParam) -> QuantParam:
|
|
76
|
+
# Assume per-tensor quantization
|
|
77
|
+
assert qparam.scale is not None and len(qparam.scale) == 1
|
|
78
|
+
assert qparam.zero_point is not None and len(qparam.zero_point) == 1
|
|
79
|
+
assert qparam.dtype == "uint8"
|
|
80
|
+
|
|
81
|
+
u8_scale = qparam.scale[0]
|
|
82
|
+
u8_zerop = qparam.zero_point[0]
|
|
83
|
+
max_ = u8_scale * (255 - u8_zerop)
|
|
84
|
+
min_ = u8_scale * (-u8_zerop)
|
|
85
|
+
|
|
86
|
+
abs_max = max(abs(max_), abs(min_))
|
|
87
|
+
s16_scale = abs_max / 32767
|
|
88
|
+
s16_zerop = 0
|
|
89
|
+
|
|
90
|
+
new_qparam = QuantParam()
|
|
91
|
+
new_qparam.scale = [s16_scale]
|
|
92
|
+
new_qparam.zero_point = [s16_zerop]
|
|
93
|
+
new_qparam.dtype = "int16"
|
|
94
|
+
|
|
95
|
+
return new_qparam
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _insert_quantize_op_before(node, inp):
|
|
99
|
+
graph = node.graph
|
|
100
|
+
qparam: QuantParam = node.meta[QPARAM_KEY]
|
|
101
|
+
assert qparam.scale is not None
|
|
102
|
+
assert qparam.zero_point is not None
|
|
103
|
+
scale = qparam.scale[0]
|
|
104
|
+
zerop = qparam.zero_point[0]
|
|
105
|
+
min_, max_ = quant_min_max(qparam.dtype)
|
|
106
|
+
dtype = getattr(torch, qparam.dtype)
|
|
107
|
+
|
|
108
|
+
with graph.inserting_before(node):
|
|
109
|
+
q_args = (inp, scale, zerop, min_, max_, dtype)
|
|
110
|
+
quantize = create_node(
|
|
111
|
+
graph,
|
|
112
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
113
|
+
args=q_args,
|
|
114
|
+
origin=node,
|
|
115
|
+
)
|
|
116
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
|
117
|
+
set_new_meta_val(quantize)
|
|
118
|
+
|
|
119
|
+
node.replace_input_with(inp, quantize)
|
|
120
|
+
|
|
121
|
+
return quantize
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _insert_quantize_op_after(node):
|
|
125
|
+
graph = node.graph
|
|
126
|
+
qparam: QuantParam = node.meta[QPARAM_KEY]
|
|
127
|
+
assert qparam.scale is not None
|
|
128
|
+
assert qparam.zero_point is not None
|
|
129
|
+
scale = qparam.scale[0]
|
|
130
|
+
zerop = qparam.zero_point[0]
|
|
131
|
+
min_, max_ = quant_min_max(qparam.dtype)
|
|
132
|
+
dtype = getattr(torch, qparam.dtype)
|
|
133
|
+
with graph.inserting_after(node):
|
|
134
|
+
q_args = (node, scale, zerop, min_, max_, dtype)
|
|
135
|
+
quantize = create_node(
|
|
136
|
+
graph,
|
|
137
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
138
|
+
args=q_args,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
node.replace_all_uses_with(quantize, propagate_meta=True)
|
|
142
|
+
quantize.replace_input_with(quantize, node)
|
|
143
|
+
|
|
144
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
|
145
|
+
|
|
146
|
+
return quantize
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _linear_handler(node, logger):
|
|
150
|
+
lin_args = LinearArgs(*node.args, **node.kwargs)
|
|
151
|
+
inp = lin_args.input
|
|
152
|
+
|
|
153
|
+
if QPARAM_KEY not in inp.meta:
|
|
154
|
+
return
|
|
155
|
+
|
|
156
|
+
if QPARAM_KEY not in node.meta:
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
|
163
|
+
quantize = _insert_quantize_op_after(node)
|
|
164
|
+
|
|
165
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
166
|
+
|
|
167
|
+
# Update node's qparam from i16 to u8
|
|
168
|
+
# NOTE This would severely degrade accuracy. It is
|
|
169
|
+
# important to mitigate this accuracy drop in backend.
|
|
170
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
|
171
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
172
|
+
else:
|
|
173
|
+
raise NotYetSupportedError(
|
|
174
|
+
f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _add_handler(node, logger):
|
|
179
|
+
add_args = AddTensorArgs(*node.args, **node.kwargs)
|
|
180
|
+
x = add_args.input
|
|
181
|
+
y = add_args.other
|
|
182
|
+
|
|
183
|
+
if not isinstance(x, torch.fx.Node):
|
|
184
|
+
return
|
|
185
|
+
if not isinstance(y, torch.fx.Node):
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
if QPARAM_KEY not in x.meta:
|
|
189
|
+
return
|
|
190
|
+
if QPARAM_KEY not in y.meta:
|
|
191
|
+
return
|
|
192
|
+
if QPARAM_KEY not in node.meta:
|
|
193
|
+
return
|
|
194
|
+
|
|
195
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
|
196
|
+
return
|
|
197
|
+
|
|
198
|
+
if qparam_dtype(x) != qparam_dtype(y):
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
|
202
|
+
quantize = _insert_quantize_op_after(node)
|
|
203
|
+
|
|
204
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
205
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
|
206
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
207
|
+
else:
|
|
208
|
+
raise NotYetSupportedError("Unsupported dtype")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _mul_handler(node, logger):
|
|
212
|
+
mul_args = MulTensorArgs(*node.args, **node.kwargs)
|
|
213
|
+
x = mul_args.input
|
|
214
|
+
y = mul_args.other
|
|
215
|
+
|
|
216
|
+
if not isinstance(x, torch.fx.Node):
|
|
217
|
+
return
|
|
218
|
+
if not isinstance(y, torch.fx.Node):
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
if QPARAM_KEY not in x.meta:
|
|
222
|
+
return
|
|
223
|
+
if QPARAM_KEY not in y.meta:
|
|
224
|
+
return
|
|
225
|
+
if QPARAM_KEY not in node.meta:
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
|
232
|
+
quantize = _insert_quantize_op_after(node)
|
|
233
|
+
|
|
234
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
235
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
|
236
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
237
|
+
else:
|
|
238
|
+
raise NotYetSupportedError("Unsupported dtype")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _cat_handler(node, logger):
|
|
242
|
+
cat_args = CatArgs(*node.args, **node.kwargs)
|
|
243
|
+
tensors = cat_args.tensors
|
|
244
|
+
|
|
245
|
+
if any(QPARAM_KEY not in x.meta for x in tensors):
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
if QPARAM_KEY not in node.meta:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
assert len(tensors) > 0
|
|
252
|
+
in_dtype = qparam_dtype(tensors[0])
|
|
253
|
+
if in_dtype == qparam_dtype(node):
|
|
254
|
+
return
|
|
255
|
+
|
|
256
|
+
if any(qparam_dtype(x) != in_dtype for x in tensors):
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
if in_dtype == "int16" and qparam_dtype(node) == "uint8":
|
|
260
|
+
quantize = _insert_quantize_op_after(node)
|
|
261
|
+
|
|
262
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
263
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
|
264
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
265
|
+
else:
|
|
266
|
+
raise NotYetSupportedError("Unsupported dtype")
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _bmm_handler(node, logger):
|
|
270
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs)
|
|
271
|
+
x = bmm_args.input
|
|
272
|
+
y = bmm_args.mat2
|
|
273
|
+
|
|
274
|
+
if QPARAM_KEY not in x.meta:
|
|
275
|
+
return
|
|
276
|
+
if QPARAM_KEY not in y.meta:
|
|
277
|
+
return
|
|
278
|
+
if QPARAM_KEY not in node.meta:
|
|
279
|
+
return
|
|
280
|
+
|
|
281
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
|
282
|
+
return
|
|
283
|
+
|
|
284
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
|
285
|
+
quantize = _insert_quantize_op_after(node)
|
|
286
|
+
|
|
287
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
288
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
|
289
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
290
|
+
elif qparam_dtype(x) == "uint8" and qparam_dtype(node) == "int16":
|
|
291
|
+
quantize = _insert_quantize_op_after(node)
|
|
292
|
+
|
|
293
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
294
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
|
295
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
296
|
+
else:
|
|
297
|
+
raise NotYetSupportedError("Unsupported dtype")
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _permute_handler(node, logger):
|
|
301
|
+
per_args = PermuteArgs(*node.args, **node.kwargs)
|
|
302
|
+
inp = per_args.input
|
|
303
|
+
|
|
304
|
+
if QPARAM_KEY not in inp.meta:
|
|
305
|
+
return
|
|
306
|
+
|
|
307
|
+
if QPARAM_KEY not in node.meta:
|
|
308
|
+
return
|
|
309
|
+
|
|
310
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
|
311
|
+
return
|
|
312
|
+
|
|
313
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
|
314
|
+
# A new Quantize Op (s16 to u8) is inserted before (not after)
|
|
315
|
+
# permute Op to reduce tensor size ealier
|
|
316
|
+
quantize = _insert_quantize_op_before(node, inp)
|
|
317
|
+
|
|
318
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
319
|
+
logger.debug(f"quantize_per_tensor.default is inserted before {node.name}.")
|
|
320
|
+
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
|
321
|
+
quantize = _insert_quantize_op_after(node)
|
|
322
|
+
|
|
323
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
324
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
|
325
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
326
|
+
else:
|
|
327
|
+
raise NotYetSupportedError("Unsupported dtype")
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _reshape_handler(node, logger):
|
|
331
|
+
reshape_args = ReshapeArgs(*node.args, **node.kwargs)
|
|
332
|
+
inp = reshape_args.input
|
|
333
|
+
|
|
334
|
+
if QPARAM_KEY not in inp.meta:
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
if QPARAM_KEY not in node.meta:
|
|
338
|
+
return
|
|
339
|
+
|
|
340
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
|
344
|
+
# A new Quantize Op (s16 to u8) is inserted before (not after)
|
|
345
|
+
# reshape Op to reduce tensor size ealier
|
|
346
|
+
quantize = _insert_quantize_op_before(node, inp)
|
|
347
|
+
|
|
348
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
349
|
+
logger.debug(f"quantize_per_tensor.default is inserted before {node.name}.")
|
|
350
|
+
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
|
351
|
+
quantize = _insert_quantize_op_after(node)
|
|
352
|
+
|
|
353
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
354
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
|
355
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
356
|
+
else:
|
|
357
|
+
raise NotYetSupportedError("Unsupported dtype")
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _relu_handler(node, logger):
|
|
361
|
+
relu_args = ReluArgs(*node.args, **node.kwargs)
|
|
362
|
+
inp = relu_args.input
|
|
363
|
+
|
|
364
|
+
if QPARAM_KEY not in inp.meta:
|
|
365
|
+
return
|
|
366
|
+
|
|
367
|
+
if QPARAM_KEY not in node.meta:
|
|
368
|
+
return
|
|
369
|
+
|
|
370
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
|
371
|
+
return
|
|
372
|
+
|
|
373
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
|
374
|
+
quantize = _insert_quantize_op_after(node)
|
|
375
|
+
|
|
376
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
377
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
|
378
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
379
|
+
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
|
380
|
+
quantize = _insert_quantize_op_after(node)
|
|
381
|
+
|
|
382
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
|
383
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
|
384
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
|
385
|
+
else:
|
|
386
|
+
raise NotYetSupportedError("Unsupported dtype")
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
_op_handler: defaultdict[Any, Any | None] = defaultdict(lambda: None)
|
|
390
|
+
_op_handler[torch.ops.aten.linear.default] = _linear_handler
|
|
391
|
+
_op_handler[torch.ops.aten.add.Tensor] = _add_handler
|
|
392
|
+
_op_handler[torch.ops.aten.mul.Tensor] = _mul_handler
|
|
393
|
+
_op_handler[torch.ops.aten.cat.default] = _cat_handler
|
|
394
|
+
_op_handler[torch.ops.aten.bmm.default] = _bmm_handler
|
|
395
|
+
_op_handler[torch.ops.aten.permute.default] = _permute_handler
|
|
396
|
+
_op_handler[torch.ops.aten.reshape.default] = _reshape_handler
|
|
397
|
+
_op_handler[torch.ops.aten.relu.default] = _relu_handler
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
@trace_graph_diff_on_pass
|
|
401
|
+
class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
402
|
+
"""
|
|
403
|
+
Insert quantize Op in the operators where circle's type inference is violated.
|
|
404
|
+
Example. FullyConnected
|
|
405
|
+
[BEFORE]
|
|
406
|
+
Op (uint8) - aten.linear.default (int16)
|
|
407
|
+
[AFTER]
|
|
408
|
+
Op (uint8) - aten.linear.default (uint8) - quantized_decomposed.quantize_per_tensor.default (int16)
|
|
409
|
+
Why is this pass necessary?
|
|
410
|
+
- For some operators, circle's type inference pass overwrites the input's dtype to
|
|
411
|
+
the output's dtype. For the above example, fully-connected layer (aten.linear.default)'s
|
|
412
|
+
output dtype (int16) is updated to the input dtype (uint8), which breaks the semantics.
|
|
413
|
+
This problem can occur in the tools (ex: circle2circle) that automatically apply type inference.
|
|
414
|
+
- To resolve the issue, we insert quantize operators not to violate circle's type inference logic.
|
|
415
|
+
- NOTE For some cases, Quantize Op is inserted before the operators.
|
|
416
|
+
|
|
417
|
+
Let's assume Reshape Op's input is int16 and output is uint8. There are two possible places to insert
|
|
418
|
+
Quantize Op.
|
|
419
|
+
|
|
420
|
+
1. Insert Quantize before Reshape.
|
|
421
|
+
|
|
422
|
+
```
|
|
423
|
+
Predecessor (int16)-> Quantize (uint8) -> Reshape (uint8) -> ...
|
|
424
|
+
```
|
|
425
|
+
|
|
426
|
+
2. Insert Quantize after Reshape.
|
|
427
|
+
|
|
428
|
+
```
|
|
429
|
+
Predecessor (int16)-> Reshape (int16) -> Quantize (uint8) -> ...
|
|
430
|
+
```
|
|
431
|
+
|
|
432
|
+
Comparing 1) and 2), the difference is that Reshape operation is conducted in uint8 or int16.
|
|
433
|
+
We go with 1), which does Reshape in uint8, for faster execution. Note that Reshape Op does not
|
|
434
|
+
change the value, so its dytpe does not affect accuracy.
|
|
435
|
+
"""
|
|
436
|
+
|
|
437
|
+
def __init__(self):
|
|
438
|
+
super().__init__()
|
|
439
|
+
|
|
440
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
441
|
+
logger = logging.getLogger(__name__)
|
|
442
|
+
|
|
443
|
+
graph_module = exported_program.graph_module
|
|
444
|
+
graph: torch.fx.Graph = graph_module.graph
|
|
445
|
+
|
|
446
|
+
for node in graph.nodes:
|
|
447
|
+
if node.op != "call_function":
|
|
448
|
+
continue
|
|
449
|
+
|
|
450
|
+
handler = _op_handler[node.target]
|
|
451
|
+
if handler is not None:
|
|
452
|
+
handler(node, logger)
|
|
453
|
+
|
|
454
|
+
graph.eliminate_dead_code()
|
|
455
|
+
graph.lint()
|
|
456
|
+
graph_module.recompile()
|
|
457
|
+
|
|
458
|
+
# Run only once.
|
|
459
|
+
return PassResult(False)
|
|
@@ -13,25 +13,17 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import copy
|
|
16
|
-
from typing import Any, Dict, Optional
|
|
16
|
+
from typing import Any, Dict, Optional
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
from tico.experimental.quantization.config import BaseConfig
|
|
26
|
-
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
20
|
+
from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer
|
|
21
|
+
from tico.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
|
|
22
|
+
from tico.quantization.config.base import BaseConfig
|
|
23
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
24
|
+
from tico.quantization.quantizer_registry import get_quantizer
|
|
27
25
|
|
|
28
26
|
|
|
29
|
-
config_to_quantizer: Dict[str, Type[BaseQuantizer]] = {
|
|
30
|
-
"pt2e": PT2EQuantizer,
|
|
31
|
-
"gptq": GPTQQuantizer,
|
|
32
|
-
"smooth_quant": SmoothQuantQuantizer,
|
|
33
|
-
}
|
|
34
|
-
|
|
35
27
|
QUANTIZER_ATTRIBUTE_NAME = "tico_quantizer"
|
|
36
28
|
|
|
37
29
|
|
|
@@ -40,7 +32,7 @@ def prepare(
|
|
|
40
32
|
quant_config: BaseConfig,
|
|
41
33
|
args: Optional[Any] = None,
|
|
42
34
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
43
|
-
inplace: Optional[bool] =
|
|
35
|
+
inplace: Optional[bool] = True,
|
|
44
36
|
):
|
|
45
37
|
"""
|
|
46
38
|
Prepare the model for quantization using the provided configuration.
|
|
@@ -59,21 +51,24 @@ def prepare(
|
|
|
59
51
|
Returns:
|
|
60
52
|
The model prepared for quantization.
|
|
61
53
|
"""
|
|
62
|
-
if
|
|
54
|
+
if hasattr(model, QUANTIZER_ATTRIBUTE_NAME):
|
|
55
|
+
raise RuntimeError("prepare() already has been called.")
|
|
56
|
+
quantizer = get_quantizer(quant_config)
|
|
57
|
+
|
|
58
|
+
if isinstance(quantizer, PT2EQuantizer) and inplace:
|
|
63
59
|
raise RuntimeError(
|
|
64
60
|
"In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
|
|
65
61
|
)
|
|
66
62
|
|
|
67
63
|
model = model if inplace else copy.deepcopy(model)
|
|
68
64
|
|
|
69
|
-
quantizer = config_to_quantizer[quant_config.name](quant_config)
|
|
70
65
|
model = quantizer.prepare(model, args, kwargs)
|
|
71
66
|
setattr(model, QUANTIZER_ATTRIBUTE_NAME, quantizer)
|
|
72
67
|
|
|
73
68
|
return model
|
|
74
69
|
|
|
75
70
|
|
|
76
|
-
def convert(model, inplace: Optional[bool] =
|
|
71
|
+
def convert(model, inplace: Optional[bool] = True):
|
|
77
72
|
"""
|
|
78
73
|
Convert the prepared model to a quantized model using the provided configuration.
|
|
79
74
|
|
|
@@ -99,6 +94,12 @@ def convert(model, inplace: Optional[bool] = False):
|
|
|
99
94
|
raise RuntimeError(
|
|
100
95
|
"In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
|
|
101
96
|
)
|
|
97
|
+
# deepcopy prevents the quantizer from restoring the catcher used for calibration.
|
|
98
|
+
# TODO Revisit `inplace` policy.
|
|
99
|
+
if isinstance(quantizer, GPTQQuantizer) and not inplace:
|
|
100
|
+
raise RuntimeError(
|
|
101
|
+
"GPTQ quantization only supports `in-place=True`. Please set 'inplace=True' to proceed."
|
|
102
|
+
)
|
|
102
103
|
|
|
103
104
|
model = model if inplace else copy.deepcopy(model)
|
|
104
105
|
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import importlib
|
|
16
|
+
from typing import Dict, Optional, Type, TypeVar
|
|
17
|
+
|
|
18
|
+
from tico.quantization.config.base import BaseConfig
|
|
19
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
20
|
+
|
|
21
|
+
TQ = TypeVar("TQ", bound=BaseQuantizer)
|
|
22
|
+
|
|
23
|
+
# Mapping: Config type -> Quantizer type
|
|
24
|
+
_REGISTRY: Dict[Type[BaseConfig], Type[BaseQuantizer]] = {}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def register_quantizer(config_cls: Type[BaseConfig]):
|
|
28
|
+
"""
|
|
29
|
+
Decorator to register a quantizer for a given config class.
|
|
30
|
+
Usage:
|
|
31
|
+
@register_quantizer(GPTQConfig)
|
|
32
|
+
class GPTQQuantizer(BaseQuantizer): ...
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def wrapper(quantizer_cls: Type[TQ]) -> Type[TQ]:
|
|
36
|
+
_REGISTRY[config_cls] = quantizer_cls
|
|
37
|
+
return quantizer_cls
|
|
38
|
+
|
|
39
|
+
return wrapper
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _lookup(cfg: BaseConfig) -> Optional[Type[BaseQuantizer]]:
|
|
43
|
+
"""Return a quantizer class only if the exact config type is registered."""
|
|
44
|
+
return _REGISTRY.get(type(cfg))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_quantizer(cfg: BaseConfig) -> BaseQuantizer:
|
|
48
|
+
"""Factory to return a quantizer instance for the given config."""
|
|
49
|
+
qcls = _lookup(cfg)
|
|
50
|
+
if qcls is not None:
|
|
51
|
+
return qcls(cfg)
|
|
52
|
+
|
|
53
|
+
# Lazy import by naming convention
|
|
54
|
+
name = getattr(cfg, "name", None)
|
|
55
|
+
if name:
|
|
56
|
+
if name == "ptq":
|
|
57
|
+
importlib.import_module(f"tico.quantization.wrapq.quantizer")
|
|
58
|
+
else:
|
|
59
|
+
try:
|
|
60
|
+
importlib.import_module(f"tico.quantization.algorithm.{name}.quantizer")
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise RuntimeError(
|
|
63
|
+
f"Failed to import quantizer module for config name='{name}': {e}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
qcls = _lookup(cfg)
|
|
67
|
+
if qcls is not None:
|
|
68
|
+
return qcls(cfg)
|
|
69
|
+
|
|
70
|
+
raise RuntimeError(
|
|
71
|
+
f"No quantizer registered for config type {type(cfg).__name__} "
|
|
72
|
+
f"(name='{getattr(cfg,'name',None)}')."
|
|
73
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|