ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,24 @@
15
15
 
16
16
  """Utilities for model calibration."""
17
17
 
18
- from typing import Union
18
+ import copy
19
+ from typing import Any, Union
20
+
19
21
  import numpy as np
22
+
23
+ from ai_edge_litert.tools import flatbuffer_utils
20
24
  from ai_edge_quantizer import qtyping
25
+ from ai_edge_quantizer.algorithms.utils import common_utils
26
+ from ai_edge_quantizer.utils import constrained_ops_utils
27
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
29
+
30
+
31
+ _SignatureInput = dict[str, Any]
32
+ _OpQuantConstraint = common_utils.OpQuantConstraint
33
+ _SignatureData = dict[
34
+ str, list[str]
35
+ ] # signature_key -> list of signature_names.
21
36
 
22
37
 
23
38
  def _update_moving_average(
@@ -84,3 +99,250 @@ def min_max_update(qsv: qtyping.QSV, new_qsv: qtyping.QSV) -> qtyping.QSV:
84
99
  updated_qsv["min"] = np.minimum(qsv["min"], new_qsv["min"])
85
100
  updated_qsv["max"] = np.maximum(qsv["max"], new_qsv["max"])
86
101
  return updated_qsv
102
+
103
+
104
+ def _find_overall_min_max(
105
+ qsv: qtyping.QSV, tensor_names: list[str]
106
+ ) -> tuple[np.ndarray, np.ndarray]:
107
+ """Finds the overall minimum and maximum values for the given tensors.
108
+
109
+ Args:
110
+ qsv: The quantization statistical value of the tensor (min/max).
111
+ tensor_names: The list of tensor names to find the minimum and maximum
112
+ values.
113
+
114
+ Returns:
115
+ The minimum and maximum values for the given tensors.
116
+ """
117
+ min_value = np.inf
118
+ max_value = -np.inf
119
+ for tensor_name in tensor_names:
120
+ min_value = min(min_value, qsv[tensor_name]["min"])
121
+ max_value = max(max_value, qsv[tensor_name]["max"])
122
+ return min_value, max_value
123
+
124
+
125
+ class CalibrationQsvAlignmentUtils:
126
+ """Calibration utils for alignment of QSVs.
127
+
128
+ This class is used to align QSVs for a given model. It builds a list of ops
129
+ that need to be constrained to the same scale as the input. Based on this
130
+ list, it finds the corresponding tensor names for a given signature data.
131
+ """
132
+
133
+ def __init__(self, model_path: str):
134
+ self._same_as_input_scale_ops = (
135
+ constrained_ops_utils.get_constrained_op_list(
136
+ _OpQuantConstraint.SAME_AS_INPUT_SCALE
137
+ )
138
+ )
139
+
140
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_path)
141
+ self._flatbuffer_object = tfl_flatbuffer_utils.read_model(model_path)
142
+
143
+ signature_keys = list(tfl_interpreter.get_signature_list().keys())
144
+
145
+ # Build a dict of signature runners.
146
+ self._signature_runners = {}
147
+ for signature_key in signature_keys:
148
+ signature_runner = tfl_interpreter.get_signature_runner(signature_key)
149
+ self._signature_runners[signature_key] = signature_runner
150
+
151
+ def _search_tensor_by_signature_name(
152
+ self, signature_key: str, signature_input_output_name: str, verbose=False
153
+ ) -> list[str]:
154
+ """Searches for a tensor name for a given signature by signature input or output name.
155
+
156
+ Args:
157
+ signature_key: Name of the signature.
158
+ signature_input_output_name: Name of the input or output in the signature.
159
+ verbose: Flag to enable verbose output.
160
+
161
+ Returns:
162
+ The list with one or two tensor names. The first one is the input tensor
163
+ name, and the second one is the output tensor name.
164
+ """
165
+
166
+ if verbose:
167
+ print("Searching tensor by signature name.")
168
+
169
+ tensor_names = []
170
+
171
+ # Search among inputs.
172
+ input_details = self._signature_runners[signature_key].get_input_details()
173
+ if signature_input_output_name in input_details.keys():
174
+ tensor_names.append(input_details[signature_input_output_name]["name"])
175
+
176
+ # Search among outputs.
177
+ output_details = self._signature_runners[signature_key].get_output_details()
178
+ if signature_input_output_name not in output_details:
179
+ if not tensor_names:
180
+ raise ValueError(
181
+ f"Signature {signature_key} does not have input or output"
182
+ f" `{signature_input_output_name}`"
183
+ )
184
+ return tensor_names
185
+
186
+ output_tensor_name = output_details[signature_input_output_name]["name"]
187
+ if verbose:
188
+ print(
189
+ ">> Starting recursive search for the output tensor name:"
190
+ f" {output_tensor_name}"
191
+ )
192
+
193
+ idx = self._signature_runners[signature_key]._subgraph_index # pylint: disable=protected-access
194
+ subgraph = self._flatbuffer_object.subgraphs[idx]
195
+ graph_info = qtyping.GraphInfo(
196
+ subgraph.tensors, self._flatbuffer_object.buffers
197
+ )
198
+
199
+ # Recursively search the graph for the output tensor name since it may be
200
+ # `SAME_AS_INPUT` constrainted.
201
+ operators = copy.deepcopy(subgraph.operators)
202
+ tensor_name = self._search_reverse_order_recursively(
203
+ graph_info, operators, output_tensor_name, indent=" ", verbose=verbose
204
+ )
205
+ tensor_names.append(tensor_name)
206
+
207
+ if verbose:
208
+ print(f"\n\nFound tensor name: {tensor_name}")
209
+
210
+ return tensor_names
211
+
212
+ def _search_reverse_order_recursively(
213
+ self,
214
+ graph_info: qtyping.GraphInfo,
215
+ operators: list[Any],
216
+ output_tensor_name: str,
217
+ indent: str,
218
+ verbose: bool = False,
219
+ ):
220
+ """Searches for a tensor name in reverse order recursively.
221
+
222
+ Stop criteria is when the tensor belongs to an operator that is not
223
+ `SAME_AS_INPUT` constrainted.
224
+
225
+ Args:
226
+ graph_info: Graph information.
227
+ operators: List of operators.
228
+ output_tensor_name: Name of the output tensor to search for.
229
+ indent: Indentation string for debug output.
230
+ verbose: Flag to enable verbose output.
231
+
232
+ Returns:
233
+ The name of the tensor found, or None if not found.
234
+ """
235
+ op_codes = self._flatbuffer_object.operatorCodes
236
+
237
+ while operators:
238
+ op = operators.pop()
239
+ op_code = op_codes[op.opcodeIndex].builtinCode
240
+ op_name = flatbuffer_utils.opcode_to_name(
241
+ self._flatbuffer_object, op.opcodeIndex
242
+ )
243
+ if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
244
+ continue
245
+ for output_idx in op.outputs:
246
+ if output_tensor_name == tfl_flatbuffer_utils.get_tensor_name(
247
+ graph_info.subgraph_tensors[output_idx]
248
+ ):
249
+ dbg_str = (
250
+ f"{indent}>> Found `{op_name}`, output tensor"
251
+ f" '{output_tensor_name}'"
252
+ )
253
+
254
+ if op_name not in self._same_as_input_scale_ops:
255
+ if verbose:
256
+ print(f"{dbg_str}, returning...")
257
+ return output_tensor_name
258
+
259
+ if verbose:
260
+ print(f"{dbg_str}, with SAME_AS_INPUT, search recursively among:")
261
+ for input_idx in op.inputs:
262
+ input_tensor_name = graph_info.subgraph_tensors[
263
+ input_idx
264
+ ].name.decode("utf-8")
265
+
266
+ if verbose:
267
+ print(f"{indent} Input: {input_tensor_name}")
268
+
269
+ return self._search_reverse_order_recursively(
270
+ graph_info,
271
+ operators,
272
+ input_tensor_name,
273
+ indent=f"{indent} ",
274
+ verbose=verbose,
275
+ )
276
+ return output_tensor_name
277
+
278
+ def align_quant_stats(
279
+ self, qsv: dict[str, Any], signature_data: _SignatureData
280
+ ) -> tuple[np.ndarray, np.ndarray]:
281
+ """Aligns quantization statistics for a given signature data.
282
+
283
+ This function takes quantization statistics and signature data as input,
284
+ identifies the tensors associated with the signature data, and aligns
285
+ the quantization statistics of these tensors by setting their minimum
286
+ and maximum values to the same value. This ensures that the tensors
287
+ have the same quantization parameters.
288
+
289
+ Args:
290
+ qsv: Quantization statistics.
291
+ signature_data: Signature data.
292
+
293
+ Returns:
294
+ Tuple of min and max values.
295
+ """
296
+ # Go over all signature info and find the corresponding tensor names.
297
+ tensor_names = []
298
+ for signature_key, signature_names in signature_data.items():
299
+ for signature_name in signature_names:
300
+ tensor_name = self._search_tensor_by_signature_name(
301
+ signature_key, signature_name
302
+ )
303
+ tensor_names.extend(tensor_name)
304
+
305
+ # Find min and max values accross all tensors.
306
+ min_value, max_value = _find_overall_min_max(qsv, tensor_names)
307
+
308
+ # Overwrite the min and max values in the QSV.
309
+ for tensor_name in tensor_names:
310
+ qsv[tensor_name]["min"] = min_value
311
+ qsv[tensor_name]["max"] = max_value
312
+
313
+ return min_value, max_value
314
+
315
+ def update_quant_stats(
316
+ self,
317
+ qsv: dict[str, Any],
318
+ signature_data: _SignatureData,
319
+ min_value: np.ndarray,
320
+ max_value: np.ndarray,
321
+ ):
322
+ """Updates quantization statistics for a given signature data.
323
+
324
+ This function updates the quantization statistics with the provided min, max
325
+ values for the tensors specified in the signature data.
326
+
327
+ Args:
328
+ qsv: Quantization statistics.
329
+ signature_data: Signature data.
330
+ min_value: Minimum value to update.
331
+ max_value: Maximum value to update.
332
+
333
+ Returns:
334
+ Updated quantization statistics.
335
+ """
336
+ # Go over all signature info and find the corresponding tensor names.
337
+ tensor_names = []
338
+ for signature_key, signature_names in signature_data.items():
339
+ for signature_name in signature_names:
340
+ tensor_name = self._search_tensor_by_signature_name(
341
+ signature_key, signature_name
342
+ )
343
+ tensor_names.extend(tensor_name)
344
+
345
+ # Overwrite the min and max values in the QSV.
346
+ for tensor_name in tensor_names:
347
+ qsv[tensor_name]["min"] = min_value
348
+ qsv[tensor_name]["max"] = max_value
@@ -14,11 +14,67 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from absl.testing import parameterized
17
+ import numpy as np
18
+
17
19
  from tensorflow.python.platform import googletest
