ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -14,11 +14,67 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from absl.testing import parameterized
17
+ import numpy as np
18
+
17
19
  from tensorflow.python.platform import googletest
20
+ from ai_edge_quantizer import quantizer
18
21
  from ai_edge_quantizer.utils import calibration_utils
22
+ from ai_edge_quantizer.utils import test_utils
23
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
24
+
25
+ _RNG = np.random.default_rng(66)
26
+
27
+ _CALIBRATION_DATASET = {
28
+ "signature_1": [{
29
+ "cache_0": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
30
+ "cache_1": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
31
+ "positions": np.zeros(shape=(1, 100), dtype=np.int32),
32
+ "tokens": np.zeros(shape=(1, 100), dtype=np.int32),
33
+ }],
34
+ "signature_2": [{
35
+ "cache_0": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
36
+ "cache_1": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
37
+ "positions": (
38
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
39
+ ),
40
+ "tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32),
41
+ }],
42
+ }
43
+
19
44
 
45
+ def _get_quant_parameters(
46
+ quantized_model: bytes, signature_data: dict[str, list[str]]
47
+ ) -> list[np.ndarray]:
48
+ """Returns the quantization parameters from the quantized model."""
49
+ quant_params = []
50
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
51
+ quantized_model
52
+ )
53
+ for signature_key, signature_names in signature_data.items():
54
+ signature_runner = tfl_interpreter.get_signature_runner(signature_key)
55
+
56
+ for signature_name in signature_names:
57
+ input_details = signature_runner.get_input_details()
58
+ output_details = signature_runner.get_output_details()
59
+ if signature_name in input_details.keys():
60
+ quant_param = input_details[signature_name]["quantization_parameters"][
61
+ "scales"
62
+ ].squeeze()
63
+ quant_params.append(quant_param)
64
+ elif signature_name in output_details.keys():
65
+ output_details = signature_runner.get_output_details()
66
+ quant_param = output_details[signature_name]["quantization_parameters"][
67
+ "scales"
68
+ ].squeeze()
69
+ quant_params.append(quant_param)
70
+ else:
71
+ raise ValueError(
72
+ f"Signature name {signature_name} not found in the model."
73
+ )
74
+ return quant_params
20
75
 
21
- class CalibrationUtilsTest(parameterized.TestCase):
76
+
77
+ class CalibrationQsvAlignmentUtilsTest(parameterized.TestCase):
22
78
 
