ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,288 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Quantization Calibration."""
|
17
|
+
|
18
|
+
from collections.abc import Callable, Iterable
|
19
|
+
import copy
|
20
|
+
from typing import Any, Union
|
21
|
+
|
22
|
+
from absl import logging
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
from ai_edge_quantizer import algorithm_manager
|
26
|
+
from ai_edge_quantizer import qtyping
|
27
|
+
from ai_edge_quantizer import recipe_manager
|
28
|
+
from ai_edge_quantizer.utils import calibration_utils
|
29
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
30
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
31
|
+
|
32
|
+
_SignatureInput = dict[str, Any] # input_argument_name -> tensor_value.
|
33
|
+
_SignatureOutput = dict[
|
34
|
+
str, np.ndarray
|
35
|
+
] # output_argument_name -> tensor_value.
|
36
|
+
|
37
|
+
|
38
|
+
class Calibrator:
|
39
|
+
"""Calibrator for TFLite model."""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
float_tflite: Union[str, bytes],
|
44
|
+
num_threads: int = 16,
|
45
|
+
):
|
46
|
+
self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
47
|
+
|
48
|
+
if not tfl_flatbuffer_utils.is_float_model(self._flatbuffer_model):
|
49
|
+
raise ValueError(
|
50
|
+
"The input model for calibration is not a float model. Please check"
|
51
|
+
" the model (e.g., if it is already quantized)."
|
52
|
+
)
|
53
|
+
self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
54
|
+
float_tflite, use_xnnpack=True, num_threads=num_threads
|
55
|
+
)
|
56
|
+
# Tensor name to tensor content.
|
57
|
+
self._tensor_content_map: dict[str, Any] = {}
|
58
|
+
# QSV of all the tensors in the model.
|
59
|
+
self._model_qsvs: dict[str, qtyping.QSV] = {}
|
60
|
+
# Cached output of the model.
|
61
|
+
self._cached_output: list[_SignatureOutput] = []
|
62
|
+
|
63
|
+
# TODO(b/330740605)- Collect multiple QSVs in one run to save compute.
|
64
|
+
def calibrate(
|
65
|
+
self,
|
66
|
+
calibration_dataset: dict[str, Iterable[_SignatureInput]],
|
67
|
+
model_recipe_manager: recipe_manager.RecipeManager,
|
68
|
+
cache_output: bool = False,
|
69
|
+
qsv_update_func: Callable[
|
70
|
+
[qtyping.QSV, qtyping.QSV],
|
71
|
+
qtyping.QSV,
|
72
|
+
] = calibration_utils.moving_average_update,
|
73
|
+
) -> None:
|
74
|
+
"""Calibrates the model using the given dataset for a model signature.
|
75
|
+
|
76
|
+
The process is
|
77
|
+
0. Initialize quantization statistics values (QSVs) using the initialization
|
78
|
+
function (from AlgorithmManager) for the op if needed.
|
79
|
+
1. Invoke TFL interpreter on the calibration data.
|
80
|
+
2. Go through each op, ask RecipeManager for the quantization setting
|
81
|
+
of the op.
|
82
|
+
3. Ask AlgorithmManager for the calibration function of the op given the
|
83
|
+
quantization setting.
|
84
|
+
4. Apply the function to the op to obtain quantization statistics (qsvs) for
|
85
|
+
the tensors associated with the op.
|
86
|
+
5. Update the global qsv dictionary
|
87
|
+
6. Start another round of calibration.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
calibration_dataset: A dictionary of input data for calibration for the
|
91
|
+
given model signature.
|
92
|
+
model_recipe_manager: A RecipeManager object that contains the
|
93
|
+
quantization recipe.
|
94
|
+
cache_output: Whether to cache the output of the model during the
|
95
|
+
calibration process. This is useful if there are dependencies between
|
96
|
+
signatures/models (e.g., decode requires encode output).
|
97
|
+
qsv_update_func: The function to update the QSVs.
|
98
|
+
"""
|
99
|
+
op_codes = self._flatbuffer_model.operatorCodes
|
100
|
+
if not self._model_qsvs:
|
101
|
+
self._initialize_model_qsvs(model_recipe_manager)
|
102
|
+
else:
|
103
|
+
logging.warning(
|
104
|
+
"Calibrator contains non-empty model qsvs, and the current"
|
105
|
+
" calibration process will start on top of this state (i.e., update"
|
106
|
+
" the existing qsvs). If this is an unintended behavior please call"
|
107
|
+
" reset_model_qsvs to reset model qsvs."
|
108
|
+
)
|
109
|
+
|
110
|
+
# TODO: b/329322226 - Enable parallel calibration.
|
111
|
+
for signature_key, dataset in calibration_dataset.items():
|
112
|
+
# Step0: get subgraph index.
|
113
|
+
subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
|
114
|
+
self._tfl_interpreter, signature_key
|
115
|
+
)
|
116
|
+
|
117
|
+
for data in dataset:
|
118
|
+
# Initialize tensor names that are updated in this round of calibration.
|
119
|
+
updated_tensor_names = set()
|
120
|
+
|
121
|
+
# Step1: run tfl interpreter on subgraph to get tensor content.
|
122
|
+
signature_output = tfl_interpreter_utils.invoke_interpreter_signature(
|
123
|
+
self._tfl_interpreter, data, signature_key
|
124
|
+
)
|
125
|
+
if cache_output:
|
126
|
+
self._cached_output.append(signature_output)
|
127
|
+
self._tensor_content_map.update(
|
128
|
+
tfl_interpreter_utils.get_tensor_name_to_content_map(
|
129
|
+
self._tfl_interpreter, subgraph_idx
|
130
|
+
)
|
131
|
+
)
|
132
|
+
|
133
|
+
# Step2: go through each op in subgraph to update quantization
|
134
|
+
# statistic values.
|
135
|
+
subgraph = self._flatbuffer_model.subgraphs[subgraph_idx]
|
136
|
+
graph_info = qtyping.GraphInfo(
|
137
|
+
subgraph.tensors, self._flatbuffer_model.buffers
|
138
|
+
)
|
139
|
+
# Add input/output operators to the subgraph.
|
140
|
+
subgraph.operators += (
|
141
|
+
tfl_flatbuffer_utils.get_subgraph_input_output_operators(subgraph)
|
142
|
+
)
|
143
|
+
for op in subgraph.operators:
|
144
|
+
if isinstance(op, qtyping.IOOperator):
|
145
|
+
op_key = op.op_key
|
146
|
+
else:
|
147
|
+
op_code = op_codes[op.opcodeIndex].builtinCode
|
148
|
+
if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
|
149
|
+
continue
|
150
|
+
op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
|
151
|
+
# Step2.1: query the quantization_recipe to get op quantization
|
152
|
+
# settings.
|
153
|
+
op_scope = self._get_op_scope(op, subgraph.tensors)
|
154
|
+
algorithm_name, _ = model_recipe_manager.get_quantization_configs(
|
155
|
+
op_key, op_scope
|
156
|
+
)
|
157
|
+
if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
|
158
|
+
continue
|
159
|
+
# Step2.2: query algorithm_manager to get/call the related calibration
|
160
|
+
# function.
|
161
|
+
calibrate_func = algorithm_manager.get_quantization_func(
|
162
|
+
algorithm_name, op_key, qtyping.QuantizeMode.CALIBRATE
|
163
|
+
)
|
164
|
+
op_qsvs = calibrate_func(op, graph_info, self._tensor_content_map)
|
165
|
+
# Step3: Update tensor qsvs with the new values. Ignore the tensor
|
166
|
+
# names that are already updated in this round of calibration.
|
167
|
+
op_updated_tensor_name = self._update_qsvs(
|
168
|
+
op_qsvs, updated_tensor_names, qsv_update_func
|
169
|
+
)
|
170
|
+
updated_tensor_names.update(op_updated_tensor_name)
|
171
|
+
# Reset interpreter after one round of calibration.
|
172
|
+
self._tfl_interpreter.reset_all_variables()
|
173
|
+
|
174
|
+
def get_model_qsvs(self) -> dict[str, qtyping.QSV]:
|
175
|
+
"""Get the model qsvs.
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
A dictionary of tensor name to QSV.
|
179
|
+
"""
|
180
|
+
return self._model_qsvs
|
181
|
+
|
182
|
+
def get_cached_output(self) -> list[_SignatureOutput]:
|
183
|
+
"""Get the cached output of the model."""
|
184
|
+
return self._cached_output
|
185
|
+
|
186
|
+
def clear_cached_output(self) -> None:
|
187
|
+
"""Clear the cached output of the model."""
|
188
|
+
self._cached_output = []
|
189
|
+
|
190
|
+
def reset_model_qsvs(self) -> None:
|
191
|
+
"""Reset the model qsvs."""
|
192
|
+
self._model_qsvs = {}
|
193
|
+
|
194
|
+
def load_model_qsvs(self, model_qsvs: dict[str, qtyping.QSV]) -> None:
|
195
|
+
"""Load the model qsvs.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
model_qsvs: A dictionary of tensor name to QSV.
|
199
|
+
"""
|
200
|
+
self._model_qsvs = copy.deepcopy(model_qsvs)
|
201
|
+
|
202
|
+
def _update_qsvs(
|
203
|
+
self,
|
204
|
+
op_qsvs: dict[str, qtyping.QSV],
|
205
|
+
ignore_tensor_names: set[str],
|
206
|
+
qsv_update_func: Callable[[qtyping.QSV, qtyping.QSV], qtyping.QSV],
|
207
|
+
) -> set[str]:
|
208
|
+
"""Update the model qsvs with the new values.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
op_qsvs: A dictionary of tensor name to QSV.
|
212
|
+
ignore_tensor_names: A set of tensor names to ignore.
|
213
|
+
qsv_update_func: The function to update the QSVs.
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
A set of tensor names that are updated.
|
217
|
+
"""
|
218
|
+
updated_tensor_names = set()
|
219
|
+
for tensor_name, qsv in op_qsvs.items():
|
220
|
+
if tensor_name in ignore_tensor_names:
|
221
|
+
continue
|
222
|
+
if tensor_name not in self._model_qsvs:
|
223
|
+
self._model_qsvs[tensor_name] = qsv
|
224
|
+
else:
|
225
|
+
updated_qsv = qsv_update_func(self._model_qsvs[tensor_name], qsv)
|
226
|
+
self._model_qsvs[tensor_name] = updated_qsv
|
227
|
+
updated_tensor_names.add(tensor_name)
|
228
|
+
return updated_tensor_names
|
229
|
+
|
230
|
+
def _get_op_scope(self, op, subgraph_tensors) -> str:
|
231
|
+
"""Get the scope of the op.
|
232
|
+
|
233
|
+
The scope is the name of the output tensor of the op.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
op: The op to get the scope.
|
237
|
+
subgraph_tensors: The tensors in the subgraph.
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
The scope of the op.
|
241
|
+
"""
|
242
|
+
scope = ""
|
243
|
+
for output_tensor_idx in op.outputs:
|
244
|
+
if output_tensor_idx != -1:
|
245
|
+
output_tensor = subgraph_tensors[output_tensor_idx]
|
246
|
+
scope += tfl_flatbuffer_utils.get_tensor_name(output_tensor)
|
247
|
+
return scope
|
248
|
+
|
249
|
+
# TODO: b/354224138 - Remove code duplication between calibrate and
|
250
|
+
# _initialize_model_qsvs.
|
251
|
+
def _initialize_model_qsvs(
|
252
|
+
self, model_recipe_manager: recipe_manager.RecipeManager
|
253
|
+
) -> None:
|
254
|
+
"""Initialize the model qsvs.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
model_recipe_manager: A RecipeManager object that contains the
|
258
|
+
quantization recipe.
|
259
|
+
"""
|
260
|
+
op_codes = self._flatbuffer_model.operatorCodes
|
261
|
+
for subgraph in self._flatbuffer_model.subgraphs:
|
262
|
+
graph_info = qtyping.GraphInfo(
|
263
|
+
subgraph.tensors, self._flatbuffer_model.buffers
|
264
|
+
)
|
265
|
+
for subgraph_op_id, op in enumerate(subgraph.operators):
|
266
|
+
op_code = op_codes[op.opcodeIndex].builtinCode
|
267
|
+
if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
|
268
|
+
continue
|
269
|
+
op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
|
270
|
+
# Step1: query the quantization_recipe to get op quantization
|
271
|
+
# settings.
|
272
|
+
op_scope = self._get_op_scope(op, subgraph.tensors)
|
273
|
+
algorithm_name, op_quant_config = (
|
274
|
+
model_recipe_manager.get_quantization_configs(op_key, op_scope)
|
275
|
+
)
|
276
|
+
if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
|
277
|
+
continue
|
278
|
+
# Step2: query algorithm_manager to get/call the related qsv init
|
279
|
+
# function.
|
280
|
+
qsv_init_func = algorithm_manager.get_init_qsv_func(
|
281
|
+
algorithm_name, op_key
|
282
|
+
)
|
283
|
+
op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
|
284
|
+
op_qsvs = qsv_init_func(op_info, graph_info)
|
285
|
+
# Step3: initialize tensor qsvs.
|
286
|
+
for tensor_name, qsv in op_qsvs.items():
|
287
|
+
if tensor_name not in self._model_qsvs:
|
288
|
+
self._model_qsvs[tensor_name] = qsv
|
@@ -0,0 +1,297 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Tests for calibrator."""
|
17
|
+
|
18
|
+
from collections.abc import Generator
|
19
|
+
import os
|
20
|
+
from typing import Any
|
21
|
+
|
22
|
+
import numpy as np
|
23
|
+
|
24
|
+
from tensorflow.python.platform import googletest
|
25
|
+
from ai_edge_quantizer import calibrator
|
26
|
+
from ai_edge_quantizer import qtyping
|
27
|
+
from ai_edge_quantizer import recipe_manager
|
28
|
+
from ai_edge_quantizer.utils import test_utils
|
29
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
30
|
+
|
31
|
+
_ComputePrecision = qtyping.ComputePrecision
|
32
|
+
_AlgorithmName = recipe_manager.AlgorithmName
|
33
|
+
|
34
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("")
|
35
|
+
_TENSOR_QUANT_CONFIG = qtyping.TensorQuantizationConfig
|
36
|
+
|
37
|
+
TEST_MIN_VAL, TEST_MAX_VAL = -1, 1
|
38
|
+
|
39
|
+
_RNG = np.random.default_rng(66)
|
40
|
+
|
41
|
+
|
42
|
+
def _representative_dataset_gen(size=(1, 8), num_samples=10):
|
43
|
+
for _ in range(num_samples):
|
44
|
+
vals = np.random.rand(*size).astype(np.float32)
|
45
|
+
vals[0][0], vals[0][1] = (
|
46
|
+
TEST_MIN_VAL,
|
47
|
+
TEST_MAX_VAL,
|
48
|
+
) # fix min/max for testing
|
49
|
+
yield {"input_1": vals}
|
50
|
+
|
51
|
+
|
52
|
+
def _get_calibration_data(
|
53
|
+
dataset_gen: Generator[dict[str, Any], Any, None],
|
54
|
+
) -> dict[str, Any]:
|
55
|
+
calibration_samples = [sample for sample in dataset_gen]
|
56
|
+
calibration_data = {
|
57
|
+
tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples,
|
58
|
+
}
|
59
|
+
return calibration_data
|
60
|
+
|
61
|
+
|
62
|
+
def _add_default_int8xint8_integer_recipe(recipe_manager_object):
|
63
|
+
recipe_manager_object.add_quantization_config(
|
64
|
+
regex=".*",
|
65
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
66
|
+
algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
67
|
+
op_config=qtyping.OpQuantizationConfig(
|
68
|
+
activation_tensor_config=_TENSOR_QUANT_CONFIG(
|
69
|
+
num_bits=8, symmetric=False
|
70
|
+
),
|
71
|
+
weight_tensor_config=_TENSOR_QUANT_CONFIG(num_bits=8, symmetric=True),
|
72
|
+
compute_precision=_ComputePrecision.INTEGER,
|
73
|
+
),
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
class CalibratorTest(googletest.TestCase):
|
78
|
+
|
79
|
+
def setUp(self):
|
80
|
+
super().setUp()
|
81
|
+
np.random.seed(0)
|
82
|
+
self._test_model_path = os.path.join(
|
83
|
+
TEST_DATA_PREFIX_PATH, "tests/models/single_fc.tflite"
|
84
|
+
)
|
85
|
+
self._calibrator = calibrator.Calibrator(self._test_model_path)
|
86
|
+
self._recipe_manager = recipe_manager.RecipeManager()
|
87
|
+
dataset_gen = _representative_dataset_gen()
|
88
|
+
self._representative_dataset = _get_calibration_data(dataset_gen)
|
89
|
+
|
90
|
+
def test_calibrator_state_manipulation(self):
|
91
|
+
# load/get qsvs
|
92
|
+
sample_qsv = {"serving_default_input_1:0": {"min": -10, "max": 8}}
|
93
|
+
self._calibrator.load_model_qsvs(sample_qsv)
|
94
|
+
model_tensor_qsvs = self._calibrator.get_model_qsvs()
|
95
|
+
self.assertLen(model_tensor_qsvs, 1)
|
96
|
+
self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
|
97
|
+
input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
|
98
|
+
self.assertEqual(input_qsv["min"], -10)
|
99
|
+
self.assertEqual(input_qsv["max"], 8)
|
100
|
+
|
101
|
+
# reset qsvs
|
102
|
+
self._calibrator.reset_model_qsvs()
|
103
|
+
model_tensor_qsvs = self._calibrator.get_model_qsvs()
|
104
|
+
self.assertEmpty(model_tensor_qsvs)
|
105
|
+
|
106
|
+
def test_calibrator_initialize_qsv(self):
|
107
|
+
_add_default_int8xint8_integer_recipe(self._recipe_manager)
|
108
|
+
# Overwrite the single op to fc
|
109
|
+
self._recipe_manager.add_quantization_config(
|
110
|
+
regex=".*Stateful.*",
|
111
|
+
operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
112
|
+
algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
113
|
+
op_config=qtyping.OpQuantizationConfig(
|
114
|
+
weight_tensor_config=_TENSOR_QUANT_CONFIG(
|
115
|
+
num_bits=4,
|
116
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
117
|
+
),
|
118
|
+
compute_precision=_ComputePrecision.INTEGER,
|
119
|
+
),
|
120
|
+
)
|
121
|
+
self._calibrator._initialize_model_qsvs(self._recipe_manager)
|
122
|
+
model_tensor_qsvs = self._calibrator.get_model_qsvs()
|
123
|
+
|
124
|
+
self.assertLen(model_tensor_qsvs, 4)
|
125
|
+
self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
|
126
|
+
input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
|
127
|
+
self.assertEmpty(input_qsv)
|
128
|
+
|
129
|
+
self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
|
130
|
+
weight_tensor_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
|
131
|
+
mins_maxs_shape = (16, 1)
|
132
|
+
self.assertTupleEqual(weight_tensor_qsv["min"].shape, mins_maxs_shape)
|
133
|
+
self.assertAlmostEqual(weight_tensor_qsv["min"][0][0], -0.40436327)
|
134
|
+
self.assertTupleEqual(weight_tensor_qsv["max"].shape, mins_maxs_shape)
|
135
|
+
self.assertAlmostEqual(weight_tensor_qsv["max"][0][0], 0.46138108)
|
136
|
+
|
137
|
+
self.assertIn(
|
138
|
+
"sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
|
139
|
+
) # bias
|
140
|
+
bias_tensor_qsv = model_tensor_qsvs[
|
141
|
+
"sequential/dense/BiasAdd/ReadVariableOp"
|
142
|
+
]
|
143
|
+
mins_maxs_shape = (16,)
|
144
|
+
self.assertTupleEqual(bias_tensor_qsv["min"].shape, mins_maxs_shape)
|
145
|
+
self.assertAlmostEqual(bias_tensor_qsv["min"][0], -0.26978338)
|
146
|
+
self.assertTupleEqual(bias_tensor_qsv["max"].shape, mins_maxs_shape)
|
147
|
+
# Here bias min/max will be the same as each element is a scalar
|
148
|
+
# Bias will be quantized with input_scale * weight_scale.
|
149
|
+
self.assertSequenceEqual(
|
150
|
+
list(bias_tensor_qsv["max"].flatten()),
|
151
|
+
list(bias_tensor_qsv["min"].flatten()),
|
152
|
+
)
|
153
|
+
|
154
|
+
self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
|
155
|
+
output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
|
156
|
+
self.assertEmpty(output_qsv)
|
157
|
+
|
158
|
+
def test_calibrate_single_fc_success(self):
|
159
|
+
_add_default_int8xint8_integer_recipe(self._recipe_manager)
|
160
|
+
self._calibrator.calibrate(
|
161
|
+
self._representative_dataset, self._recipe_manager
|
162
|
+
)
|
163
|
+
model_tensor_qsvs = self._calibrator.get_model_qsvs()
|
164
|
+
|
165
|
+
self.assertLen(model_tensor_qsvs, 4)
|
166
|
+
self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
|
167
|
+
input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
|
168
|
+
self.assertSequenceAlmostEqual(
|
169
|
+
input_qsv["min"].flatten(), [TEST_MIN_VAL], delta=1e-5
|
170
|
+
)
|
171
|
+
self.assertSequenceAlmostEqual(
|
172
|
+
input_qsv["max"].flatten(), [TEST_MAX_VAL], delta=1e-5
|
173
|
+
)
|
174
|
+
|
175
|
+
self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
|
176
|
+
weight_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
|
177
|
+
self.assertSequenceAlmostEqual(weight_qsv["min"].flatten(), [-0.49114203])
|
178
|
+
self.assertSequenceAlmostEqual(weight_qsv["max"].flatten(), [0.4903704])
|
179
|
+
|
180
|
+
self.assertIn(
|
181
|
+
"sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
|
182
|
+
) # bias
|
183
|
+
bias_qsv = model_tensor_qsvs["sequential/dense/BiasAdd/ReadVariableOp"]
|
184
|
+
self.assertSequenceAlmostEqual(bias_qsv["min"].flatten(), [-0.38401994])
|
185
|
+
self.assertSequenceAlmostEqual(bias_qsv["max"].flatten(), [0.31727126])
|
186
|
+
|
187
|
+
self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
|
188
|
+
output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
|
189
|
+
# Relu, only check the min
|
190
|
+
self.assertSequenceAlmostEqual(output_qsv["min"].flatten(), [0])
|
191
|
+
|
192
|
+
def test_calibration_cache_is_empty_when_off(self):
|
193
|
+
_add_default_int8xint8_integer_recipe(self._recipe_manager)
|
194
|
+
self.assertEmpty(self._calibrator.get_cached_output())
|
195
|
+
self._calibrator.calibrate(
|
196
|
+
self._representative_dataset, self._recipe_manager, cache_output=False
|
197
|
+
)
|
198
|
+
self.assertEmpty(self._calibrator.get_cached_output())
|
199
|
+
|
200
|
+
def test_calibration_cache_when_on(self):
|
201
|
+
_add_default_int8xint8_integer_recipe(self._recipe_manager)
|
202
|
+
self.assertEmpty(self._calibrator.get_cached_output())
|
203
|
+
self._calibrator.calibrate(
|
204
|
+
self._representative_dataset, self._recipe_manager, cache_output=True
|
205
|
+
)
|
206
|
+
self.assertLen(self._calibrator.get_cached_output(), 10)
|
207
|
+
|
208
|
+
def test_calibration_cache_is_empty_after_reset(self):
|
209
|
+
_add_default_int8xint8_integer_recipe(self._recipe_manager)
|
210
|
+
self._calibrator.calibrate(
|
211
|
+
self._representative_dataset, self._recipe_manager, cache_output=True
|
212
|
+
)
|
213
|
+
self._calibrator.clear_cached_output()
|
214
|
+
self.assertEmpty(self._calibrator.get_cached_output())
|
215
|
+
|
216
|
+
def test_calibrate_unsupported_ops_success(self):
|
217
|
+
# Many ops in the following model are not supported currently.
|
218
|
+
test_model_path = os.path.join(
|
219
|
+
TEST_DATA_PREFIX_PATH, "tests/models/branching_conv_fc.tflite"
|
220
|
+
)
|
221
|
+
test_calibrator = calibrator.Calibrator(test_model_path)
|
222
|
+
_add_default_int8xint8_integer_recipe(self._recipe_manager)
|
223
|
+
dataset_gen = _representative_dataset_gen(size=(3, 4, 4, 1))
|
224
|
+
test_calibrator.calibrate(
|
225
|
+
_get_calibration_data(dataset_gen),
|
226
|
+
self._recipe_manager,
|
227
|
+
cache_output=True,
|
228
|
+
)
|
229
|
+
self.assertLen(test_calibrator.get_cached_output(), 10)
|
230
|
+
|
231
|
+
|
232
|
+
class CalibratorAlreadyQuantizedModelTest(googletest.TestCase):
|
233
|
+
|
234
|
+
def test_check_is_float_model_succeeds_when_model_is_float(self):
|
235
|
+
test_model_path = os.path.join(
|
236
|
+
TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite"
|
237
|
+
)
|
238
|
+
_ = calibrator.Calibrator(test_model_path)
|
239
|
+
|
240
|
+
def test_check_is_float_model_raises_error_when_model_is_quantized(self):
|
241
|
+
test_model_path = os.path.join(
|
242
|
+
TEST_DATA_PREFIX_PATH, "tests/models/mnist_quantized.tflite"
|
243
|
+
)
|
244
|
+
with self.assertRaisesRegex(
|
245
|
+
ValueError,
|
246
|
+
"The input model for calibration is not a float model.",
|
247
|
+
):
|
248
|
+
_ = calibrator.Calibrator(test_model_path)
|
249
|
+
|
250
|
+
|
251
|
+
class CalibratorToyGemma2Test(googletest.TestCase):
|
252
|
+
|
253
|
+
def setUp(self):
|
254
|
+
super().setUp()
|
255
|
+
np.random.seed(0)
|
256
|
+
|
257
|
+
self._test_model_path = os.path.join(
|
258
|
+
TEST_DATA_PREFIX_PATH,
|
259
|
+
"tests/models/toy_model_with_kv_cache_multi_signature.tflite",
|
260
|
+
)
|
261
|
+
|
262
|
+
self._toy_gemma2_calibration_dataset = {
|
263
|
+
"signature_1": [{
|
264
|
+
"cache_0": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
265
|
+
"cache_1": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
266
|
+
"positions": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
|
267
|
+
np.int32
|
268
|
+
),
|
269
|
+
"tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
|
270
|
+
np.int32
|
271
|
+
),
|
272
|
+
}],
|
273
|
+
"signature_2": [{
|
274
|
+
"cache_0": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
275
|
+
"cache_1": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
276
|
+
"positions": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
|
277
|
+
np.int32
|
278
|
+
),
|
279
|
+
"tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
|
280
|
+
np.int32
|
281
|
+
),
|
282
|
+
}],
|
283
|
+
}
|
284
|
+
|
285
|
+
def test_toy_gemma2_calibration_success(self):
|
286
|
+
calib = calibrator.Calibrator(self._test_model_path)
|
287
|
+
recipe_mngr = recipe_manager.RecipeManager()
|
288
|
+
_add_default_int8xint8_integer_recipe(recipe_mngr)
|
289
|
+
calib.calibrate(
|
290
|
+
self._toy_gemma2_calibration_dataset,
|
291
|
+
model_recipe_manager=recipe_mngr,
|
292
|
+
)
|
293
|
+
self.assertLen(calib.get_model_qsvs(), 282)
|
294
|
+
|
295
|
+
|
296
|
+
if __name__ == "__main__":
|
297
|
+
googletest.main()
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Configuration file for pytest."""
|
17
|
+
|
18
|
+
from absl import flags
|
19
|
+
|
20
|
+
|
21
|
+
def pytest_configure():
|
22
|
+
flags.FLAGS.mark_as_parsed()
|