ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,30 @@ def _get_calibration_data(num_samples: int = 16):
51
51
  return calibration_data
52
52
 
53
53
 
54
+ def _is_all_signature_defs_inputs_float(model_content: bytes):
55
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
56
+ for signature_key in tfl_interpreter.get_signature_list():
57
+ input_details = tfl_interpreter.get_signature_runner(
58
+ signature_key
59
+ ).get_input_details()
60
+ for tensor_details in input_details.values():
61
+ if tensor_details['dtype'] != np.float32:
62
+ return False
63
+ return True
64
+
65
+
66
+ def _is_all_signature_defs_outputs_float(model_content: bytes):
67
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
68
+ for signature_key in tfl_interpreter.get_signature_list():
69
+ output_details = tfl_interpreter.get_signature_runner(
70
+ signature_key
71
+ ).get_output_details()
72
+ for tensor_details in output_details.values():
73
+ if tensor_details['dtype'] != np.float32:
74
+ return False
75
+ return True
76
+
77
+
54
78
  class QuantizerTest(parameterized.TestCase):
55
79
 
56
80
  def setUp(self):
@@ -92,6 +116,76 @@ class QuantizerTest(parameterized.TestCase):
92
116
  new_op_config.compute_precision,
93
117
  )
94
118
 
119
+ def test_add_dynamic_config_succeeds(self):
120
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
121
+ scope_regex = '.*/Dense/.*'
122
+ self._quantizer.add_dynamic_config(
123
+ regex=scope_regex,
124
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
125
+ num_bits=8,
126
+ )
127
+ updated_recipe = self._quantizer.get_quantization_recipe()
128
+ self.assertLen(updated_recipe, 2)
129
+
130
+ added_config = updated_recipe[-1]
131
+ self.assertEqual(added_config['regex'], scope_regex)
132
+ self.assertEqual(
133
+ added_config['op_config']['compute_precision'],
134
+ qtyping.ComputePrecision.INTEGER,
135
+ )
136
+ self.assertFalse(added_config['op_config']['explicit_dequantize'])
137
+ self.assertEqual(
138
+ added_config['op_config']['weight_tensor_config']['num_bits'], 8
139
+ )
140
+
141
+ def test_add_weight_only_config_succeeds(self):
142
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
143
+ scope_regex = '.*/Dense/.*'
144
+ self._quantizer.add_weight_only_config(
145
+ regex=scope_regex,
146
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
147
+ num_bits=4,
148
+ )
149
+ updated_recipe = self._quantizer.get_quantization_recipe()
150
+ self.assertLen(updated_recipe, 2)
151
+
152
+ added_config = updated_recipe[-1]
153
+ self.assertEqual(added_config['regex'], scope_regex)
154
+ self.assertEqual(
155
+ added_config['op_config']['compute_precision'],
156
+ qtyping.ComputePrecision.FLOAT,
157
+ )
158
+ self.assertTrue(added_config['op_config']['explicit_dequantize'])
159
+ self.assertEqual(
160
+ added_config['op_config']['weight_tensor_config']['num_bits'], 4
161
+ )
162
+
163
+ def test_add_static_config_succeeds(self):
164
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
165
+ scope_regex = '.*/Dense/.*'
166
+ self._quantizer.add_static_config(
167
+ regex=scope_regex,
168
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
169
+ activation_num_bits=8,
170
+ weight_num_bits=4,
171
+ )
172
+ updated_recipe = self._quantizer.get_quantization_recipe()
173
+ self.assertLen(updated_recipe, 2)
174
+
175
+ added_config = updated_recipe[-1]
176
+ self.assertEqual(added_config['regex'], scope_regex)
177
+ self.assertEqual(
178
+ added_config['op_config']['compute_precision'],
179
+ qtyping.ComputePrecision.INTEGER,
180
+ )
181
+ self.assertFalse(added_config['op_config']['explicit_dequantize'])
182
+ self.assertEqual(
183
+ added_config['op_config']['activation_tensor_config']['num_bits'], 8
184
+ )
185
+ self.assertEqual(
186
+ added_config['op_config']['weight_tensor_config']['num_bits'], 4
187
+ )
188
+
95
189
  def test_load_quantization_recipe_succeeds(self):
