ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,361 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Generate model tensor level quantization config."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
from typing import Any, Optional, Union
|
20
|
+
|
21
|
+
from ai_edge_quantizer import algorithm_manager
|
22
|
+
from ai_edge_quantizer import qtyping
|
23
|
+
from ai_edge_quantizer import recipe_manager
|
24
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
25
|
+
|
26
|
+
_QuantTrans = qtyping.QuantTransformation
|
27
|
+
_OpName = qtyping.TFLOperationName
|
28
|
+
|
29
|
+
|
30
|
+
class ParamsGenerator:
|
31
|
+
"""Generate model tensor level quantization parameters."""
|
32
|
+
|
33
|
+
def __init__(self, float_tflite: Union[str, bytes]):
|
34
|
+
self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
35
|
+
|
36
|
+
if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
|
37
|
+
raise ValueError(
|
38
|
+
'The input model for quantization parameters generation is not a'
|
39
|
+
' float model. Please check the model (e.g., if it is already'
|
40
|
+
' quantized).'
|
41
|
+
)
|
42
|
+
self._check_tensor_names_are_unique()
|
43
|
+
self.buffer_to_tensors: dict[int, list[Any]] = (
|
44
|
+
tfl_flatbuffer_utils.buffer_to_tensors(self.flatbuffer_model)
|
45
|
+
)
|
46
|
+
self.model_quant_results: dict[str, qtyping.TensorTransformationParams] = {}
|
47
|
+
|
48
|
+
def generate_quantization_parameters(
|
49
|
+
self,
|
50
|
+
model_recipe_manager: recipe_manager.RecipeManager,
|
51
|
+
model_qsvs: Optional[dict[str, qtyping.QSV]] = None,
|
52
|
+
) -> dict[str, qtyping.TensorTransformationParams]:
|
53
|
+
"""Generate the quantization parameters for the model.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
model_recipe_manager: The recipe manager for the model.
|
57
|
+
model_qsvs: Quantization statistics values (QSVs) for the model. This is
|
58
|
+
obtained through calibration process.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
model_quant_results: The quantization parameters for tensors in the model.
|
62
|
+
|
63
|
+
Raises:
|
64
|
+
RuntimeError: If the calibration dataset is required but not provided.
|
65
|
+
"""
|
66
|
+
if model_recipe_manager.need_calibration() and not model_qsvs:
|
67
|
+
raise RuntimeError(
|
68
|
+
'Model quantization statistics values (QSVs) are required for the'
|
69
|
+
' input recipe. This can be obtained by running calibration on sample'
|
70
|
+
' dataset.'
|
71
|
+
)
|
72
|
+
|
73
|
+
if model_qsvs is None:
|
74
|
+
model_qsvs = {}
|
75
|
+
|
76
|
+
op_codes = self.flatbuffer_model.operatorCodes
|
77
|
+
for subgraph in self.flatbuffer_model.subgraphs:
|
78
|
+
graph_info = qtyping.GraphInfo(
|
79
|
+
subgraph.tensors, self.flatbuffer_model.buffers
|
80
|
+
)
|
81
|
+
# Add input/output operators to the subgraph.
|
82
|
+
subgraph.operators += (
|
83
|
+
tfl_flatbuffer_utils.get_subgraph_input_output_operators(subgraph)
|
84
|
+
)
|
85
|
+
for subgraph_op_id, op in enumerate(subgraph.operators):
|
86
|
+
# Get the op key.
|
87
|
+
if isinstance(op, qtyping.IOOperator):
|
88
|
+
op_key = op.op_key
|
89
|
+
subgraph_op_id = -1 # Virtual op, no real id.
|
90
|
+
else:
|
91
|
+
op_code = op_codes[op.opcodeIndex].builtinCode
|
92
|
+
# Do not quantize unknown ops.
|
93
|
+
if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
|
94
|
+
op_quant_results = self._get_params_for_no_quant_op(
|
95
|
+
subgraph_op_id, op, subgraph.tensors
|
96
|
+
)
|
97
|
+
self._update_model_quant_results(op_quant_results)
|
98
|
+
continue
|
99
|
+
op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
|
100
|
+
|
101
|
+
# Step1: query the quantization_recipe to get op config.
|
102
|
+
op_scope = self._get_op_scope(op, subgraph.tensors)
|
103
|
+
algorithm_name, op_quant_config = (
|
104
|
+
model_recipe_manager.get_quantization_configs(op_key, op_scope)
|
105
|
+
)
|
106
|
+
if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
|
107
|
+
op_quant_results = self._get_params_for_no_quant_op(
|
108
|
+
subgraph_op_id, op, subgraph.tensors
|
109
|
+
)
|
110
|
+
else:
|
111
|
+
op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
|
112
|
+
# Step2: query algorithm_manager to get/call the related function.
|
113
|
+
materialize_func = algorithm_manager.get_quantization_func(
|
114
|
+
algorithm_name,
|
115
|
+
op_key,
|
116
|
+
qtyping.QuantizeMode.MATERIALIZE,
|
117
|
+
)
|
118
|
+
op_quant_results = materialize_func(
|
119
|
+
op_info,
|
120
|
+
graph_info,
|
121
|
+
model_qsvs,
|
122
|
+
)
|
123
|
+
# Step3: update the results.
|
124
|
+
self._update_model_quant_results(op_quant_results)
|
125
|
+
self._post_process_results()
|
126
|
+
return self.model_quant_results
|
127
|
+
|
128
|
+
def _check_tensor_names_are_unique(self):
|
129
|
+
"""Checks if the tensor names are unique in the model."""
|
130
|
+
global_tensor_names = set()
|
131
|
+
for subgraph in self.flatbuffer_model.subgraphs:
|
132
|
+
for tensor in subgraph.tensors:
|
133
|
+
tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
|
134
|
+
if tensor_name in global_tensor_names:
|
135
|
+
raise ValueError(
|
136
|
+
'Tensor name %s is not unique in the model. Please check your'
|
137
|
+
' model and rename the tensor as ParamsGenerator assumes tensor'
|
138
|
+
' names are unique.' % tensor_name
|
139
|
+
)
|
140
|
+
global_tensor_names.add(tensor_name)
|
141
|
+
|
142
|
+
def _post_process_results(self) -> None:
|
143
|
+
"""Post process the quantization results.
|
144
|
+
|
145
|
+
Raises:
|
146
|
+
RuntimeError: If the tensors sharing the same buffer have different
|
147
|
+
quantization settings.
|
148
|
+
"""
|
149
|
+
self._check_buffer_sharing()
|
150
|
+
|
151
|
+
def _update_model_quant_results(
|
152
|
+
self,
|
153
|
+
op_tensor_results: list[qtyping.TensorTransformationParams],
|
154
|
+
) -> None:
|
155
|
+
"""Update the op quantization results to the final output.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
op_tensor_results: Tensor level quantization params for the op.
|
159
|
+
|
160
|
+
Raises:
|
161
|
+
RuntimeError: If the same tensor has multiple quantization configs.
|
162
|
+
"""
|
163
|
+
|
164
|
+
for op_tensor_result in op_tensor_results:
|
165
|
+
tensor_name = op_tensor_result.tensor_name
|
166
|
+
if tensor_name not in self.model_quant_results:
|
167
|
+
self.model_quant_results[tensor_name] = copy.deepcopy(op_tensor_result)
|
168
|
+
else:
|
169
|
+
tensor_params = self.model_quant_results[tensor_name]
|
170
|
+
# Set source op.
|
171
|
+
if op_tensor_result.producer is not None:
|
172
|
+
# Src params must be unique (a tensor can only be produced by one op).
|
173
|
+
if tensor_params.producer is not None:
|
174
|
+
raise RuntimeError(
|
175
|
+
'Tensor %s received multiple quantization parameters from the'
|
176
|
+
' source op, which should not happen as every tensor should'
|
177
|
+
' have only one source op.' % tensor_name
|
178
|
+
)
|
179
|
+
tensor_params.producer = copy.deepcopy(op_tensor_result.producer)
|
180
|
+
# Set target op, which can be multiple (a tensor can be consumed by
|
181
|
+
# multiple ops).
|
182
|
+
if op_tensor_result.consumers is not None:
|
183
|
+
if tensor_params.consumers is None:
|
184
|
+
tensor_params.consumers = copy.deepcopy(op_tensor_result.consumers)
|
185
|
+
else:
|
186
|
+
tensor_params.consumers += copy.deepcopy(op_tensor_result.consumers)
|
187
|
+
self.model_quant_results[tensor_name] = tensor_params
|
188
|
+
|
189
|
+
def _get_op_scope(self, op: Any, subgraph_tensors: list[Any]) -> str:
|
190
|
+
"""Get the op scope.
|
191
|
+
|
192
|
+
Op scope is defined by the output tensor names (following the Model
|
193
|
+
Explorer).
|
194
|
+
|
195
|
+
Args:
|
196
|
+
op: The op that needs to be parsed.
|
197
|
+
subgraph_tensors: Tensors in the subgraph.
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
Scope for the op.
|
201
|
+
"""
|
202
|
+
scope = ''
|
203
|
+
# Op scope is determined by output tensors.
|
204
|
+
for output_tensor_idx in op.outputs:
|
205
|
+
if output_tensor_idx != -1:
|
206
|
+
scope += tfl_flatbuffer_utils.get_tensor_name(
|
207
|
+
subgraph_tensors[output_tensor_idx]
|
208
|
+
)
|
209
|
+
scope += ';' # Split names.
|
210
|
+
return scope
|
211
|
+
|
212
|
+
def _get_params_for_no_quant_op(
|
213
|
+
self,
|
214
|
+
subgraph_op_id: int,
|
215
|
+
op: Any,
|
216
|
+
subgraph_tensors: list[Any],
|
217
|
+
) -> list[qtyping.TensorTransformationParams]:
|
218
|
+
"""Get the quantization parameters for ops require no quantization.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
subgraph_op_id: The op id in the subgraph.
|
222
|
+
op: The op that needs to be parsed.
|
223
|
+
subgraph_tensors: Tensors in the subgraph.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
Tensor level quantization params for the op.
|
227
|
+
"""
|
228
|
+
|
229
|
+
def no_quant_tensor_params():
|
230
|
+
return qtyping.OpToTensorParams(
|
231
|
+
subgraph_op_id=subgraph_op_id,
|
232
|
+
transformations=[_QuantTrans.NO_QUANTIZE],
|
233
|
+
)
|
234
|
+
|
235
|
+
tensor_params = []
|
236
|
+
for input_tensor_idx in op.inputs:
|
237
|
+
if input_tensor_idx != -1:
|
238
|
+
tensor = subgraph_tensors[input_tensor_idx]
|
239
|
+
input_tensor_params = qtyping.TensorTransformationParams(
|
240
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(tensor),
|
241
|
+
consumers=[no_quant_tensor_params()],
|
242
|
+
)
|
243
|
+
tensor_params.append(input_tensor_params)
|
244
|
+
|
245
|
+
for output_tensor_idx in op.outputs:
|
246
|
+
if output_tensor_idx != -1:
|
247
|
+
tensor = subgraph_tensors[output_tensor_idx]
|
248
|
+
output_tensor_params = qtyping.TensorTransformationParams(
|
249
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(tensor),
|
250
|
+
producer=no_quant_tensor_params(),
|
251
|
+
)
|
252
|
+
tensor_params.append(output_tensor_params)
|
253
|
+
return tensor_params
|
254
|
+
|
255
|
+
def _check_buffer_sharing(self) -> None:
|
256
|
+
"""Check if tensors sharing the same buffer have the same quantization.
|
257
|
+
|
258
|
+
Raises:
|
259
|
+
RuntimeError: If the tensors sharing the same buffer have different
|
260
|
+
quantization settings.
|
261
|
+
"""
|
262
|
+
for tensors in self.buffer_to_tensors.values():
|
263
|
+
if len(tensors) <= 1:
|
264
|
+
continue
|
265
|
+
first_tensor = tensors[0]
|
266
|
+
first_tensor_params = self.model_quant_results[
|
267
|
+
tfl_flatbuffer_utils.get_tensor_name(first_tensor)
|
268
|
+
]
|
269
|
+
for tensor in tensors[1:]:
|
270
|
+
tensor_params = self.model_quant_results[
|
271
|
+
tfl_flatbuffer_utils.get_tensor_name(tensor)
|
272
|
+
]
|
273
|
+
if not _compatible_tensor_transformation_params(
|
274
|
+
first_tensor_params, tensor_params
|
275
|
+
):
|
276
|
+
error_msg = (
|
277
|
+
f'The tensors {first_tensor.name} and {tensor.name} do not have'
|
278
|
+
' the same quantization parameters even though they share the'
|
279
|
+
' same buffer. Please modify your quantization recipe to make'
|
280
|
+
' sure the two tensors have the same quantization settings.'
|
281
|
+
)
|
282
|
+
raise RuntimeError(error_msg)
|
283
|
+
|
284
|
+
|
285
|
+
def _compatible_tensor_transformation_params(
|
286
|
+
params1: qtyping.TensorTransformationParams,
|
287
|
+
params2: qtyping.TensorTransformationParams,
|
288
|
+
) -> bool:
|
289
|
+
"""Check if two tensor transformation params are compatible."""
|
290
|
+
if params1.producer is None or params2.producer is None:
|
291
|
+
if params1.producer != params2.producer:
|
292
|
+
return False
|
293
|
+
elif not _compatible_tensor_params(params1.producer, params2.producer):
|
294
|
+
return False
|
295
|
+
if params1.consumers is None or params2.consumers is None:
|
296
|
+
if params1.consumers != params2.consumers:
|
297
|
+
return False
|
298
|
+
else:
|
299
|
+
# Check all consumers within each params are compatible.
|
300
|
+
for params1_consumer in params1.consumers:
|
301
|
+
if not _compatible_tensor_params(params1_consumer, params1.consumers[0]):
|
302
|
+
return False
|
303
|
+
for params2_consumer in params2.consumers:
|
304
|
+
if not _compatible_tensor_params(params2_consumer, params2.consumers[0]):
|
305
|
+
return False
|
306
|
+
if not _compatible_tensor_params(
|
307
|
+
params1.consumers[0], params2.consumers[0]
|
308
|
+
):
|
309
|
+
return False
|
310
|
+
return True
|
311
|
+
|
312
|
+
|
313
|
+
def _same_tensor_params_except_id(
|
314
|
+
params1: qtyping.OpToTensorParams,
|
315
|
+
params2: qtyping.OpToTensorParams,
|
316
|
+
) -> bool:
|
317
|
+
"""Check if two op to tensor params are the same except for subgraph_op_id."""
|
318
|
+
return params1.transformations == params2.transformations and (
|
319
|
+
params1.parameters == params2.parameters
|
320
|
+
or params1.parameters is None
|
321
|
+
and params2.parameters is None
|
322
|
+
)
|
323
|
+
|
324
|
+
|
325
|
+
def _compatible_tensor_params(
|
326
|
+
params1: qtyping.OpToTensorParams,
|
327
|
+
params2: qtyping.OpToTensorParams,
|
328
|
+
) -> bool:
|
329
|
+
"""Check if two op to tensor params are compatible."""
|
330
|
+
float_source_transformations = [
|
331
|
+
_QuantTrans.ADD_QUANTIZE,
|
332
|
+
_QuantTrans.NO_QUANTIZE,
|
333
|
+
]
|
334
|
+
quantized_source_transformations = [
|
335
|
+
_QuantTrans.QUANTIZE_TENSOR,
|
336
|
+
_QuantTrans.ADD_DEQUANTIZE,
|
337
|
+
]
|
338
|
+
if _same_tensor_params_except_id(params1, params2):
|
339
|
+
return True
|
340
|
+
if (
|
341
|
+
params1.transformations[0] != _QuantTrans.NO_QUANTIZE
|
342
|
+
and params2.transformations[0] != _QuantTrans.NO_QUANTIZE
|
343
|
+
):
|
344
|
+
# NO_QUANTIZE has no parameters. So only if both params aren't NO_QUANTIZE
|
345
|
+
# do we expect the parameters to be the same.
|
346
|
+
if params1.parameters != params2.parameters:
|
347
|
+
return False
|
348
|
+
# We only need to check the first transformation because transformations are
|
349
|
+
# applied in order, and as long as the one that's immediately after the tensor
|
350
|
+
# is the same, it's compatible.
|
351
|
+
if (
|
352
|
+
params1.transformations[0] in float_source_transformations
|
353
|
+
and params2.transformations[0] in float_source_transformations
|
354
|
+
):
|
355
|
+
return True
|
356
|
+
if (
|
357
|
+
params1.transformations[0] in quantized_source_transformations
|
358
|
+
and params2.transformations[0] in quantized_source_transformations
|
359
|
+
):
|
360
|
+
return True
|
361
|
+
return False
|