ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -18,8 +18,10 @@
|
|
|
18
18
|
import collections
|
|
19
19
|
import copy
|
|
20
20
|
import json
|
|
21
|
-
from typing import Any
|
|
21
|
+
from typing import Any, Union
|
|
22
|
+
from ai_edge_litert.tools import flatbuffer_utils
|
|
22
23
|
from ai_edge_quantizer import qtyping
|
|
24
|
+
from ai_edge_litert import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
|
|
23
25
|
|
|
24
26
|
_TFLOpName = qtyping.TFLOperationName
|
|
25
27
|
_OpQuantizationConfig = qtyping.OpQuantizationConfig
|
|
@@ -55,6 +57,16 @@ DEFAULT_JSON_POLICY = """
|
|
|
55
57
|
"explicit_dequantize": false,
|
|
56
58
|
"compute_precision": "INTEGER"
|
|
57
59
|
},
|
|
60
|
+
"dynamic_wi4_afp32_blockwise": {
|
|
61
|
+
"weight_tensor_config": {
|
|
62
|
+
"num_bits": 4,
|
|
63
|
+
"symmetric": [true],
|
|
64
|
+
"granularity": ["BLOCKWISE_32", "BLOCKWISE_64", "BLOCKWISE_128", "BLOCKWISE_256"],
|
|
65
|
+
"dtype": "INT"
|
|
66
|
+
},
|
|
67
|
+
"explicit_dequantize": false,
|
|
68
|
+
"compute_precision": "INTEGER"
|
|
69
|
+
},
|
|
58
70
|
"static_wi8_ai16": {
|
|
59
71
|
"activation_tensor_config": {
|
|
60
72
|
"num_bits": 16,
|
|
@@ -165,9 +177,30 @@ DEFAULT_JSON_POLICY = """
|
|
|
165
177
|
"INPUT",
|
|
166
178
|
"OUTPUT",
|
|
167
179
|
"SLICE",
|
|
168
|
-
"EMBEDDING_LOOKUP",
|
|
169
180
|
"SUM",
|
|
170
|
-
"
|
|
181
|
+
"SELECT",
|
|
182
|
+
"SELECT_V2",
|
|
183
|
+
"DYNAMIC_UPDATE_SLICE",
|
|
184
|
+
"SELECT_V2",
|
|
185
|
+
"STABLEHLO_COMPOSITE",
|
|
186
|
+
"PAD",
|
|
187
|
+
"MAX_POOL_2D",
|
|
188
|
+
"RESIZE_BILINEAR",
|
|
189
|
+
"RESIZE_NEAREST_NEIGHBOR",
|
|
190
|
+
"GATHER_ND",
|
|
191
|
+
"PACK",
|
|
192
|
+
"UNPACK",
|
|
193
|
+
"DIV",
|
|
194
|
+
"BROADCAST_TO",
|
|
195
|
+
"SQRT",
|
|
196
|
+
"GATHER",
|
|
197
|
+
"MAXIMUM",
|
|
198
|
+
"PADV2",
|
|
199
|
+
"REDUCE_MIN",
|
|
200
|
+
"EQUAL",
|
|
201
|
+
"NOT_EQUAL",
|
|
202
|
+
"MIRROR_PAD",
|
|
203
|
+
"RELU"
|
|
171
204
|
],
|
|
172
205
|
"static_wi8_ai8": [
|
|
173
206
|
"ADD",
|
|
@@ -193,12 +226,36 @@ DEFAULT_JSON_POLICY = """
|
|
|
193
226
|
"INPUT",
|
|
194
227
|
"OUTPUT",
|
|
195
228
|
"SLICE",
|
|
196
|
-
"EMBEDDING_LOOKUP",
|
|
197
229
|
"SUM",
|
|
198
|
-
"
|
|
230
|
+
"SELECT",
|
|
231
|
+
"SELECT_V2",
|
|
232
|
+
"DYNAMIC_UPDATE_SLICE",
|
|
233
|
+
"SELECT_V2",
|
|
234
|
+
"STABLEHLO_COMPOSITE",
|
|
235
|
+
"PAD",
|
|
236
|
+
"SQUARED_DIFFERENCE",
|
|
237
|
+
"MAX_POOL_2D",
|
|
238
|
+
"RESIZE_BILINEAR",
|
|
239
|
+
"RESIZE_NEAREST_NEIGHBOR",
|
|
240
|
+
"GATHER_ND",
|
|
241
|
+
"PACK",
|
|
242
|
+
"UNPACK",
|
|
243
|
+
"DIV",
|
|
244
|
+
"BROADCAST_TO",
|
|
245
|
+
"SQRT",
|
|
246
|
+
"GATHER",
|
|
247
|
+
"HARD_SWISH",
|
|
248
|
+
"MAXIMUM",
|
|
249
|
+
"PADV2",
|
|
250
|
+
"REDUCE_MIN",
|
|
251
|
+
"EQUAL",
|
|
252
|
+
"NOT_EQUAL",
|
|
253
|
+
"MIRROR_PAD",
|
|
254
|
+
"SPACE_TO_DEPTH",
|
|
255
|
+
"RELU"
|
|
199
256
|
],
|
|
200
|
-
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"
|
|
201
|
-
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"
|
|
257
|
+
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
|
|
258
|
+
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
|
|
202
259
|
"dynamic_wi8_afp32": [
|
|
203
260
|
"BATCH_MATMUL",
|
|
204
261
|
"CONV_2D",
|
|
@@ -208,6 +265,7 @@ DEFAULT_JSON_POLICY = """
|
|
|
208
265
|
"FULLY_CONNECTED"
|
|
209
266
|
],
|
|
210
267
|
"dynamic_wi4_afp32": ["FULLY_CONNECTED", "EMBEDDING_LOOKUP", "CONV_2D"],
|
|
268
|
+
"dynamic_wi4_afp32_blockwise": ["EMBEDDING_LOOKUP", "FULLY_CONNECTED"],
|
|
211
269
|
"weightonly_wi8_afp32": [
|
|
212
270
|
"BATCH_MATMUL",
|
|
213
271
|
"CONV_2D",
|
|
@@ -220,6 +278,11 @@ DEFAULT_JSON_POLICY = """
|
|
|
220
278
|
}
|
|
221
279
|
}
|
|
222
280
|
"""
|
|
281
|
+
QUANTIZABLE_COMPOSITES = [
|
|
282
|
+
"od" + "ml.npu_call",
|
|
283
|
+
"od" + "ml.rms_norm",
|
|
284
|
+
"od" + "ml.l2_norm",
|
|
285
|
+
]
|
|
223
286
|
|
|
224
287
|
|
|
225
288
|
def _unroll_json_config(
|
|
@@ -251,6 +314,7 @@ def _unroll_json_config(
|
|
|
251
314
|
|
|
252
315
|
# Then unroll weight configs and turn them into quantization configs.
|
|
253
316
|
quant_configs = []
|
|
317
|
+
weight_configs = []
|
|
254
318
|
for symmetric in json_config["weight_tensor_config"]["symmetric"]:
|
|
255
319
|
for granularity in json_config["weight_tensor_config"]["granularity"]:
|
|
256
320
|
tensor_config = {
|
|
@@ -259,6 +323,9 @@ def _unroll_json_config(
|
|
|
259
323
|
"granularity": granularity,
|
|
260
324
|
"dtype": json_config["weight_tensor_config"]["dtype"],
|
|
261
325
|
}
|
|
326
|
+
weight_configs.append(
|
|
327
|
+
qtyping.TensorQuantizationConfig.from_dict(tensor_config)
|
|
328
|
+
)
|
|
262
329
|
|
|
263
330
|
if activation_configs:
|
|
264
331
|
for activation_config in activation_configs:
|
|
@@ -273,19 +340,44 @@ def _unroll_json_config(
|
|
|
273
340
|
)
|
|
274
341
|
)
|
|
275
342
|
else:
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
)
|
|
343
|
+
for weight_config in weight_configs:
|
|
344
|
+
quant_configs.append(
|
|
345
|
+
qtyping.OpQuantizationConfig(
|
|
346
|
+
weight_tensor_config=weight_config,
|
|
347
|
+
compute_precision=json_config["compute_precision"],
|
|
348
|
+
explicit_dequantize=json_config["explicit_dequantize"],
|
|
349
|
+
)
|
|
350
|
+
)
|
|
285
351
|
|
|
286
352
|
return quant_configs
|
|
287
353
|
|
|
288
354
|
|
|
355
|
+
# TODO: b/401024954 - Have a better way to specify recipes based on op options.
|
|
356
|
+
def is_non_quantizable_composite_op(
|
|
357
|
+
op: Union[schema.Operator, schema.OperatorT],
|
|
358
|
+
) -> bool:
|
|
359
|
+
"""Checks if the operator is a non-quantizable composite op.
|
|
360
|
+
|
|
361
|
+
We may want to quantize an op only when its has certain options.
|
|
362
|
+
Policies/recipes
|
|
363
|
+
are not aware of op options, so it is checked here.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
op: The operator to check.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
True if the operator is conditionally unquantized, False otherwise.
|
|
370
|
+
"""
|
|
371
|
+
if opts := flatbuffer_utils.get_options_as(
|
|
372
|
+
op, schema.StableHLOCompositeOptionsT
|
|
373
|
+
):
|
|
374
|
+
name = opts.name.decode("utf-8")
|
|
375
|
+
if name not in QUANTIZABLE_COMPOSITES:
|
|
376
|
+
return True
|
|
377
|
+
|
|
378
|
+
return False
|
|
379
|
+
|
|
380
|
+
|
|
289
381
|
def update_default_config_policy(raw_json_policy: str):
|
|
290
382
|
"""Updates the default config check policy."""
|
|
291
383
|
json_policy_content = json.loads(raw_json_policy)
|
|
@@ -15,15 +15,23 @@
|
|
|
15
15
|
|
|
16
16
|
"""Model Modifier class that produce the final quantized TFlite model."""
|
|
17
17
|
|
|
18
|
+
from collections.abc import Sequence
|
|
18
19
|
import copy
|
|
20
|
+
import logging
|
|
19
21
|
|
|
20
22
|
import numpy as np
|
|
21
23
|
|
|
24
|
+
from ai_edge_litert.tools import flatbuffer_utils
|
|
22
25
|
from ai_edge_quantizer import qtyping
|
|
23
26
|
from ai_edge_quantizer import transformation_instruction_generator
|
|
24
27
|
from ai_edge_quantizer import transformation_performer
|
|
28
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
29
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
|
30
|
+
from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
|
|
25
31
|
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
|
26
|
-
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_DEQUANT_SUFFIX = "_dequant"
|
|
27
35
|
|
|
28
36
|
|
|
29
37
|
class ModelModifier:
|
|
@@ -46,6 +54,35 @@ class ModelModifier:
|
|
|
46
54
|
transformation_performer.TransformationPerformer()
|
|
47
55
|
)
|
|
48
56
|
|
|
57
|
+
def _get_tensor_processing_order(
|
|
58
|
+
self,
|
|
59
|
+
tensor_names: Sequence[str],
|
|
60
|
+
flatbuffer_model: schema_py_generated.ModelT,
|
|
61
|
+
) -> list[str]:
|
|
62
|
+
"""Get the tensor processing order obtained from `buffer_to_tensors`.
|
|
63
|
+
|
|
64
|
+
The processing order is used to ensure that last tensor in a buffer is
|
|
65
|
+
processed the last. This is required for the correctness of buffer
|
|
66
|
+
duplication, as the last tensor in a buffer won't be duplicated.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
tensor_names: Names of the tensors that need to be processed.
|
|
70
|
+
flatbuffer_model: TFlite model.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A list of tensor names in the processing order.
|
|
74
|
+
"""
|
|
75
|
+
buffer_to_tensors = tfl_flatbuffer_utils.buffer_to_tensors(flatbuffer_model)
|
|
76
|
+
|
|
77
|
+
processing_order = []
|
|
78
|
+
for buffer_tensors in buffer_to_tensors.values():
|
|
79
|
+
for tensor in buffer_tensors:
|
|
80
|
+
tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
|
|
81
|
+
if tensor_name in tensor_names:
|
|
82
|
+
processing_order.append(tensor_name)
|
|
83
|
+
|
|
84
|
+
return processing_order
|
|
85
|
+
|
|
49
86
|
def modify_model(
|
|
50
87
|
self, params: dict[str, qtyping.TensorTransformationParams]
|
|
51
88
|
) -> bytearray:
|
|
@@ -66,15 +103,102 @@ class ModelModifier:
|
|
|
66
103
|
params, quantized_model
|
|
67
104
|
)
|
|
68
105
|
|
|
106
|
+
tensor_processing_order = self._get_tensor_processing_order(
|
|
107
|
+
list(instructions.keys()), quantized_model
|
|
108
|
+
)
|
|
69
109
|
self._transformation_performer.transform_graph(
|
|
70
|
-
instructions, quantized_model
|
|
110
|
+
instructions, quantized_model, tensor_processing_order
|
|
71
111
|
)
|
|
72
112
|
constant_buffer_size = self._process_constant_map(quantized_model)
|
|
73
|
-
# we leave
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
113
|
+
# we leave 256MB for the model architecture.
|
|
114
|
+
serialize_fun = (
|
|
115
|
+
self._serialize_large_model
|
|
116
|
+
if constant_buffer_size > 2**31 - 2**28
|
|
117
|
+
else self._serialize_small_model
|
|
118
|
+
)
|
|
119
|
+
serialized_quantized_model = serialize_fun(quantized_model)
|
|
120
|
+
|
|
121
|
+
# Update signature defs if dequant is inserted before output.
|
|
122
|
+
if self._has_dequant_before_output(instructions):
|
|
123
|
+
quantized_model = self._update_signature_defs_for_dequant_output(
|
|
124
|
+
quantized_model, serialized_quantized_model
|
|
125
|
+
)
|
|
126
|
+
serialized_quantized_model = serialize_fun(quantized_model)
|
|
127
|
+
|
|
128
|
+
return serialized_quantized_model
|
|
129
|
+
|
|
130
|
+
def _update_signature_defs_for_dequant_output(
|
|
131
|
+
self, model: schema_py_generated.ModelT, serialized_model: bytearray
|
|
132
|
+
):
|
|
133
|
+
"""Updates the signature definitions in the model.
|
|
134
|
+
|
|
135
|
+
This function is called when a dequantize operation is inserted before
|
|
136
|
+
an output tensor. It updates the tensor index in the signature
|
|
137
|
+
definitions to point to the newly inserted dequantize output tensor.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
model: The TFlite ModelT object.
|
|
141
|
+
serialized_model: The serialized bytearray of the TFlite model.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
The updated TFlite ModelT object.
|
|
145
|
+
"""
|
|
146
|
+
interpreter = tfl.Interpreter(model_content=bytes(serialized_model))
|
|
147
|
+
|
|
148
|
+
for signature_def in model.signatureDefs:
|
|
149
|
+
signature_key = signature_def.signatureKey.decode("utf-8")
|
|
150
|
+
logging.info("Signature = %s", signature_key)
|
|
151
|
+
subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
|
|
152
|
+
interpreter, signature_key
|
|
153
|
+
)
|
|
154
|
+
output_details = interpreter.get_signature_runner(
|
|
155
|
+
signature_key
|
|
156
|
+
).get_output_details()
|
|
157
|
+
subgraph = model.subgraphs[subgraph_idx]
|
|
158
|
+
graph_info = qtyping.GraphInfo(subgraph.tensors, model.buffers)
|
|
159
|
+
|
|
160
|
+
for output in subgraph.outputs:
|
|
161
|
+
tensor_name = tfl_flatbuffer_utils.get_tensor_name(
|
|
162
|
+
graph_info.subgraph_tensors[output]
|
|
163
|
+
)
|
|
164
|
+
logging.info("\tOutput tensor = `%s`", tensor_name)
|
|
165
|
+
|
|
166
|
+
for signature_name, tensor_details in output_details.items():
|
|
167
|
+
if tensor_details["name"] + _DEQUANT_SUFFIX == tensor_name:
|
|
168
|
+
logging.info(
|
|
169
|
+
"\t\tfound tensor mapping: `%s`->`%s` for signature name: `%s`",
|
|
170
|
+
tensor_details["name"],
|
|
171
|
+
tensor_name,
|
|
172
|
+
signature_name,
|
|
173
|
+
)
|
|
174
|
+
for signature_item in signature_def.outputs:
|
|
175
|
+
if signature_item.name.decode("utf-8") == signature_name:
|
|
176
|
+
signature_item.tensorIndex = output
|
|
177
|
+
logging.info(
|
|
178
|
+
"\t\t\tswapped tensor index: %s->%s",
|
|
179
|
+
tensor_details["index"],
|
|
180
|
+
output,
|
|
181
|
+
)
|
|
182
|
+
break
|
|
183
|
+
break
|
|
184
|
+
|
|
185
|
+
return model
|
|
186
|
+
|
|
187
|
+
def _has_dequant_before_output(
|
|
188
|
+
self, instructions: dict[str, qtyping.TensorTransformationInsts]
|
|
189
|
+
) -> bool:
|
|
190
|
+
"""Check if the model has dequant insert to output."""
|
|
191
|
+
for tensor_name, tensor_trans_insts in instructions.items():
|
|
192
|
+
for instr in tensor_trans_insts.instructions:
|
|
193
|
+
if (
|
|
194
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE == instr.transformation
|
|
195
|
+
and instr.consumers == [-1]
|
|
196
|
+
):
|
|
197
|
+
logging.info(
|
|
198
|
+
"Found dequant insert to output for tensor: %s", tensor_name
|
|
199
|
+
)
|
|
200
|
+
return True
|
|
201
|
+
return False
|
|
78
202
|
|
|
79
203
|
def _process_constant_map(
|
|
80
204
|
self, quantized_model: schema_py_generated.ModelT
|
|
@@ -108,7 +232,7 @@ class ModelModifier:
|
|
|
108
232
|
remainder = len(bytearr) % 16
|
|
109
233
|
if remainder != 0:
|
|
110
234
|
padding_size = 16 - remainder
|
|
111
|
-
bytearr.extend(b
|
|
235
|
+
bytearr.extend(b"\0" * padding_size)
|
|
112
236
|
|
|
113
237
|
# TODO: b/333797307 - support > 2GB output model
|
|
114
238
|
def _serialize_large_model(
|
|
@@ -19,13 +19,13 @@ import os
|
|
|
19
19
|
import tracemalloc
|
|
20
20
|
from tensorflow.python.platform import googletest
|
|
21
21
|
from absl.testing import parameterized
|
|
22
|
+
from ai_edge_litert.tools import flatbuffer_utils
|
|
22
23
|
from ai_edge_quantizer import model_modifier
|
|
23
24
|
from ai_edge_quantizer import params_generator
|
|
24
25
|
from ai_edge_quantizer import qtyping
|
|
25
26
|
from ai_edge_quantizer import recipe_manager
|
|
26
27
|
from ai_edge_quantizer.utils import test_utils
|
|
27
28
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
28
|
-
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
|
|
29
29
|
|
|
30
30
|
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
|
|
31
31
|
|
|
@@ -125,6 +125,86 @@ class ModelModifierTest(parameterized.TestCase):
|
|
|
125
125
|
loosen_mem_use_factor = 4.5
|
|
126
126
|
self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor)
|
|
127
127
|
|
|
128
|
+
def test_has_dequant_before_output_true(self):
|
|
129
|
+
instructions = {
|
|
130
|
+
'tensor1': qtyping.TensorTransformationInsts(
|
|
131
|
+
'tensor1',
|
|
132
|
+
0,
|
|
133
|
+
instructions=[
|
|
134
|
+
qtyping.TransformationInst(
|
|
135
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
136
|
+
tensor_id=0,
|
|
137
|
+
producer=0,
|
|
138
|
+
consumers=[-1],
|
|
139
|
+
)
|
|
140
|
+
],
|
|
141
|
+
)
|
|
142
|
+
}
|
|
143
|
+
self.assertTrue(
|
|
144
|
+
self._model_modifier._has_dequant_before_output(instructions)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def test_has_dequant_before_output_false(self):
|
|
148
|
+
instructions = {
|
|
149
|
+
'tensor1': qtyping.TensorTransformationInsts(
|
|
150
|
+
'tensor1',
|
|
151
|
+
0,
|
|
152
|
+
instructions=[
|
|
153
|
+
qtyping.TransformationInst(
|
|
154
|
+
transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
155
|
+
tensor_id=0,
|
|
156
|
+
producer=0,
|
|
157
|
+
consumers=[1],
|
|
158
|
+
)
|
|
159
|
+
],
|
|
160
|
+
)
|
|
161
|
+
}
|
|
162
|
+
self.assertFalse(
|
|
163
|
+
self._model_modifier._has_dequant_before_output(instructions)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def test_pad_bytearray(self):
|
|
167
|
+
arr = bytearray(b'\x01\x02\x03')
|
|
168
|
+
self._model_modifier._pad_bytearray(arr)
|
|
169
|
+
self.assertLen(arr, 16)
|
|
170
|
+
self.assertEqual(arr, b'\x01\x02\x03' + b'\0' * 13)
|
|
171
|
+
|
|
172
|
+
arr = bytearray(b'\x01' * 16)
|
|
173
|
+
self._model_modifier._pad_bytearray(arr)
|
|
174
|
+
self.assertLen(arr, 16)
|
|
175
|
+
|
|
176
|
+
arr = bytearray(b'\x01' * 17)
|
|
177
|
+
self._model_modifier._pad_bytearray(arr)
|
|
178
|
+
self.assertLen(arr, 32)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class ModelModifierTestWithSignature(parameterized.TestCase):
|
|
182
|
+
|
|
183
|
+
def setUp(self):
|
|
184
|
+
super().setUp()
|
|
185
|
+
self._model_path = os.path.join(
|
|
186
|
+
TEST_DATA_PREFIX_PATH,
|
|
187
|
+
'tests/models/single_fc.tflite',
|
|
188
|
+
)
|
|
189
|
+
self._model_content: bytes = tfl_flatbuffer_utils.get_model_content(
|
|
190
|
+
self._model_path
|
|
191
|
+
)
|
|
192
|
+
self._model_modifier = model_modifier.ModelModifier(self._model_content)
|
|
193
|
+
|
|
194
|
+
def test_update_signature_defs_for_dequant_output_succeeds(self):
|
|
195
|
+
# This is a simplified test that only checks if the function runs without
|
|
196
|
+
# crashing and returns a model. A more thorough test with a model
|
|
197
|
+
# with a known signature was added in `quantizer_test`.
|
|
198
|
+
model_bytearray = flatbuffer_utils.read_model_from_bytearray(
|
|
199
|
+
self._model_content
|
|
200
|
+
)
|
|
201
|
+
updated_model = (
|
|
202
|
+
self._model_modifier._update_signature_defs_for_dequant_output(
|
|
203
|
+
model_bytearray, bytearray(self._model_content)
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
self.assertIsNotNone(updated_model)
|
|
207
|
+
|
|
128
208
|
|
|
129
209
|
if __name__ == '__main__':
|
|
130
210
|
googletest.main()
|
|
@@ -25,7 +25,7 @@ from typing import Any, Optional, Union
|
|
|
25
25
|
import numpy as np
|
|
26
26
|
|
|
27
27
|
from ai_edge_quantizer.utils import tfl_interpreter_utils as utils
|
|
28
|
-
|
|
28
|
+
import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
_DEFAULT_SIGNATURE_KEY = utils.DEFAULT_SIGNATURE_KEY
|
|
@@ -118,7 +118,8 @@ class ComparisonResult:
|
|
|
118
118
|
for name in utils.get_input_tensor_names(
|
|
119
119
|
self._reference_model, signature_key
|
|
120
120
|
):
|
|
121
|
-
|
|
121
|
+
if name in result:
|
|
122
|
+
input_tensor_results[name] = result.pop(name)
|
|
122
123
|
|
|
123
124
|
output_tensor_results = {}
|
|
124
125
|
for name in utils.get_output_tensor_names(
|
|
@@ -136,7 +137,8 @@ class ComparisonResult:
|
|
|
136
137
|
self._reference_model,
|
|
137
138
|
subgraph_index,
|
|
138
139
|
):
|
|
139
|
-
|
|
140
|
+
if name in result:
|
|
141
|
+
constant_tensor_results[name] = result.pop(name)
|
|
140
142
|
|
|
141
143
|
self._comparison_results[signature_key] = SingleSignatureComparisonResult(
|
|
142
144
|
error_metric=error_metric,
|
|
@@ -160,6 +162,12 @@ class ComparisonResult:
|
|
|
160
162
|
result.update(signature_comparison_result.intermediate_tensors)
|
|
161
163
|
return result
|
|
162
164
|
|
|
165
|
+
def get_model_size_reduction(self) -> tuple[int, float]:
|
|
166
|
+
"""Get the model size reduction in bytes and percentage."""
|
|
167
|
+
reduced_model_size = len(self._reference_model) - len(self._target_model)
|
|
168
|
+
reduction_perc = reduced_model_size / len(self._reference_model) * 100
|
|
169
|
+
return reduced_model_size, reduction_perc
|
|
170
|
+
|
|
163
171
|
def save(self, save_folder: str, model_name: str) -> None:
|
|
164
172
|
"""Saves the model comparison result.
|
|
165
173
|
|
|
@@ -170,8 +178,7 @@ class ComparisonResult:
|
|
|
170
178
|
Raises:
|
|
171
179
|
RuntimeError: If no quantized model is available.
|
|
172
180
|
"""
|
|
173
|
-
reduced_model_size =
|
|
174
|
-
reduction_ratio = reduced_model_size / len(self._reference_model) * 100
|
|
181
|
+
reduced_model_size, reduction_ratio = self.get_model_size_reduction()
|
|
175
182
|
result = {
|
|
176
183
|
'reduced_size_bytes': reduced_model_size,
|
|
177
184
|
'reduced_size_percentage': reduction_ratio,
|
|
@@ -187,7 +194,7 @@ class ComparisonResult:
|
|
|
187
194
|
result_save_path = os.path.join(
|
|
188
195
|
save_folder, model_name + '_comparison_result.json'
|
|
189
196
|
)
|
|
190
|
-
with
|
|
197
|
+
with open(result_save_path, 'w') as output_file_handle:
|
|
191
198
|
output_file_handle.write(json.dumps(result))
|
|
192
199
|
|
|
193
200
|
# TODO: b/365578554 - Remove after ME is updated to use the new json format.
|
|
@@ -199,7 +206,7 @@ class ComparisonResult:
|
|
|
199
206
|
json_save_path = os.path.join(
|
|
200
207
|
save_folder, model_name + '_comparison_result_me_input.json'
|
|
201
208
|
)
|
|
202
|
-
with
|
|
209
|
+
with open(json_save_path, 'w') as output_file_handle:
|
|
203
210
|
output_file_handle.write(json_object)
|
|
204
211
|
|
|
205
212
|
|
|
@@ -209,6 +216,7 @@ def _setup_validation_interpreter(
|
|
|
209
216
|
signature_key: Optional[str],
|
|
210
217
|
use_xnnpack: bool,
|
|
211
218
|
num_threads: int,
|
|
219
|
+
preserve_all_tensors: bool = True,
|
|
212
220
|
) -> tuple[Any, int, dict[str, Any]]:
|
|
213
221
|
"""Setup the interpreter for validation given a signature key.
|
|
214
222
|
|
|
@@ -219,13 +227,17 @@ def _setup_validation_interpreter(
|
|
|
219
227
|
model only has one signature, this can be set to None.
|
|
220
228
|
use_xnnpack: Whether to use xnnpack for the interpreter.
|
|
221
229
|
num_threads: The number of threads to use for the interpreter.
|
|
230
|
+
preserve_all_tensors: Whether to preserve all tensors.
|
|
222
231
|
|
|
223
232
|
Returns:
|
|
224
233
|
A tuple of interpreter, subgraph_index and tensor_name_to_details.
|
|
225
234
|
"""
|
|
226
235
|
|
|
227
236
|
interpreter = utils.create_tfl_interpreter(
|
|
228
|
-
tflite_model=model,
|
|
237
|
+
tflite_model=model,
|
|
238
|
+
use_xnnpack=use_xnnpack,
|
|
239
|
+
num_threads=num_threads,
|
|
240
|
+
preserve_all_tensors=preserve_all_tensors,
|
|
229
241
|
)
|
|
230
242
|
utils.invoke_interpreter_signature(
|
|
231
243
|
interpreter, signature_input, signature_key
|
|
@@ -250,6 +262,7 @@ def compare_model(
|
|
|
250
262
|
compare_fn: Callable[[Any, Any], float],
|
|
251
263
|
use_xnnpack: bool = True,
|
|
252
264
|
num_threads: int = 16,
|
|
265
|
+
validate_output_tensors_only: bool = False,
|
|
253
266
|
) -> ComparisonResult:
|
|
254
267
|
"""Compares model tensors over a model signature using the compare_fn.
|
|
255
268
|
|
|
@@ -270,10 +283,13 @@ def compare_model(
|
|
|
270
283
|
single float value.
|
|
271
284
|
use_xnnpack: Whether to use xnnpack for the interpreter.
|
|
272
285
|
num_threads: The number of threads to use for the interpreter.
|
|
286
|
+
validate_output_tensors_only: If True, only compare output tensors.
|
|
287
|
+
Otherwise, compare all tensors.
|
|
273
288
|
|
|
274
289
|
Returns:
|
|
275
290
|
A ComparisonResult object.
|
|
276
291
|
"""
|
|
292
|
+
preserve_all_tensors = not validate_output_tensors_only
|
|
277
293
|
model_comparion_result = ComparisonResult(reference_model, target_model)
|
|
278
294
|
for signature_key, signature_inputs in test_data.items():
|
|
279
295
|
comparison_results = {}
|
|
@@ -286,6 +302,7 @@ def compare_model(
|
|
|
286
302
|
signature_key,
|
|
287
303
|
use_xnnpack,
|
|
288
304
|
num_threads,
|
|
305
|
+
preserve_all_tensors=preserve_all_tensors,
|
|
289
306
|
)
|
|
290
307
|
)
|
|
291
308
|
targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
|
|
@@ -295,12 +312,23 @@ def compare_model(
|
|
|
295
312
|
signature_key,
|
|
296
313
|
use_xnnpack,
|
|
297
314
|
num_threads,
|
|
315
|
+
preserve_all_tensors=preserve_all_tensors,
|
|
298
316
|
)
|
|
299
317
|
)
|
|
300
|
-
# Compare the cached tensor
|
|
301
|
-
|
|
318
|
+
# Compare the cached tensor value
|
|
319
|
+
tensor_names_to_compare = (
|
|
320
|
+
utils.get_output_tensor_names(reference_model, signature_key)
|
|
321
|
+
if validate_output_tensors_only
|
|
322
|
+
else list(ref_tensor_name_to_details.keys())
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
for tensor_name in tensor_names_to_compare:
|
|
326
|
+
detail = ref_tensor_name_to_details[tensor_name]
|
|
302
327
|
if detail['dtype'] == np.object_:
|
|
303
328
|
continue
|
|
329
|
+
# Ignore tensors where any dimension of the shape is 0.
|
|
330
|
+
if not np.all(detail['shape']):
|
|
331
|
+
continue
|
|
304
332
|
if tensor_name in targ_tensor_name_to_details:
|
|
305
333
|
if tensor_name not in comparison_results:
|
|
306
334
|
comparison_results[tensor_name] = []
|
|
@@ -21,6 +21,7 @@ from tensorflow.python.platform import googletest
|
|
|
21
21
|
from ai_edge_quantizer import model_validator
|
|
22
22
|
from ai_edge_quantizer.utils import test_utils
|
|
23
23
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
24
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
|
24
25
|
from ai_edge_quantizer.utils import validation_utils
|
|
25
26
|
|
|
26
27
|
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
|
|
@@ -194,7 +195,7 @@ class ModelValidatorCompareTest(googletest.TestCase):
|
|
|
194
195
|
self.target_model_path
|
|
195
196
|
)
|
|
196
197
|
self.signature_key = 'serving_default' # single signature.
|
|
197
|
-
self.test_data =
|
|
198
|
+
self.test_data = tfl_interpreter_utils.create_random_normal_input_data(
|
|
198
199
|
self.reference_model_path
|
|
199
200
|
)
|
|
200
201
|
self.test_dir = self.create_tempdir()
|