ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,357 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """function for validating output models."""
17
+
18
+ from collections.abc import Callable, Iterable
19
+ import dataclasses
20
+ import json
21
+ import math
22
+ import os
23
+ from typing import Any, Optional, Union
24
+
25
+ import numpy as np
26
+
27
+ from ai_edge_quantizer.utils import tfl_interpreter_utils as utils
28
+ from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
29
+
30
+
31
+ _DEFAULT_SIGNATURE_KEY = utils.DEFAULT_SIGNATURE_KEY
32
+
33
+
34
+ @dataclasses.dataclass(frozen=True)
35
+ class SingleSignatureComparisonResult:
36
+ """Comparison result for a single signature.
37
+
38
+ Attributes:
39
+ error_metric: The name of the error metric used for comparison.
40
+ input_tensors: A dictionary of input tensor name and its value.
41
+ output_tensors: A dictionary of output tensor name and its value.
42
+ constant_tensors: A dictionary of constant tensor name and its value.
43
+ intermediate_tensors: A dictionary of intermediate tensor name and its
44
+ value.
45
+ """
46
+
47
+ error_metric: str
48
+ input_tensors: dict[str, Any]
49
+ output_tensors: dict[str, Any]
50
+ constant_tensors: dict[str, Any]
51
+ intermediate_tensors: dict[str, Any]
52
+
53
+
54
+ class ComparisonResult:
55
+ """Comparison result for a model.
56
+
57
+ Attributes:
58
+ comparison_results: A dictionary of signature key and its comparison result.
59
+ """
60
+
61
+ def __init__(self, reference_model: bytes, target_model: bytes):
62
+ """Initialize the ComparisonResult object.
63
+
64
+ Args:
65
+ reference_model: Model which will be used as the reference.
66
+ target_model: Target model which will be compared against the reference.
67
+ We expect target_model and reference_model to have the same graph
68
+ structure.
69
+ """
70
+ self._reference_model = reference_model
71
+ self._target_model = target_model
72
+ self._comparison_results: dict[str, SingleSignatureComparisonResult] = {}
73
+
74
+ def get_signature_comparison_result(
75
+ self, signature_key: str = _DEFAULT_SIGNATURE_KEY
76
+ ) -> SingleSignatureComparisonResult:
77
+ """Get the comparison result for a signature.
78
+
79
+ Args:
80
+ signature_key: The signature key to be used for invoking the models.
81
+
82
+ Returns:
83
+ A SingleSignatureComparisonResult object.
84
+ """
85
+ if signature_key not in self._comparison_results:
86
+ raise ValueError(
87
+ f'{signature_key} is not in the comparison_results. Available'
88
+ f' signature keys are: {self.available_signature_keys()}'
89
+ )
90
+ return self._comparison_results[signature_key]
91
+
92
+ def available_signature_keys(self) -> list[str]:
93
+ """Get the available signature keys in the comparison result."""
94
+ return list(self._comparison_results.keys())
95
+
96
+ def add_new_signature_results(
97
+ self,
98
+ error_metric: str,
99
+ comparison_result: dict[str, float],
100
+ signature_key: str = _DEFAULT_SIGNATURE_KEY,
101
+ ):
102
+ """Add a new signature result to the comparison result.
103
+
104
+ Args:
105
+ error_metric: The name of the error metric used for comparison.
106
+ comparison_result: A dictionary of tensor name and its value.
107
+ signature_key: The model signature that the comparison_result belongs to.
108
+
109
+ Raises:
110
+ ValueError: If the signature_key is already in the comparison_results.
111
+ """
112
+ if signature_key in self._comparison_results:
113
+ raise ValueError(f'{signature_key} is already in the comparison_results.')
114
+
115
+ result = {key: float(value) for key, value in comparison_result.items()}
116
+
117
+ input_tensor_results = {}
118
+ for name in utils.get_input_tensor_names(
119
+ self._reference_model, signature_key
120
+ ):
121
+ input_tensor_results[name] = result.pop(name)
122
+
123
+ output_tensor_results = {}
124
+ for name in utils.get_output_tensor_names(
125
+ self._reference_model, signature_key
126
+ ):
127
+ output_tensor_results[name] = result.pop(name)
128
+
129
+ constant_tensor_results = {}
130
+ # Only get constant tensors from the main subgraph of the signature.
131
+ subgraph_index = utils.get_signature_main_subgraph_index(
132
+ utils.create_tfl_interpreter(self._reference_model),
133
+ signature_key,
134
+ )
135
+ for name in utils.get_constant_tensor_names(
136
+ self._reference_model,
137
+ subgraph_index,
138
+ ):
139
+ constant_tensor_results[name] = result.pop(name)
140
+
141
+ self._comparison_results[signature_key] = SingleSignatureComparisonResult(
142
+ error_metric=error_metric,
143
+ input_tensors=input_tensor_results,
144
+ output_tensors=output_tensor_results,
145
+ constant_tensors=constant_tensor_results,
146
+ intermediate_tensors=result,
147
+ )
148
+
149
+ def get_all_tensor_results(self) -> dict[str, Any]:
150
+ """Get all the tensor results in a single dictionary.
151
+
152
+ Returns:
153
+ A dictionary of tensor name and its value.
154
+ """
155
+ result = {}
156
+ for _, signature_comparison_result in self._comparison_results.items():
157
+ result.update(signature_comparison_result.input_tensors)
158
+ result.update(signature_comparison_result.output_tensors)
159
+ result.update(signature_comparison_result.constant_tensors)
160
+ result.update(signature_comparison_result.intermediate_tensors)
161
+ return result
162
+
163
+ def save(self, save_folder: str, model_name: str) -> None:
164
+ """Saves the model comparison result.
165
+
166
+ Args:
167
+ save_folder: Path to the folder to save the comparison result.
168
+ model_name: Name of the model.
169
+
170
+ Raises:
171
+ RuntimeError: If no quantized model is available.
172
+ """
173
+ reduced_model_size = len(self._reference_model) - len(self._target_model)
174
+ reduction_ratio = reduced_model_size / len(self._reference_model) * 100
175
+ result = {
176
+ 'reduced_size_bytes': reduced_model_size,
177
+ 'reduced_size_percentage': reduction_ratio,
178
+ }
179
+ for signature, comparison_result in self._comparison_results.items():
180
+ result[str(signature)] = {
181
+ 'error_metric': comparison_result.error_metric,
182
+ 'input_tensors': comparison_result.input_tensors,
183
+ 'output_tensors': comparison_result.output_tensors,
184
+ 'constant_tensors': comparison_result.constant_tensors,
185
+ 'intermediate_tensors': comparison_result.intermediate_tensors,
186
+ }
187
+ result_save_path = os.path.join(
188
+ save_folder, model_name + '_comparison_result.json'
189
+ )
190
+ with gfile.GFile(result_save_path, 'w') as output_file_handle:
191
+ output_file_handle.write(json.dumps(result))
192
+
193
+ # TODO: b/365578554 - Remove after ME is updated to use the new json format.
194
+ color_threshold = [0.05, 0.1, 0.2, 0.4, 1, 10, 100]
195
+ json_object = create_json_for_model_explorer(
196
+ self,
197
+ threshold=color_threshold,
198
+ )
199
+ json_save_path = os.path.join(
200
+ save_folder, model_name + '_comparison_result_me_input.json'
201
+ )
202
+ with gfile.GFile(json_save_path, 'w') as output_file_handle:
203
+ output_file_handle.write(json_object)
204
+
205
+
206
+ def _setup_validation_interpreter(
207
+ model: bytes,
208
+ signature_input: dict[str, Any],
209
+ signature_key: Optional[str],
210
+ use_xnnpack: bool,
211
+ num_threads: int,
212
+ ) -> tuple[Any, int, dict[str, Any]]:
213
+ """Setup the interpreter for validation given a signature key.
214
+
215
+ Args:
216
+ model: The model to be validated.
217
+ signature_input: A dictionary of input tensor name and its value.
218
+ signature_key: The signature key to be used for invoking the models. If the
219
+ model only has one signature, this can be set to None.
220
+ use_xnnpack: Whether to use xnnpack for the interpreter.
221
+ num_threads: The number of threads to use for the interpreter.
222
+
223
+ Returns:
224
+ A tuple of interpreter, subgraph_index and tensor_name_to_details.
225
+ """
226
+
227
+ interpreter = utils.create_tfl_interpreter(
228
+ tflite_model=model, use_xnnpack=use_xnnpack, num_threads=num_threads
229
+ )
230
+ utils.invoke_interpreter_signature(
231
+ interpreter, signature_input, signature_key
232
+ )
233
+ # Only validate tensors from the main subgraph of the signature.
234
+ subgraph_index = utils.get_signature_main_subgraph_index(
235
+ interpreter, signature_key
236
+ )
237
+ tensor_name_to_details = utils.get_tensor_name_to_details_map(
238
+ interpreter,
239
+ subgraph_index,
240
+ )
241
+ return interpreter, subgraph_index, tensor_name_to_details
242
+
243
+
244
+ # TODO: b/330797129 - Enable multi-threaded evaluation.
245
+ def compare_model(
246
+ reference_model: bytes,
247
+ target_model: bytes,
248
+ test_data: dict[str, Iterable[dict[str, Any]]],
249
+ error_metric: str,
250
+ compare_fn: Callable[[Any, Any], float],
251
+ use_xnnpack: bool = True,
252
+ num_threads: int = 16,
253
+ ) -> ComparisonResult:
254
+ """Compares model tensors over a model signature using the compare_fn.
255
+
256
+ This function will run the model signature on the provided dataset over and
257
+ compare all the tensors (cached) using the compare_fn (e.g., mean square
258
+ error).
259
+
260
+ Args:
261
+ reference_model: Model which will be used as the reference
262
+ target_model: Target model which will be compared against the reference. We
263
+ expect reference_model and target_model have the inputs and outputs
264
+ signature.
265
+ test_data: A dictionary of signature key and its correspending test input
266
+ data that will be used for comparison.
267
+ error_metric: The name of the error metric used for comparison.
268
+ compare_fn: a comparison function to be used for calculating the statistics,
269
+ this function must be taking in two ArrayLike strcuture and output a
270
+ single float value.
271
+ use_xnnpack: Whether to use xnnpack for the interpreter.
272
+ num_threads: The number of threads to use for the interpreter.
273
+
274
+ Returns:
275
+ A ComparisonResult object.
276
+ """
277
+ model_comparion_result = ComparisonResult(reference_model, target_model)
278
+ for signature_key, signature_inputs in test_data.items():
279
+ comparison_results = {}
280
+ for signature_input in signature_inputs:
281
+ # Invoke the signature on both interpreters.
282
+ ref_interpreter, ref_subgraph_index, ref_tensor_name_to_details = (
283
+ _setup_validation_interpreter(
284
+ reference_model,
285
+ signature_input,
286
+ signature_key,
287
+ use_xnnpack,
288
+ num_threads,
289
+ )
290
+ )
291
+ targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
292
+ _setup_validation_interpreter(
293
+ target_model,
294
+ signature_input,
295
+ signature_key,
296
+ use_xnnpack,
297
+ num_threads,
298
+ )
299
+ )
300
+ # Compare the cached tensor values.
301
+ for tensor_name, detail in ref_tensor_name_to_details.items():
302
+ if detail['dtype'] == np.object_:
303
+ continue
304
+ if tensor_name in targ_tensor_name_to_details:
305
+ if tensor_name not in comparison_results:
306
+ comparison_results[tensor_name] = []
307
+
308
+ reference_data = utils.get_tensor_data(
309
+ ref_interpreter, detail, ref_subgraph_index
310
+ )
311
+ target_data = utils.get_tensor_data(
312
+ targ_interpreter,
313
+ targ_tensor_name_to_details[tensor_name],
314
+ targ_subgraph_index,
315
+ )
316
+ comparison_results[tensor_name].append(
317
+ compare_fn(target_data, reference_data)
318
+ )
319
+
320
+ agregated_results = {}
321
+ for tensor_name in comparison_results:
322
+ agregated_results[tensor_name] = np.mean(comparison_results[tensor_name])
323
+ model_comparion_result.add_new_signature_results(
324
+ error_metric,
325
+ agregated_results,
326
+ signature_key,
327
+ )
328
+ return model_comparion_result
329
+
330
+
331
+ def create_json_for_model_explorer(
332
+ data: ComparisonResult, threshold: list[Union[int, float]]
333
+ ) -> str:
334
+ """create a dict type that can be exported as json for model_explorer to use.
335
+
336
+ Args:
337
+ data: Output from compare_model function
338
+ threshold: A list of numbers representing thresholds for model_exlorer to
339
+ display different colors
340
+
341
+ Returns:
342
+ A string represents the json format accepted by model_explorer
343
+ """
344
+ data = data.get_all_tensor_results()
345
+ color_scheme = []
346
+ results = {name: {'value': float(value)} for name, value in data.items()}
347
+ if threshold:
348
+ green = 255
349
+ gradient = math.floor(255 / len(threshold))
350
+ for val in threshold:
351
+ color_scheme.append({'value': val, 'bgColor': f'rgb(200, {green}, 0)'})
352
+ green = max(0, green - gradient)
353
+
354
+ return json.dumps({
355
+ 'results': results,
356
+ 'thresholds': color_scheme,
357
+ })