ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +158 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
- ai_edge_quantizer/calibrator.py +11 -60
- ai_edge_quantizer/calibrator_test.py +4 -73
- ai_edge_quantizer/default_policy.py +61 -26
- ai_edge_quantizer/model_modifier.py +97 -7
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +31 -8
- ai_edge_quantizer/params_generator.py +17 -10
- ai_edge_quantizer/params_generator_test.py +2 -7
- ai_edge_quantizer/qtyping.py +86 -6
- ai_edge_quantizer/quantizer.py +166 -21
- ai_edge_quantizer/quantizer_test.py +284 -16
- ai_edge_quantizer/recipe.py +154 -42
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +118 -13
- ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
- ai_edge_quantizer/transformation_performer.py +55 -25
- ai_edge_quantizer/transformation_performer_test.py +127 -5
- ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
- ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
- ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
- ai_edge_quantizer/transformations/transformation_utils.py +129 -6
- ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +75 -2
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -51,6 +51,30 @@ def _get_calibration_data(num_samples: int = 16):
|
|
|
51
51
|
return calibration_data
|
|
52
52
|
|
|
53
53
|
|
|
54
|
+
def _is_all_signature_defs_inputs_float(model_content: bytes):
|
|
55
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
|
|
56
|
+
for signature_key in tfl_interpreter.get_signature_list():
|
|
57
|
+
input_details = tfl_interpreter.get_signature_runner(
|
|
58
|
+
signature_key
|
|
59
|
+
).get_input_details()
|
|
60
|
+
for tensor_details in input_details.values():
|
|
61
|
+
if tensor_details['dtype'] != np.float32:
|
|
62
|
+
return False
|
|
63
|
+
return True
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _is_all_signature_defs_outputs_float(model_content: bytes):
|
|
67
|
+
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
|
|
68
|
+
for signature_key in tfl_interpreter.get_signature_list():
|
|
69
|
+
output_details = tfl_interpreter.get_signature_runner(
|
|
70
|
+
signature_key
|
|
71
|
+
).get_output_details()
|
|
72
|
+
for tensor_details in output_details.values():
|
|
73
|
+
if tensor_details['dtype'] != np.float32:
|
|
74
|
+
return False
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
|
|
54
78
|
class QuantizerTest(parameterized.TestCase):
|
|
55
79
|
|
|
56
80
|
def setUp(self):
|
|
@@ -92,6 +116,76 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
92
116
|
new_op_config.compute_precision,
|
|
93
117
|
)
|
|
94
118
|
|
|
119
|
+
def test_add_dynamic_config_succeeds(self):
|
|
120
|
+
self._quantizer.load_quantization_recipe(self._test_recipe_path)
|
|
121
|
+
scope_regex = '.*/Dense/.*'
|
|
122
|
+
self._quantizer.add_dynamic_config(
|
|
123
|
+
regex=scope_regex,
|
|
124
|
+
operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
|
125
|
+
num_bits=8,
|
|
126
|
+
)
|
|
127
|
+
updated_recipe = self._quantizer.get_quantization_recipe()
|
|
128
|
+
self.assertLen(updated_recipe, 2)
|
|
129
|
+
|
|
130
|
+
added_config = updated_recipe[-1]
|
|
131
|
+
self.assertEqual(added_config['regex'], scope_regex)
|
|
132
|
+
self.assertEqual(
|
|
133
|
+
added_config['op_config']['compute_precision'],
|
|
134
|
+
qtyping.ComputePrecision.INTEGER,
|
|
135
|
+
)
|
|
136
|
+
self.assertFalse(added_config['op_config']['explicit_dequantize'])
|
|
137
|
+
self.assertEqual(
|
|
138
|
+
added_config['op_config']['weight_tensor_config']['num_bits'], 8
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def test_add_weight_only_config_succeeds(self):
|
|
142
|
+
self._quantizer.load_quantization_recipe(self._test_recipe_path)
|
|
143
|
+
scope_regex = '.*/Dense/.*'
|
|
144
|
+
self._quantizer.add_weight_only_config(
|
|
145
|
+
regex=scope_regex,
|
|
146
|
+
operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
|
147
|
+
num_bits=4,
|
|
148
|
+
)
|
|
149
|
+
updated_recipe = self._quantizer.get_quantization_recipe()
|
|
150
|
+
self.assertLen(updated_recipe, 2)
|
|
151
|
+
|
|
152
|
+
added_config = updated_recipe[-1]
|
|
153
|
+
self.assertEqual(added_config['regex'], scope_regex)
|
|
154
|
+
self.assertEqual(
|
|
155
|
+
added_config['op_config']['compute_precision'],
|
|
156
|
+
qtyping.ComputePrecision.FLOAT,
|
|
157
|
+
)
|
|
158
|
+
self.assertTrue(added_config['op_config']['explicit_dequantize'])
|
|
159
|
+
self.assertEqual(
|
|
160
|
+
added_config['op_config']['weight_tensor_config']['num_bits'], 4
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def test_add_static_config_succeeds(self):
|
|
164
|
+
self._quantizer.load_quantization_recipe(self._test_recipe_path)
|
|
165
|
+
scope_regex = '.*/Dense/.*'
|
|
166
|
+
self._quantizer.add_static_config(
|
|
167
|
+
regex=scope_regex,
|
|
168
|
+
operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
|
169
|
+
activation_num_bits=8,
|
|
170
|
+
weight_num_bits=4,
|
|
171
|
+
)
|
|
172
|
+
updated_recipe = self._quantizer.get_quantization_recipe()
|
|
173
|
+
self.assertLen(updated_recipe, 2)
|
|
174
|
+
|
|
175
|
+
added_config = updated_recipe[-1]
|
|
176
|
+
self.assertEqual(added_config['regex'], scope_regex)
|
|
177
|
+
self.assertEqual(
|
|
178
|
+
added_config['op_config']['compute_precision'],
|
|
179
|
+
qtyping.ComputePrecision.INTEGER,
|
|
180
|
+
)
|
|
181
|
+
self.assertFalse(added_config['op_config']['explicit_dequantize'])
|
|
182
|
+
self.assertEqual(
|
|
183
|
+
added_config['op_config']['activation_tensor_config']['num_bits'], 8
|
|
184
|
+
)
|
|
185
|
+
self.assertEqual(
|
|
186
|
+
added_config['op_config']['weight_tensor_config']['num_bits'], 4
|
|
187
|
+
)
|
|
188
|
+
|
|
95
189
|
def test_load_quantization_recipe_succeeds(self):
|
|
96
190
|
qt = quantizer.Quantizer(self._test_model_path, None)
|
|
97
191
|
qt.load_quantization_recipe(self._test_recipe_path)
|
|
@@ -118,7 +212,7 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
118
212
|
# Calibrate with empty state.
|
|
119
213
|
calib_data = _get_calibration_data()
|
|
120
214
|
calibration_result = self._quantizer.calibrate(calib_data)
|
|
121
|
-
self.assertLen(calibration_result,
|
|
215
|
+
self.assertLen(calibration_result, 7)
|
|
122
216
|
|
|
123
217
|
@parameterized.parameters(
|
|
124
218
|
'recipes/default_a8w8_recipe.json',
|
|
@@ -133,7 +227,7 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
133
227
|
updated_calibration_result = self._quantizer.calibrate(
|
|
134
228
|
calib_data, previous_calibration_result=calibration_result
|
|
135
229
|
)
|
|
136
|
-
self.assertLen(updated_calibration_result,
|
|
230
|
+
self.assertLen(updated_calibration_result, 7)
|
|
137
231
|
self.assertNotEqual(
|
|
138
232
|
calibration_result['StatefulPartitionedCall:0'],
|
|
139
233
|
updated_calibration_result['StatefulPartitionedCall:0'],
|
|
@@ -215,6 +309,44 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
215
309
|
saved_recipe = json.load(json_file)
|
|
216
310
|
self.assertEqual(saved_recipe, self._test_recipe)
|
|
217
311
|
|
|
312
|
+
def test_saved_legacy_recipe_lacks_block_size(self):
|
|
313
|
+
model_name = 'test_model'
|
|
314
|
+
legacy_recipe_path = os.path.join(
|
|
315
|
+
TEST_DATA_PREFIX_PATH,
|
|
316
|
+
'recipes/dynamic_legacy_wi8_afp32_recipe.json',
|
|
317
|
+
)
|
|
318
|
+
self._quantizer.load_quantization_recipe(legacy_recipe_path)
|
|
319
|
+
result = self._quantizer.quantize()
|
|
320
|
+
result.save(self._tmp_save_path, model_name)
|
|
321
|
+
saved_recipe_path = os.path.join(
|
|
322
|
+
self._tmp_save_path, model_name + '_recipe.json'
|
|
323
|
+
)
|
|
324
|
+
with open(saved_recipe_path) as json_file:
|
|
325
|
+
saved_recipe = json.load(json_file)
|
|
326
|
+
with open(legacy_recipe_path) as json_file:
|
|
327
|
+
legacy_recipe = json.load(json_file)
|
|
328
|
+
|
|
329
|
+
self.assertNotEqual(saved_recipe, legacy_recipe)
|
|
330
|
+
|
|
331
|
+
# Verify that the default test recipe contains 'block_size'.
|
|
332
|
+
has_block_size = False
|
|
333
|
+
for config in legacy_recipe:
|
|
334
|
+
op_config = config.get('op_config')
|
|
335
|
+
if op_config:
|
|
336
|
+
weight_config = op_config.get('weight_tensor_config')
|
|
337
|
+
if weight_config and 'block_size' in weight_config:
|
|
338
|
+
has_block_size = True
|
|
339
|
+
break
|
|
340
|
+
self.assertTrue(has_block_size)
|
|
341
|
+
|
|
342
|
+
# Verify that the saved recipe does not have 'block_size'.
|
|
343
|
+
for config in saved_recipe:
|
|
344
|
+
op_config = config.get('op_config')
|
|
345
|
+
if op_config:
|
|
346
|
+
weight_config = op_config.get('weight_tensor_config')
|
|
347
|
+
if weight_config:
|
|
348
|
+
self.assertNotIn('block_size', weight_config)
|
|
349
|
+
|
|
218
350
|
def test_save_no_quantize_raise_error(self):
|
|
219
351
|
error_message = 'No quantized model to save.'
|
|
220
352
|
with self.assertRaisesWithPredicateMatch(
|
|
@@ -243,6 +375,34 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
243
375
|
'sequential/dense_1/MatMul', validation_result.intermediate_tensors
|
|
244
376
|
)
|
|
245
377
|
|
|
378
|
+
def test_validate_output_tensors_only_succeeds(self):
|
|
379
|
+
self._quantizer.quantize()
|
|
380
|
+
validation_result = self._quantizer.validate(
|
|
381
|
+
validate_output_tensors_only=True
|
|
382
|
+
)
|
|
383
|
+
validation_result = validation_result.get_signature_comparison_result()
|
|
384
|
+
self.assertIsNotNone(validation_result)
|
|
385
|
+
self.assertEmpty(validation_result.input_tensors)
|
|
386
|
+
self.assertEmpty(validation_result.constant_tensors)
|
|
387
|
+
self.assertEmpty(validation_result.intermediate_tensors)
|
|
388
|
+
self.assertNotEmpty(validation_result.output_tensors)
|
|
389
|
+
self.assertIn('StatefulPartitionedCall:0', validation_result.output_tensors)
|
|
390
|
+
|
|
391
|
+
def test_validate_with_quantized_model_arg_succeeds(self):
|
|
392
|
+
self._quantizer.quantize()
|
|
393
|
+
quantized_model = self._quantizer._result.quantized_model
|
|
394
|
+
self.assertIsNotNone(quantized_model)
|
|
395
|
+
|
|
396
|
+
new_quantizer = quantizer.Quantizer(
|
|
397
|
+
self._test_model_path, previous_quantized_model=quantized_model
|
|
398
|
+
)
|
|
399
|
+
validation_result = new_quantizer.validate()
|
|
400
|
+
validation_result = validation_result.get_signature_comparison_result()
|
|
401
|
+
self.assertIsNotNone(validation_result)
|
|
402
|
+
self.assertIn(
|
|
403
|
+
'sequential/dense_1/MatMul', validation_result.intermediate_tensors
|
|
404
|
+
)
|
|
405
|
+
|
|
246
406
|
def test_load_custom_policies_succeeds(self):
|
|
247
407
|
|
|
248
408
|
test_op_config = qtyping.OpQuantizationConfig(
|
|
@@ -284,6 +444,33 @@ class QuantizerTest(parameterized.TestCase):
|
|
|
284
444
|
op_config=test_op_config,
|
|
285
445
|
)
|
|
286
446
|
|
|
447
|
+
def test_two_pass_quantization_with_conv_and_fc_succeeds(self):
|
|
448
|
+
float_model_path = self._test_model_path
|
|
449
|
+
|
|
450
|
+
drq_recipe_path = os.path.join(
|
|
451
|
+
TEST_DATA_PREFIX_PATH, 'recipes/dynamic_wi8_afp32_hadamard_recipe.json'
|
|
452
|
+
)
|
|
453
|
+
drq_quantizer = quantizer.Quantizer(float_model_path)
|
|
454
|
+
drq_quantizer.load_quantization_recipe(drq_recipe_path)
|
|
455
|
+
drq_result = drq_quantizer.quantize()
|
|
456
|
+
drq_model_path = os.path.join(self._tmp_save_path, 'drq_model.tflite')
|
|
457
|
+
drq_result.export_model(drq_model_path)
|
|
458
|
+
|
|
459
|
+
srq_recipe_path = os.path.join(
|
|
460
|
+
TEST_DATA_PREFIX_PATH, 'recipes/default_a8w8_recipe.json'
|
|
461
|
+
)
|
|
462
|
+
srq_quantizer = quantizer.Quantizer(drq_model_path)
|
|
463
|
+
srq_quantizer.load_quantization_recipe(srq_recipe_path)
|
|
464
|
+
representative_dataset = (
|
|
465
|
+
tfl_interpreter_utils.create_random_normal_input_data(
|
|
466
|
+
drq_model_path, num_samples=1
|
|
467
|
+
)
|
|
468
|
+
)
|
|
469
|
+
calibration_result = srq_quantizer.calibrate(representative_dataset)
|
|
470
|
+
srq_result = srq_quantizer.quantize(calibration_result)
|
|
471
|
+
srq_model_path = os.path.join(self._tmp_save_path, 'srq_model.tflite')
|
|
472
|
+
srq_result.export_model(srq_model_path)
|
|
473
|
+
|
|
287
474
|
|
|
288
475
|
class QuantizerBytearrayInputs(googletest.TestCase):
|
|
289
476
|
|
|
@@ -426,14 +613,12 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
|
|
|
426
613
|
'symmetric': False,
|
|
427
614
|
'granularity': 'TENSORWISE',
|
|
428
615
|
'dtype': 'INT',
|
|
429
|
-
'block_size': 0,
|
|
430
616
|
},
|
|
431
617
|
'weight_tensor_config': {
|
|
432
618
|
'num_bits': 8,
|
|
433
619
|
'symmetric': True,
|
|
434
620
|
'granularity': 'CHANNELWISE',
|
|
435
621
|
'dtype': 'INT',
|
|
436
|
-
'block_size': 0,
|
|
437
622
|
},
|
|
438
623
|
'compute_precision': 'INTEGER',
|
|
439
624
|
'explicit_dequantize': False,
|
|
@@ -454,8 +639,7 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
|
|
|
454
639
|
|
|
455
640
|
# Quantize and expect an error about missing signature in calibration data.
|
|
456
641
|
error_message = (
|
|
457
|
-
'
|
|
458
|
-
" 'multiply'."
|
|
642
|
+
'MUL(index: 0) not found in tensor_name_to_qsv'
|
|
459
643
|
)
|
|
460
644
|
with self.assertRaisesWithPredicateMatch(
|
|
461
645
|
ValueError, lambda err: error_message in str(err)
|
|
@@ -477,21 +661,21 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
|
|
|
477
661
|
'signature_1': [{
|
|
478
662
|
'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
479
663
|
'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
480
|
-
'positions':
|
|
481
|
-
np.int32
|
|
664
|
+
'positions': (
|
|
665
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
482
666
|
),
|
|
483
|
-
'tokens':
|
|
484
|
-
np.int32
|
|
667
|
+
'tokens': (
|
|
668
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
485
669
|
),
|
|
486
670
|
}],
|
|
487
671
|
'signature_2': [{
|
|
488
672
|
'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
489
673
|
'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
|
|
490
|
-
'positions':
|
|
491
|
-
np.int32
|
|
674
|
+
'positions': (
|
|
675
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
492
676
|
),
|
|
493
|
-
'tokens':
|
|
494
|
-
np.int32
|
|
677
|
+
'tokens': (
|
|
678
|
+
_RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
|
|
495
679
|
),
|
|
496
680
|
}],
|
|
497
681
|
}
|
|
@@ -508,8 +692,8 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
|
|
|
508
692
|
)
|
|
509
693
|
|
|
510
694
|
self._quantizer.update_quantization_recipe(
|
|
511
|
-
regex='
|
|
512
|
-
operation_name=qtyping.TFLOperationName.
|
|
695
|
+
regex='.*',
|
|
696
|
+
operation_name=qtyping.TFLOperationName.OUTPUT,
|
|
513
697
|
algorithm_key=_AlgorithmName.NO_QUANTIZE,
|
|
514
698
|
)
|
|
515
699
|
|
|
@@ -521,6 +705,90 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
|
|
|
521
705
|
self._quantizer.quantize(calib_result)
|
|
522
706
|
self.assertIsNotNone(self._quantizer._result.quantized_model)
|
|
523
707
|
|
|
708
|
+
def test_toy_gemma2_update_signature_defs_succeeds(self):
|
|
709
|
+
|
|
710
|
+
self.assertTrue(
|
|
711
|
+
_is_all_signature_defs_outputs_float(
|
|
712
|
+
open(self._test_model_path, 'rb').read()
|
|
713
|
+
)
|
|
714
|
+
)
|
|
715
|
+
calib_result = self._quantizer.calibrate(
|
|
716
|
+
self._toy_gemma2_calibration_dataset
|
|
717
|
+
)
|
|
718
|
+
self.assertIsNotNone(calib_result)
|
|
719
|
+
self._quantizer.quantize(calib_result)
|
|
720
|
+
self.assertIsNotNone(self._quantizer._result.quantized_model)
|
|
721
|
+
self.assertTrue(
|
|
722
|
+
_is_all_signature_defs_outputs_float(
|
|
723
|
+
self._quantizer._result.quantized_model
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
class QuantizerFullyConnectedTest(parameterized.TestCase):
|
|
729
|
+
|
|
730
|
+
def setUp(self):
|
|
731
|
+
super().setUp()
|
|
732
|
+
self._tmp_save_path = self.create_tempdir().full_path
|
|
733
|
+
self._test_model_path = os.path.join(
|
|
734
|
+
TEST_DATA_PREFIX_PATH,
|
|
735
|
+
'tests/models/single_fc.tflite',
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
self._test_recipe_path = os.path.join(
|
|
739
|
+
TEST_DATA_PREFIX_PATH,
|
|
740
|
+
'recipes/default_a8w8_recipe.json',
|
|
741
|
+
)
|
|
742
|
+
with open(self._test_recipe_path) as json_file:
|
|
743
|
+
self._test_recipe = json.load(json_file)
|
|
744
|
+
|
|
745
|
+
self._quantizer = quantizer.Quantizer(
|
|
746
|
+
self._test_model_path, self._test_recipe_path
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
self._quantizer.update_quantization_recipe(
|
|
750
|
+
regex='.*',
|
|
751
|
+
operation_name=qtyping.TFLOperationName.INPUT,
|
|
752
|
+
algorithm_key=_AlgorithmName.NO_QUANTIZE,
|
|
753
|
+
)
|
|
754
|
+
self._quantizer.update_quantization_recipe(
|
|
755
|
+
regex='.*',
|
|
756
|
+
operation_name=qtyping.TFLOperationName.OUTPUT,
|
|
757
|
+
algorithm_key=_AlgorithmName.NO_QUANTIZE,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
def test_fully_connected_quantization_succeeds(self):
|
|
761
|
+
calib_result = self._quantizer.calibrate(
|
|
762
|
+
tfl_interpreter_utils.create_random_normal_input_data(
|
|
763
|
+
self._test_model_path, num_samples=4
|
|
764
|
+
)
|
|
765
|
+
)
|
|
766
|
+
self.assertIsNotNone(calib_result)
|
|
767
|
+
self._quantizer.quantize(calib_result)
|
|
768
|
+
self.assertIsNotNone(self._quantizer._result.quantized_model)
|
|
769
|
+
|
|
770
|
+
def test_fully_connected_quantization_update_signature_defs_succeeds(self):
|
|
771
|
+
|
|
772
|
+
model_content = open(self._test_model_path, 'rb').read()
|
|
773
|
+
self.assertTrue(_is_all_signature_defs_inputs_float(model_content))
|
|
774
|
+
self.assertTrue(_is_all_signature_defs_outputs_float(model_content))
|
|
775
|
+
|
|
776
|
+
calib_result = self._quantizer.calibrate(
|
|
777
|
+
tfl_interpreter_utils.create_random_normal_input_data(
|
|
778
|
+
self._test_model_path, num_samples=4
|
|
779
|
+
)
|
|
780
|
+
)
|
|
781
|
+
self.assertIsNotNone(calib_result)
|
|
782
|
+
quant_result = self._quantizer.quantize(calib_result)
|
|
783
|
+
self.assertIsNotNone(quant_result.quantized_model)
|
|
784
|
+
|
|
785
|
+
self.assertTrue(
|
|
786
|
+
_is_all_signature_defs_inputs_float(quant_result.quantized_model)
|
|
787
|
+
)
|
|
788
|
+
self.assertTrue(
|
|
789
|
+
_is_all_signature_defs_outputs_float(quant_result.quantized_model)
|
|
790
|
+
)
|
|
791
|
+
|
|
524
792
|
|
|
525
793
|
if __name__ == '__main__':
|
|
526
794
|
googletest.main()
|
ai_edge_quantizer/recipe.py
CHANGED
|
@@ -15,51 +15,163 @@
|
|
|
15
15
|
|
|
16
16
|
"""Quantization recipe module."""
|
|
17
17
|
|
|
18
|
+
from ai_edge_quantizer import algorithm_manager
|
|
19
|
+
from ai_edge_quantizer import qtyping
|
|
20
|
+
from ai_edge_quantizer import recipe_manager
|
|
18
21
|
|
|
19
|
-
|
|
20
|
-
"""Returns a dynamic quantization recipe with int8 weights and float32 activation."""
|
|
21
|
-
return [
|
|
22
|
-
dict({
|
|
23
|
-
'regex': '.*',
|
|
24
|
-
'operation': '*',
|
|
25
|
-
'algorithm_key': 'min_max_uniform_quantize',
|
|
26
|
-
'op_config': {
|
|
27
|
-
'weight_tensor_config': {
|
|
28
|
-
'num_bits': 8,
|
|
29
|
-
'symmetric': True,
|
|
30
|
-
'granularity': 'CHANNELWISE',
|
|
31
|
-
'dtype': 'INT',
|
|
32
|
-
'block_size': 0,
|
|
33
|
-
},
|
|
34
|
-
'compute_precision': 'INTEGER',
|
|
35
|
-
'explicit_dequantize': False,
|
|
36
|
-
'skip_checks': False,
|
|
37
|
-
},
|
|
38
|
-
})
|
|
39
|
-
]
|
|
22
|
+
AlgorithmName = algorithm_manager.AlgorithmName
|
|
40
23
|
|
|
41
24
|
|
|
42
|
-
def
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
25
|
+
def dynamic_wi8_afp32(
|
|
26
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
27
|
+
):
|
|
28
|
+
"""Returns a dynamic quantization recipe with int8 weights and float32 activation.
|
|
29
|
+
|
|
30
|
+
All supported ops will be quantized with int8 weights and float32 activations,
|
|
31
|
+
which will be dynamically quantized to int8 during inference to enable int8
|
|
32
|
+
compute. The model quality may suffer due to the on-the-fly quantization. If
|
|
33
|
+
quality is a concern, consider using weight-only quantization.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
algorithm_key: The algorithm to use for quantization.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A dynamic quantization recipe.
|
|
40
|
+
"""
|
|
41
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
42
|
+
rp_manager.add_dynamic_config(
|
|
43
|
+
regex='.*',
|
|
44
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
45
|
+
num_bits=8,
|
|
46
|
+
algorithm_key=algorithm_key,
|
|
47
|
+
)
|
|
48
|
+
return rp_manager.get_quantization_recipe()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def dynamic_wi4_afp32(
|
|
52
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
53
|
+
):
|
|
54
|
+
"""Returns a dynamic quantization recipe with int4 weights and float32 activation.
|
|
55
|
+
|
|
56
|
+
All supported ops will be quantized with int4 weights and float32 activations,
|
|
57
|
+
which will be dynamically quantized to int4 during inference to enable int4
|
|
58
|
+
compute.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
algorithm_key: The algorithm to use for quantization.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
A dynamic quantization recipe.
|
|
65
|
+
"""
|
|
66
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
67
|
+
rp_manager.add_dynamic_config(
|
|
68
|
+
regex='.*',
|
|
69
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
70
|
+
num_bits=4,
|
|
71
|
+
algorithm_key=algorithm_key,
|
|
72
|
+
)
|
|
73
|
+
return rp_manager.get_quantization_recipe()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def weight_only_wi8_afp32(
|
|
77
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
78
|
+
):
|
|
79
|
+
"""Returns a weight-only quantization recipe with int8 weights and float32 activation.
|
|
80
|
+
|
|
81
|
+
All supported ops will be quantized with int8 weights and float32 activations.
|
|
82
|
+
The weights will be explicitly dequantized before being fed into the op to
|
|
83
|
+
enable float compute thus retain model quality. If latency is a concern,
|
|
84
|
+
consider using dynamic range quantization.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
algorithm_key: The algorithm to use for quantization.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A weight-only quantization recipe.
|
|
91
|
+
"""
|
|
92
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
93
|
+
rp_manager.add_weight_only_config(
|
|
94
|
+
regex='.*',
|
|
95
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
96
|
+
num_bits=8,
|
|
97
|
+
algorithm_key=algorithm_key,
|
|
98
|
+
)
|
|
99
|
+
return rp_manager.get_quantization_recipe()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def weight_only_wi4_afp32(
|
|
103
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
104
|
+
):
|
|
105
|
+
"""Returns a weight-only quantization recipe with int4 weights and float32 activation.
|
|
106
|
+
|
|
107
|
+
All supported ops will be quantized with int4 weights and float32 activations.
|
|
108
|
+
The weights will be explicitly dequantized before being fed into the op to
|
|
109
|
+
enable float compute thus retain model quality.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
algorithm_key: The algorithm to use for quantization.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
A weight-only quantization recipe.
|
|
116
|
+
"""
|
|
117
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
118
|
+
rp_manager.add_weight_only_config(
|
|
119
|
+
regex='.*',
|
|
120
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
121
|
+
num_bits=4,
|
|
122
|
+
algorithm_key=algorithm_key,
|
|
123
|
+
)
|
|
124
|
+
return rp_manager.get_quantization_recipe()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def static_wi8_ai8(
|
|
128
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
129
|
+
):
|
|
130
|
+
"""Returns a static quantization recipe with int8 weights and int8 activations.
|
|
131
|
+
|
|
132
|
+
All supported ops will be quantized with int8 weights and int8 activations.
|
|
133
|
+
Calibration is needed to use this recipe.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
algorithm_key: The algorithm to use for quantization.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
A static quantization recipe.
|
|
140
|
+
"""
|
|
141
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
142
|
+
rp_manager.add_static_config(
|
|
143
|
+
regex='.*',
|
|
144
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
145
|
+
activation_num_bits=8,
|
|
146
|
+
weight_num_bits=8,
|
|
147
|
+
algorithm_key=algorithm_key,
|
|
148
|
+
)
|
|
149
|
+
return rp_manager.get_quantization_recipe()
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def static_wi8_ai16(
|
|
153
|
+
algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
154
|
+
):
|
|
155
|
+
"""Returns a static quantization recipe with int8 weights and int16 activations.
|
|
156
|
+
|
|
157
|
+
All supported ops will be quantized with int8 weights and int16 activations.
|
|
158
|
+
Calibration is needed to use this recipe.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
algorithm_key: The algorithm to use for quantization.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
A static quantization recipe.
|
|
165
|
+
"""
|
|
166
|
+
rp_manager = recipe_manager.RecipeManager()
|
|
167
|
+
rp_manager.add_static_config(
|
|
168
|
+
regex='.*',
|
|
169
|
+
operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
|
|
170
|
+
activation_num_bits=16,
|
|
171
|
+
weight_num_bits=8,
|
|
172
|
+
algorithm_key=algorithm_key,
|
|
173
|
+
)
|
|
174
|
+
return rp_manager.get_quantization_recipe()
|
|
63
175
|
|
|
64
176
|
|
|
65
177
|
def dynamic_legacy_wi8_afp32():
|