ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,361 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Generate model tensor level quantization config."""
17
+
18
+ import copy
19
+ from typing import Any, Optional, Union
20
+
21
+ from ai_edge_quantizer import algorithm_manager
22
+ from ai_edge_quantizer import qtyping
23
+ from ai_edge_quantizer import recipe_manager
24
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
25
+
26
+ _QuantTrans = qtyping.QuantTransformation
27
+ _OpName = qtyping.TFLOperationName
28
+
29
+
30
+ class ParamsGenerator:
31
+ """Generate model tensor level quantization parameters."""
32
+
33
+ def __init__(self, float_tflite: Union[str, bytes]):
34
+ self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
35
+
36
+ if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
37
+ raise ValueError(
38
+ 'The input model for quantization parameters generation is not a'
39
+ ' float model. Please check the model (e.g., if it is already'
40
+ ' quantized).'
41
+ )
42
+ self._check_tensor_names_are_unique()
43
+ self.buffer_to_tensors: dict[int, list[Any]] = (
44
+ tfl_flatbuffer_utils.buffer_to_tensors(self.flatbuffer_model)
45
+ )
46
+ self.model_quant_results: dict[str, qtyping.TensorTransformationParams] = {}
47
+
48
+ def generate_quantization_parameters(
49
+ self,
50
+ model_recipe_manager: recipe_manager.RecipeManager,
51
+ model_qsvs: Optional[dict[str, qtyping.QSV]] = None,
52
+ ) -> dict[str, qtyping.TensorTransformationParams]:
53
+ """Generate the quantization parameters for the model.
54
+
55
+ Args:
56
+ model_recipe_manager: The recipe manager for the model.
57
+ model_qsvs: Quantization statistics values (QSVs) for the model. This is
58
+ obtained through calibration process.
59
+
60
+ Returns:
61
+ model_quant_results: The quantization parameters for tensors in the model.
62
+
63
+ Raises:
64
+ RuntimeError: If the calibration dataset is required but not provided.
65
+ """
66
+ if model_recipe_manager.need_calibration() and not model_qsvs:
67
+ raise RuntimeError(
68
+ 'Model quantization statistics values (QSVs) are required for the'
69
+ ' input recipe. This can be obtained by running calibration on sample'
70
+ ' dataset.'
71
+ )
72
+
73
+ if model_qsvs is None:
74
+ model_qsvs = {}
75
+
76
+ op_codes = self.flatbuffer_model.operatorCodes
77
+ for subgraph in self.flatbuffer_model.subgraphs:
78
+ graph_info = qtyping.GraphInfo(
79
+ subgraph.tensors, self.flatbuffer_model.buffers
80
+ )
81
+ # Add input/output operators to the subgraph.
82
+ subgraph.operators += (
83
+ tfl_flatbuffer_utils.get_subgraph_input_output_operators(subgraph)
84
+ )
85
+ for subgraph_op_id, op in enumerate(subgraph.operators):
86
+ # Get the op key.
87
+ if isinstance(op, qtyping.IOOperator):
88
+ op_key = op.op_key
89
+ subgraph_op_id = -1 # Virtual op, no real id.
90
+ else:
91
+ op_code = op_codes[op.opcodeIndex].builtinCode
92
+ # Do not quantize unknown ops.
93
+ if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
94
+ op_quant_results = self._get_params_for_no_quant_op(
95
+ subgraph_op_id, op, subgraph.tensors
96
+ )
97
+ self._update_model_quant_results(op_quant_results)
98
+ continue
99
+ op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
100
+
101
+ # Step1: query the quantization_recipe to get op config.
102
+ op_scope = self._get_op_scope(op, subgraph.tensors)
103
+ algorithm_name, op_quant_config = (
104
+ model_recipe_manager.get_quantization_configs(op_key, op_scope)
105
+ )
106
+ if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
107
+ op_quant_results = self._get_params_for_no_quant_op(
108
+ subgraph_op_id, op, subgraph.tensors
109
+ )
110
+ else:
111
+ op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
112
+ # Step2: query algorithm_manager to get/call the related function.
113
+ materialize_func = algorithm_manager.get_quantization_func(
114
+ algorithm_name,
115
+ op_key,
116
+ qtyping.QuantizeMode.MATERIALIZE,
117
+ )
118
+ op_quant_results = materialize_func(
119
+ op_info,
120
+ graph_info,
121
+ model_qsvs,
122
+ )
123
+ # Step3: update the results.
124
+ self._update_model_quant_results(op_quant_results)
125
+ self._post_process_results()
126
+ return self.model_quant_results
127
+
128
+ def _check_tensor_names_are_unique(self):
129
+ """Checks if the tensor names are unique in the model."""
130
+ global_tensor_names = set()
131
+ for subgraph in self.flatbuffer_model.subgraphs:
132
+ for tensor in subgraph.tensors:
133
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
134
+ if tensor_name in global_tensor_names:
135
+ raise ValueError(
136
+ 'Tensor name %s is not unique in the model. Please check your'
137
+ ' model and rename the tensor as ParamsGenerator assumes tensor'
138
+ ' names are unique.' % tensor_name
139
+ )
140
+ global_tensor_names.add(tensor_name)
141
+
142
+ def _post_process_results(self) -> None:
143
+ """Post process the quantization results.
144
+
145
+ Raises:
146
+ RuntimeError: If the tensors sharing the same buffer have different
147
+ quantization settings.
148
+ """
149
+ self._check_buffer_sharing()
150
+
151
+ def _update_model_quant_results(
152
+ self,
153
+ op_tensor_results: list[qtyping.TensorTransformationParams],
154
+ ) -> None:
155
+ """Update the op quantization results to the final output.
156
+
157
+ Args:
158
+ op_tensor_results: Tensor level quantization params for the op.
159
+
160
+ Raises:
161
+ RuntimeError: If the same tensor has multiple quantization configs.
162
+ """
163
+
164
+ for op_tensor_result in op_tensor_results:
165
+ tensor_name = op_tensor_result.tensor_name
166
+ if tensor_name not in self.model_quant_results:
167
+ self.model_quant_results[tensor_name] = copy.deepcopy(op_tensor_result)
168
+ else:
169
+ tensor_params = self.model_quant_results[tensor_name]
170
+ # Set source op.
171
+ if op_tensor_result.producer is not None:
172
+ # Src params must be unique (a tensor can only be produced by one op).
173
+ if tensor_params.producer is not None:
174
+ raise RuntimeError(
175
+ 'Tensor %s received multiple quantization parameters from the'
176
+ ' source op, which should not happen as every tensor should'
177
+ ' have only one source op.' % tensor_name
178
+ )
179
+ tensor_params.producer = copy.deepcopy(op_tensor_result.producer)
180
+ # Set target op, which can be multiple (a tensor can be consumed by
181
+ # multiple ops).
182
+ if op_tensor_result.consumers is not None:
183
+ if tensor_params.consumers is None:
184
+ tensor_params.consumers = copy.deepcopy(op_tensor_result.consumers)
185
+ else:
186
+ tensor_params.consumers += copy.deepcopy(op_tensor_result.consumers)
187
+ self.model_quant_results[tensor_name] = tensor_params
188
+
189
+ def _get_op_scope(self, op: Any, subgraph_tensors: list[Any]) -> str:
190
+ """Get the op scope.
191
+
192
+ Op scope is defined by the output tensor names (following the Model
193
+ Explorer).
194
+
195
+ Args:
196
+ op: The op that needs to be parsed.
197
+ subgraph_tensors: Tensors in the subgraph.
198
+
199
+ Returns:
200
+ Scope for the op.
201
+ """
202
+ scope = ''
203
+ # Op scope is determined by output tensors.
204
+ for output_tensor_idx in op.outputs:
205
+ if output_tensor_idx != -1:
206
+ scope += tfl_flatbuffer_utils.get_tensor_name(
207
+ subgraph_tensors[output_tensor_idx]
208
+ )
209
+ scope += ';' # Split names.
210
+ return scope
211
+
212
+ def _get_params_for_no_quant_op(
213
+ self,
214
+ subgraph_op_id: int,
215
+ op: Any,
216
+ subgraph_tensors: list[Any],
217
+ ) -> list[qtyping.TensorTransformationParams]:
218
+ """Get the quantization parameters for ops require no quantization.
219
+
220
+ Args:
221
+ subgraph_op_id: The op id in the subgraph.
222
+ op: The op that needs to be parsed.
223
+ subgraph_tensors: Tensors in the subgraph.
224
+
225
+ Returns:
226
+ Tensor level quantization params for the op.
227
+ """
228
+
229
+ def no_quant_tensor_params():
230
+ return qtyping.OpToTensorParams(
231
+ subgraph_op_id=subgraph_op_id,
232
+ transformations=[_QuantTrans.NO_QUANTIZE],
233
+ )
234
+
235
+ tensor_params = []
236
+ for input_tensor_idx in op.inputs:
237
+ if input_tensor_idx != -1:
238
+ tensor = subgraph_tensors[input_tensor_idx]
239
+ input_tensor_params = qtyping.TensorTransformationParams(
240
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(tensor),
241
+ consumers=[no_quant_tensor_params()],
242
+ )
243
+ tensor_params.append(input_tensor_params)
244
+
245
+ for output_tensor_idx in op.outputs:
246
+ if output_tensor_idx != -1:
247
+ tensor = subgraph_tensors[output_tensor_idx]
248
+ output_tensor_params = qtyping.TensorTransformationParams(
249
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(tensor),
250
+ producer=no_quant_tensor_params(),
251
+ )
252
+ tensor_params.append(output_tensor_params)
253
+ return tensor_params
254
+
255
+ def _check_buffer_sharing(self) -> None:
256
+ """Check if tensors sharing the same buffer have the same quantization.
257
+
258
+ Raises:
259
+ RuntimeError: If the tensors sharing the same buffer have different
260
+ quantization settings.
261
+ """
262
+ for tensors in self.buffer_to_tensors.values():
263
+ if len(tensors) <= 1:
264
+ continue
265
+ first_tensor = tensors[0]
266
+ first_tensor_params = self.model_quant_results[
267
+ tfl_flatbuffer_utils.get_tensor_name(first_tensor)
268
+ ]
269
+ for tensor in tensors[1:]:
270
+ tensor_params = self.model_quant_results[
271
+ tfl_flatbuffer_utils.get_tensor_name(tensor)
272
+ ]
273
+ if not _compatible_tensor_transformation_params(
274
+ first_tensor_params, tensor_params
275
+ ):
276
+ error_msg = (
277
+ f'The tensors {first_tensor.name} and {tensor.name} do not have'
278
+ ' the same quantization parameters even though they share the'
279
+ ' same buffer. Please modify your quantization recipe to make'
280
+ ' sure the two tensors have the same quantization settings.'
281
+ )
282
+ raise RuntimeError(error_msg)
283
+
284
+
285
+ def _compatible_tensor_transformation_params(
286
+ params1: qtyping.TensorTransformationParams,
287
+ params2: qtyping.TensorTransformationParams,
288
+ ) -> bool:
289
+ """Check if two tensor transformation params are compatible."""
290
+ if params1.producer is None or params2.producer is None:
291
+ if params1.producer != params2.producer:
292
+ return False
293
+ elif not _compatible_tensor_params(params1.producer, params2.producer):
294
+ return False
295
+ if params1.consumers is None or params2.consumers is None:
296
+ if params1.consumers != params2.consumers:
297
+ return False
298
+ else:
299
+ # Check all consumers within each params are compatible.
300
+ for params1_consumer in params1.consumers:
301
+ if not _compatible_tensor_params(params1_consumer, params1.consumers[0]):
302
+ return False
303
+ for params2_consumer in params2.consumers:
304
+ if not _compatible_tensor_params(params2_consumer, params2.consumers[0]):
305
+ return False
306
+ if not _compatible_tensor_params(
307
+ params1.consumers[0], params2.consumers[0]
308
+ ):
309
+ return False
310
+ return True
311
+
312
+
313
+ def _same_tensor_params_except_id(
314
+ params1: qtyping.OpToTensorParams,
315
+ params2: qtyping.OpToTensorParams,
316
+ ) -> bool:
317
+ """Check if two op to tensor params are the same except for subgraph_op_id."""
318
+ return params1.transformations == params2.transformations and (
319
+ params1.parameters == params2.parameters
320
+ or params1.parameters is None
321
+ and params2.parameters is None
322
+ )
323
+
324
+
325
+ def _compatible_tensor_params(
326
+ params1: qtyping.OpToTensorParams,
327
+ params2: qtyping.OpToTensorParams,
328
+ ) -> bool:
329
+ """Check if two op to tensor params are compatible."""
330
+ float_source_transformations = [
331
+ _QuantTrans.ADD_QUANTIZE,
332
+ _QuantTrans.NO_QUANTIZE,
333
+ ]
334
+ quantized_source_transformations = [
335
+ _QuantTrans.QUANTIZE_TENSOR,
336
+ _QuantTrans.ADD_DEQUANTIZE,
337
+ ]
338
+ if _same_tensor_params_except_id(params1, params2):
339
+ return True
340
+ if (
341
+ params1.transformations[0] != _QuantTrans.NO_QUANTIZE
342
+ and params2.transformations[0] != _QuantTrans.NO_QUANTIZE
343
+ ):
344
+ # NO_QUANTIZE has no parameters. So only if both params aren't NO_QUANTIZE
345
+ # do we expect the parameters to be the same.
346
+ if params1.parameters != params2.parameters:
347
+ return False
348
+ # We only need to check the first transformation because transformations are
349
+ # applied in order, and as long as the one that's immediately after the tensor
350
+ # is the same, it's compatible.
351
+ if (
352
+ params1.transformations[0] in float_source_transformations
353
+ and params2.transformations[0] in float_source_transformations
354
+ ):
355
+ return True
356
+ if (
357
+ params1.transformations[0] in quantized_source_transformations
358
+ and params2.transformations[0] in quantized_source_transformations
359
+ ):
360
+ return True
361
+ return False