ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +158 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
- ai_edge_quantizer/calibrator.py +11 -60
- ai_edge_quantizer/calibrator_test.py +4 -73
- ai_edge_quantizer/default_policy.py +61 -26
- ai_edge_quantizer/model_modifier.py +97 -7
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +31 -8
- ai_edge_quantizer/params_generator.py +17 -10
- ai_edge_quantizer/params_generator_test.py +2 -7
- ai_edge_quantizer/qtyping.py +86 -6
- ai_edge_quantizer/quantizer.py +166 -21
- ai_edge_quantizer/quantizer_test.py +284 -16
- ai_edge_quantizer/recipe.py +154 -42
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +118 -13
- ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
- ai_edge_quantizer/transformation_performer.py +55 -25
- ai_edge_quantizer/transformation_performer_test.py +127 -5
- ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
- ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
- ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
- ai_edge_quantizer/transformations/transformation_utils.py +129 -6
- ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +75 -2
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -103,58 +103,6 @@ class CalibratorTest(googletest.TestCase):
|
|
|
103
103
|
model_tensor_qsvs = self._calibrator.get_model_qsvs()
|
|
104
104
|
self.assertEmpty(model_tensor_qsvs)
|
|
105
105
|
|
|
106
|
-
def test_calibrator_initialize_qsv(self):
|
|
107
|
-
_add_default_int8xint8_integer_recipe(self._recipe_manager)
|
|
108
|
-
# Overwrite the single op to fc
|
|
109
|
-
self._recipe_manager.add_quantization_config(
|
|
110
|
-
regex=".*Stateful.*",
|
|
111
|
-
operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
|
112
|
-
algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
113
|
-
op_config=qtyping.OpQuantizationConfig(
|
|
114
|
-
weight_tensor_config=_TENSOR_QUANT_CONFIG(
|
|
115
|
-
num_bits=4,
|
|
116
|
-
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
|
117
|
-
),
|
|
118
|
-
compute_precision=_ComputePrecision.INTEGER,
|
|
119
|
-
),
|
|
120
|
-
)
|
|
121
|
-
self._calibrator._initialize_model_qsvs(self._recipe_manager)
|
|
122
|
-
model_tensor_qsvs = self._calibrator.get_model_qsvs()
|
|
123
|
-
|
|
124
|
-
self.assertLen(model_tensor_qsvs, 4)
|
|
125
|
-
self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
|
|
126
|
-
input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
|
|
127
|
-
self.assertEmpty(input_qsv)
|
|
128
|
-
|
|
129
|
-
self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
|
|
130
|
-
weight_tensor_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
|
|
131
|
-
mins_maxs_shape = (16, 1)
|
|
132
|
-
self.assertTupleEqual(weight_tensor_qsv["min"].shape, mins_maxs_shape)
|
|
133
|
-
self.assertAlmostEqual(weight_tensor_qsv["min"][0][0], -0.40436327)
|
|
134
|
-
self.assertTupleEqual(weight_tensor_qsv["max"].shape, mins_maxs_shape)
|
|
135
|
-
self.assertAlmostEqual(weight_tensor_qsv["max"][0][0], 0.46138108)
|
|
136
|
-
|
|
137
|
-
self.assertIn(
|
|
138
|
-
"sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
|
|
139
|
-
) # bias
|
|
140
|
-
bias_tensor_qsv = model_tensor_qsvs[
|
|
141
|
-
"sequential/dense/BiasAdd/ReadVariableOp"
|
|
142
|
-
]
|
|
143
|
-
mins_maxs_shape = (16,)
|
|
144
|
-
self.assertTupleEqual(bias_tensor_qsv["min"].shape, mins_maxs_shape)
|
|
145
|
-
self.assertAlmostEqual(bias_tensor_qsv["min"][0], -0.26978338)
|
|
146
|
-
self.assertTupleEqual(bias_tensor_qsv["max"].shape, mins_maxs_shape)
|
|
147
|
-
# Here bias min/max will be the same as each element is a scalar
|
|
148
|
-
# Bias will be quantized with input_scale * weight_scale.
|
|
149
|
-
self.assertSequenceEqual(
|
|
150
|
-
list(bias_tensor_qsv["max"].flatten()),
|
|
151
|
-
list(bias_tensor_qsv["min"].flatten()),
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
|
|
155
|
-
output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
|
|
156
|
-
self.assertEmpty(output_qsv)
|
|
157
|
-
|
|
158
106
|
def test_calibrate_single_fc_success(self):
|
|
159
107
|
_add_default_int8xint8_integer_recipe(self._recipe_manager)
|
|
160
108
|
self._calibrator.calibrate(
|
|
@@ -162,7 +110,7 @@ class CalibratorTest(googletest.TestCase):
|
|
|
162
110
|
)
|
|
163
111
|
model_tensor_qsvs = self._calibrator.get_model_qsvs()
|
|
164
112
|
|
|
165
|
-
self.assertLen(model_tensor_qsvs,
|
|
113
|
+
self.assertLen(model_tensor_qsvs, 2)
|
|
166
114
|
self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
|
|
167
115
|
input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
|
|
168
116
|
self.assertSequenceAlmostEqual(
|
|
@@ -171,19 +119,6 @@ class CalibratorTest(googletest.TestCase):
|
|
|
171
119
|
self.assertSequenceAlmostEqual(
|
|
172
120
|
input_qsv["max"].flatten(), [TEST_MAX_VAL], delta=1e-5
|
|
173
121
|
)
|
|
174
|
-
|
|
175
|
-
self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
|
|
176
|
-
weight_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
|
|
177
|
-
self.assertSequenceAlmostEqual(weight_qsv["min"].flatten(), [-0.49114203])
|
|
178
|
-
self.assertSequenceAlmostEqual(weight_qsv["max"].flatten(), [0.4903704])
|
|
179
|
-
|
|
180
|
-
self.assertIn(
|
|
181
|
-
"sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
|
|
182
|
-
) # bias
|
|
183
|
-
bias_qsv = model_tensor_qsvs["sequential/dense/BiasAdd/ReadVariableOp"]
|
|
184
|
-
self.assertSequenceAlmostEqual(bias_qsv["min"].flatten(), [-0.38401994])
|
|
185
|
-
self.assertSequenceAlmostEqual(bias_qsv["max"].flatten(), [0.31727126])
|
|
186
|
-
|
|
187
122
|
self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
|
|
188
123
|
output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
|
|
189
124
|
# Relu, only check the min
|
|
@@ -249,15 +184,11 @@ class CalibratorAlreadyQuantizedModelTest(googletest.TestCase):
|
|
|
249
184
|
)
|
|
250
185
|
_ = calibrator.Calibrator(test_model_path)
|
|
251
186
|
|
|
252
|
-
def
|
|
187
|
+
def test_check_is_quantized_model_succeeds_when_model_is_quantized(self):
|
|
253
188
|
test_model_path = os.path.join(
|
|
254
189
|
TEST_DATA_PREFIX_PATH, "tests/models/mnist_quantized.tflite"
|
|
255
190
|
)
|
|
256
|
-
|
|
257
|
-
ValueError,
|
|
258
|
-
"The input model for calibration is not a float model.",
|
|
259
|
-
):
|
|
260
|
-
_ = calibrator.Calibrator(test_model_path)
|
|
191
|
+
_ = calibrator.Calibrator(test_model_path)
|
|
261
192
|
|
|
262
193
|
|
|
263
194
|
class CalibratorToyGemma2Test(googletest.TestCase):
|
|
@@ -302,7 +233,7 @@ class CalibratorToyGemma2Test(googletest.TestCase):
|
|
|
302
233
|
self._toy_gemma2_calibration_dataset,
|
|
303
234
|
model_recipe_manager=recipe_mngr,
|
|
304
235
|
)
|
|
305
|
-
self.assertLen(calib.get_model_qsvs(),
|
|
236
|
+
self.assertLen(calib.get_model_qsvs(), 202)
|
|
306
237
|
|
|
307
238
|
|
|
308
239
|
if __name__ == "__main__":
|
|
@@ -19,9 +19,9 @@ import collections
|
|
|
19
19
|
import copy
|
|
20
20
|
import json
|
|
21
21
|
from typing import Any, Union
|
|
22
|
+
from ai_edge_litert.tools import flatbuffer_utils
|
|
22
23
|
from ai_edge_quantizer import qtyping
|
|
23
24
|
from ai_edge_litert import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
|
|
24
|
-
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
|
|
25
25
|
|
|
26
26
|
_TFLOpName = qtyping.TFLOperationName
|
|
27
27
|
_OpQuantizationConfig = qtyping.OpQuantizationConfig
|
|
@@ -61,9 +61,8 @@ DEFAULT_JSON_POLICY = """
|
|
|
61
61
|
"weight_tensor_config": {
|
|
62
62
|
"num_bits": 4,
|
|
63
63
|
"symmetric": [true],
|
|
64
|
-
"granularity": ["
|
|
65
|
-
"dtype": "INT"
|
|
66
|
-
"block_size": [32, 64, 96, 128, 256]
|
|
64
|
+
"granularity": ["BLOCKWISE_32", "BLOCKWISE_64", "BLOCKWISE_128", "BLOCKWISE_256"],
|
|
65
|
+
"dtype": "INT"
|
|
67
66
|
},
|
|
68
67
|
"explicit_dequantize": false,
|
|
69
68
|
"compute_precision": "INTEGER"
|
|
@@ -178,12 +177,30 @@ DEFAULT_JSON_POLICY = """
|
|
|
178
177
|
"INPUT",
|
|
179
178
|
"OUTPUT",
|
|
180
179
|
"SLICE",
|
|
181
|
-
"EMBEDDING_LOOKUP",
|
|
182
180
|
"SUM",
|
|
181
|
+
"SELECT",
|
|
183
182
|
"SELECT_V2",
|
|
184
183
|
"DYNAMIC_UPDATE_SLICE",
|
|
185
184
|
"SELECT_V2",
|
|
186
|
-
"STABLEHLO_COMPOSITE"
|
|
185
|
+
"STABLEHLO_COMPOSITE",
|
|
186
|
+
"PAD",
|
|
187
|
+
"MAX_POOL_2D",
|
|
188
|
+
"RESIZE_BILINEAR",
|
|
189
|
+
"RESIZE_NEAREST_NEIGHBOR",
|
|
190
|
+
"GATHER_ND",
|
|
191
|
+
"PACK",
|
|
192
|
+
"UNPACK",
|
|
193
|
+
"DIV",
|
|
194
|
+
"BROADCAST_TO",
|
|
195
|
+
"SQRT",
|
|
196
|
+
"GATHER",
|
|
197
|
+
"MAXIMUM",
|
|
198
|
+
"PADV2",
|
|
199
|
+
"REDUCE_MIN",
|
|
200
|
+
"EQUAL",
|
|
201
|
+
"NOT_EQUAL",
|
|
202
|
+
"MIRROR_PAD",
|
|
203
|
+
"RELU"
|
|
187
204
|
],
|
|
188
205
|
"static_wi8_ai8": [
|
|
189
206
|
"ADD",
|
|
@@ -209,15 +226,36 @@ DEFAULT_JSON_POLICY = """
|
|
|
209
226
|
"INPUT",
|
|
210
227
|
"OUTPUT",
|
|
211
228
|
"SLICE",
|
|
212
|
-
"EMBEDDING_LOOKUP",
|
|
213
229
|
"SUM",
|
|
230
|
+
"SELECT",
|
|
214
231
|
"SELECT_V2",
|
|
215
232
|
"DYNAMIC_UPDATE_SLICE",
|
|
216
233
|
"SELECT_V2",
|
|
217
|
-
"STABLEHLO_COMPOSITE"
|
|
234
|
+
"STABLEHLO_COMPOSITE",
|
|
235
|
+
"PAD",
|
|
236
|
+
"SQUARED_DIFFERENCE",
|
|
237
|
+
"MAX_POOL_2D",
|
|
238
|
+
"RESIZE_BILINEAR",
|
|
239
|
+
"RESIZE_NEAREST_NEIGHBOR",
|
|
240
|
+
"GATHER_ND",
|
|
241
|
+
"PACK",
|
|
242
|
+
"UNPACK",
|
|
243
|
+
"DIV",
|
|
244
|
+
"BROADCAST_TO",
|
|
245
|
+
"SQRT",
|
|
246
|
+
"GATHER",
|
|
247
|
+
"HARD_SWISH",
|
|
248
|
+
"MAXIMUM",
|
|
249
|
+
"PADV2",
|
|
250
|
+
"REDUCE_MIN",
|
|
251
|
+
"EQUAL",
|
|
252
|
+
"NOT_EQUAL",
|
|
253
|
+
"MIRROR_PAD",
|
|
254
|
+
"SPACE_TO_DEPTH",
|
|
255
|
+
"RELU"
|
|
218
256
|
],
|
|
219
|
-
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"
|
|
220
|
-
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"
|
|
257
|
+
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
|
|
258
|
+
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
|
|
221
259
|
"dynamic_wi8_afp32": [
|
|
222
260
|
"BATCH_MATMUL",
|
|
223
261
|
"CONV_2D",
|
|
@@ -240,6 +278,11 @@ DEFAULT_JSON_POLICY = """
|
|
|
240
278
|
}
|
|
241
279
|
}
|
|
242
280
|
"""
|
|
281
|
+
QUANTIZABLE_COMPOSITES = [
|
|
282
|
+
"od" + "ml.npu_call",
|
|
283
|
+
"od" + "ml.rms_norm",
|
|
284
|
+
"od" + "ml.l2_norm",
|
|
285
|
+
]
|
|
243
286
|
|
|
244
287
|
|
|
245
288
|
def _unroll_json_config(
|
|
@@ -280,16 +323,9 @@ def _unroll_json_config(
|
|
|
280
323
|
"granularity": granularity,
|
|
281
324
|
"dtype": json_config["weight_tensor_config"]["dtype"],
|
|
282
325
|
}
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
weight_configs.append(
|
|
287
|
-
qtyping.TensorQuantizationConfig.from_dict(tensor_config)
|
|
288
|
-
)
|
|
289
|
-
else:
|
|
290
|
-
weight_configs.append(
|
|
291
|
-
qtyping.TensorQuantizationConfig.from_dict(tensor_config)
|
|
292
|
-
)
|
|
326
|
+
weight_configs.append(
|
|
327
|
+
qtyping.TensorQuantizationConfig.from_dict(tensor_config)
|
|
328
|
+
)
|
|
293
329
|
|
|
294
330
|
if activation_configs:
|
|
295
331
|
for activation_config in activation_configs:
|
|
@@ -317,10 +353,10 @@ def _unroll_json_config(
|
|
|
317
353
|
|
|
318
354
|
|
|
319
355
|
# TODO: b/401024954 - Have a better way to specify recipes based on op options.
|
|
320
|
-
def
|
|
356
|
+
def is_non_quantizable_composite_op(
|
|
321
357
|
op: Union[schema.Operator, schema.OperatorT],
|
|
322
358
|
) -> bool:
|
|
323
|
-
"""Checks if the operator is
|
|
359
|
+
"""Checks if the operator is a non-quantizable composite op.
|
|
324
360
|
|
|
325
361
|
We may want to quantize an op only when its has certain options.
|
|
326
362
|
Policies/recipes
|
|
@@ -335,10 +371,9 @@ def is_conditionally_unquantized(
|
|
|
335
371
|
if opts := flatbuffer_utils.get_options_as(
|
|
336
372
|
op, schema.StableHLOCompositeOptionsT
|
|
337
373
|
):
|
|
338
|
-
name
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
return ("od" + "ml.npu_call") not in name.decode("utf-8")
|
|
374
|
+
name = opts.name.decode("utf-8")
|
|
375
|
+
if name not in QUANTIZABLE_COMPOSITES:
|
|
376
|
+
return True
|
|
342
377
|
|
|
343
378
|
return False
|
|
344
379
|
|
|
@@ -17,15 +17,21 @@
|
|
|
17
17
|
|
|
18
18
|
from collections.abc import Sequence
|
|
19
19
|
import copy
|
|
20
|
+
import logging
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
|
|
24
|
+
from ai_edge_litert.tools import flatbuffer_utils
|
|
23
25
|
from ai_edge_quantizer import qtyping
|
|
24
26
|
from ai_edge_quantizer import transformation_instruction_generator
|
|
25
27
|
from ai_edge_quantizer import transformation_performer
|
|
26
28
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
29
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
|
30
|
+
from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
|
|
27
31
|
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
|
28
|
-
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_DEQUANT_SUFFIX = "_dequant"
|
|
29
35
|
|
|
30
36
|
|
|
31
37
|
class ModelModifier:
|
|
@@ -104,11 +110,95 @@ class ModelModifier:
|
|
|
104
110
|
instructions, quantized_model, tensor_processing_order
|
|
105
111
|
)
|
|
106
112
|
constant_buffer_size = self._process_constant_map(quantized_model)
|
|
107
|
-
# we leave
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
113
|
+
# we leave 256MB for the model architecture.
|
|
114
|
+
serialize_fun = (
|
|
115
|
+
self._serialize_large_model
|
|
116
|
+
if constant_buffer_size > 2**31 - 2**28
|
|
117
|
+
else self._serialize_small_model
|
|
118
|
+
)
|
|
119
|
+
serialized_quantized_model = serialize_fun(quantized_model)
|
|
120
|
+
|
|
121
|
+
# Update signature defs if dequant is inserted before output.
|
|
122
|
+
if self._has_dequant_before_output(instructions):
|
|
123
|
+
quantized_model = self._update_signature_defs_for_dequant_output(
|
|
124
|
+
quantized_model, serialized_quantized_model
|
|
125
|
+
)
|
|
126
|
+
serialized_quantized_model = serialize_fun(quantized_model)
|
|
127
|
+
|
|
128
|
+
return serialized_quantized_model
|
|
129
|
+
|
|
130
|
+
def _update_signature_defs_for_dequant_output(
|
|
131
|
+
self, model: schema_py_generated.ModelT, serialized_model: bytearray
|
|
132
|
+
):
|
|
133
|
+
"""Updates the signature definitions in the model.
|
|
134
|
+
|
|
135
|
+
This function is called when a dequantize operation is inserted before
|
|
136
|
+
an output tensor. It updates the tensor index in the signature
|
|
137
|
+
definitions to point to the newly inserted dequantize output tensor.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
model: The TFlite ModelT object.
|
|
141
|
+
serialized_model: The serialized bytearray of the TFlite model.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
The updated TFlite ModelT object.
|
|
145
|
+
"""
|
|
146
|
+
interpreter = tfl.Interpreter(model_content=bytes(serialized_model))
|
|
147
|
+
|
|
148
|
+
for signature_def in model.signatureDefs:
|
|
149
|
+
signature_key = signature_def.signatureKey.decode("utf-8")
|
|
150
|
+
logging.info("Signature = %s", signature_key)
|
|
151
|
+
subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
|
|
152
|
+
interpreter, signature_key
|
|
153
|
+
)
|
|
154
|
+
output_details = interpreter.get_signature_runner(
|
|
155
|
+
signature_key
|
|
156
|
+
).get_output_details()
|
|
157
|
+
subgraph = model.subgraphs[subgraph_idx]
|
|
158
|
+
graph_info = qtyping.GraphInfo(subgraph.tensors, model.buffers)
|
|
159
|
+
|
|
160
|
+
for output in subgraph.outputs:
|
|
161
|
+
tensor_name = tfl_flatbuffer_utils.get_tensor_name(
|
|
162
|
+
graph_info.subgraph_tensors[output]
|
|
163
|
+
)
|
|
164
|
+
logging.info("\tOutput tensor = `%s`", tensor_name)
|
|
165
|
+
|
|
166
|
+
for signature_name, tensor_details in output_details.items():
|
|
167
|
+
if tensor_details["name"] + _DEQUANT_SUFFIX == tensor_name:
|
|
168
|
+
logging.info(
|
|
169
|
+
"\t\tfound tensor mapping: `%s`->`%s` for signature name: `%s`",
|
|
170
|
+
tensor_details["name"],
|
|
171
|
+
tensor_name,
|
|
172
|
+
signature_name,
|
|
173
|
+
)
|
|
174
|
+
for signature_item in signature_def.outputs:
|
|
175
|
+
if signature_item.name.decode("utf-8") == signature_name:
|
|
176
|
+
signature_item.tensorIndex = output
|
|
177
|
+
logging.info(
|
|
178
|
+
"\t\t\tswapped tensor index: %s->%s",
|
|
179
|
+
tensor_details["index"],
|
|
180
|
+
output,
|
|
181
|
+
)
|
|
182
|
+
break
|
|
183
|
+
break
|
|
184
|
+
|
|
185
|
+
return model
|
|
186
|
+
|
|
187
|
+
def _has_dequant_before_output(
|
|
188
|
+
self, instructions: dict[str, qtyping.TensorTransformationInsts]
|
|
189
|
+
) -> bool:
|
|
190
|
+
"""Check if the model has dequant insert to output."""
|
|
191
|
+
for tensor_name, tensor_trans_insts in instructions.items():
|
|
192
|
+
for instr in tensor_trans_insts.instructions:
|
|
193
|
+
if (
|
|
194
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE == instr.transformation
|
|
195
|
+
and instr.consumers == [-1]
|
|
196
|
+
):
|
|
197
|
+
logging.info(
|
|
198
|
+
"Found dequant insert to output for tensor: %s", tensor_name
|
|
199
|
+
)
|
|
200
|
+
return True
|
|
201
|
+
return False
|
|
112
202
|
|
|
113
203
|
def _process_constant_map(
|
|
114
204
|
self, quantized_model: schema_py_generated.ModelT
|
|
@@ -142,7 +232,7 @@ class ModelModifier:
|
|
|
142
232
|
remainder = len(bytearr) % 16
|
|
143
233
|
if remainder != 0:
|
|
144
234
|
padding_size = 16 - remainder
|
|
145
|
-
bytearr.extend(b
|
|
235
|
+
bytearr.extend(b"\0" * padding_size)
|
|
146
236
|
|
|
147
237
|
# TODO: b/333797307 - support > 2GB output model
|
|
148
238
|
def _serialize_large_model(
|
|
@@ -19,13 +19,13 @@ import os
|
|
|
19
19
|
import tracemalloc
|
|
20
20
|
from tensorflow.python.platform import googletest
|
|
21
21
|
from absl.testing import parameterized
|
|
22
|
+
from ai_edge_litert.tools import flatbuffer_utils
|
|
22
23
|
from ai_edge_quantizer import model_modifier
|
|
23
24
|
from ai_edge_quantizer import params_generator
|
|
24
25
|
from ai_edge_quantizer import qtyping
|
|
25
26
|
from ai_edge_quantizer import recipe_manager
|
|
26
27
|
from ai_edge_quantizer.utils import test_utils
|
|
27
28
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
28
|
-
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
|
|
29
29
|
|
|
30
30
|
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
|
|
31
31
|
|
|
@@ -125,6 +125,86 @@ class ModelModifierTest(parameterized.TestCase):
|
|
|
125
125
|
loosen_mem_use_factor = 4.5
|
|
126
126
|
self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor)
|
|
127
127
|
|
|
128
|
+
def test_has_dequant_before_output_true(self):
|
|
129
|
+
instructions = {
|
|
130
|
+
'tensor1': qtyping.TensorTransformationInsts(
|
|
131
|
+
'tensor1',
|
|
132
|
+
0,
|
|
133
|
+
instructions=[
|
|
134
|
+
qtyping.TransformationInst(
|
|
135
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
136
|
+
tensor_id=0,
|
|
137
|
+
producer=0,
|
|
138
|
+
consumers=[-1],
|
|
139
|
+
)
|
|
140
|
+
],
|
|
141
|
+
)
|
|
142
|
+
}
|
|
143
|
+
self.assertTrue(
|
|
144
|
+
self._model_modifier._has_dequant_before_output(instructions)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def test_has_dequant_before_output_false(self):
|
|
148
|
+
instructions = {
|
|
149
|
+
'tensor1': qtyping.TensorTransformationInsts(
|
|
150
|
+
'tensor1',
|
|
151
|
+
0,
|
|
152
|
+
instructions=[
|
|
153
|
+
qtyping.TransformationInst(
|
|
154
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
155
|
+
tensor_id=0,
|
|
156
|
+
producer=0,
|
|
157
|
+
consumers=[1],
|
|
158
|
+
)
|
|
159
|
+
],
|
|
160
|
+
)
|
|
161
|
+
}
|
|
162
|
+
self.assertFalse(
|
|
163
|
+
self._model_modifier._has_dequant_before_output(instructions)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def test_pad_bytearray(self):
|
|
167
|
+
arr = bytearray(b'\x01\x02\x03')
|
|
168
|
+
self._model_modifier._pad_bytearray(arr)
|
|
169
|
+
self.assertLen(arr, 16)
|
|
170
|
+
self.assertEqual(arr, b'\x01\x02\x03' + b'\0' * 13)
|
|
171
|
+
|
|
172
|
+
arr = bytearray(b'\x01' * 16)
|
|
173
|
+
self._model_modifier._pad_bytearray(arr)
|
|
174
|
+
self.assertLen(arr, 16)
|
|
175
|
+
|
|
176
|
+
arr = bytearray(b'\x01' * 17)
|
|
177
|
+
self._model_modifier._pad_bytearray(arr)
|
|
178
|
+
self.assertLen(arr, 32)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class ModelModifierTestWithSignature(parameterized.TestCase):
|
|
182
|
+
|
|
183
|
+
def setUp(self):
|
|
184
|
+
super().setUp()
|
|
185
|
+
self._model_path = os.path.join(
|
|
186
|
+
TEST_DATA_PREFIX_PATH,
|
|
187
|
+
'tests/models/single_fc.tflite',
|
|
188
|
+
)
|
|
189
|
+
self._model_content: bytes = tfl_flatbuffer_utils.get_model_content(
|
|
190
|
+
self._model_path
|
|
191
|
+
)
|
|
192
|
+
self._model_modifier = model_modifier.ModelModifier(self._model_content)
|
|
193
|
+
|
|
194
|
+
def test_update_signature_defs_for_dequant_output_succeeds(self):
|
|
195
|
+
# This is a simplified test that only checks if the function runs without
|
|
196
|
+
# crashing and returns a model. A more thorough test with a model
|
|
197
|
+
# with a known signature was added in `quantizer_test`.
|
|
198
|
+
model_bytearray = flatbuffer_utils.read_model_from_bytearray(
|
|
199
|
+
self._model_content
|
|
200
|
+
)
|
|
201
|
+
updated_model = (
|
|
202
|
+
self._model_modifier._update_signature_defs_for_dequant_output(
|
|
203
|
+
model_bytearray, bytearray(self._model_content)
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
self.assertIsNotNone(updated_model)
|
|
207
|
+
|
|
128
208
|
|
|
129
209
|
if __name__ == '__main__':
|
|
130
210
|
googletest.main()
|
|
@@ -25,7 +25,7 @@ from typing import Any, Optional, Union
|
|
|
25
25
|
import numpy as np
|
|
26
26
|
|
|
27
27
|
from ai_edge_quantizer.utils import tfl_interpreter_utils as utils
|
|
28
|
-
|
|
28
|
+
import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
_DEFAULT_SIGNATURE_KEY = utils.DEFAULT_SIGNATURE_KEY
|
|
@@ -118,7 +118,8 @@ class ComparisonResult:
|
|
|
118
118
|
for name in utils.get_input_tensor_names(
|
|
119
119
|
self._reference_model, signature_key
|
|
120
120
|
):
|
|
121
|
-
|
|
121
|
+
if name in result:
|
|
122
|
+
input_tensor_results[name] = result.pop(name)
|
|
122
123
|
|
|
123
124
|
output_tensor_results = {}
|
|
124
125
|
for name in utils.get_output_tensor_names(
|
|
@@ -136,7 +137,8 @@ class ComparisonResult:
|
|
|
136
137
|
self._reference_model,
|
|
137
138
|
subgraph_index,
|
|
138
139
|
):
|
|
139
|
-
|
|
140
|
+
if name in result:
|
|
141
|
+
constant_tensor_results[name] = result.pop(name)
|
|
140
142
|
|
|
141
143
|
self._comparison_results[signature_key] = SingleSignatureComparisonResult(
|
|
142
144
|
error_metric=error_metric,
|
|
@@ -192,7 +194,7 @@ class ComparisonResult:
|
|
|
192
194
|
result_save_path = os.path.join(
|
|
193
195
|
save_folder, model_name + '_comparison_result.json'
|
|
194
196
|
)
|
|
195
|
-
with
|
|
197
|
+
with open(result_save_path, 'w') as output_file_handle:
|
|
196
198
|
output_file_handle.write(json.dumps(result))
|
|
197
199
|
|
|
198
200
|
# TODO: b/365578554 - Remove after ME is updated to use the new json format.
|
|
@@ -204,7 +206,7 @@ class ComparisonResult:
|
|
|
204
206
|
json_save_path = os.path.join(
|
|
205
207
|
save_folder, model_name + '_comparison_result_me_input.json'
|
|
206
208
|
)
|
|
207
|
-
with
|
|
209
|
+
with open(json_save_path, 'w') as output_file_handle:
|
|
208
210
|
output_file_handle.write(json_object)
|
|
209
211
|
|
|
210
212
|
|
|
@@ -214,6 +216,7 @@ def _setup_validation_interpreter(
|
|
|
214
216
|
signature_key: Optional[str],
|
|
215
217
|
use_xnnpack: bool,
|
|
216
218
|
num_threads: int,
|
|
219
|
+
preserve_all_tensors: bool = True,
|
|
217
220
|
) -> tuple[Any, int, dict[str, Any]]:
|
|
218
221
|
"""Setup the interpreter for validation given a signature key.
|
|
219
222
|
|
|
@@ -224,13 +227,17 @@ def _setup_validation_interpreter(
|
|
|
224
227
|
model only has one signature, this can be set to None.
|
|
225
228
|
use_xnnpack: Whether to use xnnpack for the interpreter.
|
|
226
229
|
num_threads: The number of threads to use for the interpreter.
|
|
230
|
+
preserve_all_tensors: Whether to preserve all tensors.
|
|
227
231
|
|
|
228
232
|
Returns:
|
|
229
233
|
A tuple of interpreter, subgraph_index and tensor_name_to_details.
|
|
230
234
|
"""
|
|
231
235
|
|
|
232
236
|
interpreter = utils.create_tfl_interpreter(
|
|
233
|
-
tflite_model=model,
|
|
237
|
+
tflite_model=model,
|
|
238
|
+
use_xnnpack=use_xnnpack,
|
|
239
|
+
num_threads=num_threads,
|
|
240
|
+
preserve_all_tensors=preserve_all_tensors,
|
|
234
241
|
)
|
|
235
242
|
utils.invoke_interpreter_signature(
|
|
236
243
|
interpreter, signature_input, signature_key
|
|
@@ -255,6 +262,7 @@ def compare_model(
|
|
|
255
262
|
compare_fn: Callable[[Any, Any], float],
|
|
256
263
|
use_xnnpack: bool = True,
|
|
257
264
|
num_threads: int = 16,
|
|
265
|
+
validate_output_tensors_only: bool = False,
|
|
258
266
|
) -> ComparisonResult:
|
|
259
267
|
"""Compares model tensors over a model signature using the compare_fn.
|
|
260
268
|
|
|
@@ -275,10 +283,13 @@ def compare_model(
|
|
|
275
283
|
single float value.
|
|
276
284
|
use_xnnpack: Whether to use xnnpack for the interpreter.
|
|
277
285
|
num_threads: The number of threads to use for the interpreter.
|
|
286
|
+
validate_output_tensors_only: If True, only compare output tensors.
|
|
287
|
+
Otherwise, compare all tensors.
|
|
278
288
|
|
|
279
289
|
Returns:
|
|
280
290
|
A ComparisonResult object.
|
|
281
291
|
"""
|
|
292
|
+
preserve_all_tensors = not validate_output_tensors_only
|
|
282
293
|
model_comparion_result = ComparisonResult(reference_model, target_model)
|
|
283
294
|
for signature_key, signature_inputs in test_data.items():
|
|
284
295
|
comparison_results = {}
|
|
@@ -291,6 +302,7 @@ def compare_model(
|
|
|
291
302
|
signature_key,
|
|
292
303
|
use_xnnpack,
|
|
293
304
|
num_threads,
|
|
305
|
+
preserve_all_tensors=preserve_all_tensors,
|
|
294
306
|
)
|
|
295
307
|
)
|
|
296
308
|
targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
|
|
@@ -300,12 +312,23 @@ def compare_model(
|
|
|
300
312
|
signature_key,
|
|
301
313
|
use_xnnpack,
|
|
302
314
|
num_threads,
|
|
315
|
+
preserve_all_tensors=preserve_all_tensors,
|
|
303
316
|
)
|
|
304
317
|
)
|
|
305
|
-
# Compare the cached tensor
|
|
306
|
-
|
|
318
|
+
# Compare the cached tensor value
|
|
319
|
+
tensor_names_to_compare = (
|
|
320
|
+
utils.get_output_tensor_names(reference_model, signature_key)
|
|
321
|
+
if validate_output_tensors_only
|
|
322
|
+
else list(ref_tensor_name_to_details.keys())
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
for tensor_name in tensor_names_to_compare:
|
|
326
|
+
detail = ref_tensor_name_to_details[tensor_name]
|
|
307
327
|
if detail['dtype'] == np.object_:
|
|
308
328
|
continue
|
|
329
|
+
# Ignore tensors where any dimension of the shape is 0.
|
|
330
|
+
if not np.all(detail['shape']):
|
|
331
|
+
continue
|
|
309
332
|
if tensor_name in targ_tensor_name_to_details:
|
|
310
333
|
if tensor_name not in comparison_results:
|
|
311
334
|
comparison_results[tensor_name] = []
|
|
@@ -35,12 +35,12 @@ class ParamsGenerator:
|
|
|
35
35
|
def __init__(self, float_tflite: Union[str, bytes]):
|
|
36
36
|
self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
|
37
37
|
|
|
38
|
-
if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
38
|
+
# if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
|
|
39
|
+
# raise ValueError(
|
|
40
|
+
# 'The input model for quantization parameters generation is not a'
|
|
41
|
+
# ' float model. Please check the model (e.g., if it is already'
|
|
42
|
+
# ' quantized).'
|
|
43
|
+
# )
|
|
44
44
|
self._check_tensor_names_are_unique()
|
|
45
45
|
self.buffer_to_tensors: dict[int, list[Any]] = (
|
|
46
46
|
tfl_flatbuffer_utils.buffer_to_tensors(self.flatbuffer_model)
|
|
@@ -78,8 +78,6 @@ class ParamsGenerator:
|
|
|
78
78
|
skip_subgraphs = set()
|
|
79
79
|
op_codes = self.flatbuffer_model.operatorCodes
|
|
80
80
|
for sg_ind, subgraph in enumerate(self.flatbuffer_model.subgraphs):
|
|
81
|
-
if sg_ind in skip_subgraphs:
|
|
82
|
-
continue
|
|
83
81
|
|
|
84
82
|
graph_info = qtyping.GraphInfo(
|
|
85
83
|
subgraph.tensors, self.flatbuffer_model.buffers
|
|
@@ -109,7 +107,10 @@ class ParamsGenerator:
|
|
|
109
107
|
algorithm_name, op_quant_config = (
|
|
110
108
|
model_recipe_manager.get_quantization_configs(op_key, op_scope)
|
|
111
109
|
)
|
|
112
|
-
|
|
110
|
+
|
|
111
|
+
if sg_ind in skip_subgraphs or policy.is_non_quantizable_composite_op(
|
|
112
|
+
op
|
|
113
|
+
):
|
|
113
114
|
algorithm_name = algorithm_manager.AlgorithmName.NO_QUANTIZE
|
|
114
115
|
|
|
115
116
|
if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
|
|
@@ -408,7 +409,11 @@ class ParamsGenerator:
|
|
|
408
409
|
buffers_to_duplicate = []
|
|
409
410
|
tensor_names_to_duplicate = []
|
|
410
411
|
for buffer_idx, tensors in self.buffer_to_tensors.items():
|
|
411
|
-
if
|
|
412
|
+
# TODO: b/458797890 - Investigate if skipping buffer_idx == 0 is a
|
|
413
|
+
# correct fix, or if it just covers up a deeper issue. This is only
|
|
414
|
+
# required when statically quantizing models that have already been
|
|
415
|
+
# quantized dynamically.
|
|
416
|
+
if not tensors or buffer_idx == 0:
|
|
412
417
|
continue
|
|
413
418
|
# Check if any of the tensors needs to be duplicated.
|
|
414
419
|
for tensor in tensors:
|
|
@@ -508,6 +513,8 @@ def _compatible_tensor_params(
|
|
|
508
513
|
float_source_transformations = [
|
|
509
514
|
_QuantTrans.ADD_QUANTIZE,
|
|
510
515
|
_QuantTrans.NO_QUANTIZE,
|
|
516
|
+
_QuantTrans.INSERT_HADAMARD_ROTATION,
|
|
517
|
+
_QuantTrans.INSERT_DECOMPOSED_HADAMARD_ROTATION,
|
|
511
518
|
]
|
|
512
519
|
quantized_source_transformations = [
|
|
513
520
|
_QuantTrans.QUANTIZE_TENSOR,
|