ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -51,6 +51,30 @@ def _get_calibration_data(num_samples: int = 16):
|
|
|
51
51
|
return calibration_data
|
|
52
52
|
|
|
53
53
|
|
|
54
|
+
def _is_all_signature_defs_inputs_float(model_content: bytes):
|
|
55
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
|
|
56
|
+
for signature_key in tfl_interpreter.get_signature_list():
|
|
57
|
+
input_details = tfl_interpreter.get_signature_runner(
|
|
58
|
+
signature_key
|
|
59
|
+
).get_input_details()
|
|
60
|
+
for tensor_details in input_details.values():
|
|
61
|
+
if tensor_details['dtype'] != np.float32:
|
|
62
|
+
return False
|
|
63
|
+
return True
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _is_all_signature_defs_outputs_float(model_content: bytes):
|
|
67
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
|
|
68
|
+
for signature_key in tfl_interpreter.get_signature_list():
|
|
69
|
+
output_details = tfl_interpreter.get_signature_runner(
|
|
70
|
+
signature_key
|
|
71
|
+
).get_output_details()
|
|
72
|
+
for tensor_details in output_details.values():
|
|
73
|
+
if tensor_details['dtype'] != np.float32:
|
|
74
|
+
return False
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
|
|
54
78
|
class QuantizerTest(parameterized.TestCase):
|
|
55
79
|
|
|
56
80
|
def setUp(self):
|
|
@@ -92,6 +116,76 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
92
116
|
new_op_config.compute_precision,
|
|
93
117
|
)
|
|
94
118
|
|
|
119
|
+
def test_add_dynamic_config_succeeds(self):
|
|
120
|
+
self._quantizer.load_quantization_recipe(self._test_recipe_path)
|
|
121
|
+
scope_regex = '.*/Dense/.*'
|
|
122
|
+
self._quantizer.add_dynamic_config(
|
|
123
|
+
regex=scope_regex,
|
|
124
|
+
operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
|
125
|
+
num_bits=8,
|
|
126
|
+
)
|
|
127
|
+
updated_recipe = self._quantizer.get_quantization_recipe()
|
|
128
|
+
self.assertLen(updated_recipe, 2)
|
|
129
|
+
|
|
130
|
+
added_config = updated_recipe[-1]
|
|
131
|
+
self.assertEqual(added_config['regex'], scope_regex)
|
|
132
|
+
self.assertEqual(
|
|
133
|
+
added_config['op_config']['compute_precision'],
|
|
134
|
+
qtyping.ComputePrecision.INTEGER,
|
|
135
|
+
)
|
|
136
|
+
self.assertFalse(added_config['op_config']['explicit_dequantize'])
|
|
137
|
+
self.assertEqual(
|
|
138
|
+
added_config['op_config']['weight_tensor_config']['num_bits'], 8
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def test_add_weight_only_config_succeeds(self):
|
|
142
|
+
self._quantizer.load_quantization_recipe(self._test_recipe_path)
|
|
143
|
+
scope_regex = '.*/Dense/.*'
|
|
144
|
+
self._quantizer.add_weight_only_config(
|
|
145
|
+
regex=scope_regex,
|
|
146
|
+
operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
|
147
|
+
num_bits=4,
|
|
148
|
+
)
|
|
149
|
+
updated_recipe = self._quantizer.get_quantization_recipe()
|
|
150
|
+
self.assertLen(updated_recipe, 2)
|
|
151
|
+
|
|
152
|
+
added_config = updated_recipe[-1]
|
|
153
|
+
self.assertEqual(added_config['regex'], scope_regex)
|
|
154
|
+
self.assertEqual(
|
|
155
|
+
added_config['op_config']['compute_precision'],
|
|
156
|
+
qtyping.ComputePrecision.FLOAT,
|
|
157
|
+
)
|
|
158
|
+
self.assertTrue(added_config['op_config']['explicit_dequantize'])
|
|
159
|
+
self.assertEqual(
|
|
160
|
+
added_config['op_config']['weight_tensor_config']['num_bits'], 4
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def test_add_static_config_succeeds(self):
|
|
164
|
+
self._quantizer.load_quantization_recipe(self._test_recipe_path)
|
|
165
|
+
scope_regex = '.*/Dense/.*'
|
|
166
|
+
self._quantizer.add_static_config(
|
|
167
|
+
regex=scope_regex,
|
|
168
|
+
operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
|
169
|
+
activation_num_bits=8,
|
|
170
|
+
weight_num_bits=4,
|
|
171
|
+
)
|
|
172
|
+
updated_recipe = self._quantizer.get_quantization_recipe()
|
|
173
|
+
self.assertLen(updated_recipe, 2)
|
|
174
|
+
|
|
175
|
+
added_config = updated_recipe[-1]
|
|
176
|
+
self.assertEqual(added_config['regex'], scope_regex)
|
|
177
|
+
self.assertEqual(
|
|
178
|
+
added_config['op_config']['compute_precision'],
|
|
179
|
+
qtyping.ComputePrecision.INTEGER,
|
|
180
|
+
)
|
|
181
|
+
self.assertFalse(added_config['op_config']['explicit_dequantize'])
|
|
182
|
+
self.assertEqual(
|
|
183
|
+
added_config['op_config']['activation_tensor_config']['num_bits'], 8
|
|
184
|
+
)
|
|
185
|
+
self.assertEqual(
|
|
186
|
+
added_config['op_config']['weight_tensor_config']['num_bits'], 4
|
|
187
|
+
)
|
|
188
|
+
|
|
95
189
|
def test_load_quantization_recipe_succeeds(self):
|
|
96
190
|
qt = quantizer.Quantizer(self._test_model_path, None)
|
|
97
191
|
qt.load_quantization_recipe(self._test_recipe_path)
|
|
@@ -118,7 +212,7 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
118
212
|
# Calibrate with empty state.
|
|
119
213
|
calib_data = _get_calibration_data()
|
|
120
214
|
calibration_result = self._quantizer.calibrate(calib_data)
|
|
121
|
-
self.assertLen(calibration_result,
|
|
215
|
+
self.assertLen(calibration_result, 7)
|
|
122
216
|
|
|
123
217
|
@parameterized.parameters(
|
|
124
218
|
'recipes/default_a8w8_recipe.json',
|
|
@@ -133,7 +227,7 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
133
227
|
updated_calibration_result = self._quantizer.calibrate(
|
|
134
228
|
calib_data, previous_calibration_result=calibration_result
|
|
135
229
|
)
|
|
136
|
-
self.assertLen(updated_calibration_result,
|
|
230
|
+
self.assertLen(updated_calibration_result, 7)
|
|
137
231
|
self.assertNotEqual(
|
|
138
232
|
calibration_result['StatefulPartitionedCall:0'],
|
|
139
233
|
updated_calibration_result['StatefulPartitionedCall:0'],
|
|
@@ -215,6 +309,44 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
215
309
|
saved_recipe = json.load(json_file)
|
|
216
310
|
self.assertEqual(saved_recipe, self._test_recipe)
|
|
217
311
|
|
|
312
|
+
def test_saved_legacy_recipe_lacks_block_size(self):
|
|
313
|
+
model_name = 'test_model'
|
|
314
|
+
legacy_recipe_path = os.path.join(
|
|
315
|
+
TEST_DATA_PREFIX_PATH,
|
|
316
|
+
'recipes/dynamic_legacy_wi8_afp32_recipe.json',
|
|
317
|
+
)
|
|
318
|
+
self._quantizer.load_quantization_recipe(legacy_recipe_path)
|
|
319
|
+
result = self._quantizer.quantize()
|
|
320
|
+
result.save(self._tmp_save_path, model_name)
|
|
321
|
+
saved_recipe_path = os.path.join(
|
|
322
|
+
self._tmp_save_path, model_name + '_recipe.json'
|
|
323
|
+
)
|
|
324
|
+
with open(saved_recipe_path) as json_file:
|
|
325
|
+
saved_recipe = json.load(json_file)
|
|
326
|
+
with open(legacy_recipe_path) as json_file:
|
|
327
|
+
legacy_recipe = json.load(json_file)
|
|
328
|
+
|
|
329
|
+
self.assertNotEqual(saved_recipe, legacy_recipe)
|
|
330
|
+
|
|
331
|
+
# Verify that the default test recipe contains 'block_size'.
|
|
332
|
+
has_block_size = False
|
|
333
|
+
for config in legacy_recipe:
|
|
334
|
+
op_config = config.get('op_config')
|
|
335
|
+
if op_config:
|
|
336
|
+
weight_config = op_config.get('weight_tensor_config')
|
|
337
|
+
if weight_config and 'block_size' in weight_config:
|
|
338
|
+
has_block_size = True
|
|
339
|
+
break
|
|
340
|
+
self.assertTrue(has_block_size)
|
|
341
|
+
|
|
342
|
+
# Verify that the saved recipe does not have 'block_size'.
|
|
343
|
+
for config in saved_recipe:
|
|
344
|
+
op_config = config.get('op_config')
|
|
345
|
+
if op_config:
|
|
346
|
+
weight_config = op_config.get('weight_tensor_config')
|
|
347
|
+
if weight_config:
|
|
348
|
+
self.assertNotIn('block_size', weight_config)
|
|
349
|
+
|
|
218
350
|
def test_save_no_quantize_raise_error(self):
|
|
219
351
|
error_message = 'No quantized model to save.'
|
|
220
352
|
with self.assertRaisesWithPredicateMatch(
|
|
@@ -243,6 +375,34 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
243
375
|
'sequential/dense_1/MatMul', validation_result.intermediate_tensors
|
|
244
376
|
)
|
|
245
377
|
|
|
378
|
+
def test_validate_output_tensors_only_succeeds(self):
|
|
379
|
+
self._quantizer.quantize()
|
|
380
|
+
validation_result = self._quantizer.validate(
|
|
381
|
+
validate_output_tensors_only=True
|
|
382
|
+
)
|
|
383
|
+
validation_result = validation_result.get_signature_comparison_result()
|
|
384
|
+
self.assertIsNotNone(validation_result)
|
|
385
|
+
self.assertEmpty(validation_result.input_tensors)
|
|
386
|
+
self.assertEmpty(validation_result.constant_tensors)
|
|
387
|
+
self.assertEmpty(validation_result.intermediate_tensors)
|
|
388
|
+
self.assertNotEmpty(validation_result.output_tensors)
|
|
389
|
+
self.assertIn('StatefulPartitionedCall:0', validation_result.output_tensors)
|
|
390
|
+
|
|
391
|
+
def test_validate_with_quantized_model_arg_succeeds(self):
|
|
392
|
+
self._quantizer.quantize()
|
|
393
|
+
quantized_model = self._quantizer._result.quantized_model
|
|
394
|
+
self.assertIsNotNone(quantized_model)
|
|
395
|
+
|
|
396
|
+
new_quantizer = quantizer.Quantizer(
|
|
397
|
+
self._test_model_path, previous_quantized_model=quantized_model
|
|
398
|
+
)
|
|
399
|
+
validation_result = new_quantizer.validate()
|
|
400
|
+
validation_result = validation_result.get_signature_comparison_result()
|
|
401
|
+
self.assertIsNotNone(validation_result)
|
|
402
|
+
self.assertIn(
|
|
403
|
+
'sequential/dense_1/MatMul', validation_result.intermediate_tensors
|
|
404
|
+
)
|
|
405
|
+
|
|
246
406
|
def test_load_custom_policies_succeeds(self):
|
|
247
407
|
|
|
248
408
|
test_op_config = qtyping.OpQuantizationConfig(
|
|
@@ -284,6 +444,33 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
284
444
|
op_config=test_op_config,
|
|
285
445
|
)
|
|
286
446
|
|
|
447
|
+
def test_two_pass_quantization_with_conv_and_fc_succeeds(self):
|
|
448
|
+
float_model_path = self._test_model_path
|
|
449
|
+
|
|
450
|
+
drq_recipe_path = os.path.join(
|
|
451
|
+
TEST_DATA_PREFIX_PATH, 'recipes/dynamic_wi8_afp32_hadamard_recipe.json'
|
|
452
|
+
)
|
|
453
|
+
drq_quantizer = quantizer.Quantizer(float_model_path)
|
|
454
|
+
drq_quantizer.load_quantization_recipe(drq_recipe_path)
|
|
455
|
+
drq_result = drq_quantizer.quantize()
|
|
456
|
+
drq_model_path = os.path.join(self._tmp_save_path, 'drq_model.tflite')
|
|
457
|
+
drq_result.export_model(drq_model_path)
|
|
458
|
+
|
|
459
|
+
srq_recipe_path = os.path.join(
|
|
460
|
+
TEST_DATA_PREFIX_PATH, 'recipes/default_a8w8_recipe.json'
|
|
461
|
+
)
|
|
462
|
+
srq_quantizer = quantizer.Quantizer(drq_model_path)
|
|
463
|
+
srq_quantizer.load_quantization_recipe(srq_recipe_path)
|
|
464
|
+
representative_dataset = (
|
|
465
|
+
tfl_interpreter_utils.create_random_normal_input_data(
|
|
466
|
+
drq_model_path, num_samples=1
|
|
467
|
+
)
|
|
468
|
+
)
|
|
469
|
+
calibration_result = srq_quantizer.calibrate(representative_dataset)
|
|
470
|
+
srq_result = srq_quantizer.quantize(calibration_result)
|
|
471
|
+
srq_model_path = os.path.join(self._tmp_save_path, 'srq_model.tflite')
|
|
472
|
+
srq_result.export_model(srq_model_path)
|
|
473
|
+
|
|
287
474
|
|
|
288
475
|
class QuantizerBytearrayInputs(googletest.TestCase):
|
|
289
476
|
|
|
@@ -412,7 +599,9 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
|
|
|
412
599
|
available_signatures = validation_result.available_signature_keys()
|
|
413
600
|
self.assertLen(available_signatures, 2)
|
|
414
601
|
|
|
415
|
-
def
|
|
602
|
+
def test_constant_buffer_shared_by_tensors_with_different_quantization_params_succeeds(
|
|
603
|
+
self,
|
|
604
|
+
):
|
|
416
605
|
recipe = [
|
|
417
606
|
dict({
|
|
418
607
|
'regex': '.*',
|
|
@@ -424,14 +613,12 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
|
|
|
424
613
|
'symmetric': False,
|
|
425
614
|
'granularity': 'TENSORWISE',
|
|
426
615
|
'dtype': 'INT',
|
|
427
|
-
'block_size': 0,
|
|
428
616
|
},
|
|
429
617
|
'weight_tensor_config': {
|
|
430
618
|
'num_bits': 8,
|
|
431
619
|
'symmetric': True,
|
|
432
620
|
'granularity': 'CHANNELWISE',
|
|
433
621
|
'dtype': 'INT',
|
|
434
|
-
'block_size': 0,
|
|
435
622
|
},
|
|
436
623
|
'compute_precision': 'INTEGER',
|
|
437
624
|
'explicit_dequantize': False,
|
|
@@ -439,17 +626,9 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
|
|
|
439
626
|
},
|
|
440
627
|
})
|
|
441
628
|
]
|
|
442
|
-
|
|
443
629
|
qt = quantizer.Quantizer(self._test_model_path, recipe)
|
|
444
630
|
calib_result = qt.calibrate(_MULTI_SIGNATURE_CALIBRATION_DATASET)
|
|
445
|
-
|
|
446
|
-
error_message = (
|
|
447
|
-
"The tensors b'Add/y' and b'Mul/y' do not have the same quantization"
|
|
448
|
-
)
|
|
449
|
-
with self.assertRaisesWithPredicateMatch(
|
|
450
|
-
RuntimeError, lambda err: error_message in str(err)
|
|
451
|
-
):
|
|
452
|
-
qt.quantize(calib_result)
|
|
631
|
+
self.assertIsNotNone(qt.quantize(calib_result).quantized_model)
|
|
453
632
|
|
|
454
633
|
def test_quantization_with_insufficient_calibration(self):
|
|
455
634
|
# Run calibration for one signature only.
|
|
@@ -460,8 +639,7 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
|
|
|
460
639
|
|
|
461
640
|
# Quantize and expect an error about missing signature in calibration data.
|
|
462
641
|
error_message = (
|
|
463
|
-
'
|
|
464
|
-
" 'multiply'."
|
|
642
|
+
'MUL(index: 0) not found in tensor_name_to_qsv'
|
|
465
643
|
)
|
|
466
644
|
with self.assertRaisesWithPredicateMatch(
|
|
467
645
|
ValueError, lambda err: error_message in str(err)
|
|
@@ -483,21 +661,21 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
|
|
|
483
661
|
'signature_1': [{
|
|
484
662
|
'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
485
663
|
'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
486
|
-
'positions':
|
|
487
|
-
np.int32
|
|
664
|
+
'positions': (
|
|
665
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
488
666
|
),
|
|
489
|
-
'tokens':
|
|
490
|
-
np.int32
|
|
667
|
+
'tokens': (
|
|
668
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
491
669
|
),
|
|
492
670
|
}],
|
|
493
671
|
'signature_2': [{
|
|
494
672
|
'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
495
673
|
'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
496
|
-
'positions':
|
|
497
|
-
np.int32
|
|
674
|
+
'positions': (
|
|
675
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
498
676
|
),
|
|
499
|
-
'tokens':
|
|
500
|
-
np.int32
|
|
677
|
+
'tokens': (
|
|
678
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
501
679
|
),
|
|
502
680
|
}],
|
|
503
681
|
}
|
|
@@ -514,8 +692,8 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
|
|
|
514
692
|
)
|
|
515
693
|
|
|
516
694
|
self._quantizer.update_quantization_recipe(
|
|
517
|
-
regex='
|
|
518
|
-
operation_name=qtyping.TFLOperationName.
|
|
695
|
+
regex='.*',
|
|
696
|
+
operation_name=qtyping.TFLOperationName.OUTPUT,
|
|
519
697
|
algorithm_key=_AlgorithmName.NO_QUANTIZE,
|
|
520
698
|
)
|
|
521
699
|
|
|
@@ -527,6 +705,90 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
|
|
|
527
705
|
self._quantizer.quantize(calib_result)
|
|
528
706
|
self.assertIsNotNone(self._quantizer._result.quantized_model)
|
|
529
707
|
|
|
708
|
+
def test_toy_gemma2_update_signature_defs_succeeds(self):
|
|
709
|
+
|
|
710
|
+
self.assertTrue(
|
|
711
|
+
_is_all_signature_defs_outputs_float(
|
|
712
|
+
open(self._test_model_path, 'rb').read()
|
|
713
|
+
)
|
|
714
|
+
)
|
|
715
|
+
calib_result = self._quantizer.calibrate(
|
|
716
|
+
self._toy_gemma2_calibration_dataset
|
|
717
|
+
)
|
|
718
|
+
self.assertIsNotNone(calib_result)
|
|
719
|
+
self._quantizer.quantize(calib_result)
|
|
720
|
+
self.assertIsNotNone(self._quantizer._result.quantized_model)
|
|
721
|
+
self.assertTrue(
|
|
722
|
+
_is_all_signature_defs_outputs_float(
|
|
723
|
+
self._quantizer._result.quantized_model
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
class QuantizerFullyConnectedTest(parameterized.TestCase):
|
|
729
|
+
|
|
730
|
+
def setUp(self):
|
|
731
|
+
super().setUp()
|
|
732
|
+
self._tmp_save_path = self.create_tempdir().full_path
|
|
733
|
+
self._test_model_path = os.path.join(
|
|
734
|
+
TEST_DATA_PREFIX_PATH,
|
|
735
|
+
'tests/models/single_fc.tflite',
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
self._test_recipe_path = os.path.join(
|
|
739
|
+
TEST_DATA_PREFIX_PATH,
|
|
740
|
+
'recipes/default_a8w8_recipe.json',
|
|
741
|
+
)
|
|
742
|
+
with open(self._test_recipe_path) as json_file:
|
|
743
|
+
self._test_recipe = json.load(json_file)
|
|
744
|
+
|
|
745
|
+
self._quantizer = quantizer.Quantizer(
|
|
746
|
+
self._test_model_path, self._test_recipe_path
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
self._quantizer.update_quantization_recipe(
|
|
750
|
+
regex='.*',
|
|
751
|
+
operation_name=qtyping.TFLOperationName.INPUT,
|
|
752
|
+
algorithm_key=_AlgorithmName.NO_QUANTIZE,
|
|
753
|
+
)
|
|
754
|
+
self._quantizer.update_quantization_recipe(
|
|
755
|
+
regex='.*',
|
|
756
|
+
operation_name=qtyping.TFLOperationName.OUTPUT,
|
|
757
|
+
algorithm_key=_AlgorithmName.NO_QUANTIZE,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
def test_fully_connected_quantization_succeeds(self):
|
|
761
|
+
calib_result = self._quantizer.calibrate(
|
|
762
|
+
tfl_interpreter_utils.create_random_normal_input_data(
|
|
763
|
+
self._test_model_path, num_samples=4
|
|
764
|
+
)
|
|
765
|
+
)
|
|
766
|
+
self.assertIsNotNone(calib_result)
|
|
767
|
+
self._quantizer.quantize(calib_result)
|
|
768
|
+
self.assertIsNotNone(self._quantizer._result.quantized_model)
|
|
769
|
+
|
|
770
|
+
def test_fully_connected_quantization_update_signature_defs_succeeds(self):
|
|
771
|
+
|
|
772
|
+
model_content = open(self._test_model_path, 'rb').read()
|
|
773
|
+
self.assertTrue(_is_all_signature_defs_inputs_float(model_content))
|
|
774
|
+
self.assertTrue(_is_all_signature_defs_outputs_float(model_content))
|
|
775
|
+
|
|
776
|
+
calib_result = self._quantizer.calibrate(
|
|
777
|
+
tfl_interpreter_utils.create_random_normal_input_data(
|
|
778
|
+
self._test_model_path, num_samples=4
|
|
779
|
+
)
|
|
780
|
+
)
|
|
781
|
+
self.assertIsNotNone(calib_result)
|
|
782
|
+
quant_result = self._quantizer.quantize(calib_result)
|
|
783
|
+
self.assertIsNotNone(quant_result.quantized_model)
|
|
784
|
+
|
|
785
|
+
self.assertTrue(
|
|
786
|
+
_is_all_signature_defs_inputs_float(quant_result.quantized_model)
|
|
787
|
+
)
|
|
788
|
+
self.assertTrue(
|
|
789
|
+
_is_all_signature_defs_outputs_float(quant_result.quantized_model)
|
|
790
|
+
)
|
|
791
|
+
|
|
530
792
|
|
|
531
793
|
if __name__ == '__main__':
|
|
532
794
|
googletest.main()
|
ai_edge_quantizer/recipe.py
CHANGED
|
@@ -15,28 +15,163 @@
|
|
|
15
15
|
|
|
16
16
|
"""Quantization recipe module."""
|
|
17
17
|
|
|
18
|
+
from ai_edge_quantizer import algorithm_manager
|
|
19
|
+
from ai_edge_quantizer import qtyping
|
|
20
|
+
from ai_edge_quantizer import recipe_manager
|
|
18
21
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
22
|
+
AlgorithmName = algorithm_manager.AlgorithmName
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def dynamic_wi8_afp32(
|
|
26
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
27
|
+
):
|
|
28
|
+
"""Returns a dynamic quantization recipe with int8 weights and float32 activation.
|
|
29
|
+
|
|
30
|
+
All supported ops will be quantized with int8 weights and float32 activations,
|
|
31
|
+
which will be dynamically quantized to int8 during inference to enable int8
|
|
32
|
+
compute. The model quality may suffer due to the on-the-fly quantization. If
|
|
33
|
+
quality is a concern, consider using weight-only quantization.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
algorithm_key: The algorithm to use for quantization.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A dynamic quantization recipe.
|
|
40
|
+
"""
|
|
41
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
42
|
+
rp_manager.add_dynamic_config(
|
|
43
|
+
regex='.*',
|
|
44
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
45
|
+
num_bits=8,
|
|
46
|
+
algorithm_key=algorithm_key,
|
|
47
|
+
)
|
|
48
|
+
return rp_manager.get_quantization_recipe()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def dynamic_wi4_afp32(
|
|
52
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
53
|
+
):
|
|
54
|
+
"""Returns a dynamic quantization recipe with int4 weights and float32 activation.
|
|
55
|
+
|
|
56
|
+
All supported ops will be quantized with int4 weights and float32 activations,
|
|
57
|
+
which will be dynamically quantized to int4 during inference to enable int4
|
|
58
|
+
compute.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
algorithm_key: The algorithm to use for quantization.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
A dynamic quantization recipe.
|
|
65
|
+
"""
|
|
66
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
67
|
+
rp_manager.add_dynamic_config(
|
|
68
|
+
regex='.*',
|
|
69
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
70
|
+
num_bits=4,
|
|
71
|
+
algorithm_key=algorithm_key,
|
|
72
|
+
)
|
|
73
|
+
return rp_manager.get_quantization_recipe()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def weight_only_wi8_afp32(
|
|
77
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
78
|
+
):
|
|
79
|
+
"""Returns a weight-only quantization recipe with int8 weights and float32 activation.
|
|
80
|
+
|
|
81
|
+
All supported ops will be quantized with int8 weights and float32 activations.
|
|
82
|
+
The weights will be explicitly dequantized before being fed into the op to
|
|
83
|
+
enable float compute thus retain model quality. If latency is a concern,
|
|
84
|
+
consider using dynamic range quantization.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
algorithm_key: The algorithm to use for quantization.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A weight-only quantization recipe.
|
|
91
|
+
"""
|
|
92
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
93
|
+
rp_manager.add_weight_only_config(
|
|
94
|
+
regex='.*',
|
|
95
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
96
|
+
num_bits=8,
|
|
97
|
+
algorithm_key=algorithm_key,
|
|
98
|
+
)
|
|
99
|
+
return rp_manager.get_quantization_recipe()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def weight_only_wi4_afp32(
|
|
103
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
104
|
+
):
|
|
105
|
+
"""Returns a weight-only quantization recipe with int4 weights and float32 activation.
|
|
106
|
+
|
|
107
|
+
All supported ops will be quantized with int4 weights and float32 activations.
|
|
108
|
+
The weights will be explicitly dequantized before being fed into the op to
|
|
109
|
+
enable float compute thus retain model quality.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
algorithm_key: The algorithm to use for quantization.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
A weight-only quantization recipe.
|
|
116
|
+
"""
|
|
117
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
118
|
+
rp_manager.add_weight_only_config(
|
|
119
|
+
regex='.*',
|
|
120
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
121
|
+
num_bits=4,
|
|
122
|
+
algorithm_key=algorithm_key,
|
|
123
|
+
)
|
|
124
|
+
return rp_manager.get_quantization_recipe()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def static_wi8_ai8(
|
|
128
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
129
|
+
):
|
|
130
|
+
"""Returns a static quantization recipe with int8 weights and int8 activations.
|
|
131
|
+
|
|
132
|
+
All supported ops will be quantized with int8 weights and int8 activations.
|
|
133
|
+
Calibration is needed to use this recipe.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
algorithm_key: The algorithm to use for quantization.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
A static quantization recipe.
|
|
140
|
+
"""
|
|
141
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
142
|
+
rp_manager.add_static_config(
|
|
143
|
+
regex='.*',
|
|
144
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
145
|
+
activation_num_bits=8,
|
|
146
|
+
weight_num_bits=8,
|
|
147
|
+
algorithm_key=algorithm_key,
|
|
148
|
+
)
|
|
149
|
+
return rp_manager.get_quantization_recipe()
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def static_wi8_ai16(
|
|
153
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
154
|
+
):
|
|
155
|
+
"""Returns a static quantization recipe with int8 weights and int16 activations.
|
|
156
|
+
|
|
157
|
+
All supported ops will be quantized with int8 weights and int16 activations.
|
|
158
|
+
Calibration is needed to use this recipe.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
algorithm_key: The algorithm to use for quantization.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
A static quantization recipe.
|
|
165
|
+
"""
|
|
166
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
167
|
+
rp_manager.add_static_config(
|
|
168
|
+
regex='.*',
|
|
169
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
170
|
+
activation_num_bits=16,
|
|
171
|
+
weight_num_bits=8,
|
|
172
|
+
algorithm_key=algorithm_key,
|
|
173
|
+
)
|
|
174
|
+
return rp_manager.get_quantization_recipe()
|
|
40
175
|
|
|
41
176
|
|
|
42
177
|
def dynamic_legacy_wi8_afp32():
|