ai-edge-quantizer-nightly 0.3.0.dev20250614__py3-none-any.whl → 0.3.0.dev20250616__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,19 +29,6 @@ _AlgorithmName = recipe_manager.AlgorithmName
29
29
  _QuantGranularity = qtyping.QuantGranularity
30
30
 
31
31
 
32
- # Sample functions for test cases.
33
- def _sample_init_qsvs(*_, **__):
34
- return 1.0, dict()
35
-
36
-
37
- def _sample_calibration_func(*_, **__):
38
- return 2.0, dict()
39
-
40
-
41
- def _sample_materialize_func(*_, **__):
42
- return 3.0, dict()
43
-
44
-
45
32
  def _sample_check_op_config_func(op_name, op_config, _):
46
33
  if (
47
34
  op_config.weight_tensor_config is not None
@@ -67,6 +54,16 @@ def _add_default_int8xint8_integer_recipe(recipe_manager_object):
67
54
 
68
55
  # register some currently unsupported ops for testing purposes
69
56
  def _register_testing_op(algorithm_key, tfl_op):
57
+ # Sample functions for test cases.
58
+ def _sample_init_qsvs(*_, **__):
59
+ return {'name': dict()}
60
+
61
+ def _sample_calibration_func(*_, **__):
62
+ return {'name2': dict()}
63
+
64
+ def _sample_materialize_func(*_, **__):
65
+ return []
66
+
70
67
  algorithm_manager.register_op_quant_config_validation_func(
71
68
  algorithm_key, _sample_check_op_config_func
72
69
  )
@@ -15,9 +15,26 @@
15
15
 
16
16
  """Utilities for model calibration."""
17
17
 
18
- from typing import Union
18
+ import copy
19
+ from typing import Any, Union
20
+
19
21
  import numpy as np
22
+
23
+ from ai_edge_quantizer import algorithm_manager
20
24
  from ai_edge_quantizer import qtyping
25
+ from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
26
+ from ai_edge_quantizer.algorithms.utils import common_utils
27
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
29
+ from ai_edge_litert import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
30
+ from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
31
+
32
+
33
+ _SignatureInput = dict[str, Any]
34
+ _OpQuantConstraint = common_utils.OpQuantConstraint
35
+ _SignatureData = dict[
36
+ str, list[str]
37
+ ] # signature_key -> list of signature_names.
21
38
 
22
39
 
23
40
  def _update_moving_average(
@@ -84,3 +101,327 @@ def min_max_update(qsv: qtyping.QSV, new_qsv: qtyping.QSV) -> qtyping.QSV:
84
101
  updated_qsv["min"] = np.minimum(qsv["min"], new_qsv["min"])
85
102
  updated_qsv["max"] = np.maximum(qsv["max"], new_qsv["max"])
86
103
  return updated_qsv
104
+
105
+
106
+ def _find_overall_min_max(
107
+ qsv: qtyping.QSV, tensor_names: list[str]
108
+ ) -> tuple[np.ndarray, np.ndarray]:
109
+ """Finds the overall minimum and maximum values for the given tensors.
110
+
111
+ Args:
112
+ qsv: The quantization statistical value of the tensor (min/max).
113
+ tensor_names: The list of tensor names to find the minimum and maximum
114
+ values.
115
+
116
+ Returns:
117
+ The minimum and maximum values for the given tensors.
118
+ """
119
+ min_value = np.inf
120
+ max_value = -np.inf
121
+ for tensor_name in tensor_names:
122
+ min_value = min(min_value, qsv[tensor_name]["min"])
123
+ max_value = max(max_value, qsv[tensor_name]["max"])
124
+ return min_value, max_value
125
+
126
+
127
+ class CalibrationQsvAlignmentUtils:
128
+ """Calibration utils for alignment of QSVs.
129
+
130
+ This class is used to align QSVs for a given model. It builds a list of ops
131
+ that need to be constrained to the same scale as the input. Based on this
132
+ list, it finds the corresponding tensor names for a given signature data.
133
+ """
134
+
135
+ def __init__(self, model_path: str):
136
+ self._same_as_input_scale_ops = []
137
+
138
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_path)
139
+ self._flatbuffer_object = tfl_flatbuffer_utils.read_model(model_path)
140
+
141
+ signature_keys = list(tfl_interpreter.get_signature_list().keys())
142
+
143
+ # Build a dict of signature runners.
144
+ self._signature_runners = {}
145
+ for signature_key in signature_keys:
146
+ signature_runner = tfl_interpreter.get_signature_runner(signature_key)
147
+ self._signature_runners[signature_key] = signature_runner
148
+
149
+ # Make a list of `SAME_AS_INPUT_SCALE` operators. This is used to identify
150
+ # the operators that need to be constrained to the same scale as the input.
151
+ self._build_same_as_input_scale_op_list()
152
+
153
+ def _build_same_as_input_scale_op_list(self, verbose: bool = False):
154
+ """Constructs a list of SAME_AS_INPUT_SCALE operators.
155
+
156
+ This is achieved by invoking all materialization functions and extracting
157
+ the constraint argument, using monkey patching to redirect logic to wrapper
158
+ functions.
159
+
160
+ Args:
161
+ verbose: Flag to enable verbose output.
162
+ """
163
+
164
+ def materialize_standard_op_wrapper(
165
+ op_info: qtyping.OpInfo,
166
+ *_args,
167
+ constraint: _OpQuantConstraint = _OpQuantConstraint.NO_CONSTRAIN,
168
+ **_kwargs,
169
+ ) -> list[qtyping.TensorTransformationParams]:
170
+ if constraint == _OpQuantConstraint.SAME_AS_INPUT_SCALE:
171
+ self._same_as_input_scale_ops.append(op_info.op_name)
172
+ # Return dummy values to avoid exceptions.
173
+ dummy_value = [qtyping.TensorTransformationParams("")] * 2
174
+ return dummy_value
175
+
176
+ # Dummy implementation of the `_are_weights_too_small` function to support
177
+ # `materialize_standard_op_wrapper` above.
178
+ def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
179
+ return False
180
+
181
+ # Dummy implementation of the `_materialize_bias_for_conv_ops` function to
182
+ # support `materialize_standard_op_wrapper` above.
183
+ def materialize_bias_for_conv_ops_wrapper(*_args, **_kwargs):
184
+ return
185
+
186
+ # Do monkey patch to intercept the `materialize_standard_op` function to
187
+ # support `materialize_standard_op_wrapper` above.
188
+ original_materialize_standard_op = common_utils.materialize_standard_op
189
+ original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
190
+ original_materialize_bias_for_conv_ops = (
191
+ common_quantize._materialize_bias_for_conv_ops # pylint: disable=protected-access
192
+ )
193
+ common_utils.materialize_standard_op = materialize_standard_op_wrapper
194
+ common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
195
+ common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
196
+ materialize_bias_for_conv_ops_wrapper
197
+ )
198
+ minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
199
+
200
+ # Loop over all available materialization functions to build up a list of
201
+ # `SAME_AS_INPUT_SCALE` constrained ops.
202
+ for op, materialize_fn in minmax_func_dict.items():
203
+ # Create a dummy op info to trigger the materialization.
204
+ mock_op = schema_fb.OperatorT()
205
+ mock_op.inputs = [0]
206
+ mock_op.outputs = [0]
207
+ op_info = qtyping.OpInfo(
208
+ op=mock_op,
209
+ op_name=op,
210
+ subgraph_op_index=0,
211
+ op_quant_config=qtyping.OpQuantizationConfig(),
212
+ )
213
+ materialize_fn(
214
+ get_tensor_quant_params_fn=None,
215
+ op_info=op_info,
216
+ graph_info=None,
217
+ tensor_name_to_qsv=None,
218
+ )
219
+
220
+ if verbose:
221
+ print(f" Constrained op list: {self._same_as_input_scale_ops}")
222
+
223
+ # Restore the original functions.
224
+ common_utils.materialize_standard_op = original_materialize_standard_op
225
+ common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
226
+ common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
227
+ original_materialize_bias_for_conv_ops
228
+ )
229
+
230
+ def _search_tensor_by_signature_name(
231
+ self, signature_key: str, signature_input_output_name: str, verbose=False
232
+ ) -> list[str]:
233
+ """Searches for a tensor name for a given signature by signature input or output name.
234
+
235
+ Args:
236
+ signature_key: Name of the signature.
237
+ signature_input_output_name: Name of the input or output in the signature.
238
+ verbose: Flag to enable verbose output.
239
+
240
+ Returns:
241
+ The list with one or two tensor names. The first one is the input tensor
242
+ name, and the second one is the output tensor name.
243
+ """
244
+
245
+ if verbose:
246
+ print("Searching tensor by signature name.")
247
+
248
+ tensor_names = []
249
+
250
+ # Search among inputs.
251
+ input_details = self._signature_runners[signature_key].get_input_details()
252
+ if signature_input_output_name in input_details.keys():
253
+ tensor_names.append(input_details[signature_input_output_name]["name"])
254
+
255
+ # Search among outputs.
256
+ output_details = self._signature_runners[signature_key].get_output_details()
257
+ if signature_input_output_name not in output_details:
258
+ if not tensor_names:
259
+ raise ValueError(
260
+ f"Signature {signature_key} does not have input or output"
261
+ f" `{signature_input_output_name}`"
262
+ )
263
+ return tensor_names
264
+
265
+ output_tensor_name = output_details[signature_input_output_name]["name"]
266
+ if verbose:
267
+ print(
268
+ ">> Starting recursive search for the output tensor name:"
269
+ f" {output_tensor_name}"
270
+ )
271
+
272
+ idx = self._signature_runners[signature_key]._subgraph_index # pylint: disable=protected-access
273
+ subgraph = self._flatbuffer_object.subgraphs[idx]
274
+ graph_info = qtyping.GraphInfo(
275
+ subgraph.tensors, self._flatbuffer_object.buffers
276
+ )
277
+
278
+ # Recursively search the graph for the output tensor name since it may be
279
+ # `SAME_AS_INPUT` constrainted.
280
+ operators = copy.deepcopy(subgraph.operators)
281
+ tensor_name = self._search_reverse_order_recursively(
282
+ graph_info, operators, output_tensor_name, indent=" ", verbose=verbose
283
+ )
284
+ tensor_names.append(tensor_name)
285
+
286
+ if verbose:
287
+ print(f"\n\nFound tensor name: {tensor_name}")
288
+
289
+ return tensor_names
290
+
291
+ def _search_reverse_order_recursively(
292
+ self,
293
+ graph_info: qtyping.GraphInfo,
294
+ operators: list[Any],
295
+ output_tensor_name: str,
296
+ indent: str,
297
+ verbose: bool = False,
298
+ ):
299
+ """Searches for a tensor name in reverse order recursively.
300
+
301
+ Stop criteria is when the tensor belongs to an operator that is not
302
+ `SAME_AS_INPUT` constrainted.
303
+
304
+ Args:
305
+ graph_info: Graph information.
306
+ operators: List of operators.
307
+ output_tensor_name: Name of the output tensor to search for.
308
+ indent: Indentation string for debug output.
309
+ verbose: Flag to enable verbose output.
310
+
311
+ Returns:
312
+ The name of the tensor found, or None if not found.
313
+ """
314
+ op_codes = self._flatbuffer_object.operatorCodes
315
+
316
+ while operators:
317
+ op = operators.pop()
318
+ op_code = op_codes[op.opcodeIndex].builtinCode
319
+ op_name = flatbuffer_utils.opcode_to_name(
320
+ self._flatbuffer_object, op.opcodeIndex
321
+ )
322
+ if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
323
+ continue
324
+ for output_idx in op.outputs:
325
+ if output_tensor_name == tfl_flatbuffer_utils.get_tensor_name(
326
+ graph_info.subgraph_tensors[output_idx]
327
+ ):
328
+ dbg_str = (
329
+ f"{indent}>> Found `{op_name}`, output tensor"
330
+ f" '{output_tensor_name}'"
331
+ )
332
+
333
+ if op_name not in self._same_as_input_scale_ops:
334
+ if verbose:
335
+ print(f"{dbg_str}, returning...")
336
+ return output_tensor_name
337
+
338
+ if verbose:
339
+ print(f"{dbg_str}, with SAME_AS_INPUT, search recursively among:")
340
+ for input_idx in op.inputs:
341
+ input_tensor_name = graph_info.subgraph_tensors[
342
+ input_idx
343
+ ].name.decode("utf-8")
344
+
345
+ if verbose:
346
+ print(f"{indent} Input: {input_tensor_name}")
347
+
348
+ return self._search_reverse_order_recursively(
349
+ graph_info,
350
+ operators,
351
+ input_tensor_name,
352
+ indent=f"{indent} ",
353
+ verbose=verbose,
354
+ )
355
+ return output_tensor_name
356
+
357
+ def align_quant_stats(
358
+ self, qsv: dict[str, Any], signature_data: _SignatureData
359
+ ) -> tuple[np.ndarray, np.ndarray]:
360
+ """Aligns quantization statistics for a given signature data.
361
+
362
+ This function takes quantization statistics and signature data as input,
363
+ identifies the tensors associated with the signature data, and aligns
364
+ the quantization statistics of these tensors by setting their minimum
365
+ and maximum values to the same value. This ensures that the tensors
366
+ have the same quantization parameters.
367
+
368
+ Args:
369
+ qsv: Quantization statistics.
370
+ signature_data: Signature data.
371
+
372
+ Returns:
373
+ Tuple of min and max values.
374
+ """
375
+ # Go over all signature info and find the corresponding tensor names.
376
+ tensor_names = []
377
+ for signature_key, signature_names in signature_data.items():
378
+ for signature_name in signature_names:
379
+ tensor_name = self._search_tensor_by_signature_name(
380
+ signature_key, signature_name
381
+ )
382
+ tensor_names.extend(tensor_name)
383
+
384
+ # Find min and max values accross all tensors.
385
+ min_value, max_value = _find_overall_min_max(qsv, tensor_names)
386
+
387
+ # Overwrite the min and max values in the QSV.
388
+ for tensor_name in tensor_names:
389
+ qsv[tensor_name]["min"] = min_value
390
+ qsv[tensor_name]["max"] = max_value
391
+
392
+ return min_value, max_value
393
+
394
+ def update_quant_stats(
395
+ self,
396
+ qsv: dict[str, Any],
397
+ signature_data: _SignatureData,
398
+ min_value: np.ndarray,
399
+ max_value: np.ndarray,
400
+ ):
401
+ """Updates quantization statistics for a given signature data.
402
+
403
+ This function updates the quantization statistics with the provided min, max
404
+ values for the tensors specified in the signature data.
405
+
406
+ Args:
407
+ qsv: Quantization statistics.
408
+ signature_data: Signature data.
409
+ min_value: Minimum value to update.
410
+ max_value: Maximum value to update.
411
+
412
+ Returns:
413
+ Updated quantization statistics.
414
+ """
415
+ # Go over all signature info and find the corresponding tensor names.
416
+ tensor_names = []
417
+ for signature_key, signature_names in signature_data.items():
418
+ for signature_name in signature_names:
419
+ tensor_name = self._search_tensor_by_signature_name(
420
+ signature_key, signature_name
421
+ )
422
+ tensor_names.extend(tensor_name)
423
+
424
+ # Overwrite the min and max values in the QSV.
425
+ for tensor_name in tensor_names:
426
+ qsv[tensor_name]["min"] = min_value
427
+ qsv[tensor_name]["max"] = max_value
@@ -14,11 +14,68 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from absl.testing import parameterized
17
+ import numpy as np
18
+ import tensorflow as tf
19
+
17
20
  from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import quantizer