20
+ from ai_edge_quantizer import quantizer
18
21
  from ai_edge_quantizer.utils import calibration_utils
22
+ from ai_edge_quantizer.utils import test_utils
23
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
24
+
25
+ _RNG = np.random.default_rng(66)
26
+
27
+ _CALIBRATION_DATASET = {
28
+ "signature_1": [{
29
+ "cache_0": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
30
+ "cache_1": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
31
+ "positions": np.zeros(shape=(1, 100), dtype=np.int32),
32
+ "tokens": np.zeros(shape=(1, 100), dtype=np.int32),
33
+ }],
34
+ "signature_2": [{
35
+ "cache_0": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
36
+ "cache_1": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
37
+ "positions": (
38
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
39
+ ),
40
+ "tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32),
41
+ }],
42
+ }
43
+
19
44
 
45
+ def _get_quant_parameters(
46
+ quantized_model: bytes, signature_data: dict[str, list[str]]
47
+ ) -> list[np.ndarray]:
48
+ """Returns the quantization parameters from the quantized model."""
49
+ quant_params = []
50
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
51
+ quantized_model
52
+ )
53
+ for signature_key, signature_names in signature_data.items():
54
+ signature_runner = tfl_interpreter.get_signature_runner(signature_key)
55
+
56
+ for signature_name in signature_names:
57
+ input_details = signature_runner.get_input_details()
58
+ output_details = signature_runner.get_output_details()
59
+ if signature_name in input_details.keys():
60
+ quant_param = input_details[signature_name]["quantization_parameters"][
61
+ "scales"
62
+ ].squeeze()
63
+ quant_params.append(quant_param)
64
+ elif signature_name in output_details.keys():
65
+ output_details = signature_runner.get_output_details()
66
+ quant_param = output_details[signature_name]["quantization_parameters"][
67
+ "scales"
68
+ ].squeeze()
69
+ quant_params.append(quant_param)
70
+ else:
71
+ raise ValueError(
72
+ f"Signature name {signature_name} not found in the model."
73
+ )
74
+ return quant_params
20
75
 