96
190
  qt = quantizer.Quantizer(self._test_model_path, None)
97
191
  qt.load_quantization_recipe(self._test_recipe_path)
@@ -118,7 +212,7 @@ class QuantizerTest(parameterized.TestCase):
118
212
  # Calibrate with empty state.
119
213
  calib_data = _get_calibration_data()
120
214
  calibration_result = self._quantizer.calibrate(calib_data)
121
- self.assertLen(calibration_result, 13)
215
+ self.assertLen(calibration_result, 7)
122
216
 
123
217
  @parameterized.parameters(
124
218
  'recipes/default_a8w8_recipe.json',
@@ -133,7 +227,7 @@ class QuantizerTest(parameterized.TestCase):
133
227
  updated_calibration_result = self._quantizer.calibrate(
134
228
  calib_data, previous_calibration_result=calibration_result
135
229
  )
136
- self.assertLen(updated_calibration_result, 13)
230
+ self.assertLen(updated_calibration_result, 7)
137
231
  self.assertNotEqual(
138
232
  calibration_result['StatefulPartitionedCall:0'],
139
233
  updated_calibration_result['StatefulPartitionedCall:0'],
@@ -215,6 +309,44 @@ class QuantizerTest(parameterized.TestCase):
215
309
  saved_recipe = json.load(json_file)
216
310
  self.assertEqual(saved_recipe, self._test_recipe)
217
311
 
312
+ def test_saved_legacy_recipe_lacks_block_size(self):
313
+ model_name = 'test_model'
314
+ legacy_recipe_path = os.path.join(
315
+ TEST_DATA_PREFIX_PATH,
316
+ 'recipes/dynamic_legacy_wi8_afp32_recipe.json',
317
+ )
318
+ self._quantizer.load_quantization_recipe(legacy_recipe_path)
319
+ result = self._quantizer.quantize()
320
+ result.save(self._tmp_save_path, model_name)
321
+ saved_recipe_path = os.path.join(
322
+ self._tmp_save_path, model_name + '_recipe.json'
323
+ )
324
+ with open(saved_recipe_path) as json_file:
325
+ saved_recipe = json.load(json_file)
326
+ with open(legacy_recipe_path) as json_file:
327
+ legacy_recipe = json.load(json_file)
328
+
329
+ self.assertNotEqual(saved_recipe, legacy_recipe)
330
+
331
+ # Verify that the default test recipe contains 'block_size'.
332
+ has_block_size = False
333
+ for config in legacy_recipe:
334
+ op_config = config.get('op_config')
335
+ if op_config:
336
+ weight_config = op_config.get('weight_tensor_config')
337
+ if weight_config and 'block_size' in weight_config:
338
+ has_block_size = True
339
+ break
340
+ self.assertTrue(has_block_size)
341
+
342
+ # Verify that the saved recipe does not have 'block_size'.
343
+ for config in saved_recipe:
344
+ op_config = config.get('op_config')
345
+ if op_config:
346
+ weight_config = op_config.get('weight_tensor_config')
347
+ if weight_config:
348
+ self.assertNotIn('block_size', weight_config)
349
+
218
350
  def test_save_no_quantize_raise_error(self):
219
351
  error_message = 'No quantized model to save.'
220
352
  with self.assertRaisesWithPredicateMatch(
@@ -243,6 +375,34 @@ class QuantizerTest(parameterized.TestCase):
243
375
  'sequential/dense_1/MatMul', validation_result.intermediate_tensors
244
376
  )
245
377
 
378
+ def test_validate_output_tensors_only_succeeds(self):
379
+ self._quantizer.quantize()
380
+ validation_result = self._quantizer.validate(
381
+ validate_output_tensors_only=True
382
+ )
383
+ validation_result = validation_result.get_signature_comparison_result()
384
+ self.assertIsNotNone(validation_result)
385
+ self.assertEmpty(validation_result.input_tensors)
386
+ self.assertEmpty(validation_result.constant_tensors)
387
+ self.assertEmpty(validation_result.intermediate_tensors)
388
+ self.assertNotEmpty(validation_result.output_tensors)
389
+ self.assertIn('StatefulPartitionedCall:0', validation_result.output_tensors)
390
+
391
+ def test_validate_with_quantized_model_arg_succeeds(self):
392
+ self._quantizer.quantize()
393
+ quantized_model = self._quantizer._result.quantized_model
394
+ self.assertIsNotNone(quantized_model)
395
+
396
+ new_quantizer = quantizer.Quantizer(
397
+ self._test_model_path, previous_quantized_model=quantized_model
398
+ )
399
+ validation_result = new_quantizer.validate()
400
+ validation_result = validation_result.get_signature_comparison_result()
401
+ self.assertIsNotNone(validation_result)
402
+ self.assertIn(
403
+ 'sequential/dense_1/MatMul', validation_result.intermediate_tensors
404
+ )
405
+
246
406
  def test_load_custom_policies_succeeds(self):
247
407
 
248
408
  test_op_config = qtyping.OpQuantizationConfig(
@@ -284,6 +444,33 @@ class QuantizerTest(parameterized.TestCase):
284
444
  op_config=test_op_config,
285
445
  )
286
446
 
447
+ def test_two_pass_quantization_with_conv_and_fc_succeeds(self):
448
+ float_model_path = self._test_model_path
449
+
450
+ drq_recipe_path = os.path.join(
451
+ TEST_DATA_PREFIX_PATH, 'recipes/dynamic_wi8_afp32_hadamard_recipe.json'
452
+ )
453
+ drq_quantizer = quantizer.Quantizer(float_model_path)
454
+ drq_quantizer.load_quantization_recipe(drq_recipe_path)
455
+ drq_result = drq_quantizer.quantize()
456
+ drq_model_path = os.path.join(self._tmp_save_path, 'drq_model.tflite')
457
+ drq_result.export_model(drq_model_path)
458
+
459
+ srq_recipe_path = os.path.join(
460
+ TEST_DATA_PREFIX_PATH, 'recipes/default_a8w8_recipe.json'
461
+ )
462
+ srq_quantizer = quantizer.Quantizer(drq_model_path)
463
+ srq_quantizer.load_quantization_recipe(srq_recipe_path)
464
+ representative_dataset = (
465
+ tfl_interpreter_utils.create_random_normal_input_data(
466
+ drq_model_path, num_samples=1
467
+ )
468
+ )
469
+ calibration_result = srq_quantizer.calibrate(representative_dataset)
470
+ srq_result = srq_quantizer.quantize(calibration_result)
471
+ srq_model_path = os.path.join(self._tmp_save_path, 'srq_model.tflite')
472
+ srq_result.export_model(srq_model_path)
473
+
287
474
 
288
475
  class QuantizerBytearrayInputs(googletest.TestCase):
289
476
 
@@ -412,7 +599,9 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
412
599
  available_signatures = validation_result.available_signature_keys()
413
600
  self.assertLen(available_signatures, 2)
414
601
 
415
- def test_recipe_conflict_raises_error(self):
602
+ def test_constant_buffer_shared_by_tensors_with_different_quantization_params_succeeds(
603
+ self,
604
+ ):
416
605
  recipe = [
417
606
  dict({
418
607
  'regex': '.*',
@@ -424,14 +613,12 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
424
613
  'symmetric': False,
425
614
  'granularity': 'TENSORWISE',
426
615
  'dtype': 'INT',
427
- 'block_size': 0,
428
616
  },
429
617
  'weight_tensor_config': {
430
618
  'num_bits': 8,
431
619
  'symmetric': True,
432
620
  'granularity': 'CHANNELWISE',
433
621
  'dtype': 'INT',
434
- 'block_size': 0,
435
622
  },
436
623
  'compute_precision': 'INTEGER',
437
624
  'explicit_dequantize': False,
@@ -439,17 +626,9 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
439
626
  },
440
627
  })
441
628
  ]