23
79
  @parameterized.named_parameters(
24
80
  dict(
@@ -66,12 +122,126 @@ class CalibrationUtilsTest(parameterized.TestCase):
66
122
  def test_update_tensor_qsv_min_max(self, old_qsv, new_qsv, expected_qsv):
67
123
  updated_qsv = calibration_utils.min_max_update(old_qsv, new_qsv)
68
124
  if isinstance(expected_qsv["min"], list):
69
- self.assertListEqual(list(updated_qsv["min"]), expected_qsv["min"])
70
- self.assertListEqual(list(updated_qsv["max"]), expected_qsv["max"])
125
+ self.assertEqual(list(updated_qsv["min"]), expected_qsv["min"])
126
+ self.assertEqual(list(updated_qsv["max"]), expected_qsv["max"])
71
127
  else:
72
128
  self.assertEqual(updated_qsv["min"], expected_qsv["min"])
73
129
  self.assertEqual(updated_qsv["max"], expected_qsv["max"])
74
130
 
131
+ def test_calibration_utils_init_fails(self):
132
+ model_path = "non_existent_model.tflite"
133
+ with self.assertRaisesWithPredicateMatch(
134
+ Exception, lambda err: f"{model_path}" in str(err)
135
+ ):
136
+ calibration_utils.CalibrationQsvAlignmentUtils(model_path)
137
+
138
+ def test_calibration_utils_init_succeeds(self):
139
+ model_path = test_utils.get_path_to_datafile(
140
+ "../tests/models/single_add.tflite"
141
+ )
142
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
143
+ self.assertNotEmpty(calib_utils._signature_runners)
144
+ self.assertNotEmpty(calib_utils._same_as_input_scale_ops)
145
+
146
+ def test_search_tensor_by_signature_name_succeeds_on_unconstrained_op(self):
147
+ model_path = test_utils.get_path_to_datafile(
148
+ "../tests/models/single_add.tflite"
149
+ )
150
+ expected_tensor_name = "PartitionedCall:0"
151
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
152
+ tensor_name = calib_utils._search_tensor_by_signature_name(
153
+ "serving_default", "add"
154
+ )
155
+ self.assertEqual(tensor_name, [expected_tensor_name])
156
+
157
+ def test_search_tensor_by_signature_name_succeeds_on_constrained_op(self):
158
+ model_path = test_utils.get_path_to_datafile(
159
+ "../tests/models/single_slice.tflite"
160
+ )
161
+ expected_tensor_name = "slice_input_tensor:0"
162
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
163
+ tensor_name = calib_utils._search_tensor_by_signature_name(
164
+ "slice", "output_0"
165
+ )
166
+ self.assertEqual(tensor_name, [expected_tensor_name])
167
+
168
+ def test_align_quant_stats_succeeds(self):
169
+ model_path = test_utils.get_path_to_datafile(
170
+ "../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
171
+ )
172
+ recipe_path = test_utils.get_path_to_datafile(
173
+ "../recipes/default_a8w8_recipe.json"
174
+ )
175
+ signature_data = {
176
+ "signature_1": ["output_1_1"],
177
+ "signature_2": ["output_1_1"],
178
+ }
179
+
180
+ # Obtain the calibration results.
181
+ qt = quantizer.Quantizer(model_path, recipe_path)
182
+ qsv = qt.calibrate(_CALIBRATION_DATASET)
183
+
184
+ # First quantize the model without aligning the quantization parameters.
185
+ quantized_model = qt.quantize(qsv).quantized_model
186
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
187
+ self.assertFalse(
188
+ all(x == quant_params[0] for x in quant_params)
189
+ ) # not equal quantization params.
190
+
191
+ # Align the quantization parameters and quantize again.
192
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
193
+ calib_utils.align_quant_stats(qsv, signature_data)
194
+ quantized_model = qt.quantize(qsv).quantized_model
195
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
196
+ self.assertTrue(
197
+ all(x == quant_params[0] for x in quant_params)
198
+ ) # equal quantization params.
199
+
200
+ def test_update_quant_stats_succeeds(self):
201
+ model_path = test_utils.get_path_to_datafile(
202
+ "../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
203
+ )
204
+ recipe_path = test_utils.get_path_to_datafile(
205
+ "../recipes/default_a8w8_recipe.json"
206
+ )
207
+ signature_data = {
208
+ "signature_1": ["output_1_1"],
209
+ "signature_2": ["output_1_1"],
210
+ }
211
+
212
+ # Obtain the calibration results.
213
+ qt = quantizer.Quantizer(model_path, recipe_path)
214
+ qsv = qt.calibrate(_CALIBRATION_DATASET)
215
+
216
+ # First quantize the model without updating the `signature_1`.
217
+ quantized_model = qt.quantize(qsv).quantized_model
218
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
219
+ self.assertFalse(
220
+ all(x == quant_params[0] for x in quant_params)
221
+ ) # not equal quantization params.
222
+
223
+ # Update the `signature_1` with stats from `signature_2`.
224
+ calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
225
+ min_val, max_val = calib_utils.align_quant_stats( # for min and max only.
226
+ qsv,
227
+ {
228
+ "signature_2": ["output_1_1"],
229
+ },
230
+ )
231
+ calib_utils.update_quant_stats(
232
+ qsv,
233
+ {
234
+ "signature_1": ["output_1_1"],
235
+ },
236
+ min_val,
237
+ max_val,
238
+ )
239
+ quantized_model = qt.quantize(qsv).quantized_model
240
+ quant_params = _get_quant_parameters(quantized_model, signature_data)
241
+ self.assertTrue(
242
+ all(x == quant_params[0] for x in quant_params)
243
+ ) # equal quantization params.
244
+
75
245
 
76
246
  if __name__ == "__main__":
77
247
  googletest.main()
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utils for handling operators with quantization constraints."""
17
+
18
+ from ai_edge_quantizer import algorithm_manager
19
+ from ai_edge_quantizer import qtyping
20
+ from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
21
+ from ai_edge_quantizer.algorithms.utils import common_utils
22
+ from ai_edge_litert import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ _OpQuantConstraint = common_utils.OpQuantConstraint
26
+
27
+
28
+ def get_constrained_op_list(
29
+ quant_constraint: _OpQuantConstraint, verbose: bool = False
30
+ ) -> list[str]:
31
+ """Constructs and returns a list of constrained operators.
32
+
33
+ This is achieved by invoking all materialization functions and extracting
34
+ the constraint argument, using monkey patching to redirect logic to wrapper
35
+ functions.
36
+
37
+ Args:
38
+ quant_constraint: The quantization constraint to filter operators by.
39
+ verbose: Flag to enable verbose output.
40
+
41
+ Returns:
42
+ A list containing operators with the specified constraint.
43
+ """
44
+ constrained_ops = []
45
+
46
+ def materialize_standard_op_wrapper(
47
+ op_info: qtyping.OpInfo,
48
+ *_args,
49
+ constraint: _OpQuantConstraint = _OpQuantConstraint.NO_CONSTRAIN,
50
+ **_kwargs,
51
+ ) -> list[qtyping.TensorTransformationParams]:
52
+ if constraint == quant_constraint:
53
+ constrained_ops.append(op_info.op_name)
54
+ # Return dummy values to avoid exceptions.
55
+ dummy_value = [qtyping.TensorTransformationParams("")] * 2
56
+ return dummy_value
57
+
58
+ # Dummy implementation of the `_are_weights_too_small` function to support
59
+ # `materialize_standard_op_wrapper` above.
60
+ def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
61
+ return False
62
+
63
+ # Dummy implementation of the `_materialize_bias_for_fc_conv_ops` function to
64
+ # support `materialize_standard_op_wrapper` above.
65
+ def materialize_bias_for_fc_conv_ops_wrapper(*_args, **_kwargs):
66
+ return
67
+
68
+ # Do monkey patch to intercept the `materialize_standard_op` function to
69
+ # support `materialize_standard_op_wrapper` above.
70
+ original_materialize_standard_op = common_utils.materialize_standard_op
71
+ original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
72
+ original_materialize_bias_for_fc_conv_ops = (
73
+ common_quantize._materialize_bias_for_fc_conv_ops # pylint: disable=protected-access
74
+ )
75
+ common_utils.materialize_standard_op = materialize_standard_op_wrapper
76
+ common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
77
+ common_quantize._materialize_bias_for_fc_conv_ops = ( # pylint: disable=protected-access
78
+ materialize_bias_for_fc_conv_ops_wrapper
79
+ )
80
+ minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
81
+
82
+ # Loop over all available materialization functions to build up a list of
83
+ # ops with the given constraint.
84
+ for op, materialize_fn in minmax_func_dict.items():
85
+ # Create a dummy op info to trigger the materialization.
86
+ mock_op = schema_fb.OperatorT()
87
+ mock_op.inputs = [0]
88
+ mock_op.outputs = [0]
89
+ op_info = qtyping.OpInfo(
90
+ op=mock_op,
91
+ op_name=op,
92
+ subgraph_op_index=0,
93
+ op_quant_config=qtyping.OpQuantizationConfig(),
94
+ )
95
+ materialize_fn(
96
+ get_tensor_quant_params_fn=None,
97
+ op_info=op_info,
98
+ graph_info=None,
99
+ tensor_name_to_qsv=None,
100
+ )
101
+
102
+ if verbose:
103
+ print(f" {quant_constraint} op list: {constrained_ops}")
104
+
105
+ # Restore the original functions.
106
+ common_utils.materialize_standard_op = original_materialize_standard_op
107
+ common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
108
+ common_quantize._materialize_bias_for_fc_conv_ops = ( # pylint: disable=protected-access
109
+ original_materialize_bias_for_fc_conv_ops
110
+ )
111
+ return constrained_ops
@@ -0,0 +1,50 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from tensorflow.python.platform import googletest
17
+ from absl.testing import parameterized
18
+ from ai_edge_quantizer.algorithms.utils import common_utils
19
+ from ai_edge_quantizer.utils import constrained_ops_utils
20
+
21
+
22
+ _OpQuantConstraint = common_utils.OpQuantConstraint
23
+
24
+
25
+ class ConstrainedOpsUtilsTest(parameterized.TestCase):
26
+
27
+ @parameterized.named_parameters(
28
+ dict(
29
+ testcase_name="same_as_input_scale",
30
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
31
+ expected_num_ops=18,
32
+ ),
33
+ dict(
34
+ testcase_name="same_as_output_scale",
35
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
36
+ expected_num_ops=7,
37
+ ),
38
+ dict(
39
+ testcase_name="no_constrain",
40
+ constraint=_OpQuantConstraint.NO_CONSTRAIN,
41
+ expected_num_ops=25,
42
+ ),
43
+ )
44
+ def test_get_constrained_op_list(self, constraint, expected_num_ops):
45
+ constrained_ops = constrained_ops_utils.get_constrained_op_list(constraint)
46
+ self.assertLen(constrained_ops, expected_num_ops)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ googletest.main()
@@ -18,6 +18,7 @@
18
18
  import inspect as _inspect
19
19
  import os.path as _os_path
20
20
  import sys as _sys
21
+ from typing import Optional, Union
21
22
 
22
23
  from absl.testing import parameterized
23
24
 
@@ -31,6 +32,40 @@ _OpName = qtyping.TFLOperationName
31
32
  _TensorQuantConfig = qtyping.TensorQuantizationConfig
32
33
  _OpQuantConfig = qtyping.OpQuantizationConfig
33
34
  _AlgorithmName = quantizer.AlgorithmName
35
+ _Numeric = Union[int, float]
36
+
37
+
38
+ DEFAULT_ACTIVATION_QUANT_SETTING = _TensorQuantConfig(
39
+ num_bits=8,
40
+ symmetric=False,
41
+ granularity=qtyping.QuantGranularity.TENSORWISE,
42
+ )
43
+ DEFAULT_WEIGHT_QUANT_SETTING = _TensorQuantConfig(
44
+ num_bits=8,
45
+ symmetric=True,
46
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
47
+ )
48
+
49
+
50
+ def get_static_activation_quant_setting(
51
+ num_bits: int, symmetric: bool
52
+ ) -> _TensorQuantConfig:
53
+ return _TensorQuantConfig(
54
+ num_bits=num_bits,
55
+ symmetric=symmetric,
56
+ granularity=qtyping.QuantGranularity.TENSORWISE,
57
+ )
58
+
59
+
60
+ def get_static_op_quant_config(
61
+ activation_config: _TensorQuantConfig = DEFAULT_ACTIVATION_QUANT_SETTING,
62
+ weight_config: _TensorQuantConfig = DEFAULT_WEIGHT_QUANT_SETTING,
63
+ ) -> _OpQuantConfig:
64
+ return qtyping.OpQuantizationConfig(
65
+ activation_tensor_config=activation_config,
66
+ weight_tensor_config=weight_config,
67
+ compute_precision=_ComputePrecision.INTEGER,
68
+ )
34
69
 
35
70
 
36
71
  def get_path_to_datafile(path):
@@ -64,7 +99,9 @@ class BaseOpTestCase(parameterized.TestCase):
64
99
  op_name: _OpName,
65
100
  op_config: _OpQuantConfig,
66
101
  num_validation_samples: int = 4,
102
+ num_calibration_samples: Optional[int] = None,
67
103
  error_metric: str = 'mse',
104
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
68
105
  ) -> model_validator.ComparisonResult:
69
106
  """Quantizes and validates the given model with the given configurations.
70
107
 
@@ -74,7 +111,10 @@ class BaseOpTestCase(parameterized.TestCase):
74
111
  op_name: The name of the operation to be quantized.
75
112
  op_config: The configuration for the operation to be quantized.
76
113
  num_validation_samples: The number of samples to use for validation.
114
+ num_calibration_samples: The number of samples to use for calibration. If
115
+ None then it will be set to num_validation_samples * 8.
77
116
  error_metric: The error error_metric to use for validation.
117
+ min_max_range: The min and max of the input range.
78
118
 
79
119
  Returns:
80
120
  The comparison result of the validation.
@@ -87,15 +127,21 @@ class BaseOpTestCase(parameterized.TestCase):
87
127
  op_config=op_config,
88
128
  )
89
129
  if quantizer_instance.need_calibration:
130
+ if num_calibration_samples is None:
131
+ num_calibration_samples = num_validation_samples * 8
90
132
  calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
91
- quantizer_instance.float_model, num_samples=num_validation_samples * 8
133
+ quantizer_instance.float_model,
134
+ num_samples=num_calibration_samples,
135
+ min_max_range=min_max_range,
92
136
  )
