ai-edge-quantizer-nightly 0.0.1.dev20250221__py3-none-any.whl → 0.0.1.dev20250223__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +133 -14
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +102 -11
- {ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info}/RECORD +7 -7
- {ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,13 @@
|
|
15
15
|
|
16
16
|
"""Recovers quantized weights from dequantized weights (often from QAT)."""
|
17
17
|
|
18
|
-
|
18
|
+
import dataclasses
|
19
|
+
from typing import Any, Optional
|
19
20
|
import numpy as np
|
20
21
|
from ai_edge_quantizer import qtyping
|
22
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
|
21
23
|
from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
|
22
|
-
|
24
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
23
25
|
|
24
26
|
ALGORITHM_KEY = "dequantized_weight_recovery"
|
25
27
|
_TFLOpName = qtyping.TFLOperationName
|
@@ -27,24 +29,26 @@ _QuantTransformation = qtyping.QuantTransformation
|
|
27
29
|
_IntType = uniform_quantize_tensor.IntType
|
28
30
|
|
29
31
|
|
30
|
-
def
|
31
|
-
|
32
|
+
def _validate_recovered_weights(
|
33
|
+
original_vals: np.ndarray,
|
34
|
+
quant_vals: np.ndarray,
|
35
|
+
scale: np.ndarray,
|
36
|
+
tol: float = 1e-4,
|
32
37
|
):
|
33
|
-
"""Validates if the
|
38
|
+
"""Validates if recovered weights (from the quantized values) are close enough to the original ones.
|
34
39
|
|
35
40
|
Args:
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
41
|
+
original_vals: Original values before quantization.
|
42
|
+
quant_vals: Quantized values.
|
43
|
+
scale: Scale used for quantization.
|
44
|
+
tol: Tolerance for the difference between original and recovered values.
|
40
45
|
|
41
46
|
Raises:
|
42
|
-
|
43
|
-
|
47
|
+
RuntimeError: If the maximum difference between original and recovered
|
48
|
+
values exceeds the tolerance.
|
44
49
|
"""
|
45
|
-
quant_vals = np.round(dequant_vals / scale) # no need to clamp.
|
46
50
|
recovered_vals = quant_vals * scale
|
47
|
-
diff = np.abs(recovered_vals -
|
51
|
+
diff = np.abs(recovered_vals - original_vals).flatten()
|
48
52
|
max_diff = diff.max()
|
49
53
|
if max_diff > tol:
|
50
54
|
raise RuntimeError(
|
@@ -127,5 +131,120 @@ def get_zp_scale_from_2d_dequantized_symmetric_weights(
|
|
127
131
|
)
|
128
132
|
|
129
133
|
zero_points = np.zeros_like(scales, dtype=np.int32)
|
130
|
-
_validate_recovered_scale(dequant_vals, scales)
|
131
134
|
return zero_points, scales
|
135
|
+
|
136
|
+
|
137
|
+
def get_tensor_quant_params(
|
138
|
+
op_info: qtyping.OpInfo,
|
139
|
+
tensor_quant_config: qtyping.TensorQuantizationConfig,
|
140
|
+
tensor_content: Optional[np.ndarray] = None,
|
141
|
+
tensor_qsv: Optional[dict[str, Any]] = None,
|
142
|
+
) -> qtyping.UniformQuantParams:
|
143
|
+
"""Get the quantization parameters for a tensor.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
147
|
+
tensor_quant_config: The quantization config for the tensor.
|
148
|
+
tensor_content: The content of the tensor.
|
149
|
+
tensor_qsv: A dictionary containing the min/max of the tensor.
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
The quantization parameters for the tensor.
|
153
|
+
|
154
|
+
Raises:
|
155
|
+
ValueError: If the quantization granularity is blockwise, or if the tensor
|
156
|
+
is not a 2D symmetric weight tensor.
|
157
|
+
"""
|
158
|
+
# Fallback to naive_min_max_quantize.py for non-weight tensors.
|
159
|
+
if tensor_content is None:
|
160
|
+
return naive_min_max_quantize.get_tensor_quant_params(
|
161
|
+
op_info, tensor_quant_config, tensor_content, tensor_qsv
|
162
|
+
)
|
163
|
+
|
164
|
+
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
|
165
|
+
raise ValueError(
|
166
|
+
"Blockwise quantization is not supported for dequantized weight"
|
167
|
+
" recovery."
|
168
|
+
)
|
169
|
+
if tensor_content.ndim != 2 or not tensor_quant_config.symmetric:
|
170
|
+
raise ValueError(
|
171
|
+
"Only 2D symmetric weights are supported for dequantized weight"
|
172
|
+
" recovery."
|
173
|
+
)
|
174
|
+
|
175
|
+
quantized_dim = None
|
176
|
+
if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
|
177
|
+
quantized_dim = common_utils.get_weight_quantized_dim(
|
178
|
+
op_info, tensor_content
|
179
|
+
)
|
180
|
+
|
181
|
+
zp, scale = get_zp_scale_from_2d_dequantized_symmetric_weights(
|
182
|
+
dequant_vals=tensor_content,
|
183
|
+
quantized_dimension=quantized_dim,
|
184
|
+
)
|
185
|
+
quant_params = qtyping.UniformQuantParams(
|
186
|
+
scale=scale,
|
187
|
+
zero_point=zp,
|
188
|
+
num_bits=tensor_quant_config.num_bits,
|
189
|
+
symmetric=tensor_quant_config.symmetric,
|
190
|
+
quantized_dimension=quantized_dim,
|
191
|
+
)
|
192
|
+
quantized_vars = uniform_quantize_tensor.uniform_quantize(
|
193
|
+
tensor_content, quant_params
|
194
|
+
)
|
195
|
+
_validate_recovered_weights(tensor_content, quantized_vars, scale)
|
196
|
+
return dataclasses.replace(quant_params, quantized_data=quantized_vars)
|
197
|
+
|
198
|
+
|
199
|
+
def calibrate(
|
200
|
+
tfl_op: Any,
|
201
|
+
graph_info: qtyping.GraphInfo,
|
202
|
+
tensor_content_map: dict[str, np.ndarray],
|
203
|
+
inputs_to_ignore: Optional[list[int]] = None,
|
204
|
+
outputs_to_ignore: Optional[list[int]] = None,
|
205
|
+
) -> dict[str, qtyping.QSV]:
|
206
|
+
"""Collect quantization statistics variable (QSV, e.g., min/max) for the op.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
tfl_op: The tfl operation.
|
210
|
+
graph_info: Graph information needed to perform quantization for the op.
|
211
|
+
tensor_content_map: A map of tensor name to tensor content.
|
212
|
+
inputs_to_ignore: Input tensor indices to ignore.
|
213
|
+
outputs_to_ignore: Output tensor indices to ignore.
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
A dictionary with key as tensor name and value as the collected QSV.
|
217
|
+
"""
|
218
|
+
# Reuse the min/max calibration algorithm from naive_min_max_quantize.py since
|
219
|
+
# only weights need to be handled differently.
|
220
|
+
return naive_min_max_quantize.min_max_calibrate(
|
221
|
+
tfl_op,
|
222
|
+
graph_info,
|
223
|
+
tensor_content_map,
|
224
|
+
inputs_to_ignore,
|
225
|
+
outputs_to_ignore,
|
226
|
+
)
|
227
|
+
|
228
|
+
|
229
|
+
def init_qsvs(
|
230
|
+
op_info: qtyping.OpInfo,
|
231
|
+
graph_info: qtyping.GraphInfo,
|
232
|
+
inputs_to_ignore: Optional[list[int]] = None,
|
233
|
+
outputs_to_ignore: Optional[list[int]] = None,
|
234
|
+
) -> qtyping.QSV:
|
235
|
+
"""Initialize the QSVs.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
239
|
+
graph_info: Graph information needed to perform quantization for the op.
|
240
|
+
inputs_to_ignore: Input tensor indices to ignore.
|
241
|
+
outputs_to_ignore: Output tensor indices to ignore.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
QSVs.
|
245
|
+
"""
|
246
|
+
# Reuse the min/max calibration algorithm from naive_min_max_quantize.py since
|
247
|
+
# only weights need to be handeled differently.
|
248
|
+
return naive_min_max_quantize.init_qsvs(
|
249
|
+
op_info, graph_info, inputs_to_ignore, outputs_to_ignore
|
250
|
+
)
|
@@ -19,7 +19,6 @@ import numpy as np
|
|
19
19
|
from tensorflow.python.platform import googletest
|
20
20
|
from ai_edge_quantizer import qtyping
|
21
21
|
from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
|
22
|
-
from ai_edge_quantizer.utils import test_utils
|
23
22
|
|
24
23
|
_TFLOpName = qtyping.TFLOperationName
|
25
24
|
_TensorQuantConfig = qtyping.TensorQuantizationConfig
|
@@ -31,9 +30,15 @@ class DequantizedWeightRecoveryTest(parameterized.TestCase):
|
|
31
30
|
super().setUp()
|
32
31
|
self._dummy_quantized_weights = np.array([
|
33
32
|
[1, -2, 3, 4],
|
34
|
-
[6, 7, -
|
35
|
-
[
|
33
|
+
[6, 7, -6, 5],
|
34
|
+
[2, -6, -7, -4],
|
36
35
|
])
|
36
|
+
self._dummy_op_info = qtyping.OpInfo(
|
37
|
+
op=None,
|
38
|
+
op_name=_TFLOpName.FULLY_CONNECTED,
|
39
|
+
subgraph_op_index=0,
|
40
|
+
op_quant_config=qtyping.OpQuantizationConfig(),
|
41
|
+
)
|
37
42
|
|
38
43
|
@parameterized.named_parameters(
|
39
44
|
dict(
|
@@ -96,18 +101,104 @@ class DequantizedWeightRecoveryTest(parameterized.TestCase):
|
|
96
101
|
|
97
102
|
@parameterized.named_parameters(
|
98
103
|
dict(
|
99
|
-
testcase_name="
|
100
|
-
|
104
|
+
testcase_name="tensor-recovery-tensor-quant",
|
105
|
+
tensor_quant_config=qtyping.TensorQuantizationConfig(
|
106
|
+
num_bits=4,
|
107
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
108
|
+
),
|
109
|
+
scale=np.array([0.1875]).reshape(1, 1),
|
110
|
+
),
|
111
|
+
dict(
|
112
|
+
testcase_name="channel-recovery-channel-quant",
|
113
|
+
tensor_quant_config=qtyping.TensorQuantizationConfig(
|
114
|
+
num_bits=4,
|
115
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
116
|
+
),
|
101
117
|
scale=np.array([0.1875, 1e-4, 12.3]).reshape(3, 1),
|
102
118
|
),
|
103
119
|
dict(
|
104
|
-
testcase_name="
|
105
|
-
|
120
|
+
testcase_name="channel-recovery-excessive-bits",
|
121
|
+
tensor_quant_config=qtyping.TensorQuantizationConfig(
|
122
|
+
num_bits=8, # int4 is enough for the sample weights.
|
123
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
124
|
+
),
|
125
|
+
scale=np.array([0.1875, 1e-4, 12.3]).reshape(3, 1),
|
126
|
+
),
|
127
|
+
)
|
128
|
+
def test_get_tensor_quant_params_success_with_dequantized_weights(
|
129
|
+
self, tensor_quant_config, scale
|
130
|
+
):
|
131
|
+
dequant_vals = scale * self._dummy_quantized_weights
|
132
|
+
tensor_quant_params = dequantized_weight_recovery.get_tensor_quant_params(
|
133
|
+
self._dummy_op_info, tensor_quant_config, dequant_vals
|
134
|
+
)
|
135
|
+
|
136
|
+
if tensor_quant_config.granularity is qtyping.QuantGranularity.TENSORWISE:
|
137
|
+
self.assertIsNone(tensor_quant_params.quantized_dimension)
|
138
|
+
else:
|
139
|
+
self.assertEqual(tensor_quant_params.quantized_dimension, 0)
|
140
|
+
|
141
|
+
recovered_scale = tensor_quant_params.scale
|
142
|
+
self.assertEqual(recovered_scale.shape, scale.shape)
|
143
|
+
self.assertSequenceAlmostEqual(recovered_scale.flatten(), scale.flatten())
|
144
|
+
|
145
|
+
# Zero point should be zero for symmetric quantization.
|
146
|
+
recovered_zp = tensor_quant_params.zero_point
|
147
|
+
self.assertEqual(np.sum(recovered_zp), 0)
|
148
|
+
self.assertEqual(recovered_zp.shape, scale.shape)
|
149
|
+
|
150
|
+
def test_get_tensor_quant_params_success_with_qsv(self):
|
151
|
+
# Fall back to naive_min_max_quantize.py for non-weight tensors.
|
152
|
+
tensor_quant_params = dequantized_weight_recovery.get_tensor_quant_params(
|
153
|
+
self._dummy_op_info,
|
154
|
+
tensor_quant_config=qtyping.TensorQuantizationConfig(
|
155
|
+
num_bits=8,
|
156
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
157
|
+
),
|
158
|
+
tensor_qsv={
|
159
|
+
"min": np.array([-1]),
|
160
|
+
"max": np.array([1]),
|
161
|
+
},
|
162
|
+
)
|
163
|
+
|
164
|
+
self.assertIsNone(tensor_quant_params.quantized_dimension)
|
165
|
+
recovered_scale = tensor_quant_params.scale
|
166
|
+
self.assertEqual(recovered_scale.shape, (1,))
|
167
|
+
self.assertSequenceAlmostEqual(recovered_scale.flatten(), [1 / 127])
|
168
|
+
|
169
|
+
# Zero point should be zero for symmetric quantization.
|
170
|
+
recovered_zp = tensor_quant_params.zero_point
|
171
|
+
self.assertEqual(np.sum(recovered_zp), 0)
|
172
|
+
self.assertEqual(recovered_zp.shape, (1,))
|
173
|
+
|
174
|
+
@parameterized.named_parameters(
|
175
|
+
dict(
|
176
|
+
testcase_name="recovery_on_wrong_dimension",
|
177
|
+
tensor_quant_config=qtyping.TensorQuantizationConfig(
|
178
|
+
num_bits=4,
|
179
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
180
|
+
),
|
106
181
|
scale=np.array([0.003, 1.234, 12.65, 2.24e-4]).reshape(1, 4),
|
107
182
|
),
|
183
|
+
dict(
|
184
|
+
testcase_name="tensor_recovery_for_channel_quantization",
|
185
|
+
tensor_quant_config=qtyping.TensorQuantizationConfig(
|
186
|
+
num_bits=4,
|
187
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
188
|
+
),
|
189
|
+
scale=np.array([0.1875, 1e-2, 12.3]).reshape(3, 1),
|
190
|
+
),
|
191
|
+
dict(
|
192
|
+
testcase_name="insufficient_bits",
|
193
|
+
tensor_quant_config=qtyping.TensorQuantizationConfig(
|
194
|
+
num_bits=2,
|
195
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
196
|
+
),
|
197
|
+
scale=np.array([0.1875, 1e-2, 12.3]).reshape(3, 1),
|
198
|
+
),
|
108
199
|
)
|
109
|
-
def
|
110
|
-
self,
|
200
|
+
def test_get_tensor_quant_params_raises_error_big_recovery_error(
|
201
|
+
self, tensor_quant_config, scale
|
111
202
|
):
|
112
203
|
dequant_vals = scale * self._dummy_quantized_weights
|
113
204
|
with self.assertRaisesRegex(
|
@@ -115,8 +206,8 @@ class DequantizedWeightRecoveryTest(parameterized.TestCase):
|
|
115
206
|
"Failed to recover the original quantized values from dequantized"
|
116
207
|
" values. Max diff between recovered and original values: ",
|
117
208
|
):
|
118
|
-
dequantized_weight_recovery.
|
119
|
-
|
209
|
+
dequantized_weight_recovery.get_tensor_quant_params(
|
210
|
+
self._dummy_op_info, tensor_quant_config, dequant_vals
|
120
211
|
)
|
121
212
|
|
122
213
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-quantizer-nightly
|
3
|
-
Version: 0.0.1.
|
3
|
+
Version: 0.0.1.dev20250223
|
4
4
|
Summary: A quantizer for advanced developers to quantize converted AI Edge models.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
|
@@ -30,8 +30,8 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=s64
|
|
30
30
|
ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
31
31
|
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=wPZevOuowJczG9t4Gynzv7tIeH6zhOnaKPsfr2K_fsk,21259
|
32
32
|
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=qMmKbWqxrCoVKbLKHn9WuCrGKPfHkEyU0Nmhokh8Qeo,2597
|
33
|
-
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=
|
34
|
-
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=
|
33
|
+
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=OTXjEZ3Ctq3ffYzisX-6HwgK_DuA7uos_aap5PiIUPE,8686
|
34
|
+
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=y7BK11fkF63Ex_Jzg3fbIdy0D_Ca6HuvChVZR7Uwggc,8073
|
35
35
|
ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=fBqSidFVKZmdO-xIFfwZPdIN1eLJjOik8mUZxZj2ljk,12149
|
36
36
|
ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=Hok09dloSyBfD0oDM5VABdSZjM9JWSQhm_hDHNbFujA,7640
|
37
37
|
ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py,sha256=Q_vx7YN7KMpjubsngxRdJ4bfdSIV-gmXjtVuxIkZuX4,11078
|
@@ -60,8 +60,8 @@ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=SM8H4i7Jq_nfdsJpImopHndN
|
|
60
60
|
ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=Op3JxtOqlrjzmYF18jnnstL1k9xiY9kKJ8S2vklKGkc,11327
|
61
61
|
ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
|
62
62
|
ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
|
63
|
-
ai_edge_quantizer_nightly-0.0.1.
|
64
|
-
ai_edge_quantizer_nightly-0.0.1.
|
65
|
-
ai_edge_quantizer_nightly-0.0.1.
|
66
|
-
ai_edge_quantizer_nightly-0.0.1.
|
67
|
-
ai_edge_quantizer_nightly-0.0.1.
|
63
|
+
ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
64
|
+
ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info/METADATA,sha256=OatGT9K83-Q_E2SdpfZiyfqURscnos2yBXEZqQoyygQ,1484
|
65
|
+
ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
66
|
+
ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
|
67
|
+
ai_edge_quantizer_nightly-0.0.1.dev20250223.dist-info/RECORD,,
|
File without changes
|
File without changes
|