ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,288 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Quantization Calibration."""
17
+
18
+ from collections.abc import Callable, Iterable
19
+ import copy
20
+ from typing import Any, Union
21
+
22
+ from absl import logging
23
+ import numpy as np
24
+
25
+ from ai_edge_quantizer import algorithm_manager
26
+ from ai_edge_quantizer import qtyping
27
+ from ai_edge_quantizer import recipe_manager
28
+ from ai_edge_quantizer.utils import calibration_utils
29
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
30
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
31
+
32
+ _SignatureInput = dict[str, Any] # input_argument_name -> tensor_value.
33
+ _SignatureOutput = dict[
34
+ str, np.ndarray
35
+ ] # output_argument_name -> tensor_value.
36
+
37
+
38
+ class Calibrator:
39
+ """Calibrator for TFLite model."""
40
+
41
+ def __init__(
42
+ self,
43
+ float_tflite: Union[str, bytes],
44
+ num_threads: int = 16,
45
+ ):
46
+ self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
47
+
48
+ if not tfl_flatbuffer_utils.is_float_model(self._flatbuffer_model):
49
+ raise ValueError(
50
+ "The input model for calibration is not a float model. Please check"
51
+ " the model (e.g., if it is already quantized)."
52
+ )
53
+ self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
54
+ float_tflite, use_xnnpack=True, num_threads=num_threads
55
+ )
56
+ # Tensor name to tensor content.
57
+ self._tensor_content_map: dict[str, Any] = {}
58
+ # QSV of all the tensors in the model.
59
+ self._model_qsvs: dict[str, qtyping.QSV] = {}
60
+ # Cached output of the model.
61
+ self._cached_output: list[_SignatureOutput] = []
62
+
63
+ # TODO(b/330740605)- Collect multiple QSVs in one run to save compute.
64
+ def calibrate(
65
+ self,
66
+ calibration_dataset: dict[str, Iterable[_SignatureInput]],
67
+ model_recipe_manager: recipe_manager.RecipeManager,
68
+ cache_output: bool = False,
69
+ qsv_update_func: Callable[
70
+ [qtyping.QSV, qtyping.QSV],
71
+ qtyping.QSV,
72
+ ] = calibration_utils.moving_average_update,
73
+ ) -> None:
74
+ """Calibrates the model using the given dataset for a model signature.
75
+
76
+ The process is
77
+ 0. Initialize quantization statistics values (QSVs) using the initialization
78
+ function (from AlgorithmManager) for the op if needed.
79
+ 1. Invoke TFL interpreter on the calibration data.
80
+ 2. Go through each op, ask RecipeManager for the quantization setting
81
+ of the op.
82
+ 3. Ask AlgorithmManager for the calibration function of the op given the
83
+ quantization setting.
84
+ 4. Apply the function to the op to obtain quantization statistics (qsvs) for
85
+ the tensors associated with the op.
86
+ 5. Update the global qsv dictionary
87
+ 6. Start another round of calibration.
88
+
89
+ Args:
90
+ calibration_dataset: A dictionary of input data for calibration for the
91
+ given model signature.
92
+ model_recipe_manager: A RecipeManager object that contains the
93
+ quantization recipe.
94
+ cache_output: Whether to cache the output of the model during the
95
+ calibration process. This is useful if there are dependencies between
96
+ signatures/models (e.g., decode requires encode output).
97
+ qsv_update_func: The function to update the QSVs.
98
+ """
99
+ op_codes = self._flatbuffer_model.operatorCodes
100
+ if not self._model_qsvs:
101
+ self._initialize_model_qsvs(model_recipe_manager)
102
+ else:
103
+ logging.warning(
104
+ "Calibrator contains non-empty model qsvs, and the current"
105
+ " calibration process will start on top of this state (i.e., update"
106
+ " the existing qsvs). If this is an unintended behavior please call"
107
+ " reset_model_qsvs to reset model qsvs."
108
+ )
109
+
110
+ # TODO: b/329322226 - Enable parallel calibration.
111
+ for signature_key, dataset in calibration_dataset.items():
112
+ # Step0: get subgraph index.
113
+ subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
114
+ self._tfl_interpreter, signature_key
115
+ )
116
+
117
+ for data in dataset:
118
+ # Initialize tensor names that are updated in this round of calibration.
119
+ updated_tensor_names = set()
120
+
121
+ # Step1: run tfl interpreter on subgraph to get tensor content.
122
+ signature_output = tfl_interpreter_utils.invoke_interpreter_signature(
123
+ self._tfl_interpreter, data, signature_key
124
+ )
125
+ if cache_output:
126
+ self._cached_output.append(signature_output)
127
+ self._tensor_content_map.update(
128
+ tfl_interpreter_utils.get_tensor_name_to_content_map(
129
+ self._tfl_interpreter, subgraph_idx
130
+ )
131
+ )
132
+
133
+ # Step2: go through each op in subgraph to update quantization
134
+ # statistic values.
135
+ subgraph = self._flatbuffer_model.subgraphs[subgraph_idx]
136
+ graph_info = qtyping.GraphInfo(
137
+ subgraph.tensors, self._flatbuffer_model.buffers
138
+ )
139
+ # Add input/output operators to the subgraph.
140
+ subgraph.operators += (
141
+ tfl_flatbuffer_utils.get_subgraph_input_output_operators(subgraph)
142
+ )
143
+ for op in subgraph.operators:
144
+ if isinstance(op, qtyping.IOOperator):
145
+ op_key = op.op_key
146
+ else:
147
+ op_code = op_codes[op.opcodeIndex].builtinCode
148
+ if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
149
+ continue
150
+ op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
151
+ # Step2.1: query the quantization_recipe to get op quantization
152
+ # settings.
153
+ op_scope = self._get_op_scope(op, subgraph.tensors)
154
+ algorithm_name, _ = model_recipe_manager.get_quantization_configs(
155
+ op_key, op_scope
156
+ )
157
+ if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
158
+ continue
159
+ # Step2.2: query algorithm_manager to get/call the related calibration
160
+ # function.
161
+ calibrate_func = algorithm_manager.get_quantization_func(
162
+ algorithm_name, op_key, qtyping.QuantizeMode.CALIBRATE
163
+ )
164
+ op_qsvs = calibrate_func(op, graph_info, self._tensor_content_map)
165
+ # Step3: Update tensor qsvs with the new values. Ignore the tensor
166
+ # names that are already updated in this round of calibration.
167
+ op_updated_tensor_name = self._update_qsvs(
168
+ op_qsvs, updated_tensor_names, qsv_update_func
169
+ )
170
+ updated_tensor_names.update(op_updated_tensor_name)
171
+ # Reset interpreter after one round of calibration.
172
+ self._tfl_interpreter.reset_all_variables()
173
+
174
+ def get_model_qsvs(self) -> dict[str, qtyping.QSV]:
175
+ """Get the model qsvs.
176
+
177
+ Returns:
178
+ A dictionary of tensor name to QSV.
179
+ """
180
+ return self._model_qsvs
181
+
182
+ def get_cached_output(self) -> list[_SignatureOutput]:
183
+ """Get the cached output of the model."""
184
+ return self._cached_output
185
+
186
+ def clear_cached_output(self) -> None:
187
+ """Clear the cached output of the model."""
188
+ self._cached_output = []
189
+
190
+ def reset_model_qsvs(self) -> None:
191
+ """Reset the model qsvs."""
192
+ self._model_qsvs = {}
193
+
194
+ def load_model_qsvs(self, model_qsvs: dict[str, qtyping.QSV]) -> None:
195
+ """Load the model qsvs.
196
+
197
+ Args:
198
+ model_qsvs: A dictionary of tensor name to QSV.
199
+ """
200
+ self._model_qsvs = copy.deepcopy(model_qsvs)
201
+
202
+ def _update_qsvs(
203
+ self,
204
+ op_qsvs: dict[str, qtyping.QSV],
205
+ ignore_tensor_names: set[str],
206
+ qsv_update_func: Callable[[qtyping.QSV, qtyping.QSV], qtyping.QSV],
207
+ ) -> set[str]:
208
+ """Update the model qsvs with the new values.
209
+
210
+ Args:
211
+ op_qsvs: A dictionary of tensor name to QSV.
212
+ ignore_tensor_names: A set of tensor names to ignore.
213
+ qsv_update_func: The function to update the QSVs.
214
+
215
+ Returns:
216
+ A set of tensor names that are updated.
217
+ """
218
+ updated_tensor_names = set()
219
+ for tensor_name, qsv in op_qsvs.items():
220
+ if tensor_name in ignore_tensor_names:
221
+ continue
222
+ if tensor_name not in self._model_qsvs:
223
+ self._model_qsvs[tensor_name] = qsv
224
+ else:
225
+ updated_qsv = qsv_update_func(self._model_qsvs[tensor_name], qsv)
226
+ self._model_qsvs[tensor_name] = updated_qsv
227
+ updated_tensor_names.add(tensor_name)
228
+ return updated_tensor_names
229
+
230
+ def _get_op_scope(self, op, subgraph_tensors) -> str:
231
+ """Get the scope of the op.
232
+
233
+ The scope is the name of the output tensor of the op.
234
+
235
+ Args:
236
+ op: The op to get the scope.
237
+ subgraph_tensors: The tensors in the subgraph.
238
+
239
+ Returns:
240
+ The scope of the op.
241
+ """
242
+ scope = ""
243
+ for output_tensor_idx in op.outputs:
244
+ if output_tensor_idx != -1:
245
+ output_tensor = subgraph_tensors[output_tensor_idx]
246
+ scope += tfl_flatbuffer_utils.get_tensor_name(output_tensor)
247
+ return scope
248
+
249
+ # TODO: b/354224138 - Remove code duplication between calibrate and
250
+ # _initialize_model_qsvs.
251
+ def _initialize_model_qsvs(
252
+ self, model_recipe_manager: recipe_manager.RecipeManager
253
+ ) -> None:
254
+ """Initialize the model qsvs.
255
+
256
+ Args:
257
+ model_recipe_manager: A RecipeManager object that contains the
258
+ quantization recipe.
259
+ """
260
+ op_codes = self._flatbuffer_model.operatorCodes
261
+ for subgraph in self._flatbuffer_model.subgraphs:
262
+ graph_info = qtyping.GraphInfo(
263
+ subgraph.tensors, self._flatbuffer_model.buffers
264
+ )
265
+ for subgraph_op_id, op in enumerate(subgraph.operators):
266
+ op_code = op_codes[op.opcodeIndex].builtinCode
267
+ if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
268
+ continue
269
+ op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
270
+ # Step1: query the quantization_recipe to get op quantization
271
+ # settings.
272
+ op_scope = self._get_op_scope(op, subgraph.tensors)
273
+ algorithm_name, op_quant_config = (
274
+ model_recipe_manager.get_quantization_configs(op_key, op_scope)
275
+ )
276
+ if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
277
+ continue
278
+ # Step2: query algorithm_manager to get/call the related qsv init
279
+ # function.
280
+ qsv_init_func = algorithm_manager.get_init_qsv_func(
281
+ algorithm_name, op_key
282
+ )
283
+ op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
284
+ op_qsvs = qsv_init_func(op_info, graph_info)
285
+ # Step3: initialize tensor qsvs.
286
+ for tensor_name, qsv in op_qsvs.items():
287
+ if tensor_name not in self._model_qsvs:
288
+ self._model_qsvs[tensor_name] = qsv
@@ -0,0 +1,297 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for calibrator."""
17
+
18
+ from collections.abc import Generator
19
+ import os
20
+ from typing import Any
21
+
22
+ import numpy as np
23
+
24
+ from tensorflow.python.platform import googletest
25
+ from ai_edge_quantizer import calibrator
26
+ from ai_edge_quantizer import qtyping
27
+ from ai_edge_quantizer import recipe_manager
28
+ from ai_edge_quantizer.utils import test_utils
29
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
30
+
31
+ _ComputePrecision = qtyping.ComputePrecision
32
+ _AlgorithmName = recipe_manager.AlgorithmName
33
+
34
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("")
35
+ _TENSOR_QUANT_CONFIG = qtyping.TensorQuantizationConfig
36
+
37
+ TEST_MIN_VAL, TEST_MAX_VAL = -1, 1
38
+
39
+ _RNG = np.random.default_rng(66)
40
+
41
+
42
+ def _representative_dataset_gen(size=(1, 8), num_samples=10):
43
+ for _ in range(num_samples):
44
+ vals = np.random.rand(*size).astype(np.float32)
45
+ vals[0][0], vals[0][1] = (
46
+ TEST_MIN_VAL,
47
+ TEST_MAX_VAL,
48
+ ) # fix min/max for testing
49
+ yield {"input_1": vals}
50
+
51
+
52
+ def _get_calibration_data(
53
+ dataset_gen: Generator[dict[str, Any], Any, None],
54
+ ) -> dict[str, Any]:
55
+ calibration_samples = [sample for sample in dataset_gen]
56
+ calibration_data = {
57
+ tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples,
58
+ }
59
+ return calibration_data
60
+
61
+
62
+ def _add_default_int8xint8_integer_recipe(recipe_manager_object):
63
+ recipe_manager_object.add_quantization_config(
64
+ regex=".*",
65
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
66
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
67
+ op_config=qtyping.OpQuantizationConfig(
68
+ activation_tensor_config=_TENSOR_QUANT_CONFIG(
69
+ num_bits=8, symmetric=False
70
+ ),
71
+ weight_tensor_config=_TENSOR_QUANT_CONFIG(num_bits=8, symmetric=True),
72
+ compute_precision=_ComputePrecision.INTEGER,
73
+ ),
74
+ )
75
+
76
+
77
+ class CalibratorTest(googletest.TestCase):
78
+
79
+ def setUp(self):
80
+ super().setUp()
81
+ np.random.seed(0)
82
+ self._test_model_path = os.path.join(
83
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc.tflite"
84
+ )
85
+ self._calibrator = calibrator.Calibrator(self._test_model_path)
86
+ self._recipe_manager = recipe_manager.RecipeManager()
87
+ dataset_gen = _representative_dataset_gen()
88
+ self._representative_dataset = _get_calibration_data(dataset_gen)
89
+
90
+ def test_calibrator_state_manipulation(self):
91
+ # load/get qsvs
92
+ sample_qsv = {"serving_default_input_1:0": {"min": -10, "max": 8}}
93
+ self._calibrator.load_model_qsvs(sample_qsv)
94
+ model_tensor_qsvs = self._calibrator.get_model_qsvs()
95
+ self.assertLen(model_tensor_qsvs, 1)
96
+ self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
97
+ input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
98
+ self.assertEqual(input_qsv["min"], -10)
99
+ self.assertEqual(input_qsv["max"], 8)
100
+
101
+ # reset qsvs
102
+ self._calibrator.reset_model_qsvs()
103
+ model_tensor_qsvs = self._calibrator.get_model_qsvs()
104
+ self.assertEmpty(model_tensor_qsvs)
105
+
106
+ def test_calibrator_initialize_qsv(self):
107
+ _add_default_int8xint8_integer_recipe(self._recipe_manager)
108
+ # Overwrite the single op to fc
109
+ self._recipe_manager.add_quantization_config(
110
+ regex=".*Stateful.*",
111
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
112
+ algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
113
+ op_config=qtyping.OpQuantizationConfig(
114
+ weight_tensor_config=_TENSOR_QUANT_CONFIG(
115
+ num_bits=4,
116
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
117
+ ),
118
+ compute_precision=_ComputePrecision.INTEGER,
119
+ ),
120
+ )
121
+ self._calibrator._initialize_model_qsvs(self._recipe_manager)
122
+ model_tensor_qsvs = self._calibrator.get_model_qsvs()
123
+
124
+ self.assertLen(model_tensor_qsvs, 4)
125
+ self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
126
+ input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
127
+ self.assertEmpty(input_qsv)
128
+
129
+ self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
130
+ weight_tensor_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
131
+ mins_maxs_shape = (16, 1)
132
+ self.assertTupleEqual(weight_tensor_qsv["min"].shape, mins_maxs_shape)
133
+ self.assertAlmostEqual(weight_tensor_qsv["min"][0][0], -0.40436327)
134
+ self.assertTupleEqual(weight_tensor_qsv["max"].shape, mins_maxs_shape)
135
+ self.assertAlmostEqual(weight_tensor_qsv["max"][0][0], 0.46138108)
136
+
137
+ self.assertIn(
138
+ "sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
139
+ ) # bias
140
+ bias_tensor_qsv = model_tensor_qsvs[
141
+ "sequential/dense/BiasAdd/ReadVariableOp"
142
+ ]
143
+ mins_maxs_shape = (16,)
144
+ self.assertTupleEqual(bias_tensor_qsv["min"].shape, mins_maxs_shape)
145
+ self.assertAlmostEqual(bias_tensor_qsv["min"][0], -0.26978338)
146
+ self.assertTupleEqual(bias_tensor_qsv["max"].shape, mins_maxs_shape)
147
+ # Here bias min/max will be the same as each element is a scalar
148
+ # Bias will be quantized with input_scale * weight_scale.
149
+ self.assertSequenceEqual(
150
+ list(bias_tensor_qsv["max"].flatten()),
151
+ list(bias_tensor_qsv["min"].flatten()),
152
+ )
153
+
154
+ self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
155
+ output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
156
+ self.assertEmpty(output_qsv)
157
+
158
+ def test_calibrate_single_fc_success(self):
159
+ _add_default_int8xint8_integer_recipe(self._recipe_manager)
160
+ self._calibrator.calibrate(
161
+ self._representative_dataset, self._recipe_manager
162
+ )
163
+ model_tensor_qsvs = self._calibrator.get_model_qsvs()
164
+
165
+ self.assertLen(model_tensor_qsvs, 4)
166
+ self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
167
+ input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
168
+ self.assertSequenceAlmostEqual(
169
+ input_qsv["min"].flatten(), [TEST_MIN_VAL], delta=1e-5
170
+ )
171
+ self.assertSequenceAlmostEqual(
172
+ input_qsv["max"].flatten(), [TEST_MAX_VAL], delta=1e-5
173
+ )
174
+
175
+ self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
176
+ weight_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
177
+ self.assertSequenceAlmostEqual(weight_qsv["min"].flatten(), [-0.49114203])
178
+ self.assertSequenceAlmostEqual(weight_qsv["max"].flatten(), [0.4903704])
179
+
180
+ self.assertIn(
181
+ "sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
182
+ ) # bias
183
+ bias_qsv = model_tensor_qsvs["sequential/dense/BiasAdd/ReadVariableOp"]
184
+ self.assertSequenceAlmostEqual(bias_qsv["min"].flatten(), [-0.38401994])
185
+ self.assertSequenceAlmostEqual(bias_qsv["max"].flatten(), [0.31727126])
186
+
187
+ self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
188
+ output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
189
+ # Relu, only check the min
190
+ self.assertSequenceAlmostEqual(output_qsv["min"].flatten(), [0])
191
+
192
+ def test_calibration_cache_is_empty_when_off(self):
193
+ _add_default_int8xint8_integer_recipe(self._recipe_manager)
194
+ self.assertEmpty(self._calibrator.get_cached_output())
195
+ self._calibrator.calibrate(
196
+ self._representative_dataset, self._recipe_manager, cache_output=False
197
+ )
198
+ self.assertEmpty(self._calibrator.get_cached_output())
199
+
200
+ def test_calibration_cache_when_on(self):
201
+ _add_default_int8xint8_integer_recipe(self._recipe_manager)
202
+ self.assertEmpty(self._calibrator.get_cached_output())
203
+ self._calibrator.calibrate(
204
+ self._representative_dataset, self._recipe_manager, cache_output=True
205
+ )
206
+ self.assertLen(self._calibrator.get_cached_output(), 10)
207
+
208
+ def test_calibration_cache_is_empty_after_reset(self):
209
+ _add_default_int8xint8_integer_recipe(self._recipe_manager)
210
+ self._calibrator.calibrate(
211
+ self._representative_dataset, self._recipe_manager, cache_output=True
212
+ )
213
+ self._calibrator.clear_cached_output()
214
+ self.assertEmpty(self._calibrator.get_cached_output())
215
+
216
+ def test_calibrate_unsupported_ops_success(self):
217
+ # Many ops in the following model are not supported currently.
218
+ test_model_path = os.path.join(
219
+ TEST_DATA_PREFIX_PATH, "tests/models/branching_conv_fc.tflite"
220
+ )
221
+ test_calibrator = calibrator.Calibrator(test_model_path)
222
+ _add_default_int8xint8_integer_recipe(self._recipe_manager)
223
+ dataset_gen = _representative_dataset_gen(size=(3, 4, 4, 1))
224
+ test_calibrator.calibrate(
225
+ _get_calibration_data(dataset_gen),
226
+ self._recipe_manager,
227
+ cache_output=True,
228
+ )
229
+ self.assertLen(test_calibrator.get_cached_output(), 10)
230
+
231
+
232
+ class CalibratorAlreadyQuantizedModelTest(googletest.TestCase):
233
+
234
+ def test_check_is_float_model_succeeds_when_model_is_float(self):
235
+ test_model_path = os.path.join(
236
+ TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite"
237
+ )
238
+ _ = calibrator.Calibrator(test_model_path)
239
+
240
+ def test_check_is_float_model_raises_error_when_model_is_quantized(self):
241
+ test_model_path = os.path.join(
242
+ TEST_DATA_PREFIX_PATH, "tests/models/mnist_quantized.tflite"
243
+ )
244
+ with self.assertRaisesRegex(
245
+ ValueError,
246
+ "The input model for calibration is not a float model.",
247
+ ):
248
+ _ = calibrator.Calibrator(test_model_path)
249
+
250
+
251
+ class CalibratorToyGemma2Test(googletest.TestCase):
252
+
253
+ def setUp(self):
254
+ super().setUp()
255
+ np.random.seed(0)
256
+
257
+ self._test_model_path = os.path.join(
258
+ TEST_DATA_PREFIX_PATH,
259
+ "tests/models/toy_model_with_kv_cache_multi_signature.tflite",
260
+ )
261
+
262
+ self._toy_gemma2_calibration_dataset = {
263
+ "signature_1": [{
264
+ "cache_0": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
265
+ "cache_1": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
266
+ "positions": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
267
+ np.int32
268
+ ),
269
+ "tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
270
+ np.int32
271
+ ),
272
+ }],
273
+ "signature_2": [{
274
+ "cache_0": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
275
+ "cache_1": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
276
+ "positions": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
277
+ np.int32
278
+ ),
279
+ "tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
280
+ np.int32
281
+ ),
282
+ }],
283
+ }
284
+
285
+ def test_toy_gemma2_calibration_success(self):
286
+ calib = calibrator.Calibrator(self._test_model_path)
287
+ recipe_mngr = recipe_manager.RecipeManager()
288
+ _add_default_int8xint8_integer_recipe(recipe_mngr)
289
+ calib.calibrate(
290
+ self._toy_gemma2_calibration_dataset,
291
+ model_recipe_manager=recipe_mngr,
292
+ )
293
+ self.assertLen(calib.get_model_qsvs(), 282)
294
+
295
+
296
+ if __name__ == "__main__":
297
+ googletest.main()
@@ -0,0 +1,22 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Configuration file for pytest."""
17
+
18
+ from absl import flags
19
+
20
+
21
+ def pytest_configure():
22
+ flags.FLAGS.mark_as_parsed()