93
137
  calibration_result = quantizer_instance.calibrate(calibration_data)
94
138
  quantization_result = quantizer_instance.quantize(calibration_result)
95
139
  else:
96
140
  quantization_result = quantizer_instance.quantize()
97
141
  test_data = tfl_interpreter_utils.create_random_normal_input_data(
98
- quantization_result.quantized_model, num_samples=num_validation_samples
142
+ quantization_result.quantized_model,
143
+ num_samples=num_validation_samples,
144
+ min_max_range=min_max_range,
99
145
  )
100
146
  return quantizer_instance.validate(test_data, error_metric)
101
147
 
@@ -145,6 +191,7 @@ class BaseOpTestCase(parameterized.TestCase):
145
191
  expected_model_size_reduction: float,
146
192
  weight_tolerance: float = 1e-4,
147
193
  output_tolerance: float = 1e-4,
194
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
148
195
  ):
149
196
  """Check if the quantization is successful and the result is valid."""
150
197
  validation_result = self.quantize_and_validate(
@@ -152,6 +199,7 @@ class BaseOpTestCase(parameterized.TestCase):
152
199
  algorithm_key=algorithm_key,
153
200
  op_name=op_name,
154
201
  op_config=op_config,
202
+ min_max_range=min_max_range,
155
203
  )