21
- class CalibrationUtilsTest(parameterized.TestCase):
76
+
77
+ class CalibrationQsvAlignmentUtilsTest(parameterized.TestCase):
22
78
 
23
79
  @parameterized.named_parameters(
24
80
  dict(
@@ -66,12 +122,126 @@ class CalibrationUtilsTest(parameterized.TestCase):
66
122
  def test_update_tensor_qsv_min_max(self, old_qsv, new_qsv, expected_qsv):
67
123
  updated_qsv = calibration_utils.min_max_update(old_qsv, new_qsv)
68
124
  if isinstance(expected_qsv["min"], list):
69
- self.assertListEqual(list(updated_qsv["min"]), expected_qsv["min"])
70
- self.assertListEqual(list(updated_qsv["max"]), expected_qsv["max"])
125
+ self.assertEqual(list(updated_qsv["min"]), expected_qsv["min"])
126
+ self.assertEqual(list(updated_qsv["max"]), expected_qsv["max"])
71
127
  else:
72
128
  self.assertEqual(updated_qsv["min"], expected_qsv["min"])
73
129
  self.assertEqual(updated_qsv["max"], expected_qsv["max"])
74
130
 
131
+ def test_calibration_utils_init_fails(self):
132
+ model_path = "non_existent_model.tflite"
133
+ with self.assertRaisesWithPredicateMatch(
134
+ Exception, lambda err: f"{model_path}" in str(err)
135
+ ):
136
+ calibration_utils.CalibrationQsvAlignmentUtils(model_path)
137
+
138
+ def test_calibration_utils_init_succeeds(self):
139
+ model_path = test_utils.get_path_to_datafile(
140
+ "../tests/models/single_add.tflite"
141
+ )
142
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
143
+ self.assertNotEmpty(calib_utils._signature_runners)
144
+ self.assertNotEmpty(calib_utils._same_as_input_scale_ops)
145
+
146
+ def test_search_tensor_by_signature_name_succeeds_on_unconstrained_op(self):
147
+ model_path = test_utils.get_path_to_datafile(
148
+ "../tests/models/single_add.tflite"
149
+ )
150
+ expected_tensor_name = "PartitionedCall:0"
151
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
152
+ tensor_name = calib_utils._search_tensor_by_signature_name(
153
+ "serving_default", "add"
154
+ )
155
+ self.assertEqual(tensor_name, [expected_tensor_name])
156
+
157
+ def test_search_tensor_by_signature_name_succeeds_on_constrained_op(self):
158
+ model_path = test_utils.get_path_to_datafile(
159
+ "../tests/models/single_slice.tflite"
160
+ )
161
+ expected_tensor_name = "slice_input_tensor:0"
162
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
163
+ tensor_name = calib_utils._search_tensor_by_signature_name(
164
+ "slice", "output_0"
165
+ )
166
+ self.assertEqual(tensor_name, [expected_tensor_name])
167
+
168
+ def test_align_quant_stats_succeeds(self):
169
+ model_path = test_utils.get_path_to_datafile(
170
+ "../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
171
+ )
172
+ recipe_path = test_utils.get_path_to_datafile(
173
+ "../recipes/default_a8w8_recipe.json"
174
+ )
175
+ signature_data = {
176
+ "signature_1": ["output_1_1"],
177
+ "signature_2": ["output_1_1"],
178
+ }
179
+
180
+ # Obtain the calibration results.
181
+ qt = quantizer.Quantizer(model_path, recipe_path)
182
+ qsv = qt.calibrate(_CALIBRATION_DATASET)
183
+
184
+ # First quantize the model without aligning the quantization parameters.
185
+ quantized_model = qt.quantize(qsv).quantized_model
186
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
187
+ self.assertFalse(
188
+ all(x == quant_params[0] for x in quant_params)
189
+ ) # not equal quantization params.
190
+
191
+ # Align the quantization parameters and quantize again.
192
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
193
+ calib_utils.align_quant_stats(qsv, signature_data)
194
+ quantized_model = qt.quantize(qsv).quantized_model
195
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
196
+ self.assertTrue(
197
+ all(x == quant_params[0] for x in quant_params)
198
+ ) # equal quantization params.
199
+
200
+ def test_update_quant_stats_succeeds(self):
201
+ model_path = test_utils.get_path_to_datafile(
202
+ "../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
203
+ )
204
+ recipe_path = test_utils.get_path_to_datafile(
205
+ "../recipes/default_a8w8_recipe.json"
206
+ )
207
+ signature_data = {
208
+ "signature_1": ["output_1_1"],
209
+ "signature_2": ["output_1_1"],
210
+ }
211
+
212
+ # Obtain the calibration results.
213
+ qt = quantizer.Quantizer(model_path, recipe_path)
214
+ qsv = qt.calibrate(_CALIBRATION_DATASET)
215
+
216
+ # First quantize the model without updating the `signature_1`.
217
+ quantized_model = qt.quantize(qsv).quantized_model
218
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
219
+ self.assertFalse(
220
+ all(x == quant_params[0] for x in quant_params)
221
+ ) # not equal quantization params.
222
+
223
+ # Update the `signature_1` with stats from `signature_2`.
224
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
225
+ min_val, max_val = calib_utils.align_quant_stats( # for min and max only.
226
+ qsv,
227
+ {
228
+ "signature_2": ["output_1_1"],
229
+ },
230
+ )
231
+ calib_utils.update_quant_stats(
232
+ qsv,
233
+ {
234
+ "signature_1": ["output_1_1"],
235
+ },
236
+ min_val,
237
+ max_val,
238
+ )
239
+ quantized_model = qt.quantize(qsv).quantized_model
240
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
241
+ self.assertTrue(
242
+ all(x == quant_params[0] for x in quant_params)
243
+ ) # equal quantization params.
244
+
75
245
 
76
246
  if __name__ == "__main__":
77
247
  googletest.main()
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utils for handling operators with quantization constraints."""
17
+
18
+ from ai_edge_quantizer import algorithm_manager
19
+ from ai_edge_quantizer import qtyping
20
+ from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
21
+ from ai_edge_quantizer.algorithms.utils import common_utils
22
+ from ai_edge_litert import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ _OpQuantConstraint = common_utils.OpQuantConstraint
26
+
27
+
28
+ def get_constrained_op_list(
29
+ quant_constraint: _OpQuantConstraint, verbose: bool = False
30
+ ) -> list[str]:
31
+ """Constructs and returns a list of constrained operators.
32
+
33
+ This is achieved by invoking all materialization functions and extracting
34
+ the constraint argument, using monkey patching to redirect logic to wrapper
35
+ functions.
36
+
37
+ Args:
38
+ quant_constraint: The quantization constraint to filter operators by.
39
+ verbose: Flag to enable verbose output.
40
+
41
+ Returns:
42
+ A list containing operators with the specified constraint.
43
+ """
44
+ constrained_ops = []
45
+
46
+ def materialize_standard_op_wrapper(
47
+ op_info: qtyping.OpInfo,
48
+ *_args,
49
+ constraint: _OpQuantConstraint = _OpQuantConstraint.NO_CONSTRAIN,
50
+ **_kwargs,
51
+ ) -> list[qtyping.TensorTransformationParams]:
52
+ if constraint == quant_constraint:
53
+ constrained_ops.append(op_info.op_name)
54
+ # Return dummy values to avoid exceptions.
55
+ dummy_value = [qtyping.TensorTransformationParams("")] * 2
56
+ return dummy_value
57
+
58
+ # Dummy implementation of the `_are_weights_too_small` function to support
59
+ # `materialize_standard_op_wrapper` above.
60
+ def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
61
+ return False
62
+
63
+ # Dummy implementation of the `_materialize_bias_for_fc_conv_ops` function to
64
+ # support `materialize_standard_op_wrapper` above.
65
+ def materialize_bias_for_fc_conv_ops_wrapper(*_args, **_kwargs):
66
+ return
67
+
68
+ # Do monkey patch to intercept the `materialize_standard_op` function to
69
+ # support `materialize_standard_op_wrapper` above.
70
+ original_materialize_standard_op = common_utils.materialize_standard_op
71
+ original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
72
+ original_materialize_bias_for_fc_conv_ops = (
73
+ common_quantize._materialize_bias_for_fc_conv_ops # pylint: disable=protected-access
74
+ )
75
+ common_utils.materialize_standard_op = materialize_standard_op_wrapper
76
+ common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
77
+ common_quantize._materialize_bias_for_fc_conv_ops = ( # pylint: disable=protected-access
78
+ materialize_bias_for_fc_conv_ops_wrapper
79
+ )
80
+ minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
81
+
82
+ # Loop over all available materialization functions to build up a list of
83
+ # ops with the given constraint.
84
+ for op, materialize_fn in minmax_func_dict.items():
85
+ # Create a dummy op info to trigger the materialization.
86
+ mock_op = schema_fb.OperatorT()
87
+ mock_op.inputs = [0]
88
+ mock_op.outputs = [0]
89
+ op_info = qtyping.OpInfo(
90
+ op=mock_op,
91
+ op_name=op,
92
+ subgraph_op_index=0,
93
+ op_quant_config=qtyping.OpQuantizationConfig(),
94
+ )
95
+ materialize_fn(
96
+ get_tensor_quant_params_fn=None,
97
+ op_info=op_info,
98
+ graph_info=None,
99
+ tensor_name_to_qsv=None,
100
+ )
101
+
102
+ if verbose:
103
+ print(f" {quant_constraint} op list: {constrained_ops}")
104
+
105
+ # Restore the original functions.
106
+ common_utils.materialize_standard_op = original_materialize_standard_op
107
+ common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
108
+ common_quantize._materialize_bias_for_fc_conv_ops = ( # pylint: disable=protected-access
109
+ original_materialize_bias_for_fc_conv_ops
110
+ )
111
+ return constrained_ops
@@ -0,0 +1,50 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from tensorflow.python.platform import googletest
17
+ from absl.testing import parameterized
18
+ from ai_edge_quantizer.algorithms.utils import common_utils
19
+ from ai_edge_quantizer.utils import constrained_ops_utils
20
+
21
+
22
+ _OpQuantConstraint = common_utils.OpQuantConstraint
23
+
24
+
25
+ class ConstrainedOpsUtilsTest(parameterized.TestCase):
26
+
27
+ @parameterized.named_parameters(
28
+ dict(
29
+ testcase_name="same_as_input_scale",
30
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
31
+ expected_num_ops=18,
32
+ ),
33
+ dict(
34
+ testcase_name="same_as_output_scale",
35
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
36
+ expected_num_ops=7,
37
+ ),
38
+ dict(
39
+ testcase_name="no_constrain",
40
+ constraint=_OpQuantConstraint.NO_CONSTRAIN,
41
+ expected_num_ops=25,
42
+ ),
43
+ )
44
+ def test_get_constrained_op_list(self, constraint, expected_num_ops):
45
+ constrained_ops = constrained_ops_utils.get_constrained_op_list(constraint)
46
+ self.assertLen(constrained_ops, expected_num_ops)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ googletest.main()