18
22
  from ai_edge_quantizer.utils import calibration_utils
23
+ from ai_edge_quantizer.utils import test_utils
24
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
25
+
26
+ _RNG = np.random.default_rng(66)
27
+
28
+ _CALIBRATION_DATASET = {
29
+ "signature_1": [{
30
+ "cache_0": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
31
+ "cache_1": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
32
+ "positions": np.zeros(shape=(1, 100), dtype=np.int32),
33
+ "tokens": np.zeros(shape=(1, 100), dtype=np.int32),
34
+ }],
35
+ "signature_2": [{
36
+ "cache_0": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
37
+ "cache_1": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
38
+ "positions": (
39
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
40
+ ),
41
+ "tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32),
42
+ }],
43
+ }
44
+
19
45
 
46
+ def _get_quant_parameters(
47
+ quantized_model: bytes, signature_data: dict[str, list[str]]
48
+ ) -> list[np.ndarray]:
49
+ """Returns the quantization parameters from the quantized model."""
50
+ quant_params = []
51
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
52
+ quantized_model
53
+ )
54
+ for signature_key, signature_names in signature_data.items():
55
+ signature_runner = tfl_interpreter.get_signature_runner(signature_key)
56
+
57
+ for signature_name in signature_names:
58
+ input_details = signature_runner.get_input_details()
59
+ output_details = signature_runner.get_output_details()
60
+ if signature_name in input_details.keys():
61
+ quant_param = input_details[signature_name]["quantization_parameters"][
62
+ "scales"
63
+ ].squeeze()
64
+ quant_params.append(quant_param)
65
+ elif signature_name in output_details.keys():
66
+ output_details = signature_runner.get_output_details()
67
+ quant_param = output_details[signature_name]["quantization_parameters"][
68
+ "scales"
69
+ ].squeeze()
70
+ quant_params.append(quant_param)
71
+ else:
72
+ raise ValueError(
73
+ f"Signature name {signature_name} not found in the model."
74
+ )
75
+ return quant_params
20
76
 