442
-
443
629
  qt = quantizer.Quantizer(self._test_model_path, recipe)
444
630
  calib_result = qt.calibrate(_MULTI_SIGNATURE_CALIBRATION_DATASET)
445
-
446
- error_message = (
447
- "The tensors b'Add/y' and b'Mul/y' do not have the same quantization"
448
- )
449
- with self.assertRaisesWithPredicateMatch(
450
- RuntimeError, lambda err: error_message in str(err)
451
- ):
452
- qt.quantize(calib_result)
631
+ self.assertIsNotNone(qt.quantize(calib_result).quantized_model)
453
632
 
454
633
  def test_quantization_with_insufficient_calibration(self):
455
634
  # Run calibration for one signature only.
@@ -460,8 +639,7 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
460
639
 
461
640
  # Quantize and expect an error about missing signature in calibration data.
462
641
  error_message = (
463
- 'Missing QSVs (min/max) for tensor multiply_x:0 in Signature'
464
- " 'multiply'."
642
+ 'MUL(index: 0) not found in tensor_name_to_qsv'
465
643
  )
466
644
  with self.assertRaisesWithPredicateMatch(
467
645
  ValueError, lambda err: error_message in str(err)
@@ -483,21 +661,21 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
483
661
  'signature_1': [{
484
662
  'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
485
663
  'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
486
- 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
487
- np.int32
664
+ 'positions': (
665
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
488
666
  ),
489
- 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
490
- np.int32
667
+ 'tokens': (
668
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
491
669
  ),
492
670
  }],
493
671
  'signature_2': [{
494
672
  'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
495
673
  'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
496
- 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
497
- np.int32
674
+ 'positions': (
675
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
498
676
  ),
499
- 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
500
- np.int32
677
+ 'tokens': (
678
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
501
679
  ),
502
680
  }],
503
681
  }