156
204
  with self.subTest(name='ModelSizeReduction'):
157
205
  self.assert_model_size_reduction_above_min_pct(
@@ -165,3 +213,28 @@ class BaseOpTestCase(parameterized.TestCase):
165
213
  self.assert_output_errors_below_tolerance(
166
214
  validation_result, output_tolerance
167
215
  )
216
+
217
+ def assert_quantization_accuracy(
218
+ self,
219
+ algorithm_key: _AlgorithmName,
220
+ model_path: str,
221
+ op_name: _OpName,
222
+ op_config: _OpQuantConfig,
223
+ num_validation_samples: int = 4,
224
+ num_calibration_samples: Optional[int] = None,
225
+ output_tolerance: float = 1e-4,
226
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
227
+ ):
228
+ """Checks if the output errors after quantization are within the tolerance."""
229
+ validation_result = self.quantize_and_validate(
230
+ model_path=model_path,
231
+ algorithm_key=algorithm_key,
232
+ num_validation_samples=num_validation_samples,
233
+ num_calibration_samples=num_calibration_samples,
234
+ op_name=op_name,
235
+ op_config=op_config,
236
+ min_max_range=min_max_range,
237
+ )
238
+ self.assert_output_errors_below_tolerance(
239
+ validation_result, output_tolerance
240
+ )
@@ -20,10 +20,10 @@ from typing import Any, Optional, Union
20
20
  import immutabledict
21
21
  import numpy as np
22
22
 
23
+ from ai_edge_litert.tools import flatbuffer_utils
23
24
  from ai_edge_quantizer import qtyping
24
25
  from ai_edge_litert import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
25
- from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
26
- from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
26
+ import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
27
27
 
28
28
  _TFLOpName = qtyping.TFLOperationName
29
29
 
@@ -51,11 +51,35 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
51
51
  _TFLOpName.LOGISTIC: schema.BuiltinOperator.LOGISTIC,
52
52
  _TFLOpName.SLICE: schema.BuiltinOperator.SLICE,
53
53
  _TFLOpName.SUM: schema.BuiltinOperator.SUM,
54
+ _TFLOpName.SELECT: schema.BuiltinOperator.SELECT,
54
55
  _TFLOpName.SELECT_V2: schema.BuiltinOperator.SELECT_V2,
55
56
  _TFLOpName.STABLEHLO_COMPOSITE: schema.BuiltinOperator.STABLEHLO_COMPOSITE,
56
57
  _TFLOpName.DYNAMIC_UPDATE_SLICE: (
57
58
  schema.BuiltinOperator.DYNAMIC_UPDATE_SLICE
58
59
  ),
60
+ _TFLOpName.PAD: schema.BuiltinOperator.PAD,
61
+ _TFLOpName.SQUARED_DIFFERENCE: schema.BuiltinOperator.SQUARED_DIFFERENCE,
62
+ _TFLOpName.MAX_POOL_2D: schema.BuiltinOperator.MAX_POOL_2D,
63
+ _TFLOpName.RESIZE_BILINEAR: schema.BuiltinOperator.RESIZE_BILINEAR,
64
+ _TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
65
+ schema.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR
66
+ ),
67
+ _TFLOpName.GATHER_ND: schema.BuiltinOperator.GATHER_ND,
68
+ _TFLOpName.PACK: schema.BuiltinOperator.PACK,
69
+ _TFLOpName.UNPACK: schema.BuiltinOperator.UNPACK,
70
+ _TFLOpName.DIV: schema.BuiltinOperator.DIV,
71
+ _TFLOpName.BROADCAST_TO: schema.BuiltinOperator.BROADCAST_TO,
72
+ _TFLOpName.SQRT: schema.BuiltinOperator.SQRT,
73
+ _TFLOpName.GATHER: schema.BuiltinOperator.GATHER,
74
+ _TFLOpName.HARD_SWISH: schema.BuiltinOperator.HARD_SWISH,
75
+ _TFLOpName.MAXIMUM: schema.BuiltinOperator.MAXIMUM,
76
+ _TFLOpName.PADV2: schema.BuiltinOperator.PADV2,
77
+ _TFLOpName.REDUCE_MIN: schema.BuiltinOperator.REDUCE_MIN,
78
+ _TFLOpName.EQUAL: schema.BuiltinOperator.EQUAL,
79
+ _TFLOpName.NOT_EQUAL: schema.BuiltinOperator.NOT_EQUAL,
80
+ _TFLOpName.MIRROR_PAD: schema.BuiltinOperator.MIRROR_PAD,
81
+ _TFLOpName.SPACE_TO_DEPTH: schema.BuiltinOperator.SPACE_TO_DEPTH,
82
+ _TFLOpName.RELU: schema.BuiltinOperator.RELU,
59
83
  })
