ai-edge-quantizer-nightly 0.3.0.dev20250613__py3-none-any.whl → 0.3.0.dev20250615__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +2 -0
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +17 -0
- ai_edge_quantizer/default_policy.py +4 -2
- ai_edge_quantizer/qtyping.py +1 -0
- ai_edge_quantizer/recipe_manager_test.py +10 -13
- ai_edge_quantizer/utils/calibration_utils.py +342 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +174 -3
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
- {ai_edge_quantizer_nightly-0.3.0.dev20250613.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.3.0.dev20250613.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info}/RECORD +13 -13
- {ai_edge_quantizer_nightly-0.3.0.dev20250613.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.3.0.dev20250613.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.3.0.dev20250613.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info}/top_level.txt +0 -0
@@ -112,6 +112,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
|
|
112
112
|
common_quantize.materialize_squared_difference
|
113
113
|
),
|
114
114
|
_TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
|
115
|
+
_TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
|
115
116
|
}
|
116
117
|
for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
117
118
|
register_quantized_op(
|
@@ -250,6 +251,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
250
251
|
common_quantize.materialize_squared_difference
|
251
252
|
),
|
252
253
|
_TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
|
254
|
+
_TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
|
253
255
|
})
|
254
256
|
|
255
257
|
for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
@@ -728,6 +728,23 @@ def materialize_max_pool_2d(
|
|
728
728
|
)
|
729
729
|
|
730
730
|
|
731
|
+
def materialize_resize_bilinear(
|
732
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
733
|
+
op_info: qtyping.OpInfo,
|
734
|
+
graph_info: qtyping.GraphInfo,
|
735
|
+
tensor_name_to_qsv: dict[str, Any],
|
736
|
+
) -> list[qtyping.TensorTransformationParams]:
|
737
|
+
"""Materialize tensors in tfl.resize_bilinear."""
|
738
|
+
return common_utils.materialize_standard_op(
|
739
|
+
op_info,
|
740
|
+
graph_info,
|
741
|
+
tensor_name_to_qsv,
|
742
|
+
get_tensor_quant_params_fn,
|
743
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
744
|
+
inputs_to_ignore=[1], # Resize size does not need to be quantized.
|
745
|
+
)
|
746
|
+
|
747
|
+
|
731
748
|
def _get_tensor_shape_for_blockwise(
|
732
749
|
tensor_shape: Sequence[int], quantized_dim: int, block_size: int
|
733
750
|
) -> list[int]:
|
@@ -185,7 +185,8 @@ DEFAULT_JSON_POLICY = """
|
|
185
185
|
"SELECT_V2",
|
186
186
|
"STABLEHLO_COMPOSITE",
|
187
187
|
"PAD",
|
188
|
-
"MAX_POOL_2D"
|
188
|
+
"MAX_POOL_2D",
|
189
|
+
"RESIZE_BILINEAR"
|
189
190
|
],
|
190
191
|
"static_wi8_ai8": [
|
191
192
|
"ADD",
|
@@ -219,7 +220,8 @@ DEFAULT_JSON_POLICY = """
|
|
219
220
|
"STABLEHLO_COMPOSITE",
|
220
221
|
"PAD",
|
221
222
|
"SQUARED_DIFFERENCE",
|
222
|
-
"MAX_POOL_2D"
|
223
|
+
"MAX_POOL_2D",
|
224
|
+
"RESIZE_BILINEAR"
|
223
225
|
],
|
224
226
|
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
225
227
|
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
ai_edge_quantizer/qtyping.py
CHANGED
@@ -29,19 +29,6 @@ _AlgorithmName = recipe_manager.AlgorithmName
|
|
29
29
|
_QuantGranularity = qtyping.QuantGranularity
|
30
30
|
|
31
31
|
|
32
|
-
# Sample functions for test cases.
|
33
|
-
def _sample_init_qsvs(*_, **__):
|
34
|
-
return 1.0, dict()
|
35
|
-
|
36
|
-
|
37
|
-
def _sample_calibration_func(*_, **__):
|
38
|
-
return 2.0, dict()
|
39
|
-
|
40
|
-
|
41
|
-
def _sample_materialize_func(*_, **__):
|
42
|
-
return 3.0, dict()
|
43
|
-
|
44
|
-
|
45
32
|
def _sample_check_op_config_func(op_name, op_config, _):
|
46
33
|
if (
|
47
34
|
op_config.weight_tensor_config is not None
|
@@ -67,6 +54,16 @@ def _add_default_int8xint8_integer_recipe(recipe_manager_object):
|
|
67
54
|
|
68
55
|
# register some currently unsupported ops for testing purposes
|
69
56
|
def _register_testing_op(algorithm_key, tfl_op):
|
57
|
+
# Sample functions for test cases.
|
58
|
+
def _sample_init_qsvs(*_, **__):
|
59
|
+
return {'name': dict()}
|
60
|
+
|
61
|
+
def _sample_calibration_func(*_, **__):
|
62
|
+
return {'name2': dict()}
|
63
|
+
|
64
|
+
def _sample_materialize_func(*_, **__):
|
65
|
+
return []
|
66
|
+
|
70
67
|
algorithm_manager.register_op_quant_config_validation_func(
|
71
68
|
algorithm_key, _sample_check_op_config_func
|
72
69
|
)
|
@@ -15,9 +15,26 @@
|
|
15
15
|
|
16
16
|
"""Utilities for model calibration."""
|
17
17
|
|
18
|
-
|
18
|
+
import copy
|
19
|
+
from typing import Any, Union
|
20
|
+
|
19
21
|
import numpy as np
|
22
|
+
|
23
|
+
from ai_edge_quantizer import algorithm_manager
|
20
24
|
from ai_edge_quantizer import qtyping
|
25
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
|
26
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
27
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
28
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
29
|
+
from ai_edge_litert import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
|
30
|
+
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
|
31
|
+
|
32
|
+
|
33
|
+
_SignatureInput = dict[str, Any]
|
34
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
35
|
+
_SignatureData = dict[
|
36
|
+
str, list[str]
|
37
|
+
] # signature_key -> list of signature_names.
|
21
38
|
|
22
39
|
|
23
40
|
def _update_moving_average(
|
@@ -84,3 +101,327 @@ def min_max_update(qsv: qtyping.QSV, new_qsv: qtyping.QSV) -> qtyping.QSV:
|
|
84
101
|
updated_qsv["min"] = np.minimum(qsv["min"], new_qsv["min"])
|
85
102
|
updated_qsv["max"] = np.maximum(qsv["max"], new_qsv["max"])
|
86
103
|
return updated_qsv
|
104
|
+
|
105
|
+
|
106
|
+
def _find_overall_min_max(
|
107
|
+
qsv: qtyping.QSV, tensor_names: list[str]
|
108
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
109
|
+
"""Finds the overall minimum and maximum values for the given tensors.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
qsv: The quantization statistical value of the tensor (min/max).
|
113
|
+
tensor_names: The list of tensor names to find the minimum and maximum
|
114
|
+
values.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
The minimum and maximum values for the given tensors.
|
118
|
+
"""
|
119
|
+
min_value = np.inf
|
120
|
+
max_value = -np.inf
|
121
|
+
for tensor_name in tensor_names:
|
122
|
+
min_value = min(min_value, qsv[tensor_name]["min"])
|
123
|
+
max_value = max(max_value, qsv[tensor_name]["max"])
|
124
|
+
return min_value, max_value
|
125
|
+
|
126
|
+
|
127
|
+
class CalibrationQsvAlignmentUtils:
|
128
|
+
"""Calibration utils for alignment of QSVs.
|
129
|
+
|
130
|
+
This class is used to align QSVs for a given model. It builds a list of ops
|
131
|
+
that need to be constrained to the same scale as the input. Based on this
|
132
|
+
list, it finds the corresponding tensor names for a given signature data.
|
133
|
+
"""
|
134
|
+
|
135
|
+
def __init__(self, model_path: str):
|
136
|
+
self._same_as_input_scale_ops = []
|
137
|
+
|
138
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_path)
|
139
|
+
self._flatbuffer_object = tfl_flatbuffer_utils.read_model(model_path)
|
140
|
+
|
141
|
+
signature_keys = list(tfl_interpreter.get_signature_list().keys())
|
142
|
+
|
143
|
+
# Build a dict of signature runners.
|
144
|
+
self._signature_runners = {}
|
145
|
+
for signature_key in signature_keys:
|
146
|
+
signature_runner = tfl_interpreter.get_signature_runner(signature_key)
|
147
|
+
self._signature_runners[signature_key] = signature_runner
|
148
|
+
|
149
|
+
# Make a list of `SAME_AS_INPUT_SCALE` operators. This is used to identify
|
150
|
+
# the operators that need to be constrained to the same scale as the input.
|
151
|
+
self._build_same_as_input_scale_op_list()
|
152
|
+
|
153
|
+
def _build_same_as_input_scale_op_list(self, verbose: bool = False):
|
154
|
+
"""Constructs a list of SAME_AS_INPUT_SCALE operators.
|
155
|
+
|
156
|
+
This is achieved by invoking all materialization functions and extracting
|
157
|
+
the constraint argument, using monkey patching to redirect logic to wrapper
|
158
|
+
functions.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
verbose: Flag to enable verbose output.
|
162
|
+
"""
|
163
|
+
|
164
|
+
def materialize_standard_op_wrapper(
|
165
|
+
op_info: qtyping.OpInfo,
|
166
|
+
*_args,
|
167
|
+
constraint: _OpQuantConstraint = _OpQuantConstraint.NO_CONSTRAIN,
|
168
|
+
**_kwargs,
|
169
|
+
) -> list[qtyping.TensorTransformationParams]:
|
170
|
+
if constraint == _OpQuantConstraint.SAME_AS_INPUT_SCALE:
|
171
|
+
self._same_as_input_scale_ops.append(op_info.op_name)
|
172
|
+
# Return dummy values to avoid exceptions.
|
173
|
+
dummy_value = [qtyping.TensorTransformationParams("")] * 2
|
174
|
+
return dummy_value
|
175
|
+
|
176
|
+
# Dummy implementation of the `_are_weights_too_small` function to support
|
177
|
+
# `materialize_standard_op_wrapper` above.
|
178
|
+
def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
|
179
|
+
return False
|
180
|
+
|
181
|
+
# Dummy implementation of the `_materialize_bias_for_conv_ops` function to
|
182
|
+
# support `materialize_standard_op_wrapper` above.
|
183
|
+
def materialize_bias_for_conv_ops_wrapper(*_args, **_kwargs):
|
184
|
+
return
|
185
|
+
|
186
|
+
# Do monkey patch to intercept the `materialize_standard_op` function to
|
187
|
+
# support `materialize_standard_op_wrapper` above.
|
188
|
+
original_materialize_standard_op = common_utils.materialize_standard_op
|
189
|
+
original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
|
190
|
+
original_materialize_bias_for_conv_ops = (
|
191
|
+
common_quantize._materialize_bias_for_conv_ops # pylint: disable=protected-access
|
192
|
+
)
|
193
|
+
common_utils.materialize_standard_op = materialize_standard_op_wrapper
|
194
|
+
common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
|
195
|
+
common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
|
196
|
+
materialize_bias_for_conv_ops_wrapper
|
197
|
+
)
|
198
|
+
minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
|
199
|
+
|
200
|
+
# Loop over all available materialization functions to build up a list of
|
201
|
+
# `SAME_AS_INPUT_SCALE` constrained ops.
|
202
|
+
for op, materialize_fn in minmax_func_dict.items():
|
203
|
+
# Create a dummy op info to trigger the materialization.
|
204
|
+
mock_op = schema_fb.OperatorT()
|
205
|
+
mock_op.inputs = [0]
|
206
|
+
mock_op.outputs = [0]
|
207
|
+
op_info = qtyping.OpInfo(
|
208
|
+
op=mock_op,
|
209
|
+
op_name=op,
|
210
|
+
subgraph_op_index=0,
|
211
|
+
op_quant_config=qtyping.OpQuantizationConfig(),
|
212
|
+
)
|
213
|
+
materialize_fn(
|
214
|
+
get_tensor_quant_params_fn=None,
|
215
|
+
op_info=op_info,
|
216
|
+
graph_info=None,
|
217
|
+
tensor_name_to_qsv=None,
|
218
|
+
)
|
219
|
+
|
220
|
+
if verbose:
|
221
|
+
print(f" Constrained op list: {self._same_as_input_scale_ops}")
|
222
|
+
|
223
|
+
# Restore the original functions.
|
224
|
+
common_utils.materialize_standard_op = original_materialize_standard_op
|
225
|
+
common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
|
226
|
+
common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
|
227
|
+
original_materialize_bias_for_conv_ops
|
228
|
+
)
|
229
|
+
|
230
|
+
def _search_tensor_by_signature_name(
|
231
|
+
self, signature_key: str, signature_input_output_name: str, verbose=False
|
232
|
+
) -> list[str]:
|
233
|
+
"""Searches for a tensor name for a given signature by signature input or output name.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
signature_key: Name of the signature.
|
237
|
+
signature_input_output_name: Name of the input or output in the signature.
|
238
|
+
verbose: Flag to enable verbose output.
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
The list with one or two tensor names. The first one is the input tensor
|
242
|
+
name, and the second one is the output tensor name.
|
243
|
+
"""
|
244
|
+
|
245
|
+
if verbose:
|
246
|
+
print("Searching tensor by signature name.")
|
247
|
+
|
248
|
+
tensor_names = []
|
249
|
+
|
250
|
+
# Search among inputs.
|
251
|
+
input_details = self._signature_runners[signature_key].get_input_details()
|
252
|
+
if signature_input_output_name in input_details.keys():
|
253
|
+
tensor_names.append(input_details[signature_input_output_name]["name"])
|
254
|
+
|
255
|
+
# Search among outputs.
|
256
|
+
output_details = self._signature_runners[signature_key].get_output_details()
|
257
|
+
if signature_input_output_name not in output_details:
|
258
|
+
if not tensor_names:
|
259
|
+
raise ValueError(
|
260
|
+
f"Signature {signature_key} does not have input or output"
|
261
|
+
f" `{signature_input_output_name}`"
|
262
|
+
)
|
263
|
+
return tensor_names
|
264
|
+
|
265
|
+
output_tensor_name = output_details[signature_input_output_name]["name"]
|
266
|
+
if verbose:
|
267
|
+
print(
|
268
|
+
">> Starting recursive search for the output tensor name:"
|
269
|
+
f" {output_tensor_name}"
|
270
|
+
)
|
271
|
+
|
272
|
+
idx = self._signature_runners[signature_key]._subgraph_index # pylint: disable=protected-access
|
273
|
+
subgraph = self._flatbuffer_object.subgraphs[idx]
|
274
|
+
graph_info = qtyping.GraphInfo(
|
275
|
+
subgraph.tensors, self._flatbuffer_object.buffers
|
276
|
+
)
|
277
|
+
|
278
|
+
# Recursively search the graph for the output tensor name since it may be
|
279
|
+
# `SAME_AS_INPUT` constrainted.
|
280
|
+
operators = copy.deepcopy(subgraph.operators)
|
281
|
+
tensor_name = self._search_reverse_order_recursively(
|
282
|
+
graph_info, operators, output_tensor_name, indent=" ", verbose=verbose
|
283
|
+
)
|
284
|
+
tensor_names.append(tensor_name)
|
285
|
+
|
286
|
+
if verbose:
|
287
|
+
print(f"\n\nFound tensor name: {tensor_name}")
|
288
|
+
|
289
|
+
return tensor_names
|
290
|
+
|
291
|
+
def _search_reverse_order_recursively(
|
292
|
+
self,
|
293
|
+
graph_info: qtyping.GraphInfo,
|
294
|
+
operators: list[Any],
|
295
|
+
output_tensor_name: str,
|
296
|
+
indent: str,
|
297
|
+
verbose: bool = False,
|
298
|
+
):
|
299
|
+
"""Searches for a tensor name in reverse order recursively.
|
300
|
+
|
301
|
+
Stop criteria is when the tensor belongs to an operator that is not
|
302
|
+
`SAME_AS_INPUT` constrainted.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
graph_info: Graph information.
|
306
|
+
operators: List of operators.
|
307
|
+
output_tensor_name: Name of the output tensor to search for.
|
308
|
+
indent: Indentation string for debug output.
|
309
|
+
verbose: Flag to enable verbose output.
|
310
|
+
|
311
|
+
Returns:
|
312
|
+
The name of the tensor found, or None if not found.
|
313
|
+
"""
|
314
|
+
op_codes = self._flatbuffer_object.operatorCodes
|
315
|
+
|
316
|
+
while operators:
|
317
|
+
op = operators.pop()
|
318
|
+
op_code = op_codes[op.opcodeIndex].builtinCode
|
319
|
+
op_name = flatbuffer_utils.opcode_to_name(
|
320
|
+
self._flatbuffer_object, op.opcodeIndex
|
321
|
+
)
|
322
|
+
if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
|
323
|
+
continue
|
324
|
+
for output_idx in op.outputs:
|
325
|
+
if output_tensor_name == tfl_flatbuffer_utils.get_tensor_name(
|
326
|
+
graph_info.subgraph_tensors[output_idx]
|
327
|
+
):
|
328
|
+
dbg_str = (
|
329
|
+
f"{indent}>> Found `{op_name}`, output tensor"
|
330
|
+
f" '{output_tensor_name}'"
|
331
|
+
)
|
332
|
+
|
333
|
+
if op_name not in self._same_as_input_scale_ops:
|
334
|
+
if verbose:
|
335
|
+
print(f"{dbg_str}, returning...")
|
336
|
+
return output_tensor_name
|
337
|
+
|
338
|
+
if verbose:
|
339
|
+
print(f"{dbg_str}, with SAME_AS_INPUT, search recursively among:")
|
340
|
+
for input_idx in op.inputs:
|
341
|
+
input_tensor_name = graph_info.subgraph_tensors[
|
342
|
+
input_idx
|
343
|
+
].name.decode("utf-8")
|
344
|
+
|
345
|
+
if verbose:
|
346
|
+
print(f"{indent} Input: {input_tensor_name}")
|
347
|
+
|
348
|
+
return self._search_reverse_order_recursively(
|
349
|
+
graph_info,
|
350
|
+
operators,
|
351
|
+
input_tensor_name,
|
352
|
+
indent=f"{indent} ",
|
353
|
+
verbose=verbose,
|
354
|
+
)
|
355
|
+
return output_tensor_name
|
356
|
+
|
357
|
+
def align_quant_stats(
|
358
|
+
self, qsv: dict[str, Any], signature_data: _SignatureData
|
359
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
360
|
+
"""Aligns quantization statistics for a given signature data.
|
361
|
+
|
362
|
+
This function takes quantization statistics and signature data as input,
|
363
|
+
identifies the tensors associated with the signature data, and aligns
|
364
|
+
the quantization statistics of these tensors by setting their minimum
|
365
|
+
and maximum values to the same value. This ensures that the tensors
|
366
|
+
have the same quantization parameters.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
qsv: Quantization statistics.
|
370
|
+
signature_data: Signature data.
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
Tuple of min and max values.
|
374
|
+
"""
|
375
|
+
# Go over all signature info and find the corresponding tensor names.
|
376
|
+
tensor_names = []
|
377
|
+
for signature_key, signature_names in signature_data.items():
|
378
|
+
for signature_name in signature_names:
|
379
|
+
tensor_name = self._search_tensor_by_signature_name(
|
380
|
+
signature_key, signature_name
|
381
|
+
)
|
382
|
+
tensor_names.extend(tensor_name)
|
383
|
+
|
384
|
+
# Find min and max values accross all tensors.
|
385
|
+
min_value, max_value = _find_overall_min_max(qsv, tensor_names)
|
386
|
+
|
387
|
+
# Overwrite the min and max values in the QSV.
|
388
|
+
for tensor_name in tensor_names:
|
389
|
+
qsv[tensor_name]["min"] = min_value
|
390
|
+
qsv[tensor_name]["max"] = max_value
|
391
|
+
|
392
|
+
return min_value, max_value
|
393
|
+
|
394
|
+
def update_quant_stats(
|
395
|
+
self,
|
396
|
+
qsv: dict[str, Any],
|
397
|
+
signature_data: _SignatureData,
|
398
|
+
min_value: np.ndarray,
|
399
|
+
max_value: np.ndarray,
|
400
|
+
):
|
401
|
+
"""Updates quantization statistics for a given signature data.
|
402
|
+
|
403
|
+
This function updates the quantization statistics with the provided min, max
|
404
|
+
values for the tensors specified in the signature data.
|
405
|
+
|
406
|
+
Args:
|
407
|
+
qsv: Quantization statistics.
|
408
|
+
signature_data: Signature data.
|
409
|
+
min_value: Minimum value to update.
|
410
|
+
max_value: Maximum value to update.
|
411
|
+
|
412
|
+
Returns:
|
413
|
+
Updated quantization statistics.
|
414
|
+
"""
|
415
|
+
# Go over all signature info and find the corresponding tensor names.
|
416
|
+
tensor_names = []
|
417
|
+
for signature_key, signature_names in signature_data.items():
|
418
|
+
for signature_name in signature_names:
|
419
|
+
tensor_name = self._search_tensor_by_signature_name(
|
420
|
+
signature_key, signature_name
|
421
|
+
)
|
422
|
+
tensor_names.extend(tensor_name)
|
423
|
+
|
424
|
+
# Overwrite the min and max values in the QSV.
|
425
|
+
for tensor_name in tensor_names:
|
426
|
+
qsv[tensor_name]["min"] = min_value
|
427
|
+
qsv[tensor_name]["max"] = max_value
|
@@ -14,11 +14,68 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from absl.testing import parameterized
|
17
|
+
import numpy as np
|
18
|
+
import tensorflow as tf
|
19
|
+
|
17
20
|
from tensorflow.python.platform import googletest
|
21
|
+
from ai_edge_quantizer import quantizer
|
18
22
|
from ai_edge_quantizer.utils import calibration_utils
|
23
|
+
from ai_edge_quantizer.utils import test_utils
|
24
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
25
|
+
|
26
|
+
_RNG = np.random.default_rng(66)
|
27
|
+
|
28
|
+
_CALIBRATION_DATASET = {
|
29
|
+
"signature_1": [{
|
30
|
+
"cache_0": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
|
31
|
+
"cache_1": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
|
32
|
+
"positions": np.zeros(shape=(1, 100), dtype=np.int32),
|
33
|
+
"tokens": np.zeros(shape=(1, 100), dtype=np.int32),
|
34
|
+
}],
|
35
|
+
"signature_2": [{
|
36
|
+
"cache_0": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
37
|
+
"cache_1": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
38
|
+
"positions": (
|
39
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
40
|
+
),
|
41
|
+
"tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32),
|
42
|
+
}],
|
43
|
+
}
|
44
|
+
|
19
45
|
|
46
|
+
def _get_quant_parameters(
|
47
|
+
quantized_model: bytes, signature_data: dict[str, list[str]]
|
48
|
+
) -> list[np.ndarray]:
|
49
|
+
"""Returns the quantization parameters from the quantized model."""
|
50
|
+
quant_params = []
|
51
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
52
|
+
quantized_model
|
53
|
+
)
|
54
|
+
for signature_key, signature_names in signature_data.items():
|
55
|
+
signature_runner = tfl_interpreter.get_signature_runner(signature_key)
|
56
|
+
|
57
|
+
for signature_name in signature_names:
|
58
|
+
input_details = signature_runner.get_input_details()
|
59
|
+
output_details = signature_runner.get_output_details()
|
60
|
+
if signature_name in input_details.keys():
|
61
|
+
quant_param = input_details[signature_name]["quantization_parameters"][
|
62
|
+
"scales"
|
63
|
+
].squeeze()
|
64
|
+
quant_params.append(quant_param)
|
65
|
+
elif signature_name in output_details.keys():
|
66
|
+
output_details = signature_runner.get_output_details()
|
67
|
+
quant_param = output_details[signature_name]["quantization_parameters"][
|
68
|
+
"scales"
|
69
|
+
].squeeze()
|
70
|
+
quant_params.append(quant_param)
|
71
|
+
else:
|
72
|
+
raise ValueError(
|
73
|
+
f"Signature name {signature_name} not found in the model."
|
74
|
+
)
|
75
|
+
return quant_params
|
20
76
|
|
21
|
-
|
77
|
+
|
78
|
+
class CalibrationQsvAlignmentUtilsTest(parameterized.TestCase):
|
22
79
|
|
23
80
|
@parameterized.named_parameters(
|
24
81
|
dict(
|
@@ -66,12 +123,126 @@ class CalibrationUtilsTest(parameterized.TestCase):
|
|
66
123
|
def test_update_tensor_qsv_min_max(self, old_qsv, new_qsv, expected_qsv):
|
67
124
|
updated_qsv = calibration_utils.min_max_update(old_qsv, new_qsv)
|
68
125
|
if isinstance(expected_qsv["min"], list):
|
69
|
-
self.
|
70
|
-
self.
|
126
|
+
self.assertEqual(list(updated_qsv["min"]), expected_qsv["min"])
|
127
|
+
self.assertEqual(list(updated_qsv["max"]), expected_qsv["max"])
|
71
128
|
else:
|
72
129
|
self.assertEqual(updated_qsv["min"], expected_qsv["min"])
|
73
130
|
self.assertEqual(updated_qsv["max"], expected_qsv["max"])
|
74
131
|
|
132
|
+
def test_calibration_utils_init_fails(self):
|
133
|
+
model_path = "non_existent_model.tflite"
|
134
|
+
with self.assertRaisesWithPredicateMatch(
|
135
|
+
tf.errors.NotFoundError, lambda err: f"{model_path}" in str(err)
|
136
|
+
):
|
137
|
+
calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
138
|
+
|
139
|
+
def test_calibration_utils_init_succeeds(self):
|
140
|
+
model_path = test_utils.get_path_to_datafile(
|
141
|
+
"../tests/models/single_add.tflite"
|
142
|
+
)
|
143
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
144
|
+
self.assertNotEmpty(calib_utils._signature_runners)
|
145
|
+
self.assertNotEmpty(calib_utils._same_as_input_scale_ops)
|
146
|
+
|
147
|
+
def test_search_tensor_by_signature_name_succeeds_on_unconstrained_op(self):
|
148
|
+
model_path = test_utils.get_path_to_datafile(
|
149
|
+
"../tests/models/single_add.tflite"
|
150
|
+
)
|
151
|
+
expected_tensor_name = "PartitionedCall:0"
|
152
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
153
|
+
tensor_name = calib_utils._search_tensor_by_signature_name(
|
154
|
+
"serving_default", "add"
|
155
|
+
)
|
156
|
+
self.assertEqual(tensor_name, [expected_tensor_name])
|
157
|
+
|
158
|
+
def test_search_tensor_by_signature_name_succeeds_on_constrained_op(self):
|
159
|
+
model_path = test_utils.get_path_to_datafile(
|
160
|
+
"../tests/models/single_slice.tflite"
|
161
|
+
)
|
162
|
+
expected_tensor_name = "slice_input_tensor:0"
|
163
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
164
|
+
tensor_name = calib_utils._search_tensor_by_signature_name(
|
165
|
+
"slice", "output_0"
|
166
|
+
)
|
167
|
+
self.assertEqual(tensor_name, [expected_tensor_name])
|
168
|
+
|
169
|
+
def test_align_quant_stats_succeeds(self):
|
170
|
+
model_path = test_utils.get_path_to_datafile(
|
171
|
+
"../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
|
172
|
+
)
|
173
|
+
recipe_path = test_utils.get_path_to_datafile(
|
174
|
+
"../recipes/default_a8w8_recipe.json"
|
175
|
+
)
|
176
|
+
signature_data = {
|
177
|
+
"signature_1": ["output_1_1"],
|
178
|
+
"signature_2": ["output_1_1"],
|
179
|
+
}
|
180
|
+
|
181
|
+
# Obtain the calibration results.
|
182
|
+
qt = quantizer.Quantizer(model_path, recipe_path)
|
183
|
+
qsv = qt.calibrate(_CALIBRATION_DATASET)
|
184
|
+
|
185
|
+
# First quantize the model without aligning the quantization parameters.
|
186
|
+
quantized_model = qt.quantize(qsv).quantized_model
|
187
|
+
quant_params = _get_quant_parameters(quantized_model, signature_data)
|
188
|
+
self.assertFalse(
|
189
|
+
all(x == quant_params[0] for x in quant_params)
|
190
|
+
) # not equal quantization params.
|
191
|
+
|
192
|
+
# Align the quantization parameters and quantize again.
|
193
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
194
|
+
calib_utils.align_quant_stats(qsv, signature_data)
|
195
|
+
quantized_model = qt.quantize(qsv).quantized_model
|
196
|
+
quant_params = _get_quant_parameters(quantized_model, signature_data)
|
197
|
+
self.assertTrue(
|
198
|
+
all(x == quant_params[0] for x in quant_params)
|
199
|
+
) # equal quantization params.
|
200
|
+
|
201
|
+
def test_update_quant_stats_succeeds(self):
|
202
|
+
model_path = test_utils.get_path_to_datafile(
|
203
|
+
"../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
|
204
|
+
)
|
205
|
+
recipe_path = test_utils.get_path_to_datafile(
|
206
|
+
"../recipes/default_a8w8_recipe.json"
|
207
|
+
)
|
208
|
+
signature_data = {
|
209
|
+
"signature_1": ["output_1_1"],
|
210
|
+
"signature_2": ["output_1_1"],
|
211
|
+
}
|
212
|
+
|
213
|
+
# Obtain the calibration results.
|
214
|
+
qt = quantizer.Quantizer(model_path, recipe_path)
|
215
|
+
qsv = qt.calibrate(_CALIBRATION_DATASET)
|
216
|
+
|
217
|
+
# First quantize the model without updating the `signature_1`.
|
218
|
+
quantized_model = qt.quantize(qsv).quantized_model
|
219
|
+
quant_params = _get_quant_parameters(quantized_model, signature_data)
|
220
|
+
self.assertFalse(
|
221
|
+
all(x == quant_params[0] for x in quant_params)
|
222
|
+
) # not equal quantization params.
|
223
|
+
|
224
|
+
# Update the `signature_1` with stats from `signature_2`.
|
225
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
226
|
+
min_val, max_val = calib_utils.align_quant_stats( # for min and max only.
|
227
|
+
qsv,
|
228
|
+
{
|
229
|
+
"signature_2": ["output_1_1"],
|
230
|
+
},
|
231
|
+
)
|
232
|
+
calib_utils.update_quant_stats(
|
233
|
+
qsv,
|
234
|
+
{
|
235
|
+
"signature_1": ["output_1_1"],
|
236
|
+
},
|
237
|
+
min_val,
|
238
|
+
max_val,
|
239
|
+
)
|
240
|
+
quantized_model = qt.quantize(qsv).quantized_model
|
241
|
+
quant_params = _get_quant_parameters(quantized_model, signature_data)
|
242
|
+
self.assertTrue(
|
243
|
+
all(x == quant_params[0] for x in quant_params)
|
244
|
+
) # equal quantization params.
|
245
|
+
|
75
246
|
|
76
247
|
if __name__ == "__main__":
|
77
248
|
googletest.main()
|
@@ -59,6 +59,7 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
|
|
59
59
|
_TFLOpName.PAD: schema.BuiltinOperator.PAD,
|
60
60
|
_TFLOpName.SQUARED_DIFFERENCE: schema.BuiltinOperator.SQUARED_DIFFERENCE,
|
61
61
|
_TFLOpName.MAX_POOL_2D: schema.BuiltinOperator.MAX_POOL_2D,
|
62
|
+
_TFLOpName.RESIZE_BILINEAR: schema.BuiltinOperator.RESIZE_BILINEAR,
|
62
63
|
})
|
63
64
|
|
64
65
|
TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-quantizer-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250615
|
4
4
|
Summary: A quantizer for advanced developers to quantize converted AI Edge models.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
|
@@ -1,23 +1,23 @@
|
|
1
1
|
ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas,793
|
2
|
-
ai_edge_quantizer/algorithm_manager.py,sha256=
|
2
|
+
ai_edge_quantizer/algorithm_manager.py,sha256=rMTM89YDPkmLKlUQV_Rjr7B2KpcvldAHzfpgUqaOqdU,12216
|
3
3
|
ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
|
4
4
|
ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
|
5
5
|
ai_edge_quantizer/calibrator.py,sha256=Sms7_AIHPH9G5xFaz5Ef3a5gPhxuIWQI8d2LUM8C96I,12071
|
6
6
|
ai_edge_quantizer/calibrator_test.py,sha256=C_oWOaRugPKYX74jF-eRFH-k6nGOdA8I9_uPiocaOuE,11900
|
7
7
|
ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
|
8
|
-
ai_edge_quantizer/default_policy.py,sha256=
|
8
|
+
ai_edge_quantizer/default_policy.py,sha256=zghBh9dTB-ouPFumV-0siBSnEbp0WxF6tGOsn3TLirg,11242
|
9
9
|
ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
|
10
10
|
ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
|
11
11
|
ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
|
12
12
|
ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
|
13
13
|
ai_edge_quantizer/params_generator.py,sha256=gC7G6Ne4Fumc8RSmIAbx96ZBhszZlHqBKSmE9p6RPTo,20099
|
14
14
|
ai_edge_quantizer/params_generator_test.py,sha256=RDYoRZDJfEZRtjlTAU2kZ_4t3JHOqEHxfJX9V4ETAhg,40597
|
15
|
-
ai_edge_quantizer/qtyping.py,sha256=
|
15
|
+
ai_edge_quantizer/qtyping.py,sha256=kX1AoD-YlHYbDI1RfGVXIbPn-CYT7HUF2x77-hPtKBM,16565
|
16
16
|
ai_edge_quantizer/quantizer.py,sha256=g3DMqFMrMpt9jQttCE0WcdNbMtk0JZnmN5MmCHrNdyM,13202
|
17
17
|
ai_edge_quantizer/quantizer_test.py,sha256=K_HBA56JkFI3HL8VLWCqGEfC0ISh5ldMKoNyBdGRAJg,20368
|
18
18
|
ai_edge_quantizer/recipe.py,sha256=FR0uJceumZrnle2VRSOQZ1uXup4S1cTYKRH-N53mWRo,2919
|
19
19
|
ai_edge_quantizer/recipe_manager.py,sha256=qcGUD7e7BISKdsY9WH2rdaRR3acmzSA5qMezGNbzlpo,8931
|
20
|
-
ai_edge_quantizer/recipe_manager_test.py,sha256=
|
20
|
+
ai_edge_quantizer/recipe_manager_test.py,sha256=GVOfGFZPRciUb4EF4GkSi6d96LdjS6PbUkAJ0ayy0k8,32243
|
21
21
|
ai_edge_quantizer/recipe_test.py,sha256=Fg_sfxovI2fRjk5qdu18ghOvXdUvhDR1TxbE0GHDczc,3381
|
22
22
|
ai_edge_quantizer/transformation_instruction_generator.py,sha256=B_TQQe9_Qs7UKXLjMMuz5lORUvXyZOxBS2SpntTnkI8,28077
|
23
23
|
ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=E0QSDCav6N6izlJ-a1ZJOsb2VEUxuxBmTbt0-EgDdxY,49890
|
@@ -28,7 +28,7 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCP
|
|
28
28
|
ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
|
29
29
|
ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
|
30
30
|
ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
31
|
-
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=
|
31
|
+
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=rImKK2ax7LrRx6XurSdvRTk0h6WtFGtQn9sYNJcn-uw,30222
|
32
32
|
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
|
33
33
|
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=BDdn_uBZakfHyzdMJPKadsOqxqyC-s6W2ZzFH99L4fE,8652
|
34
34
|
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
|
@@ -61,17 +61,17 @@ ai_edge_quantizer/transformations/quantize_tensor_test.py,sha256=mHLO3_MRt36A8-Z
|
|
61
61
|
ai_edge_quantizer/transformations/transformation_utils.py,sha256=GwIaKVsePZYgVG2lSanOswcaZYMjvgyqstDVwXl9DGY,6923
|
62
62
|
ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rvxRQIfi4ny9IoODFCTcbpjnIwoCL40zDKk,8698
|
63
63
|
ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
64
|
-
ai_edge_quantizer/utils/calibration_utils.py,sha256=
|
65
|
-
ai_edge_quantizer/utils/calibration_utils_test.py,sha256=
|
64
|
+
ai_edge_quantizer/utils/calibration_utils.py,sha256=e3dG7Nm94Ix0hkTWTWPUhEG6a8QR_cAM3PSwblfJV5g,15106
|
65
|
+
ai_edge_quantizer/utils/calibration_utils_test.py,sha256=4BlksXl7b4yptL8xPR67hmJCnjhN9V10a2PunzfHrUE,9372
|
66
66
|
ai_edge_quantizer/utils/test_utils.py,sha256=Y2pdMvn1k4gmqDo3noJfzx3fJcDHX_1hcsP6oiIz65Y,8240
|
67
|
-
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=
|
67
|
+
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=pZv8FMWyjBSLN5MGJ2K_dZ6oqkJGbp9RI4CfnlPuPII,10830
|
68
68
|
ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
|
69
69
|
ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EtOv6cpKM_F0uv2bWuSXylYmTeXT6zUc182pw4sdYSI,13889
|
70
70
|
ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
|
71
71
|
ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
|
72
72
|
ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
|
73
|
-
ai_edge_quantizer_nightly-0.3.0.
|
74
|
-
ai_edge_quantizer_nightly-0.3.0.
|
75
|
-
ai_edge_quantizer_nightly-0.3.0.
|
76
|
-
ai_edge_quantizer_nightly-0.3.0.
|
77
|
-
ai_edge_quantizer_nightly-0.3.0.
|
73
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
74
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info/METADATA,sha256=IklxnJKNI7_fJW9CmL-QfF9EWmzzn8DRoGjwtpDZ8Wg,1528
|
75
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
76
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
|
77
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250615.dist-info/RECORD,,
|
File without changes
|
File without changes
|