21
- class CalibrationUtilsTest(parameterized.TestCase):
77
+
78
+ class CalibrationQsvAlignmentUtilsTest(parameterized.TestCase):
22
79
 
23
80
  @parameterized.named_parameters(
24
81
  dict(
@@ -66,12 +123,126 @@ class CalibrationUtilsTest(parameterized.TestCase):
66
123
  def test_update_tensor_qsv_min_max(self, old_qsv, new_qsv, expected_qsv):
67
124
  updated_qsv = calibration_utils.min_max_update(old_qsv, new_qsv)
68
125
  if isinstance(expected_qsv["min"], list):
69
- self.assertListEqual(list(updated_qsv["min"]), expected_qsv["min"])
70
- self.assertListEqual(list(updated_qsv["max"]), expected_qsv["max"])
126
+ self.assertEqual(list(updated_qsv["min"]), expected_qsv["min"])
127
+ self.assertEqual(list(updated_qsv["max"]), expected_qsv["max"])
71
128
  else:
72
129
  self.assertEqual(updated_qsv["min"], expected_qsv["min"])
73
130
  self.assertEqual(updated_qsv["max"], expected_qsv["max"])
74
131
 
132
+ def test_calibration_utils_init_fails(self):
133
+ model_path = "non_existent_model.tflite"
134
+ with self.assertRaisesWithPredicateMatch(
135
+ tf.errors.NotFoundError, lambda err: f"{model_path}" in str(err)
136
+ ):
137
+ calibration_utils.CalibrationQsvAlignmentUtils(model_path)
138
+
139
+ def test_calibration_utils_init_succeeds(self):
140
+ model_path = test_utils.get_path_to_datafile(
141
+ "../tests/models/single_add.tflite"
142
+ )
143
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
144
+ self.assertNotEmpty(calib_utils._signature_runners)
145
+ self.assertNotEmpty(calib_utils._same_as_input_scale_ops)
146
+
147
+ def test_search_tensor_by_signature_name_succeeds_on_unconstrained_op(self):
148
+ model_path = test_utils.get_path_to_datafile(
149
+ "../tests/models/single_add.tflite"
150
+ )
151
+ expected_tensor_name = "PartitionedCall:0"
152
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
153
+ tensor_name = calib_utils._search_tensor_by_signature_name(
154
+ "serving_default", "add"
155
+ )
156
+ self.assertEqual(tensor_name, [expected_tensor_name])
157
+
158
+ def test_search_tensor_by_signature_name_succeeds_on_constrained_op(self):
159
+ model_path = test_utils.get_path_to_datafile(
160
+ "../tests/models/single_slice.tflite"
161
+ )
162
+ expected_tensor_name = "slice_input_tensor:0"
163
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
164
+ tensor_name = calib_utils._search_tensor_by_signature_name(
165
+ "slice", "output_0"
166
+ )
167
+ self.assertEqual(tensor_name, [expected_tensor_name])
168
+
169
+ def test_align_quant_stats_succeeds(self):
170
+ model_path = test_utils.get_path_to_datafile(
171
+ "../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
172
+ )
173
+ recipe_path = test_utils.get_path_to_datafile(
174
+ "../recipes/default_a8w8_recipe.json"
175
+ )
176
+ signature_data = {
177
+ "signature_1": ["output_1_1"],
178
+ "signature_2": ["output_1_1"],
179
+ }
180
+
181
+ # Obtain the calibration results.
182
+ qt = quantizer.Quantizer(model_path, recipe_path)
183
+ qsv = qt.calibrate(_CALIBRATION_DATASET)
184
+
185
+ # First quantize the model without aligning the quantization parameters.
186
+ quantized_model = qt.quantize(qsv).quantized_model
187
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
188
+ self.assertFalse(
189
+ all(x == quant_params[0] for x in quant_params)
190
+ ) # not equal quantization params.
191
+
192
+ # Align the quantization parameters and quantize again.
193
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
194
+ calib_utils.align_quant_stats(qsv, signature_data)
195
+ quantized_model = qt.quantize(qsv).quantized_model
196
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
197
+ self.assertTrue(
198
+ all(x == quant_params[0] for x in quant_params)
199
+ ) # equal quantization params.
200
+
201
+ def test_update_quant_stats_succeeds(self):
202
+ model_path = test_utils.get_path_to_datafile(
203
+ "../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
204
+ )
205
+ recipe_path = test_utils.get_path_to_datafile(
206
+ "../recipes/default_a8w8_recipe.json"
207
+ )
208
+ signature_data = {
209
+ "signature_1": ["output_1_1"],
210
+ "signature_2": ["output_1_1"],
211
+ }
212
+
213
+ # Obtain the calibration results.
214
+ qt = quantizer.Quantizer(model_path, recipe_path)
215
+ qsv = qt.calibrate(_CALIBRATION_DATASET)
216
+
217
+ # First quantize the model without updating the `signature_1`.
218
+ quantized_model = qt.quantize(qsv).quantized_model
219
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
220
+ self.assertFalse(
221
+ all(x == quant_params[0] for x in quant_params)
222
+ ) # not equal quantization params.
223
+
224
+ # Update the `signature_1` with stats from `signature_2`.
225
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
226
+ min_val, max_val = calib_utils.align_quant_stats( # for min and max only.
227
+ qsv,
228
+ {
229
+ "signature_2": ["output_1_1"],
230
+ },
231
+ )
232
+ calib_utils.update_quant_stats(
233
+ qsv,
234
+ {
235
+ "signature_1": ["output_1_1"],
236
+ },
237
+ min_val,
238
+ max_val,
239
+ )
240
+ quantized_model = qt.quantize(qsv).quantized_model
241
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
242
+ self.assertTrue(
243
+ all(x == quant_params[0] for x in quant_params)
244
+ ) # equal quantization params.
245
+
75
246
 
76
247
  if __name__ == "__main__":
77
248
  googletest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.3.0.dev20250614
3
+ Version: 0.3.0.dev20250616
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -17,7 +17,7 @@ ai_edge_quantizer/quantizer.py,sha256=g3DMqFMrMpt9jQttCE0WcdNbMtk0JZnmN5MmCHrNdy
17
17
  ai_edge_quantizer/quantizer_test.py,sha256=K_HBA56JkFI3HL8VLWCqGEfC0ISh5ldMKoNyBdGRAJg,20368
18
18
  ai_edge_quantizer/recipe.py,sha256=FR0uJceumZrnle2VRSOQZ1uXup4S1cTYKRH-N53mWRo,2919
19
19
  ai_edge_quantizer/recipe_manager.py,sha256=qcGUD7e7BISKdsY9WH2rdaRR3acmzSA5qMezGNbzlpo,8931
20
- ai_edge_quantizer/recipe_manager_test.py,sha256=LulVxsYp6TBGFI2PLCUCd4VsFq8ELpC7kMNkUjsLgbo,32230
20
+ ai_edge_quantizer/recipe_manager_test.py,sha256=GVOfGFZPRciUb4EF4GkSi6d96LdjS6PbUkAJ0ayy0k8,32243
21
21
  ai_edge_quantizer/recipe_test.py,sha256=Fg_sfxovI2fRjk5qdu18ghOvXdUvhDR1TxbE0GHDczc,3381
22
22
  ai_edge_quantizer/transformation_instruction_generator.py,sha256=B_TQQe9_Qs7UKXLjMMuz5lORUvXyZOxBS2SpntTnkI8,28077
23
23
  ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=E0QSDCav6N6izlJ-a1ZJOsb2VEUxuxBmTbt0-EgDdxY,49890
@@ -61,8 +61,8 @@ ai_edge_quantizer/transformations/quantize_tensor_test.py,sha256=mHLO3_MRt36A8-Z
61
61
  ai_edge_quantizer/transformations/transformation_utils.py,sha256=GwIaKVsePZYgVG2lSanOswcaZYMjvgyqstDVwXl9DGY,6923
62
62
  ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rvxRQIfi4ny9IoODFCTcbpjnIwoCL40zDKk,8698
63
63
  ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
64
- ai_edge_quantizer/utils/calibration_utils.py,sha256=1Fj9MIO6aLZIRgyd4axvZN4S_O64nB_-Miu1WP664js,2536
65
- ai_edge_quantizer/utils/calibration_utils_test.py,sha256=Z-AcdTieesWFKyKBb08ZXm4Mgu6cvJ4bg2-MJ7hLD10,2856
64
+ ai_edge_quantizer/utils/calibration_utils.py,sha256=e3dG7Nm94Ix0hkTWTWPUhEG6a8QR_cAM3PSwblfJV5g,15106
65
+ ai_edge_quantizer/utils/calibration_utils_test.py,sha256=4BlksXl7b4yptL8xPR67hmJCnjhN9V10a2PunzfHrUE,9372
66
66
  ai_edge_quantizer/utils/test_utils.py,sha256=Y2pdMvn1k4gmqDo3noJfzx3fJcDHX_1hcsP6oiIz65Y,8240
67
67
  ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=pZv8FMWyjBSLN5MGJ2K_dZ6oqkJGbp9RI4CfnlPuPII,10830
68
68
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
@@ -70,8 +70,8 @@ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EtOv6cpKM_F0uv2bWuSXylYm
70
70
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
71
71
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
72
72
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
73
- ai_edge_quantizer_nightly-0.3.0.dev20250614.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
- ai_edge_quantizer_nightly-0.3.0.dev20250614.dist-info/METADATA,sha256=5ZPSscczc1tLmVN4sCf-xtX2qvmabAWOAkIjZVCb_7U,1528
75
- ai_edge_quantizer_nightly-0.3.0.dev20250614.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
- ai_edge_quantizer_nightly-0.3.0.dev20250614.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
- ai_edge_quantizer_nightly-0.3.0.dev20250614.dist-info/RECORD,,
73
+ ai_edge_quantizer_nightly-0.3.0.dev20250616.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
+ ai_edge_quantizer_nightly-0.3.0.dev20250616.dist-info/METADATA,sha256=D9hP8s7AVoaH6P69HoNYu7IGINXb8uTfzcBnJamMmJw,1528
75
+ ai_edge_quantizer_nightly-0.3.0.dev20250616.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
+ ai_edge_quantizer_nightly-0.3.0.dev20250616.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
+ ai_edge_quantizer_nightly-0.3.0.dev20250616.dist-info/RECORD,,