60
84
 
61
85
  TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
@@ -86,7 +110,7 @@ TENSOR_TYPE_TO_CODE = immutabledict.immutabledict(
86
110
  (reversed(item) for item in TENSOR_CODE_TO_TYPE.items())
87
111
  )
88
112
 
89
- # Expose functions in tensorflow.lite.tools.flatbuffer_utils
113
+ # Expose functions in litert.python.tools.flatbuffer_utils
90
114
  write_model = flatbuffer_utils.write_model
91
115
 
92
116
 
@@ -121,7 +145,7 @@ def get_model_content(tflite_path: str) -> bytes:
121
145
  Returns:
122
146
  The model bytes.
123
147
  """
124
- with gfile.Open(tflite_path, "rb") as tflite_file:
148
+ with open(tflite_path, "rb") as tflite_file:
125
149
  return tflite_file.read()
126
150
 
127
151
 
@@ -134,7 +158,7 @@ def get_model_buffer(tflite_path: str) -> bytearray:
134
158
  Returns:
135
159
  model_buffer: the model buffer.
136
160
  """
137
- with gfile.Open(tflite_path, "rb") as tflite_file:
161
+ with open(tflite_path, "rb") as tflite_file:
138
162
  return bytearray(tflite_file.read())
139
163
 
140
164
 
