ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +158 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
- ai_edge_quantizer/calibrator.py +11 -60
- ai_edge_quantizer/calibrator_test.py +4 -73
- ai_edge_quantizer/default_policy.py +61 -26
- ai_edge_quantizer/model_modifier.py +97 -7
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +31 -8
- ai_edge_quantizer/params_generator.py +17 -10
- ai_edge_quantizer/params_generator_test.py +2 -7
- ai_edge_quantizer/qtyping.py +86 -6
- ai_edge_quantizer/quantizer.py +166 -21
- ai_edge_quantizer/quantizer_test.py +284 -16
- ai_edge_quantizer/recipe.py +154 -42
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +118 -13
- ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
- ai_edge_quantizer/transformation_performer.py +55 -25
- ai_edge_quantizer/transformation_performer_test.py +127 -5
- ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
- ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
- ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
- ai_edge_quantizer/transformations/transformation_utils.py +129 -6
- ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +75 -2
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -14,11 +14,67 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
from absl.testing import parameterized
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
17
19
|
from tensorflow.python.platform import googletest
|
|
20
|
+
from ai_edge_quantizer import quantizer
|
|
18
21
|
from ai_edge_quantizer.utils import calibration_utils
|
|
22
|
+
from ai_edge_quantizer.utils import test_utils
|
|
23
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
|
24
|
+
|
|
25
|
+
_RNG = np.random.default_rng(66)
|
|
26
|
+
|
|
27
|
+
_CALIBRATION_DATASET = {
|
|
28
|
+
"signature_1": [{
|
|
29
|
+
"cache_0": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
|
|
30
|
+
"cache_1": np.zeros(shape=(1, 100, 4, 4), dtype=np.float32),
|
|
31
|
+
"positions": np.zeros(shape=(1, 100), dtype=np.int32),
|
|
32
|
+
"tokens": np.zeros(shape=(1, 100), dtype=np.int32),
|
|
33
|
+
}],
|
|
34
|
+
"signature_2": [{
|
|
35
|
+
"cache_0": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
36
|
+
"cache_1": _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
37
|
+
"positions": (
|
|
38
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
39
|
+
),
|
|
40
|
+
"tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32),
|
|
41
|
+
}],
|
|
42
|
+
}
|
|
43
|
+
|
|
19
44
|
|
|
45
|
+
def _get_quant_parameters(
|
|
46
|
+
quantized_model: bytes, signature_data: dict[str, list[str]]
|
|
47
|
+
) -> list[np.ndarray]:
|
|
48
|
+
"""Returns the quantization parameters from the quantized model."""
|
|
49
|
+
quant_params = []
|
|
50
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
|
51
|
+
quantized_model
|
|
52
|
+
)
|
|
53
|
+
for signature_key, signature_names in signature_data.items():
|
|
54
|
+
signature_runner = tfl_interpreter.get_signature_runner(signature_key)
|
|
55
|
+
|
|
56
|
+
for signature_name in signature_names:
|
|
57
|
+
input_details = signature_runner.get_input_details()
|
|
58
|
+
output_details = signature_runner.get_output_details()
|
|
59
|
+
if signature_name in input_details.keys():
|
|
60
|
+
quant_param = input_details[signature_name]["quantization_parameters"][
|
|
61
|
+
"scales"
|
|
62
|
+
].squeeze()
|
|
63
|
+
quant_params.append(quant_param)
|
|
64
|
+
elif signature_name in output_details.keys():
|
|
65
|
+
output_details = signature_runner.get_output_details()
|
|
66
|
+
quant_param = output_details[signature_name]["quantization_parameters"][
|
|
67
|
+
"scales"
|
|
68
|
+
].squeeze()
|
|
69
|
+
quant_params.append(quant_param)
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Signature name {signature_name} not found in the model."
|
|
73
|
+
)
|
|
74
|
+
return quant_params
|
|
20
75
|
|
|
21
|
-
|
|
76
|
+
|
|
77
|
+
class CalibrationQsvAlignmentUtilsTest(parameterized.TestCase):
|
|
22
78
|
|
|
23
79
|
@parameterized.named_parameters(
|
|
24
80
|
dict(
|
|
@@ -66,12 +122,126 @@ class CalibrationUtilsTest(parameterized.TestCase):
|
|
|
66
122
|
def test_update_tensor_qsv_min_max(self, old_qsv, new_qsv, expected_qsv):
|
|
67
123
|
updated_qsv = calibration_utils.min_max_update(old_qsv, new_qsv)
|
|
68
124
|
if isinstance(expected_qsv["min"], list):
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
125
|
+
self.assertEqual(list(updated_qsv["min"]), expected_qsv["min"])
|
|
126
|
+
self.assertEqual(list(updated_qsv["max"]), expected_qsv["max"])
|
|
71
127
|
else:
|
|
72
128
|
self.assertEqual(updated_qsv["min"], expected_qsv["min"])
|
|
73
129
|
self.assertEqual(updated_qsv["max"], expected_qsv["max"])
|
|
74
130
|
|
|
131
|
+
def test_calibration_utils_init_fails(self):
|
|
132
|
+
model_path = "non_existent_model.tflite"
|
|
133
|
+
with self.assertRaisesWithPredicateMatch(
|
|
134
|
+
Exception, lambda err: f"{model_path}" in str(err)
|
|
135
|
+
):
|
|
136
|
+
calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
|
137
|
+
|
|
138
|
+
def test_calibration_utils_init_succeeds(self):
|
|
139
|
+
model_path = test_utils.get_path_to_datafile(
|
|
140
|
+
"../tests/models/single_add.tflite"
|
|
141
|
+
)
|
|
142
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
|
143
|
+
self.assertNotEmpty(calib_utils._signature_runners)
|
|
144
|
+
self.assertNotEmpty(calib_utils._same_as_input_scale_ops)
|
|
145
|
+
|
|
146
|
+
def test_search_tensor_by_signature_name_succeeds_on_unconstrained_op(self):
|
|
147
|
+
model_path = test_utils.get_path_to_datafile(
|
|
148
|
+
"../tests/models/single_add.tflite"
|
|
149
|
+
)
|
|
150
|
+
expected_tensor_name = "PartitionedCall:0"
|
|
151
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
|
152
|
+
tensor_name = calib_utils._search_tensor_by_signature_name(
|
|
153
|
+
"serving_default", "add"
|
|
154
|
+
)
|
|
155
|
+
self.assertEqual(tensor_name, [expected_tensor_name])
|
|
156
|
+
|
|
157
|
+
def test_search_tensor_by_signature_name_succeeds_on_constrained_op(self):
|
|
158
|
+
model_path = test_utils.get_path_to_datafile(
|
|
159
|
+
"../tests/models/single_slice.tflite"
|
|
160
|
+
)
|
|
161
|
+
expected_tensor_name = "slice_input_tensor:0"
|
|
162
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
|
163
|
+
tensor_name = calib_utils._search_tensor_by_signature_name(
|
|
164
|
+
"slice", "output_0"
|
|
165
|
+
)
|
|
166
|
+
self.assertEqual(tensor_name, [expected_tensor_name])
|
|
167
|
+
|
|
168
|
+
def test_align_quant_stats_succeeds(self):
|
|
169
|
+
model_path = test_utils.get_path_to_datafile(
|
|
170
|
+
"../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
|
|
171
|
+
)
|
|
172
|
+
recipe_path = test_utils.get_path_to_datafile(
|
|
173
|
+
"../recipes/default_a8w8_recipe.json"
|
|
174
|
+
)
|
|
175
|
+
signature_data = {
|
|
176
|
+
"signature_1": ["output_1_1"],
|
|
177
|
+
"signature_2": ["output_1_1"],
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Obtain the calibration results.
|
|
181
|
+
qt = quantizer.Quantizer(model_path, recipe_path)
|
|
182
|
+
qsv = qt.calibrate(_CALIBRATION_DATASET)
|
|
183
|
+
|
|
184
|
+
# First quantize the model without aligning the quantization parameters.
|
|
185
|
+
quantized_model = qt.quantize(qsv).quantized_model
|
|
186
|
+
quant_params = _get_quant_parameters(quantized_model, signature_data)
|
|
187
|
+
self.assertFalse(
|
|
188
|
+
all(x == quant_params[0] for x in quant_params)
|
|
189
|
+
) # not equal quantization params.
|
|
190
|
+
|
|
191
|
+
# Align the quantization parameters and quantize again.
|
|
192
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
|
193
|
+
calib_utils.align_quant_stats(qsv, signature_data)
|
|
194
|
+
quantized_model = qt.quantize(qsv).quantized_model
|
|
195
|
+
quant_params = _get_quant_parameters(quantized_model, signature_data)
|
|
196
|
+
self.assertTrue(
|
|
197
|
+
all(x == quant_params[0] for x in quant_params)
|
|
198
|
+
) # equal quantization params.
|
|
199
|
+
|
|
200
|
+
def test_update_quant_stats_succeeds(self):
|
|
201
|
+
model_path = test_utils.get_path_to_datafile(
|
|
202
|
+
"../tests/models/toy_model_with_kv_cache_multi_signature.tflite"
|
|
203
|
+
)
|
|
204
|
+
recipe_path = test_utils.get_path_to_datafile(
|
|
205
|
+
"../recipes/default_a8w8_recipe.json"
|
|
206
|
+
)
|
|
207
|
+
signature_data = {
|
|
208
|
+
"signature_1": ["output_1_1"],
|
|
209
|
+
"signature_2": ["output_1_1"],
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
# Obtain the calibration results.
|
|
213
|
+
qt = quantizer.Quantizer(model_path, recipe_path)
|
|
214
|
+
qsv = qt.calibrate(_CALIBRATION_DATASET)
|
|
215
|
+
|
|
216
|
+
# First quantize the model without updating the `signature_1`.
|
|
217
|
+
quantized_model = qt.quantize(qsv).quantized_model
|
|
218
|
+
quant_params = _get_quant_parameters(quantized_model, signature_data)
|
|
219
|
+
self.assertFalse(
|
|
220
|
+
all(x == quant_params[0] for x in quant_params)
|
|
221
|
+
) # not equal quantization params.
|
|
222
|
+
|
|
223
|
+
# Update the `signature_1` with stats from `signature_2`.
|
|
224
|
+
calib_utils = calibration_utils.CalibrationQsvAlignmentUtils(model_path)
|
|
225
|
+
min_val, max_val = calib_utils.align_quant_stats( # for min and max only.
|
|
226
|
+
qsv,
|
|
227
|
+
{
|
|
228
|
+
"signature_2": ["output_1_1"],
|
|
229
|
+
},
|
|
230
|
+
)
|
|
231
|
+
calib_utils.update_quant_stats(
|
|
232
|
+
qsv,
|
|
233
|
+
{
|
|
234
|
+
"signature_1": ["output_1_1"],
|
|
235
|
+
},
|
|
236
|
+
min_val,
|
|
237
|
+
max_val,
|
|
238
|
+
)
|
|
239
|
+
quantized_model = qt.quantize(qsv).quantized_model
|
|
240
|
+
quant_params = _get_quant_parameters(quantized_model, signature_data)
|
|
241
|
+
self.assertTrue(
|
|
242
|
+
all(x == quant_params[0] for x in quant_params)
|
|
243
|
+
) # equal quantization params.
|
|
244
|
+
|
|
75
245
|
|
|
76
246
|
if __name__ == "__main__":
|
|
77
247
|
googletest.main()
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Utils for handling operators with quantization constraints."""
|
|
17
|
+
|
|
18
|
+
from ai_edge_quantizer import algorithm_manager
|
|
19
|
+
from ai_edge_quantizer import qtyping
|
|
20
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
|
|
21
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
|
22
|
+
from ai_edge_litert import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_constrained_op_list(
|
|
29
|
+
quant_constraint: _OpQuantConstraint, verbose: bool = False
|
|
30
|
+
) -> list[str]:
|
|
31
|
+
"""Constructs and returns a list of constrained operators.
|
|
32
|
+
|
|
33
|
+
This is achieved by invoking all materialization functions and extracting
|
|
34
|
+
the constraint argument, using monkey patching to redirect logic to wrapper
|
|
35
|
+
functions.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
quant_constraint: The quantization constraint to filter operators by.
|
|
39
|
+
verbose: Flag to enable verbose output.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A list containing operators with the specified constraint.
|
|
43
|
+
"""
|
|
44
|
+
constrained_ops = []
|
|
45
|
+
|
|
46
|
+
def materialize_standard_op_wrapper(
|
|
47
|
+
op_info: qtyping.OpInfo,
|
|
48
|
+
*_args,
|
|
49
|
+
constraint: _OpQuantConstraint = _OpQuantConstraint.NO_CONSTRAIN,
|
|
50
|
+
**_kwargs,
|
|
51
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
52
|
+
if constraint == quant_constraint:
|
|
53
|
+
constrained_ops.append(op_info.op_name)
|
|
54
|
+
# Return dummy values to avoid exceptions.
|
|
55
|
+
dummy_value = [qtyping.TensorTransformationParams("")] * 2
|
|
56
|
+
return dummy_value
|
|
57
|
+
|
|
58
|
+
# Dummy implementation of the `_are_weights_too_small` function to support
|
|
59
|
+
# `materialize_standard_op_wrapper` above.
|
|
60
|
+
def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
# Dummy implementation of the `_materialize_bias_for_fc_conv_ops` function to
|
|
64
|
+
# support `materialize_standard_op_wrapper` above.
|
|
65
|
+
def materialize_bias_for_fc_conv_ops_wrapper(*_args, **_kwargs):
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
# Do monkey patch to intercept the `materialize_standard_op` function to
|
|
69
|
+
# support `materialize_standard_op_wrapper` above.
|
|
70
|
+
original_materialize_standard_op = common_utils.materialize_standard_op
|
|
71
|
+
original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
|
|
72
|
+
original_materialize_bias_for_fc_conv_ops = (
|
|
73
|
+
common_quantize._materialize_bias_for_fc_conv_ops # pylint: disable=protected-access
|
|
74
|
+
)
|
|
75
|
+
common_utils.materialize_standard_op = materialize_standard_op_wrapper
|
|
76
|
+
common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
|
|
77
|
+
common_quantize._materialize_bias_for_fc_conv_ops = ( # pylint: disable=protected-access
|
|
78
|
+
materialize_bias_for_fc_conv_ops_wrapper
|
|
79
|
+
)
|
|
80
|
+
minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
|
|
81
|
+
|
|
82
|
+
# Loop over all available materialization functions to build up a list of
|
|
83
|
+
# ops with the given constraint.
|
|
84
|
+
for op, materialize_fn in minmax_func_dict.items():
|
|
85
|
+
# Create a dummy op info to trigger the materialization.
|
|
86
|
+
mock_op = schema_fb.OperatorT()
|
|
87
|
+
mock_op.inputs = [0]
|
|
88
|
+
mock_op.outputs = [0]
|
|
89
|
+
op_info = qtyping.OpInfo(
|
|
90
|
+
op=mock_op,
|
|
91
|
+
op_name=op,
|
|
92
|
+
subgraph_op_index=0,
|
|
93
|
+
op_quant_config=qtyping.OpQuantizationConfig(),
|
|
94
|
+
)
|
|
95
|
+
materialize_fn(
|
|
96
|
+
get_tensor_quant_params_fn=None,
|
|
97
|
+
op_info=op_info,
|
|
98
|
+
graph_info=None,
|
|
99
|
+
tensor_name_to_qsv=None,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if verbose:
|
|
103
|
+
print(f" {quant_constraint} op list: {constrained_ops}")
|
|
104
|
+
|
|
105
|
+
# Restore the original functions.
|
|
106
|
+
common_utils.materialize_standard_op = original_materialize_standard_op
|
|
107
|
+
common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
|
|
108
|
+
common_quantize._materialize_bias_for_fc_conv_ops = ( # pylint: disable=protected-access
|
|
109
|
+
original_materialize_bias_for_fc_conv_ops
|
|
110
|
+
)
|
|
111
|
+
return constrained_ops
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from tensorflow.python.platform import googletest
|
|
17
|
+
from absl.testing import parameterized
|
|
18
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
|
19
|
+
from ai_edge_quantizer.utils import constrained_ops_utils
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ConstrainedOpsUtilsTest(parameterized.TestCase):
|
|
26
|
+
|
|
27
|
+
@parameterized.named_parameters(
|
|
28
|
+
dict(
|
|
29
|
+
testcase_name="same_as_input_scale",
|
|
30
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
31
|
+
expected_num_ops=18,
|
|
32
|
+
),
|
|
33
|
+
dict(
|
|
34
|
+
testcase_name="same_as_output_scale",
|
|
35
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
36
|
+
expected_num_ops=7,
|
|
37
|
+
),
|
|
38
|
+
dict(
|
|
39
|
+
testcase_name="no_constrain",
|
|
40
|
+
constraint=_OpQuantConstraint.NO_CONSTRAIN,
|
|
41
|
+
expected_num_ops=25,
|
|
42
|
+
),
|
|
43
|
+
)
|
|
44
|
+
def test_get_constrained_op_list(self, constraint, expected_num_ops):
|
|
45
|
+
constrained_ops = constrained_ops_utils.get_constrained_op_list(constraint)
|
|
46
|
+
self.assertLen(constrained_ops, expected_num_ops)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
if __name__ == "__main__":
|
|
50
|
+
googletest.main()
|
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
import inspect as _inspect
|
|
19
19
|
import os.path as _os_path
|
|
20
20
|
import sys as _sys
|
|
21
|
+
from typing import Optional, Union
|
|
21
22
|
|
|
22
23
|
from absl.testing import parameterized
|
|
23
24
|
|
|
@@ -31,6 +32,40 @@ _OpName = qtyping.TFLOperationName
|
|
|
31
32
|
_TensorQuantConfig = qtyping.TensorQuantizationConfig
|
|
32
33
|
_OpQuantConfig = qtyping.OpQuantizationConfig
|
|
33
34
|
_AlgorithmName = quantizer.AlgorithmName
|
|
35
|
+
_Numeric = Union[int, float]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
DEFAULT_ACTIVATION_QUANT_SETTING = _TensorQuantConfig(
|
|
39
|
+
num_bits=8,
|
|
40
|
+
symmetric=False,
|
|
41
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
|
42
|
+
)
|
|
43
|
+
DEFAULT_WEIGHT_QUANT_SETTING = _TensorQuantConfig(
|
|
44
|
+
num_bits=8,
|
|
45
|
+
symmetric=True,
|
|
46
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_static_activation_quant_setting(
|
|
51
|
+
num_bits: int, symmetric: bool
|
|
52
|
+
) -> _TensorQuantConfig:
|
|
53
|
+
return _TensorQuantConfig(
|
|
54
|
+
num_bits=num_bits,
|
|
55
|
+
symmetric=symmetric,
|
|
56
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_static_op_quant_config(
|
|
61
|
+
activation_config: _TensorQuantConfig = DEFAULT_ACTIVATION_QUANT_SETTING,
|
|
62
|
+
weight_config: _TensorQuantConfig = DEFAULT_WEIGHT_QUANT_SETTING,
|
|
63
|
+
) -> _OpQuantConfig:
|
|
64
|
+
return qtyping.OpQuantizationConfig(
|
|
65
|
+
activation_tensor_config=activation_config,
|
|
66
|
+
weight_tensor_config=weight_config,
|
|
67
|
+
compute_precision=_ComputePrecision.INTEGER,
|
|
68
|
+
)
|
|
34
69
|
|
|
35
70
|
|
|
36
71
|
def get_path_to_datafile(path):
|
|
@@ -64,7 +99,9 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
|
64
99
|
op_name: _OpName,
|
|
65
100
|
op_config: _OpQuantConfig,
|
|
66
101
|
num_validation_samples: int = 4,
|
|
102
|
+
num_calibration_samples: Optional[int] = None,
|
|
67
103
|
error_metric: str = 'mse',
|
|
104
|
+
min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
|
|
68
105
|
) -> model_validator.ComparisonResult:
|
|
69
106
|
"""Quantizes and validates the given model with the given configurations.
|
|
70
107
|
|
|
@@ -74,7 +111,10 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
|
74
111
|
op_name: The name of the operation to be quantized.
|
|
75
112
|
op_config: The configuration for the operation to be quantized.
|
|
76
113
|
num_validation_samples: The number of samples to use for validation.
|
|
114
|
+
num_calibration_samples: The number of samples to use for calibration. If
|
|
115
|
+
None then it will be set to num_validation_samples * 8.
|
|
77
116
|
error_metric: The error error_metric to use for validation.
|
|
117
|
+
min_max_range: The min and max of the input range.
|
|
78
118
|
|
|
79
119
|
Returns:
|
|
80
120
|
The comparison result of the validation.
|
|
@@ -87,15 +127,21 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
|
87
127
|
op_config=op_config,
|
|
88
128
|
)
|
|
89
129
|
if quantizer_instance.need_calibration:
|
|
130
|
+
if num_calibration_samples is None:
|
|
131
|
+
num_calibration_samples = num_validation_samples * 8
|
|
90
132
|
calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
|
|
91
|
-
quantizer_instance.float_model,
|
|
133
|
+
quantizer_instance.float_model,
|
|
134
|
+
num_samples=num_calibration_samples,
|
|
135
|
+
min_max_range=min_max_range,
|
|
92
136
|
)
|
|
93
137
|
calibration_result = quantizer_instance.calibrate(calibration_data)
|
|
94
138
|
quantization_result = quantizer_instance.quantize(calibration_result)
|
|
95
139
|
else:
|
|
96
140
|
quantization_result = quantizer_instance.quantize()
|
|
97
141
|
test_data = tfl_interpreter_utils.create_random_normal_input_data(
|
|
98
|
-
quantization_result.quantized_model,
|
|
142
|
+
quantization_result.quantized_model,
|
|
143
|
+
num_samples=num_validation_samples,
|
|
144
|
+
min_max_range=min_max_range,
|
|
99
145
|
)
|
|
100
146
|
return quantizer_instance.validate(test_data, error_metric)
|
|
101
147
|
|
|
@@ -145,6 +191,7 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
|
145
191
|
expected_model_size_reduction: float,
|
|
146
192
|
weight_tolerance: float = 1e-4,
|
|
147
193
|
output_tolerance: float = 1e-4,
|
|
194
|
+
min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
|
|
148
195
|
):
|
|
149
196
|
"""Check if the quantization is successful and the result is valid."""
|
|
150
197
|
validation_result = self.quantize_and_validate(
|
|
@@ -152,6 +199,7 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
|
152
199
|
algorithm_key=algorithm_key,
|
|
153
200
|
op_name=op_name,
|
|
154
201
|
op_config=op_config,
|
|
202
|
+
min_max_range=min_max_range,
|
|
155
203
|
)
|
|
156
204
|
with self.subTest(name='ModelSizeReduction'):
|
|
157
205
|
self.assert_model_size_reduction_above_min_pct(
|
|
@@ -165,3 +213,28 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
|
165
213
|
self.assert_output_errors_below_tolerance(
|
|
166
214
|
validation_result, output_tolerance
|
|
167
215
|
)
|
|
216
|
+
|
|
217
|
+
def assert_quantization_accuracy(
|
|
218
|
+
self,
|
|
219
|
+
algorithm_key: _AlgorithmName,
|
|
220
|
+
model_path: str,
|
|
221
|
+
op_name: _OpName,
|
|
222
|
+
op_config: _OpQuantConfig,
|
|
223
|
+
num_validation_samples: int = 4,
|
|
224
|
+
num_calibration_samples: Optional[int] = None,
|
|
225
|
+
output_tolerance: float = 1e-4,
|
|
226
|
+
min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
|
|
227
|
+
):
|
|
228
|
+
"""Checks if the output errors after quantization are within the tolerance."""
|
|
229
|
+
validation_result = self.quantize_and_validate(
|
|
230
|
+
model_path=model_path,
|
|
231
|
+
algorithm_key=algorithm_key,
|
|
232
|
+
num_validation_samples=num_validation_samples,
|
|
233
|
+
num_calibration_samples=num_calibration_samples,
|
|
234
|
+
op_name=op_name,
|
|
235
|
+
op_config=op_config,
|
|
236
|
+
min_max_range=min_max_range,
|
|
237
|
+
)
|
|
238
|
+
self.assert_output_errors_below_tolerance(
|
|
239
|
+
validation_result, output_tolerance
|
|
240
|
+
)
|
|
@@ -20,10 +20,10 @@ from typing import Any, Optional, Union
|
|
|
20
20
|
import immutabledict
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
23
|
+
from ai_edge_litert.tools import flatbuffer_utils
|
|
23
24
|
from ai_edge_quantizer import qtyping
|
|
24
25
|
from ai_edge_litert import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
|
|
25
|
-
|
|
26
|
-
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
|
|
26
|
+
import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
|
|
27
27
|
|
|
28
28
|
_TFLOpName = qtyping.TFLOperationName
|
|
29
29
|
|
|
@@ -51,11 +51,35 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
|
|
|
51
51
|
_TFLOpName.LOGISTIC: schema.BuiltinOperator.LOGISTIC,
|
|
52
52
|
_TFLOpName.SLICE: schema.BuiltinOperator.SLICE,
|
|
53
53
|
_TFLOpName.SUM: schema.BuiltinOperator.SUM,
|
|
54
|
+
_TFLOpName.SELECT: schema.BuiltinOperator.SELECT,
|
|
54
55
|
_TFLOpName.SELECT_V2: schema.BuiltinOperator.SELECT_V2,
|
|
55
56
|
_TFLOpName.STABLEHLO_COMPOSITE: schema.BuiltinOperator.STABLEHLO_COMPOSITE,
|
|
56
57
|
_TFLOpName.DYNAMIC_UPDATE_SLICE: (
|
|
57
58
|
schema.BuiltinOperator.DYNAMIC_UPDATE_SLICE
|
|
58
59
|
),
|
|
60
|
+
_TFLOpName.PAD: schema.BuiltinOperator.PAD,
|
|
61
|
+
_TFLOpName.SQUARED_DIFFERENCE: schema.BuiltinOperator.SQUARED_DIFFERENCE,
|
|
62
|
+
_TFLOpName.MAX_POOL_2D: schema.BuiltinOperator.MAX_POOL_2D,
|
|
63
|
+
_TFLOpName.RESIZE_BILINEAR: schema.BuiltinOperator.RESIZE_BILINEAR,
|
|
64
|
+
_TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
|
|
65
|
+
schema.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR
|
|
66
|
+
),
|
|
67
|
+
_TFLOpName.GATHER_ND: schema.BuiltinOperator.GATHER_ND,
|
|
68
|
+
_TFLOpName.PACK: schema.BuiltinOperator.PACK,
|
|
69
|
+
_TFLOpName.UNPACK: schema.BuiltinOperator.UNPACK,
|
|
70
|
+
_TFLOpName.DIV: schema.BuiltinOperator.DIV,
|
|
71
|
+
_TFLOpName.BROADCAST_TO: schema.BuiltinOperator.BROADCAST_TO,
|
|
72
|
+
_TFLOpName.SQRT: schema.BuiltinOperator.SQRT,
|
|
73
|
+
_TFLOpName.GATHER: schema.BuiltinOperator.GATHER,
|
|
74
|
+
_TFLOpName.HARD_SWISH: schema.BuiltinOperator.HARD_SWISH,
|
|
75
|
+
_TFLOpName.MAXIMUM: schema.BuiltinOperator.MAXIMUM,
|
|
76
|
+
_TFLOpName.PADV2: schema.BuiltinOperator.PADV2,
|
|
77
|
+
_TFLOpName.REDUCE_MIN: schema.BuiltinOperator.REDUCE_MIN,
|
|
78
|
+
_TFLOpName.EQUAL: schema.BuiltinOperator.EQUAL,
|
|
79
|
+
_TFLOpName.NOT_EQUAL: schema.BuiltinOperator.NOT_EQUAL,
|
|
80
|
+
_TFLOpName.MIRROR_PAD: schema.BuiltinOperator.MIRROR_PAD,
|
|
81
|
+
_TFLOpName.SPACE_TO_DEPTH: schema.BuiltinOperator.SPACE_TO_DEPTH,
|
|
82
|
+
_TFLOpName.RELU: schema.BuiltinOperator.RELU,
|
|
59
83
|
})
|
|
60
84
|
|
|
61
85
|
TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
|
|
@@ -86,7 +110,7 @@ TENSOR_TYPE_TO_CODE = immutabledict.immutabledict(
|
|
|
86
110
|
(reversed(item) for item in TENSOR_CODE_TO_TYPE.items())
|
|
87
111
|
)
|
|
88
112
|
|
|
89
|
-
# Expose functions in
|
|
113
|
+
# Expose functions in litert.python.tools.flatbuffer_utils
|
|
90
114
|
write_model = flatbuffer_utils.write_model
|
|
91
115
|
|
|
92
116
|
|
|
@@ -121,7 +145,7 @@ def get_model_content(tflite_path: str) -> bytes:
|
|
|
121
145
|
Returns:
|
|
122
146
|
The model bytes.
|
|
123
147
|
"""
|
|
124
|
-
with
|
|
148
|
+
with open(tflite_path, "rb") as tflite_file:
|
|
125
149
|
return tflite_file.read()
|
|
126
150
|
|
|
127
151
|
|
|
@@ -134,7 +158,7 @@ def get_model_buffer(tflite_path: str) -> bytearray:
|
|
|
134
158
|
Returns:
|
|
135
159
|
model_buffer: the model buffer.
|
|
136
160
|
"""
|
|
137
|
-
with
|
|
161
|
+
with open(tflite_path, "rb") as tflite_file:
|
|
138
162
|
return bytearray(tflite_file.read())
|
|
139
163
|
|
|
140
164
|
|
|
@@ -187,7 +211,7 @@ def parse_fc_bmm_conv_tensors(
|
|
|
187
211
|
return input_tensor, weight_tensor, bias_tensor, output_tensor
|
|
188
212
|
|
|
189
213
|
|
|
190
|
-
# flatbuffer_model has Any type since
|
|
214
|
+
# flatbuffer_model has Any type since litert.python.tools.flatbuffer_utils
|
|
191
215
|
# is not type annotated.
|
|
192
216
|
def buffer_to_tensors(flatbuffer_model: Any) -> dict[int, list[Any]]:
|
|
193
217
|
"""Returns a map from buffer id to tensors that use it."""
|
|
@@ -328,3 +352,12 @@ def get_op_side_effect_subgraphs(
|
|
|
328
352
|
return [opts.decompositionSubgraphIndex]
|
|
329
353
|
# Can add other nested ops here (control flow ops, etc).
|
|
330
354
|
return []
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def get_op_name_by_index(
|
|
358
|
+
flatbuffer_model: Any, subgraph_id: int, op_index: int
|
|
359
|
+
) -> str:
|
|
360
|
+
"""Get the op name from the flatbuffer model."""
|
|
361
|
+
op = flatbuffer_model.subgraphs[subgraph_id].operators[op_index]
|
|
362
|
+
builtin_code = flatbuffer_model.operatorCodes[op.opcodeIndex].builtinCode
|
|
363
|
+
return TFL_OP_CODE_TO_NAME[builtin_code]
|