@@ -514,8 +692,8 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
514
692
  )
515
693
 
516
694
  self._quantizer.update_quantization_recipe(
517
- regex='StatefulPartitionedCall',
518
- operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
695
+ regex='.*',
696
+ operation_name=qtyping.TFLOperationName.OUTPUT,
519
697
  algorithm_key=_AlgorithmName.NO_QUANTIZE,
520
698
  )
521
699
 
@@ -527,6 +705,90 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
527
705
  self._quantizer.quantize(calib_result)
528
706
  self.assertIsNotNone(self._quantizer._result.quantized_model)
529
707
 
708
+ def test_toy_gemma2_update_signature_defs_succeeds(self):
709
+
710
+ self.assertTrue(
711
+ _is_all_signature_defs_outputs_float(
712
+ open(self._test_model_path, 'rb').read()
713
+ )
714
+ )
715
+ calib_result = self._quantizer.calibrate(
716
+ self._toy_gemma2_calibration_dataset
717
+ )
718
+ self.assertIsNotNone(calib_result)
719
+ self._quantizer.quantize(calib_result)
720
+ self.assertIsNotNone(self._quantizer._result.quantized_model)
721
+ self.assertTrue(
722
+ _is_all_signature_defs_outputs_float(
723
+ self._quantizer._result.quantized_model
724
+ )
725
+ )
726
+
727
+
728
+ class QuantizerFullyConnectedTest(parameterized.TestCase):
729
+
730
+ def setUp(self):
731
+ super().setUp()
732
+ self._tmp_save_path = self.create_tempdir().full_path
733
+ self._test_model_path = os.path.join(
734
+ TEST_DATA_PREFIX_PATH,
735
+ 'tests/models/single_fc.tflite',
736
+ )
737
+
738
+ self._test_recipe_path = os.path.join(
739
+ TEST_DATA_PREFIX_PATH,
740
+ 'recipes/default_a8w8_recipe.json',
741
+ )
742
+ with open(self._test_recipe_path) as json_file:
743
+ self._test_recipe = json.load(json_file)
744
+
745
+ self._quantizer = quantizer.Quantizer(
746
+ self._test_model_path, self._test_recipe_path
747
+ )
748
+
749
+ self._quantizer.update_quantization_recipe(
750
+ regex='.*',
751
+ operation_name=qtyping.TFLOperationName.INPUT,
752
+ algorithm_key=_AlgorithmName.NO_QUANTIZE,
753
+ )
754
+ self._quantizer.update_quantization_recipe(
755
+ regex='.*',
756
+ operation_name=qtyping.TFLOperationName.OUTPUT,
757
+ algorithm_key=_AlgorithmName.NO_QUANTIZE,
758
+ )
759
+
760
+ def test_fully_connected_quantization_succeeds(self):
761
+ calib_result = self._quantizer.calibrate(
762
+ tfl_interpreter_utils.create_random_normal_input_data(
763
+ self._test_model_path, num_samples=4
764
+ )
765
+ )
766
+ self.assertIsNotNone(calib_result)
767
+ self._quantizer.quantize(calib_result)
768
+ self.assertIsNotNone(self._quantizer._result.quantized_model)
769
+
770
+ def test_fully_connected_quantization_update_signature_defs_succeeds(self):
771
+
772
+ model_content = open(self._test_model_path, 'rb').read()
773
+ self.assertTrue(_is_all_signature_defs_inputs_float(model_content))
774
+ self.assertTrue(_is_all_signature_defs_outputs_float(model_content))
775
+
776
+ calib_result = self._quantizer.calibrate(
777
+ tfl_interpreter_utils.create_random_normal_input_data(
778
+ self._test_model_path, num_samples=4
779
+ )
780
+ )
781
+ self.assertIsNotNone(calib_result)
782
+ quant_result = self._quantizer.quantize(calib_result)
783
+ self.assertIsNotNone(quant_result.quantized_model)
784
+
785
+ self.assertTrue(
786
+ _is_all_signature_defs_inputs_float(quant_result.quantized_model)
787
+ )
788
+ self.assertTrue(
789
+ _is_all_signature_defs_outputs_float(quant_result.quantized_model)
790
+ )
791
+
530
792
 
