ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,357 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""function for validating output models."""
|
17
|
+
|
18
|
+
from collections.abc import Callable, Iterable
|
19
|
+
import dataclasses
|
20
|
+
import json
|
21
|
+
import math
|
22
|
+
import os
|
23
|
+
from typing import Any, Optional, Union
|
24
|
+
|
25
|
+
import numpy as np
|
26
|
+
|
27
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils as utils
|
28
|
+
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
|
29
|
+
|
30
|
+
|
31
|
+
_DEFAULT_SIGNATURE_KEY = utils.DEFAULT_SIGNATURE_KEY
|
32
|
+
|
33
|
+
|
34
|
+
@dataclasses.dataclass(frozen=True)
|
35
|
+
class SingleSignatureComparisonResult:
|
36
|
+
"""Comparison result for a single signature.
|
37
|
+
|
38
|
+
Attributes:
|
39
|
+
error_metric: The name of the error metric used for comparison.
|
40
|
+
input_tensors: A dictionary of input tensor name and its value.
|
41
|
+
output_tensors: A dictionary of output tensor name and its value.
|
42
|
+
constant_tensors: A dictionary of constant tensor name and its value.
|
43
|
+
intermediate_tensors: A dictionary of intermediate tensor name and its
|
44
|
+
value.
|
45
|
+
"""
|
46
|
+
|
47
|
+
error_metric: str
|
48
|
+
input_tensors: dict[str, Any]
|
49
|
+
output_tensors: dict[str, Any]
|
50
|
+
constant_tensors: dict[str, Any]
|
51
|
+
intermediate_tensors: dict[str, Any]
|
52
|
+
|
53
|
+
|
54
|
+
class ComparisonResult:
|
55
|
+
"""Comparison result for a model.
|
56
|
+
|
57
|
+
Attributes:
|
58
|
+
comparison_results: A dictionary of signature key and its comparison result.
|
59
|
+
"""
|
60
|
+
|
61
|
+
def __init__(self, reference_model: bytes, target_model: bytes):
|
62
|
+
"""Initialize the ComparisonResult object.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
reference_model: Model which will be used as the reference.
|
66
|
+
target_model: Target model which will be compared against the reference.
|
67
|
+
We expect target_model and reference_model to have the same graph
|
68
|
+
structure.
|
69
|
+
"""
|
70
|
+
self._reference_model = reference_model
|
71
|
+
self._target_model = target_model
|
72
|
+
self._comparison_results: dict[str, SingleSignatureComparisonResult] = {}
|
73
|
+
|
74
|
+
def get_signature_comparison_result(
|
75
|
+
self, signature_key: str = _DEFAULT_SIGNATURE_KEY
|
76
|
+
) -> SingleSignatureComparisonResult:
|
77
|
+
"""Get the comparison result for a signature.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
signature_key: The signature key to be used for invoking the models.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
A SingleSignatureComparisonResult object.
|
84
|
+
"""
|
85
|
+
if signature_key not in self._comparison_results:
|
86
|
+
raise ValueError(
|
87
|
+
f'{signature_key} is not in the comparison_results. Available'
|
88
|
+
f' signature keys are: {self.available_signature_keys()}'
|
89
|
+
)
|
90
|
+
return self._comparison_results[signature_key]
|
91
|
+
|
92
|
+
def available_signature_keys(self) -> list[str]:
|
93
|
+
"""Get the available signature keys in the comparison result."""
|
94
|
+
return list(self._comparison_results.keys())
|
95
|
+
|
96
|
+
def add_new_signature_results(
|
97
|
+
self,
|
98
|
+
error_metric: str,
|
99
|
+
comparison_result: dict[str, float],
|
100
|
+
signature_key: str = _DEFAULT_SIGNATURE_KEY,
|
101
|
+
):
|
102
|
+
"""Add a new signature result to the comparison result.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
error_metric: The name of the error metric used for comparison.
|
106
|
+
comparison_result: A dictionary of tensor name and its value.
|
107
|
+
signature_key: The model signature that the comparison_result belongs to.
|
108
|
+
|
109
|
+
Raises:
|
110
|
+
ValueError: If the signature_key is already in the comparison_results.
|
111
|
+
"""
|
112
|
+
if signature_key in self._comparison_results:
|
113
|
+
raise ValueError(f'{signature_key} is already in the comparison_results.')
|
114
|
+
|
115
|
+
result = {key: float(value) for key, value in comparison_result.items()}
|
116
|
+
|
117
|
+
input_tensor_results = {}
|
118
|
+
for name in utils.get_input_tensor_names(
|
119
|
+
self._reference_model, signature_key
|
120
|
+
):
|
121
|
+
input_tensor_results[name] = result.pop(name)
|
122
|
+
|
123
|
+
output_tensor_results = {}
|
124
|
+
for name in utils.get_output_tensor_names(
|
125
|
+
self._reference_model, signature_key
|
126
|
+
):
|
127
|
+
output_tensor_results[name] = result.pop(name)
|
128
|
+
|
129
|
+
constant_tensor_results = {}
|
130
|
+
# Only get constant tensors from the main subgraph of the signature.
|
131
|
+
subgraph_index = utils.get_signature_main_subgraph_index(
|
132
|
+
utils.create_tfl_interpreter(self._reference_model),
|
133
|
+
signature_key,
|
134
|
+
)
|
135
|
+
for name in utils.get_constant_tensor_names(
|
136
|
+
self._reference_model,
|
137
|
+
subgraph_index,
|
138
|
+
):
|
139
|
+
constant_tensor_results[name] = result.pop(name)
|
140
|
+
|
141
|
+
self._comparison_results[signature_key] = SingleSignatureComparisonResult(
|
142
|
+
error_metric=error_metric,
|
143
|
+
input_tensors=input_tensor_results,
|
144
|
+
output_tensors=output_tensor_results,
|
145
|
+
constant_tensors=constant_tensor_results,
|
146
|
+
intermediate_tensors=result,
|
147
|
+
)
|
148
|
+
|
149
|
+
def get_all_tensor_results(self) -> dict[str, Any]:
|
150
|
+
"""Get all the tensor results in a single dictionary.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
A dictionary of tensor name and its value.
|
154
|
+
"""
|
155
|
+
result = {}
|
156
|
+
for _, signature_comparison_result in self._comparison_results.items():
|
157
|
+
result.update(signature_comparison_result.input_tensors)
|
158
|
+
result.update(signature_comparison_result.output_tensors)
|
159
|
+
result.update(signature_comparison_result.constant_tensors)
|
160
|
+
result.update(signature_comparison_result.intermediate_tensors)
|
161
|
+
return result
|
162
|
+
|
163
|
+
def save(self, save_folder: str, model_name: str) -> None:
|
164
|
+
"""Saves the model comparison result.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
save_folder: Path to the folder to save the comparison result.
|
168
|
+
model_name: Name of the model.
|
169
|
+
|
170
|
+
Raises:
|
171
|
+
RuntimeError: If no quantized model is available.
|
172
|
+
"""
|
173
|
+
reduced_model_size = len(self._reference_model) - len(self._target_model)
|
174
|
+
reduction_ratio = reduced_model_size / len(self._reference_model) * 100
|
175
|
+
result = {
|
176
|
+
'reduced_size_bytes': reduced_model_size,
|
177
|
+
'reduced_size_percentage': reduction_ratio,
|
178
|
+
}
|
179
|
+
for signature, comparison_result in self._comparison_results.items():
|
180
|
+
result[str(signature)] = {
|
181
|
+
'error_metric': comparison_result.error_metric,
|
182
|
+
'input_tensors': comparison_result.input_tensors,
|
183
|
+
'output_tensors': comparison_result.output_tensors,
|
184
|
+
'constant_tensors': comparison_result.constant_tensors,
|
185
|
+
'intermediate_tensors': comparison_result.intermediate_tensors,
|
186
|
+
}
|
187
|
+
result_save_path = os.path.join(
|
188
|
+
save_folder, model_name + '_comparison_result.json'
|
189
|
+
)
|
190
|
+
with gfile.GFile(result_save_path, 'w') as output_file_handle:
|
191
|
+
output_file_handle.write(json.dumps(result))
|
192
|
+
|
193
|
+
# TODO: b/365578554 - Remove after ME is updated to use the new json format.
|
194
|
+
color_threshold = [0.05, 0.1, 0.2, 0.4, 1, 10, 100]
|
195
|
+
json_object = create_json_for_model_explorer(
|
196
|
+
self,
|
197
|
+
threshold=color_threshold,
|
198
|
+
)
|
199
|
+
json_save_path = os.path.join(
|
200
|
+
save_folder, model_name + '_comparison_result_me_input.json'
|
201
|
+
)
|
202
|
+
with gfile.GFile(json_save_path, 'w') as output_file_handle:
|
203
|
+
output_file_handle.write(json_object)
|
204
|
+
|
205
|
+
|
206
|
+
def _setup_validation_interpreter(
|
207
|
+
model: bytes,
|
208
|
+
signature_input: dict[str, Any],
|
209
|
+
signature_key: Optional[str],
|
210
|
+
use_xnnpack: bool,
|
211
|
+
num_threads: int,
|
212
|
+
) -> tuple[Any, int, dict[str, Any]]:
|
213
|
+
"""Setup the interpreter for validation given a signature key.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
model: The model to be validated.
|
217
|
+
signature_input: A dictionary of input tensor name and its value.
|
218
|
+
signature_key: The signature key to be used for invoking the models. If the
|
219
|
+
model only has one signature, this can be set to None.
|
220
|
+
use_xnnpack: Whether to use xnnpack for the interpreter.
|
221
|
+
num_threads: The number of threads to use for the interpreter.
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
A tuple of interpreter, subgraph_index and tensor_name_to_details.
|
225
|
+
"""
|
226
|
+
|
227
|
+
interpreter = utils.create_tfl_interpreter(
|
228
|
+
tflite_model=model, use_xnnpack=use_xnnpack, num_threads=num_threads
|
229
|
+
)
|
230
|
+
utils.invoke_interpreter_signature(
|
231
|
+
interpreter, signature_input, signature_key
|
232
|
+
)
|
233
|
+
# Only validate tensors from the main subgraph of the signature.
|
234
|
+
subgraph_index = utils.get_signature_main_subgraph_index(
|
235
|
+
interpreter, signature_key
|
236
|
+
)
|
237
|
+
tensor_name_to_details = utils.get_tensor_name_to_details_map(
|
238
|
+
interpreter,
|
239
|
+
subgraph_index,
|
240
|
+
)
|
241
|
+
return interpreter, subgraph_index, tensor_name_to_details
|
242
|
+
|
243
|
+
|
244
|
+
# TODO: b/330797129 - Enable multi-threaded evaluation.
|
245
|
+
def compare_model(
|
246
|
+
reference_model: bytes,
|
247
|
+
target_model: bytes,
|
248
|
+
test_data: dict[str, Iterable[dict[str, Any]]],
|
249
|
+
error_metric: str,
|
250
|
+
compare_fn: Callable[[Any, Any], float],
|
251
|
+
use_xnnpack: bool = True,
|
252
|
+
num_threads: int = 16,
|
253
|
+
) -> ComparisonResult:
|
254
|
+
"""Compares model tensors over a model signature using the compare_fn.
|
255
|
+
|
256
|
+
This function will run the model signature on the provided dataset over and
|
257
|
+
compare all the tensors (cached) using the compare_fn (e.g., mean square
|
258
|
+
error).
|
259
|
+
|
260
|
+
Args:
|
261
|
+
reference_model: Model which will be used as the reference
|
262
|
+
target_model: Target model which will be compared against the reference. We
|
263
|
+
expect reference_model and target_model have the inputs and outputs
|
264
|
+
signature.
|
265
|
+
test_data: A dictionary of signature key and its correspending test input
|
266
|
+
data that will be used for comparison.
|
267
|
+
error_metric: The name of the error metric used for comparison.
|
268
|
+
compare_fn: a comparison function to be used for calculating the statistics,
|
269
|
+
this function must be taking in two ArrayLike strcuture and output a
|
270
|
+
single float value.
|
271
|
+
use_xnnpack: Whether to use xnnpack for the interpreter.
|
272
|
+
num_threads: The number of threads to use for the interpreter.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
A ComparisonResult object.
|
276
|
+
"""
|
277
|
+
model_comparion_result = ComparisonResult(reference_model, target_model)
|
278
|
+
for signature_key, signature_inputs in test_data.items():
|
279
|
+
comparison_results = {}
|
280
|
+
for signature_input in signature_inputs:
|
281
|
+
# Invoke the signature on both interpreters.
|
282
|
+
ref_interpreter, ref_subgraph_index, ref_tensor_name_to_details = (
|
283
|
+
_setup_validation_interpreter(
|
284
|
+
reference_model,
|
285
|
+
signature_input,
|
286
|
+
signature_key,
|
287
|
+
use_xnnpack,
|
288
|
+
num_threads,
|
289
|
+
)
|
290
|
+
)
|
291
|
+
targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
|
292
|
+
_setup_validation_interpreter(
|
293
|
+
target_model,
|
294
|
+
signature_input,
|
295
|
+
signature_key,
|
296
|
+
use_xnnpack,
|
297
|
+
num_threads,
|
298
|
+
)
|
299
|
+
)
|
300
|
+
# Compare the cached tensor values.
|
301
|
+
for tensor_name, detail in ref_tensor_name_to_details.items():
|
302
|
+
if detail['dtype'] == np.object_:
|
303
|
+
continue
|
304
|
+
if tensor_name in targ_tensor_name_to_details:
|
305
|
+
if tensor_name not in comparison_results:
|
306
|
+
comparison_results[tensor_name] = []
|
307
|
+
|
308
|
+
reference_data = utils.get_tensor_data(
|
309
|
+
ref_interpreter, detail, ref_subgraph_index
|
310
|
+
)
|
311
|
+
target_data = utils.get_tensor_data(
|
312
|
+
targ_interpreter,
|
313
|
+
targ_tensor_name_to_details[tensor_name],
|
314
|
+
targ_subgraph_index,
|
315
|
+
)
|
316
|
+
comparison_results[tensor_name].append(
|
317
|
+
compare_fn(target_data, reference_data)
|
318
|
+
)
|
319
|
+
|
320
|
+
agregated_results = {}
|
321
|
+
for tensor_name in comparison_results:
|
322
|
+
agregated_results[tensor_name] = np.mean(comparison_results[tensor_name])
|
323
|
+
model_comparion_result.add_new_signature_results(
|
324
|
+
error_metric,
|
325
|
+
agregated_results,
|
326
|
+
signature_key,
|
327
|
+
)
|
328
|
+
return model_comparion_result
|
329
|
+
|
330
|
+
|
331
|
+
def create_json_for_model_explorer(
|
332
|
+
data: ComparisonResult, threshold: list[Union[int, float]]
|
333
|
+
) -> str:
|
334
|
+
"""create a dict type that can be exported as json for model_explorer to use.
|
335
|
+
|
336
|
+
Args:
|
337
|
+
data: Output from compare_model function
|
338
|
+
threshold: A list of numbers representing thresholds for model_exlorer to
|
339
|
+
display different colors
|
340
|
+
|
341
|
+
Returns:
|
342
|
+
A string represents the json format accepted by model_explorer
|
343
|
+
"""
|
344
|
+
data = data.get_all_tensor_results()
|
345
|
+
color_scheme = []
|
346
|
+
results = {name: {'value': float(value)} for name, value in data.items()}
|
347
|
+
if threshold:
|
348
|
+
green = 255
|
349
|
+
gradient = math.floor(255 / len(threshold))
|
350
|
+
for val in threshold:
|
351
|
+
color_scheme.append({'value': val, 'bgColor': f'rgb(200, {green}, 0)'})
|
352
|
+
green = max(0, green - gradient)
|
353
|
+
|
354
|
+
return json.dumps({
|
355
|
+
'results': results,
|
356
|
+
'thresholds': color_scheme,
|
357
|
+
})
|