ai-edge-quantizer-nightly 0.5.0.dev20251209__py3-none-any.whl → 0.5.0.dev20251211__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +39 -11
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +20 -0
- ai_edge_quantizer/calibrator.py +0 -5
- ai_edge_quantizer/calibrator_test.py +2 -6
- ai_edge_quantizer/model_validator.py +25 -5
- ai_edge_quantizer/params_generator.py +11 -7
- ai_edge_quantizer/params_generator_test.py +2 -7
- ai_edge_quantizer/quantizer.py +4 -0
- ai_edge_quantizer/quantizer_test.py +40 -0
- ai_edge_quantizer/transformation_instruction_generator.py +9 -0
- ai_edge_quantizer/utils/constrained_ops_utils.py +8 -8
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +4 -1
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +10 -0
- {ai_edge_quantizer_nightly-0.5.0.dev20251209.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.5.0.dev20251209.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info}/RECORD +18 -18
- {ai_edge_quantizer_nightly-0.5.0.dev20251209.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.5.0.dev20251209.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info}/licenses/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.5.0.dev20251209.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info}/top_level.txt +0 -0
|
@@ -36,6 +36,13 @@ _OpQuantConstraint = common_utils.OpQuantConstraint
|
|
|
36
36
|
_ComputePrecision = qtyping.ComputePrecision
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
def check_if_quantized(tensor: Any) -> bool:
|
|
40
|
+
"""Checks if the tensor is quantized."""
|
|
41
|
+
return (
|
|
42
|
+
tensor.quantization is not None and tensor.quantization.scale is not None
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
39
46
|
def check_op_quantization_config(
|
|
40
47
|
op_name: _TFLOpName,
|
|
41
48
|
op_quant_config: qtyping.OpQuantizationConfig,
|
|
@@ -271,7 +278,7 @@ def materialize_average_pool_2d(
|
|
|
271
278
|
)
|
|
272
279
|
|
|
273
280
|
|
|
274
|
-
def
|
|
281
|
+
def _materialize_bias_for_fc_conv_ops(
|
|
275
282
|
op_info: qtyping.OpInfo,
|
|
276
283
|
graph_info: qtyping.GraphInfo,
|
|
277
284
|
op_tensor_params: list[qtyping.TensorTransformationParams],
|
|
@@ -290,14 +297,16 @@ def _materialize_bias_for_conv_ops(
|
|
|
290
297
|
op_weight_index: Index for the weight tensor in the op.
|
|
291
298
|
op_bias_index: Index for the bias tensor in the op.
|
|
292
299
|
"""
|
|
293
|
-
_,
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
300
|
+
_, weight_tensor, bias_tensor, _ = (
|
|
301
|
+
tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
|
|
302
|
+
op_info.op,
|
|
303
|
+
graph_info.subgraph_tensors,
|
|
304
|
+
op_input_index,
|
|
305
|
+
op_weight_index,
|
|
306
|
+
op_bias_index,
|
|
307
|
+
)
|
|
299
308
|
)
|
|
300
|
-
if bias_tensor is not None:
|
|
309
|
+
if bias_tensor is not None and not check_if_quantized(bias_tensor):
|
|
301
310
|
bias_quant_params = None
|
|
302
311
|
# Fused bias needs to be quantized for SRQ.
|
|
303
312
|
# Check if SRQ.
|
|
@@ -315,6 +324,19 @@ def _materialize_bias_for_conv_ops(
|
|
|
315
324
|
weight_consumer_params = (
|
|
316
325
|
op_tensor_params[op_weight_index].consumers[0].parameters
|
|
317
326
|
)
|
|
327
|
+
if weight_consumer_params is None and check_if_quantized(weight_tensor):
|
|
328
|
+
quant_params = weight_tensor.quantization
|
|
329
|
+
if op_info.op_quant_config.weight_tensor_config is None:
|
|
330
|
+
raise ValueError(
|
|
331
|
+
"weight_tensor_config cannot be None when weight tensor is"
|
|
332
|
+
" quantized."
|
|
333
|
+
)
|
|
334
|
+
weight_consumer_params = qtyping.UniformQuantParams(
|
|
335
|
+
num_bits=op_info.op_quant_config.weight_tensor_config.num_bits,
|
|
336
|
+
scale=quant_params.scale,
|
|
337
|
+
zero_point=quant_params.zeroPoint,
|
|
338
|
+
quantized_dimension=quant_params.quantizedDimension,
|
|
339
|
+
)
|
|
318
340
|
try:
|
|
319
341
|
# Bias quantization is using fixed quantization scale:
|
|
320
342
|
# input_scale * weight_scale. To avoid hidden numerics error, we check
|
|
@@ -495,7 +517,13 @@ def materialize_fc_conv(
|
|
|
495
517
|
weights, bias).
|
|
496
518
|
"""
|
|
497
519
|
ignored_inputs = [bias_index] # Bias tensor is quantized separately.
|
|
498
|
-
|
|
520
|
+
should_ignore_weight = False
|
|
521
|
+
if graph_info:
|
|
522
|
+
w_tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]]
|
|
523
|
+
should_ignore_weight = check_if_quantized(w_tensor)
|
|
524
|
+
if should_ignore_weight or _are_weights_too_small(
|
|
525
|
+
op_info, graph_info, weight_index
|
|
526
|
+
):
|
|
499
527
|
ignored_inputs.append(weight_index)
|
|
500
528
|
|
|
501
529
|
op_tensor_params = common_utils.materialize_standard_op(
|
|
@@ -506,7 +534,7 @@ def materialize_fc_conv(
|
|
|
506
534
|
inputs_to_ignore=ignored_inputs,
|
|
507
535
|
)
|
|
508
536
|
|
|
509
|
-
|
|
537
|
+
_materialize_bias_for_fc_conv_ops(
|
|
510
538
|
op_info,
|
|
511
539
|
graph_info,
|
|
512
540
|
op_tensor_params,
|
|
@@ -561,7 +589,7 @@ def materialize_conv2d_transpose(
|
|
|
561
589
|
"Materialize standard op should return at least two tensors for"
|
|
562
590
|
" conv2d_transpose."
|
|
563
591
|
)
|
|
564
|
-
|
|
592
|
+
_materialize_bias_for_fc_conv_ops(
|
|
565
593
|
op_info,
|
|
566
594
|
graph_info,
|
|
567
595
|
op_tensor_params,
|
|
@@ -108,6 +108,13 @@ def get_tensor_quant_params(
|
|
|
108
108
|
return dataclasses.replace(quant_params, quantized_data=quantized_vars)
|
|
109
109
|
|
|
110
110
|
|
|
111
|
+
def check_if_quantized(tensor: Any) -> bool:
|
|
112
|
+
"""Checks if the tensor is quantized."""
|
|
113
|
+
return (
|
|
114
|
+
tensor.quantization is not None and tensor.quantization.scale is not None
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
111
118
|
# TODO: b/333731147 - Use named tuple to store min/max.
|
|
112
119
|
def init_qsvs(
|
|
113
120
|
op_info: qtyping.OpInfo,
|
|
@@ -129,6 +136,13 @@ def init_qsvs(
|
|
|
129
136
|
op_qsvs = {}
|
|
130
137
|
|
|
131
138
|
inputs_to_ignore = inputs_to_ignore or []
|
|
139
|
+
quantized_inputs_to_ignore = [
|
|
140
|
+
opr_idx
|
|
141
|
+
for opr_idx, tensor_idx in enumerate(op_info.op.inputs)
|
|
142
|
+
if check_if_quantized(graph_info.subgraph_tensors[tensor_idx])
|
|
143
|
+
]
|
|
144
|
+
inputs_to_ignore.extend(quantized_inputs_to_ignore)
|
|
145
|
+
|
|
132
146
|
outputs_to_ignore = outputs_to_ignore or []
|
|
133
147
|
for opr_idx, tensor_idx in enumerate(op_info.op.inputs):
|
|
134
148
|
if tensor_idx != -1 and opr_idx not in inputs_to_ignore:
|
|
@@ -207,6 +221,12 @@ def min_max_calibrate(
|
|
|
207
221
|
}
|
|
208
222
|
|
|
209
223
|
inputs_to_ignore = inputs_to_ignore or []
|
|
224
|
+
quantized_inputs_to_ignore = [
|
|
225
|
+
opr_idx
|
|
226
|
+
for opr_idx, tensor_idx in enumerate(tfl_op.inputs)
|
|
227
|
+
if check_if_quantized(graph_info.subgraph_tensors[tensor_idx])
|
|
228
|
+
]
|
|
229
|
+
inputs_to_ignore.extend(quantized_inputs_to_ignore)
|
|
210
230
|
outputs_to_ignore = outputs_to_ignore or []
|
|
211
231
|
for i, tensor_idx in enumerate(tfl_op.inputs):
|
|
212
232
|
if tensor_idx != -1 and i not in inputs_to_ignore:
|
ai_edge_quantizer/calibrator.py
CHANGED
|
@@ -46,11 +46,6 @@ class Calibrator:
|
|
|
46
46
|
):
|
|
47
47
|
self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
|
48
48
|
|
|
49
|
-
if not tfl_flatbuffer_utils.is_float_model(self._flatbuffer_model):
|
|
50
|
-
raise ValueError(
|
|
51
|
-
"The input model for calibration is not a float model. Please check"
|
|
52
|
-
" the model (e.g., if it is already quantized)."
|
|
53
|
-
)
|
|
54
49
|
self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
|
55
50
|
float_tflite, use_xnnpack=True, num_threads=num_threads
|
|
56
51
|
)
|
|
@@ -184,15 +184,11 @@ class CalibratorAlreadyQuantizedModelTest(googletest.TestCase):
|
|
|
184
184
|
)
|
|
185
185
|
_ = calibrator.Calibrator(test_model_path)
|
|
186
186
|
|
|
187
|
-
def
|
|
187
|
+
def test_check_is_quantized_model_succeeds_when_model_is_quantized(self):
|
|
188
188
|
test_model_path = os.path.join(
|
|
189
189
|
TEST_DATA_PREFIX_PATH, "tests/models/mnist_quantized.tflite"
|
|
190
190
|
)
|
|
191
|
-
|
|
192
|
-
ValueError,
|
|
193
|
-
"The input model for calibration is not a float model.",
|
|
194
|
-
):
|
|
195
|
-
_ = calibrator.Calibrator(test_model_path)
|
|
191
|
+
_ = calibrator.Calibrator(test_model_path)
|
|
196
192
|
|
|
197
193
|
|
|
198
194
|
class CalibratorToyGemma2Test(googletest.TestCase):
|
|
@@ -118,7 +118,8 @@ class ComparisonResult:
|
|
|
118
118
|
for name in utils.get_input_tensor_names(
|
|
119
119
|
self._reference_model, signature_key
|
|
120
120
|
):
|
|
121
|
-
|
|
121
|
+
if name in result:
|
|
122
|
+
input_tensor_results[name] = result.pop(name)
|
|
122
123
|
|
|
123
124
|
output_tensor_results = {}
|
|
124
125
|
for name in utils.get_output_tensor_names(
|
|
@@ -136,7 +137,8 @@ class ComparisonResult:
|
|
|
136
137
|
self._reference_model,
|
|
137
138
|
subgraph_index,
|
|
138
139
|
):
|
|
139
|
-
|
|
140
|
+
if name in result:
|
|
141
|
+
constant_tensor_results[name] = result.pop(name)
|
|
140
142
|
|
|
141
143
|
self._comparison_results[signature_key] = SingleSignatureComparisonResult(
|
|
142
144
|
error_metric=error_metric,
|
|
@@ -214,6 +216,7 @@ def _setup_validation_interpreter(
|
|
|
214
216
|
signature_key: Optional[str],
|
|
215
217
|
use_xnnpack: bool,
|
|
216
218
|
num_threads: int,
|
|
219
|
+
preserve_all_tensors: bool = True,
|
|
217
220
|
) -> tuple[Any, int, dict[str, Any]]:
|
|
218
221
|
"""Setup the interpreter for validation given a signature key.
|
|
219
222
|
|
|
@@ -224,13 +227,17 @@ def _setup_validation_interpreter(
|
|
|
224
227
|
model only has one signature, this can be set to None.
|
|
225
228
|
use_xnnpack: Whether to use xnnpack for the interpreter.
|
|
226
229
|
num_threads: The number of threads to use for the interpreter.
|
|
230
|
+
preserve_all_tensors: Whether to preserve all tensors.
|
|
227
231
|
|
|
228
232
|
Returns:
|
|
229
233
|
A tuple of interpreter, subgraph_index and tensor_name_to_details.
|
|
230
234
|
"""
|
|
231
235
|
|
|
232
236
|
interpreter = utils.create_tfl_interpreter(
|
|
233
|
-
tflite_model=model,
|
|
237
|
+
tflite_model=model,
|
|
238
|
+
use_xnnpack=use_xnnpack,
|
|
239
|
+
num_threads=num_threads,
|
|
240
|
+
preserve_all_tensors=preserve_all_tensors,
|
|
234
241
|
)
|
|
235
242
|
utils.invoke_interpreter_signature(
|
|
236
243
|
interpreter, signature_input, signature_key
|
|
@@ -255,6 +262,7 @@ def compare_model(
|
|
|
255
262
|
compare_fn: Callable[[Any, Any], float],
|
|
256
263
|
use_xnnpack: bool = True,
|
|
257
264
|
num_threads: int = 16,
|
|
265
|
+
validate_output_tensors_only: bool = False,
|
|
258
266
|
) -> ComparisonResult:
|
|
259
267
|
"""Compares model tensors over a model signature using the compare_fn.
|
|
260
268
|
|
|
@@ -275,10 +283,13 @@ def compare_model(
|
|
|
275
283
|
single float value.
|
|
276
284
|
use_xnnpack: Whether to use xnnpack for the interpreter.
|
|
277
285
|
num_threads: The number of threads to use for the interpreter.
|
|
286
|
+
validate_output_tensors_only: If True, only compare output tensors.
|
|
287
|
+
Otherwise, compare all tensors.
|
|
278
288
|
|
|
279
289
|
Returns:
|
|
280
290
|
A ComparisonResult object.
|
|
281
291
|
"""
|
|
292
|
+
preserve_all_tensors = not validate_output_tensors_only
|
|
282
293
|
model_comparion_result = ComparisonResult(reference_model, target_model)
|
|
283
294
|
for signature_key, signature_inputs in test_data.items():
|
|
284
295
|
comparison_results = {}
|
|
@@ -291,6 +302,7 @@ def compare_model(
|
|
|
291
302
|
signature_key,
|
|
292
303
|
use_xnnpack,
|
|
293
304
|
num_threads,
|
|
305
|
+
preserve_all_tensors=preserve_all_tensors,
|
|
294
306
|
)
|
|
295
307
|
)
|
|
296
308
|
targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
|
|
@@ -300,10 +312,18 @@ def compare_model(
|
|
|
300
312
|
signature_key,
|
|
301
313
|
use_xnnpack,
|
|
302
314
|
num_threads,
|
|
315
|
+
preserve_all_tensors=preserve_all_tensors,
|
|
303
316
|
)
|
|
304
317
|
)
|
|
305
|
-
# Compare the cached tensor
|
|
306
|
-
|
|
318
|
+
# Compare the cached tensor value
|
|
319
|
+
tensor_names_to_compare = (
|
|
320
|
+
utils.get_output_tensor_names(reference_model, signature_key)
|
|
321
|
+
if validate_output_tensors_only
|
|
322
|
+
else list(ref_tensor_name_to_details.keys())
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
for tensor_name in tensor_names_to_compare:
|
|
326
|
+
detail = ref_tensor_name_to_details[tensor_name]
|
|
307
327
|
if detail['dtype'] == np.object_:
|
|
308
328
|
continue
|
|
309
329
|
# Ignore tensors where any dimension of the shape is 0.
|
|
@@ -35,12 +35,12 @@ class ParamsGenerator:
|
|
|
35
35
|
def __init__(self, float_tflite: Union[str, bytes]):
|
|
36
36
|
self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
|
37
37
|
|
|
38
|
-
if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
38
|
+
# if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
|
|
39
|
+
# raise ValueError(
|
|
40
|
+
# 'The input model for quantization parameters generation is not a'
|
|
41
|
+
# ' float model. Please check the model (e.g., if it is already'
|
|
42
|
+
# ' quantized).'
|
|
43
|
+
# )
|
|
44
44
|
self._check_tensor_names_are_unique()
|
|
45
45
|
self.buffer_to_tensors: dict[int, list[Any]] = (
|
|
46
46
|
tfl_flatbuffer_utils.buffer_to_tensors(self.flatbuffer_model)
|
|
@@ -409,7 +409,11 @@ class ParamsGenerator:
|
|
|
409
409
|
buffers_to_duplicate = []
|
|
410
410
|
tensor_names_to_duplicate = []
|
|
411
411
|
for buffer_idx, tensors in self.buffer_to_tensors.items():
|
|
412
|
-
if
|
|
412
|
+
# TODO: b/458797890 - Investigate if skipping buffer_idx == 0 is a
|
|
413
|
+
# correct fix, or if it just covers up a deeper issue. This is only
|
|
414
|
+
# required when statically quantizing models that have already been
|
|
415
|
+
# quantized dynamically.
|
|
416
|
+
if not tensors or buffer_idx == 0:
|
|
413
417
|
continue
|
|
414
418
|
# Check if any of the tensors needs to be duplicated.
|
|
415
419
|
for tensor in tensors:
|
|
@@ -1135,16 +1135,11 @@ class ParamsGeneratorAlreadyQuantizedModelTest(googletest.TestCase):
|
|
|
1135
1135
|
)
|
|
1136
1136
|
_ = params_generator.ParamsGenerator(test_model_path)
|
|
1137
1137
|
|
|
1138
|
-
def
|
|
1138
|
+
def test_check_is_quantized_model_succeeds_when_model_is_quantized(self):
|
|
1139
1139
|
test_model_path = os.path.join(
|
|
1140
1140
|
TEST_DATA_PREFIX_PATH, 'tests/models/mnist_quantized.tflite'
|
|
1141
1141
|
)
|
|
1142
|
-
|
|
1143
|
-
ValueError,
|
|
1144
|
-
'The input model for quantization parameters generation is not a float'
|
|
1145
|
-
' model.',
|
|
1146
|
-
):
|
|
1147
|
-
_ = params_generator.ParamsGenerator(test_model_path)
|
|
1142
|
+
_ = params_generator.ParamsGenerator(test_model_path)
|
|
1148
1143
|
|
|
1149
1144
|
|
|
1150
1145
|
if __name__ == '__main__':
|
ai_edge_quantizer/quantizer.py
CHANGED
|
@@ -434,6 +434,7 @@ class Quantizer:
|
|
|
434
434
|
error_metrics: str = 'mse',
|
|
435
435
|
use_xnnpack: bool = True,
|
|
436
436
|
num_threads: int = 16,
|
|
437
|
+
validate_output_tensors_only: bool = False,
|
|
437
438
|
) -> model_validator.ComparisonResult:
|
|
438
439
|
"""Numerical validation of the quantized model for a model signature.
|
|
439
440
|
|
|
@@ -452,6 +453,8 @@ class Quantizer:
|
|
|
452
453
|
error_metrics: Error metrics to be used for comparison.
|
|
453
454
|
use_xnnpack: Whether to use the xnnpack library for validation.
|
|
454
455
|
num_threads: Number of threads to use for validation.
|
|
456
|
+
validate_output_tensors_only: If True, only compare output tensors.
|
|
457
|
+
Otherwise, compare all tensors.
|
|
455
458
|
|
|
456
459
|
Returns:
|
|
457
460
|
The comparison result.
|
|
@@ -476,6 +479,7 @@ class Quantizer:
|
|
|
476
479
|
validation_utils.get_validation_func(error_metrics),
|
|
477
480
|
use_xnnpack=use_xnnpack,
|
|
478
481
|
num_threads=num_threads,
|
|
482
|
+
validate_output_tensors_only=validate_output_tensors_only,
|
|
479
483
|
)
|
|
480
484
|
|
|
481
485
|
def _get_quantization_params(
|
|
@@ -375,6 +375,19 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
375
375
|
'sequential/dense_1/MatMul', validation_result.intermediate_tensors
|
|
376
376
|
)
|
|
377
377
|
|
|
378
|
+
def test_validate_output_tensors_only_succeeds(self):
|
|
379
|
+
self._quantizer.quantize()
|
|
380
|
+
validation_result = self._quantizer.validate(
|
|
381
|
+
validate_output_tensors_only=True
|
|
382
|
+
)
|
|
383
|
+
validation_result = validation_result.get_signature_comparison_result()
|
|
384
|
+
self.assertIsNotNone(validation_result)
|
|
385
|
+
self.assertEmpty(validation_result.input_tensors)
|
|
386
|
+
self.assertEmpty(validation_result.constant_tensors)
|
|
387
|
+
self.assertEmpty(validation_result.intermediate_tensors)
|
|
388
|
+
self.assertNotEmpty(validation_result.output_tensors)
|
|
389
|
+
self.assertIn('StatefulPartitionedCall:0', validation_result.output_tensors)
|
|
390
|
+
|
|
378
391
|
def test_validate_with_quantized_model_arg_succeeds(self):
|
|
379
392
|
self._quantizer.quantize()
|
|
380
393
|
quantized_model = self._quantizer._result.quantized_model
|
|
@@ -431,6 +444,33 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
431
444
|
op_config=test_op_config,
|
|
432
445
|
)
|
|
433
446
|
|
|
447
|
+
def test_two_pass_quantization_with_conv_and_fc_succeeds(self):
|
|
448
|
+
float_model_path = self._test_model_path
|
|
449
|
+
|
|
450
|
+
drq_recipe_path = os.path.join(
|
|
451
|
+
TEST_DATA_PREFIX_PATH, 'recipes/dynamic_wi8_afp32_hadamard_recipe.json'
|
|
452
|
+
)
|
|
453
|
+
drq_quantizer = quantizer.Quantizer(float_model_path)
|
|
454
|
+
drq_quantizer.load_quantization_recipe(drq_recipe_path)
|
|
455
|
+
drq_result = drq_quantizer.quantize()
|
|
456
|
+
drq_model_path = os.path.join(self._tmp_save_path, 'drq_model.tflite')
|
|
457
|
+
drq_result.export_model(drq_model_path)
|
|
458
|
+
|
|
459
|
+
srq_recipe_path = os.path.join(
|
|
460
|
+
TEST_DATA_PREFIX_PATH, 'recipes/default_a8w8_recipe.json'
|
|
461
|
+
)
|
|
462
|
+
srq_quantizer = quantizer.Quantizer(drq_model_path)
|
|
463
|
+
srq_quantizer.load_quantization_recipe(srq_recipe_path)
|
|
464
|
+
representative_dataset = (
|
|
465
|
+
tfl_interpreter_utils.create_random_normal_input_data(
|
|
466
|
+
drq_model_path, num_samples=1
|
|
467
|
+
)
|
|
468
|
+
)
|
|
469
|
+
calibration_result = srq_quantizer.calibrate(representative_dataset)
|
|
470
|
+
srq_result = srq_quantizer.quantize(calibration_result)
|
|
471
|
+
srq_model_path = os.path.join(self._tmp_save_path, 'srq_model.tflite')
|
|
472
|
+
srq_result.export_model(srq_model_path)
|
|
473
|
+
|
|
434
474
|
|
|
435
475
|
class QuantizerBytearrayInputs(googletest.TestCase):
|
|
436
476
|
|
|
@@ -54,6 +54,15 @@ def check_horizontal_optimization(
|
|
|
54
54
|
Returns:
|
|
55
55
|
True if the two transformations can be merged, False otherwise
|
|
56
56
|
"""
|
|
57
|
+
if (
|
|
58
|
+
isinstance(param1.parameters, qtyping.UniformQuantParams)
|
|
59
|
+
and param1.parameters.hadamard is not None
|
|
60
|
+
):
|
|
61
|
+
if (
|
|
62
|
+
isinstance(param2.parameters, qtyping.UniformQuantParams)
|
|
63
|
+
and param2.parameters.hadamard is not None
|
|
64
|
+
):
|
|
65
|
+
return True
|
|
57
66
|
return (
|
|
58
67
|
param1.parameters == param2.parameters
|
|
59
68
|
and len(param1.transformations) > index
|
|
@@ -60,22 +60,22 @@ def get_constrained_op_list(
|
|
|
60
60
|
def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
|
|
61
61
|
return False
|
|
62
62
|
|
|
63
|
-
# Dummy implementation of the `
|
|
63
|
+
# Dummy implementation of the `_materialize_bias_for_fc_conv_ops` function to
|
|
64
64
|
# support `materialize_standard_op_wrapper` above.
|
|
65
|
-
def
|
|
65
|
+
def materialize_bias_for_fc_conv_ops_wrapper(*_args, **_kwargs):
|
|
66
66
|
return
|
|
67
67
|
|
|
68
68
|
# Do monkey patch to intercept the `materialize_standard_op` function to
|
|
69
69
|
# support `materialize_standard_op_wrapper` above.
|
|
70
70
|
original_materialize_standard_op = common_utils.materialize_standard_op
|
|
71
71
|
original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
|
|
72
|
-
|
|
73
|
-
common_quantize.
|
|
72
|
+
original_materialize_bias_for_fc_conv_ops = (
|
|
73
|
+
common_quantize._materialize_bias_for_fc_conv_ops # pylint: disable=protected-access
|
|
74
74
|
)
|
|
75
75
|
common_utils.materialize_standard_op = materialize_standard_op_wrapper
|
|
76
76
|
common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
|
|
77
|
-
common_quantize.
|
|
78
|
-
|
|
77
|
+
common_quantize._materialize_bias_for_fc_conv_ops = ( # pylint: disable=protected-access
|
|
78
|
+
materialize_bias_for_fc_conv_ops_wrapper
|
|
79
79
|
)
|
|
80
80
|
minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
|
|
81
81
|
|
|
@@ -105,7 +105,7 @@ def get_constrained_op_list(
|
|
|
105
105
|
# Restore the original functions.
|
|
106
106
|
common_utils.materialize_standard_op = original_materialize_standard_op
|
|
107
107
|
common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
|
|
108
|
-
common_quantize.
|
|
109
|
-
|
|
108
|
+
common_quantize._materialize_bias_for_fc_conv_ops = ( # pylint: disable=protected-access
|
|
109
|
+
original_materialize_bias_for_fc_conv_ops
|
|
110
110
|
)
|
|
111
111
|
return constrained_ops
|
|
@@ -35,6 +35,7 @@ def create_tfl_interpreter(
|
|
|
35
35
|
allocate_tensors: bool = True,
|
|
36
36
|
use_xnnpack: bool = True,
|
|
37
37
|
num_threads: int = 16,
|
|
38
|
+
preserve_all_tensors: bool = True,
|
|
38
39
|
) -> tfl.Interpreter:
|
|
39
40
|
"""Creates a TFLite interpreter from a model file.
|
|
40
41
|
|
|
@@ -43,6 +44,8 @@ def create_tfl_interpreter(
|
|
|
43
44
|
allocate_tensors: Whether to allocate tensors.
|
|
44
45
|
use_xnnpack: Whether to use the XNNPACK delegate for the interpreter.
|
|
45
46
|
num_threads: The number of threads to use for the interpreter.
|
|
47
|
+
preserve_all_tensors: Whether to preserve all tensors. If False, only input
|
|
48
|
+
and output tensors are preserved.
|
|
46
49
|
|
|
47
50
|
Returns:
|
|
48
51
|
A TFLite interpreter.
|
|
@@ -59,7 +62,7 @@ def create_tfl_interpreter(
|
|
|
59
62
|
model_content=bytes(tflite_model),
|
|
60
63
|
num_threads=num_threads,
|
|
61
64
|
experimental_op_resolver_type=op_resolver,
|
|
62
|
-
experimental_preserve_all_tensors=
|
|
65
|
+
experimental_preserve_all_tensors=preserve_all_tensors,
|
|
63
66
|
)
|
|
64
67
|
if allocate_tensors:
|
|
65
68
|
tflite_interpreter.allocate_tensors()
|
|
@@ -90,6 +90,16 @@ class TflUtilsSingleSignatureModelTest(googletest.TestCase):
|
|
|
90
90
|
]
|
|
91
91
|
self.assertEqual(tuple(average_pool_res.shape), (1, 14, 14, 8))
|
|
92
92
|
|
|
93
|
+
def test_get_tensor_name_to_content_map_fails_no_preserve_all_tensors(self):
|
|
94
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
|
95
|
+
self._test_model_path, preserve_all_tensors=False
|
|
96
|
+
)
|
|
97
|
+
tfl_interpreter_utils.invoke_interpreter_once(
|
|
98
|
+
tfl_interpreter, [self._input_data]
|
|
99
|
+
)
|
|
100
|
+
with self.assertRaisesRegex(ValueError, "Tensor data is null."):
|
|
101
|
+
tfl_interpreter_utils.get_tensor_name_to_content_map(tfl_interpreter)
|
|
102
|
+
|
|
93
103
|
def test_is_tensor_quantized(self):
|
|
94
104
|
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
|
95
105
|
self._test_model_path
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ai-edge-quantizer-nightly
|
|
3
|
-
Version: 0.5.0.
|
|
3
|
+
Version: 0.5.0.dev20251211
|
|
4
4
|
Summary: A quantizer for advanced developers to quantize converted AI Edge models.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
|
|
@@ -2,24 +2,24 @@ ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas
|
|
|
2
2
|
ai_edge_quantizer/algorithm_manager.py,sha256=0jSNITKl0Ge1XeYKueOUj9brlS4B5ZcdcVQ1kZS3JKg,16518
|
|
3
3
|
ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
|
|
4
4
|
ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
|
|
5
|
-
ai_edge_quantizer/calibrator.py,sha256=
|
|
6
|
-
ai_edge_quantizer/calibrator_test.py,sha256=
|
|
5
|
+
ai_edge_quantizer/calibrator.py,sha256=nkHUmxdWy16Vw3EOD3B_7EkGiX8V-XJRXXFynweGfG8,9744
|
|
6
|
+
ai_edge_quantizer/calibrator_test.py,sha256=c2ZCjl7PQYU9KtAovpDO9JX8sClgaLGO0P7oqoL6rP0,8830
|
|
7
7
|
ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
|
|
8
8
|
ai_edge_quantizer/default_policy.py,sha256=YcwwtVzoWUhjYgMtJ7b9f647740lURKteDOeJvwe17o,11384
|
|
9
9
|
ai_edge_quantizer/model_modifier.py,sha256=U70JByv6CItP8tg4bdyMfX-R3UlwylAGSviZkF_FSAM,10468
|
|
10
10
|
ai_edge_quantizer/model_modifier_test.py,sha256=CV4pgMEQkBJr_qbYR720TO8HBCutbEYLHptDHgdQMUE,7274
|
|
11
|
-
ai_edge_quantizer/model_validator.py,sha256=
|
|
11
|
+
ai_edge_quantizer/model_validator.py,sha256=HCXl8lu8wRmLn6wUaEm3I7xDOul3s7VC6XzbKjGfkuU,13945
|
|
12
12
|
ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
|
|
13
|
-
ai_edge_quantizer/params_generator.py,sha256
|
|
14
|
-
ai_edge_quantizer/params_generator_test.py,sha256=
|
|
13
|
+
ai_edge_quantizer/params_generator.py,sha256=-tbXB6crutiFhmLFEMe_-sxGylsvgd_cRZQ2fB67bNE,20436
|
|
14
|
+
ai_edge_quantizer/params_generator_test.py,sha256=gJlq_qCPC0dWkbkyCpQiqAsmCYoWYxtxM2xYMEkrr3g,40436
|
|
15
15
|
ai_edge_quantizer/qtyping.py,sha256=y9KretGzUGztyLdmto2XV6U0cxrSrfLWP1UOVcwR4dY,18011
|
|
16
|
-
ai_edge_quantizer/quantizer.py,sha256=
|
|
17
|
-
ai_edge_quantizer/quantizer_test.py,sha256=
|
|
16
|
+
ai_edge_quantizer/quantizer.py,sha256=_XRzj1UTXoPa0AeE1Ygz6XAelst2p2fGLqrhYB5MOCg,19150
|
|
17
|
+
ai_edge_quantizer/quantizer_test.py,sha256=6gcOLsZO-XW9VoKmcf_9CalG-_2lSUAe_fcmH2zHcoU,30167
|
|
18
18
|
ai_edge_quantizer/recipe.py,sha256=MEkfQ2Sg3KAE9LAORHWcbjYNPg06EUbwc1d-VspQA2U,6461
|
|
19
19
|
ai_edge_quantizer/recipe_manager.py,sha256=6l2uq8KL23KLu9OQDmPGkxrFiwHrdDB9xnn-ni8WdEM,15036
|
|
20
20
|
ai_edge_quantizer/recipe_manager_test.py,sha256=gYK3haUJ8-AISQvTI6tD-E-drJXQPSXPqBZdgpc5QTo,36595
|
|
21
21
|
ai_edge_quantizer/recipe_test.py,sha256=QisyaTol8JRZFcGOGyee7QRCvqj5VbF4guKWdIoMUOE,6213
|
|
22
|
-
ai_edge_quantizer/transformation_instruction_generator.py,sha256=
|
|
22
|
+
ai_edge_quantizer/transformation_instruction_generator.py,sha256=YmjtOFqc4ajGzvHEWTyIUIom0I0uJtxt4Uc9nxzmw2A,31852
|
|
23
23
|
ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=KW5-WoTTo9IqLEVnWxVC8ut8eWLi_91xfKgGqVQ9QDk,54635
|
|
24
24
|
ai_edge_quantizer/transformation_performer.py,sha256=mFsig0E5Isy7cnG1wMO2jzBn3Wql8fElM_PSpaL8okw,13354
|
|
25
25
|
ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
|
|
@@ -28,7 +28,7 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCP
|
|
|
28
28
|
ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
|
|
29
29
|
ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
|
|
30
30
|
ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
|
31
|
-
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=
|
|
31
|
+
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=s4JudZaYZlL5PwdfjKV-HcbaVSzVcXXueNFdBxZDv9I,41033
|
|
32
32
|
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
|
|
33
33
|
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=VjBDxGxjITHJc7xJABqBbZt6_qhobtZAl2gnVQrYJgc,8652
|
|
34
34
|
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
|
|
@@ -36,7 +36,7 @@ ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py,sha256=qxt9CP
|
|
|
36
36
|
ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py,sha256=1ejj5WS3GZwFk3qpsPiPS8jcmVS1-e7zRmvj2Nj8fKw,15440
|
|
37
37
|
ai_edge_quantizer/algorithms/uniform_quantize/mse.py,sha256=EP5yPw6khAhTo6VNTPXEE2aGKLfNnqz8COeJnTKaGWs,4641
|
|
38
38
|
ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py,sha256=-E1LIlxadckspltdgBWTiUzsiwbawSubndavHhWLt1g,7145
|
|
39
|
-
ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=
|
|
39
|
+
ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=NCLKwM8Teu2yI-Qd36e8KfqZWIqtHeAg_gMD7Z_sqNE,8988
|
|
40
40
|
ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=Eqa4OUqoCGywbHz-HxJ9dWRj9BKlVzJPuIhVzvrpdLM,8925
|
|
41
41
|
ai_edge_quantizer/algorithms/uniform_quantize/octav.py,sha256=-n-QZyp9y8WCy5FPSpXZXHfOA-p-RLvfSaCzAfhHiHI,7040
|
|
42
42
|
ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py,sha256=6m2U-9JdNei0XzOORg2gt87TJdD0XHZ-z5h9c4g_TB4,9120
|
|
@@ -65,17 +65,17 @@ ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rv
|
|
|
65
65
|
ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
|
66
66
|
ai_edge_quantizer/utils/calibration_utils.py,sha256=iMf_bSCf-O86MzDt5D9hLKqbTydqLwirluaC6BJ9yHo,11553
|
|
67
67
|
ai_edge_quantizer/utils/calibration_utils_test.py,sha256=4BlksXl7b4yptL8xPR67hmJCnjhN9V10a2PunzfHrUE,9372
|
|
68
|
-
ai_edge_quantizer/utils/constrained_ops_utils.py,sha256=
|
|
68
|
+
ai_edge_quantizer/utils/constrained_ops_utils.py,sha256=z0sm1R9anRRVgdgI23XQKwDRcdARdpTo_6UBDB_lHXE,4502
|
|
69
69
|
ai_edge_quantizer/utils/constrained_ops_utils_test.py,sha256=i_uERo-KvMj0dvUSuI67kdOBHvRQETg8-qnejs_MgTE,1756
|
|
70
70
|
ai_edge_quantizer/utils/test_utils.py,sha256=a4Nk-wbeB09dFjTDZiA0K67d26j5DD0UDH_GIVmVG_4,8685
|
|
71
71
|
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=42OWzQsRTXq3XQYmoxlz177_dw2fJfq7mDSJaU--ArQ,12076
|
|
72
72
|
ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
|
|
73
|
-
ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=
|
|
74
|
-
ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=
|
|
73
|
+
ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=zgXVSIoNU-M2V1Wcq06M0MPoA-dCXXEZd1Y9vvors_c,15100
|
|
74
|
+
ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=EPOXbmXqbt3tAewo3BQQjh2mjuxrrFit5tkF0wUVYHU,12471
|
|
75
75
|
ai_edge_quantizer/utils/validation_utils.py,sha256=Mr0D6X-pTDLODFAnCX3IlqdV1OL02tlq0ZjHbqx8nzg,7439
|
|
76
76
|
ai_edge_quantizer/utils/validation_utils_test.py,sha256=T8K5mCWeMcihND2KS_dHvCJUU9lEdG2sD95EgPkaX3w,5584
|
|
77
|
-
ai_edge_quantizer_nightly-0.5.0.
|
|
78
|
-
ai_edge_quantizer_nightly-0.5.0.
|
|
79
|
-
ai_edge_quantizer_nightly-0.5.0.
|
|
80
|
-
ai_edge_quantizer_nightly-0.5.0.
|
|
81
|
-
ai_edge_quantizer_nightly-0.5.0.
|
|
77
|
+
ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
78
|
+
ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info/METADATA,sha256=lvKdgf6OsKuQ_r5s7ZV8Zdj64RJxlwux0Eiyo0Ao0KI,1707
|
|
79
|
+
ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
|
80
|
+
ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
|
|
81
|
+
ai_edge_quantizer_nightly-0.5.0.dev20251211.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|