531
793
  if __name__ == '__main__':
532
794
  googletest.main()
@@ -15,28 +15,163 @@
15
15
 
16
16
  """Quantization recipe module."""
17
17
 
18
+ from ai_edge_quantizer import algorithm_manager
19
+ from ai_edge_quantizer import qtyping
20
+ from ai_edge_quantizer import recipe_manager
18
21
 
19
- def dynamic_wi8_afp32():
20
- """Returns a dynamic quantization recipe with int8 weights and float32 activation."""
21
- return [
22
- dict({
23
- 'regex': '.*',
24
- 'operation': '*',
25
- 'algorithm_key': 'min_max_uniform_quantize',
26
- 'op_config': {
27
- 'weight_tensor_config': {
28
- 'num_bits': 8,
29
- 'symmetric': True,
30
- 'granularity': 'CHANNELWISE',
31
- 'dtype': 'INT',
32
- 'block_size': 0,
33
- },
34
- 'compute_precision': 'INTEGER',
35
- 'explicit_dequantize': False,
36
- 'skip_checks': False,
37
- },
38
- })
39
- ]
22
+ AlgorithmName = algorithm_manager.AlgorithmName
23
+
24
+
25
+ def dynamic_wi8_afp32(
26
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
27
+ ):
28
+ """Returns a dynamic quantization recipe with int8 weights and float32 activation.
29
+
30
+ All supported ops will be quantized with int8 weights and float32 activations,
31
+ which will be dynamically quantized to int8 during inference to enable int8
32
+ compute. The model quality may suffer due to the on-the-fly quantization. If
33
+ quality is a concern, consider using weight-only quantization.
34
+
35
+ Args:
36
+ algorithm_key: The algorithm to use for quantization.
37
+
38
+ Returns:
39
+ A dynamic quantization recipe.
40
+ """
41
+ rp_manager = recipe_manager.RecipeManager()
42
+ rp_manager.add_dynamic_config(
43
+ regex='.*',
44
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
45
+ num_bits=8,
46
+ algorithm_key=algorithm_key,
47
+ )
48
+ return rp_manager.get_quantization_recipe()
49
+
50
+
51
+ def dynamic_wi4_afp32(
52
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
53
+ ):
54
+ """Returns a dynamic quantization recipe with int4 weights and float32 activation.
55
+
56
+ All supported ops will be quantized with int4 weights and float32 activations,
57
+ which will be dynamically quantized to int4 during inference to enable int4
58
+ compute.
59
+
60
+ Args:
61
+ algorithm_key: The algorithm to use for quantization.
62
+
63
+ Returns:
64
+ A dynamic quantization recipe.
65
+ """
66
+ rp_manager = recipe_manager.RecipeManager()
67
+ rp_manager.add_dynamic_config(
68
+ regex='.*',
69
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
70
+ num_bits=4,
71
+ algorithm_key=algorithm_key,
72
+ )
73
+ return rp_manager.get_quantization_recipe()
74
+
75
+
76
+ def weight_only_wi8_afp32(
77
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
78
+ ):
79
+ """Returns a weight-only quantization recipe with int8 weights and float32 activation.
80
+
81
+ All supported ops will be quantized with int8 weights and float32 activations.
82
+ The weights will be explicitly dequantized before being fed into the op to
83
+ enable float compute thus retain model quality. If latency is a concern,
84
+ consider using dynamic range quantization.
85
+
86
+ Args:
87
+ algorithm_key: The algorithm to use for quantization.
88
+
89
+ Returns:
90
+ A weight-only quantization recipe.
91
+ """
92
+ rp_manager = recipe_manager.RecipeManager()
93
+ rp_manager.add_weight_only_config(
94
+ regex='.*',
95
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
96
+ num_bits=8,
97
+ algorithm_key=algorithm_key,
98
+ )
99
+ return rp_manager.get_quantization_recipe()
100
+
101
+
102
+ def weight_only_wi4_afp32(
103
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
104
+ ):
105
+ """Returns a weight-only quantization recipe with int4 weights and float32 activation.
106
+
107
+ All supported ops will be quantized with int4 weights and float32 activations.
108
+ The weights will be explicitly dequantized before being fed into the op to
109
+ enable float compute thus retain model quality.
110
+
111
+ Args:
112
+ algorithm_key: The algorithm to use for quantization.
113
+
114
+ Returns:
115
+ A weight-only quantization recipe.
116
+ """
117
+ rp_manager = recipe_manager.RecipeManager()
118
+ rp_manager.add_weight_only_config(
119
+ regex='.*',
120
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
121
+ num_bits=4,
122
+ algorithm_key=algorithm_key,
123
+ )
124
+ return rp_manager.get_quantization_recipe()
125
+
126
+
127
+ def static_wi8_ai8(
128
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
129
+ ):
130
+ """Returns a static quantization recipe with int8 weights and int8 activations.
131
+
132
+ All supported ops will be quantized with int8 weights and int8 activations.
133
+ Calibration is needed to use this recipe.
134
+
135
+ Args:
136
+ algorithm_key: The algorithm to use for quantization.
137
+
138
+ Returns:
139
+ A static quantization recipe.
140
+ """
141
+ rp_manager = recipe_manager.RecipeManager()
142
+ rp_manager.add_static_config(
143
+ regex='.*',
144
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
145
+ activation_num_bits=8,
146
+ weight_num_bits=8,
147
+ algorithm_key=algorithm_key,
148
+ )
149
+ return rp_manager.get_quantization_recipe()
150
+
151
+
152
+ def static_wi8_ai16(
153
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
154
+ ):
155
+ """Returns a static quantization recipe with int8 weights and int16 activations.
156
+
157
+ All supported ops will be quantized with int8 weights and int16 activations.
158
+ Calibration is needed to use this recipe.
159
+
160
+ Args:
161
+ algorithm_key: The algorithm to use for quantization.
162
+
163
+ Returns:
164
+ A static quantization recipe.
165
+ """
166
+ rp_manager = recipe_manager.RecipeManager()
167
+ rp_manager.add_static_config(
168
+ regex='.*',
169
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
170
+ activation_num_bits=16,
171
+ weight_num_bits=8,
172
+ algorithm_key=algorithm_key,
173
+ )
174
+ return rp_manager.get_quantization_recipe()
40
175
 
41
176
 
42
177
  def dynamic_legacy_wi8_afp32():