@@ -187,7 +211,7 @@ def parse_fc_bmm_conv_tensors(
187
211
  return input_tensor, weight_tensor, bias_tensor, output_tensor
188
212
 
189
213
 
190
- # flatbuffer_model has Any type since tensorflow.lite.tools.flatbuffer_utils
214
+ # flatbuffer_model has Any type since litert.python.tools.flatbuffer_utils
191
215
  # is not type annotated.
192
216
  def buffer_to_tensors(flatbuffer_model: Any) -> dict[int, list[Any]]:
193
217
  """Returns a map from buffer id to tensors that use it."""
@@ -328,3 +352,12 @@ def get_op_side_effect_subgraphs(
328
352
  return [opts.decompositionSubgraphIndex]
329
353
  # Can add other nested ops here (control flow ops, etc).
330
354
  return []
355
+
356
+
357
+ def get_op_name_by_index(
358
+ flatbuffer_model: Any, subgraph_id: int, op_index: int
359
+ ) -> str:
360
+ """Get the op name from the flatbuffer model."""
361
+ op = flatbuffer_model.subgraphs[subgraph_id].operators[op_index]
362
+ builtin_code = flatbuffer_model.operatorCodes[op.opcodeIndex].builtinCode
363
+ return TFL_OP_CODE_TO_NAME[builtin_code]