ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,30 @@ def _get_calibration_data(num_samples: int = 16):
51
51
  return calibration_data
52
52
 
53
53
 
54
+ def _is_all_signature_defs_inputs_float(model_content: bytes):
55
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
56
+ for signature_key in tfl_interpreter.get_signature_list():
57
+ input_details = tfl_interpreter.get_signature_runner(
58
+ signature_key
59
+ ).get_input_details()
60
+ for tensor_details in input_details.values():
61
+ if tensor_details['dtype'] != np.float32:
62
+ return False
63
+ return True
64
+
65
+
66
+ def _is_all_signature_defs_outputs_float(model_content: bytes):
67
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_content)
68
+ for signature_key in tfl_interpreter.get_signature_list():
69
+ output_details = tfl_interpreter.get_signature_runner(
70
+ signature_key
71
+ ).get_output_details()
72
+ for tensor_details in output_details.values():
73
+ if tensor_details['dtype'] != np.float32:
74
+ return False
75
+ return True
76
+
77
+
54
78
  class QuantizerTest(parameterized.TestCase):
55
79
 
56
80
  def setUp(self):
@@ -92,6 +116,76 @@ class QuantizerTest(parameterized.TestCase):
92
116
  new_op_config.compute_precision,
93
117
  )
94
118
 
119
+ def test_add_dynamic_config_succeeds(self):
120
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
121
+ scope_regex = '.*/Dense/.*'
122
+ self._quantizer.add_dynamic_config(
123
+ regex=scope_regex,
124
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
125
+ num_bits=8,
126
+ )
127
+ updated_recipe = self._quantizer.get_quantization_recipe()
128
+ self.assertLen(updated_recipe, 2)
129
+
130
+ added_config = updated_recipe[-1]
131
+ self.assertEqual(added_config['regex'], scope_regex)
132
+ self.assertEqual(
133
+ added_config['op_config']['compute_precision'],
134
+ qtyping.ComputePrecision.INTEGER,
135
+ )
136
+ self.assertFalse(added_config['op_config']['explicit_dequantize'])
137
+ self.assertEqual(
138
+ added_config['op_config']['weight_tensor_config']['num_bits'], 8
139
+ )
140
+
141
+ def test_add_weight_only_config_succeeds(self):
142
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
143
+ scope_regex = '.*/Dense/.*'
144
+ self._quantizer.add_weight_only_config(
145
+ regex=scope_regex,
146
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
147
+ num_bits=4,
148
+ )
149
+ updated_recipe = self._quantizer.get_quantization_recipe()
150
+ self.assertLen(updated_recipe, 2)
151
+
152
+ added_config = updated_recipe[-1]
153
+ self.assertEqual(added_config['regex'], scope_regex)
154
+ self.assertEqual(
155
+ added_config['op_config']['compute_precision'],
156
+ qtyping.ComputePrecision.FLOAT,
157
+ )
158
+ self.assertTrue(added_config['op_config']['explicit_dequantize'])
159
+ self.assertEqual(
160
+ added_config['op_config']['weight_tensor_config']['num_bits'], 4
161
+ )
162
+
163
+ def test_add_static_config_succeeds(self):
164
+ self._quantizer.load_quantization_recipe(self._test_recipe_path)
165
+ scope_regex = '.*/Dense/.*'
166
+ self._quantizer.add_static_config(
167
+ regex=scope_regex,
168
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
169
+ activation_num_bits=8,
170
+ weight_num_bits=4,
171
+ )
172
+ updated_recipe = self._quantizer.get_quantization_recipe()
173
+ self.assertLen(updated_recipe, 2)
174
+
175
+ added_config = updated_recipe[-1]
176
+ self.assertEqual(added_config['regex'], scope_regex)
177
+ self.assertEqual(
178
+ added_config['op_config']['compute_precision'],
179
+ qtyping.ComputePrecision.INTEGER,
180
+ )
181
+ self.assertFalse(added_config['op_config']['explicit_dequantize'])
182
+ self.assertEqual(
183
+ added_config['op_config']['activation_tensor_config']['num_bits'], 8
184
+ )
185
+ self.assertEqual(
186
+ added_config['op_config']['weight_tensor_config']['num_bits'], 4
187
+ )
188
+
95
189
  def test_load_quantization_recipe_succeeds(self):
96
190
  qt = quantizer.Quantizer(self._test_model_path, None)
97
191
  qt.load_quantization_recipe(self._test_recipe_path)
@@ -118,7 +212,7 @@ class QuantizerTest(parameterized.TestCase):
118
212
  # Calibrate with empty state.
119
213
  calib_data = _get_calibration_data()
120
214
  calibration_result = self._quantizer.calibrate(calib_data)
121
- self.assertLen(calibration_result, 13)
215
+ self.assertLen(calibration_result, 7)
122
216
 
123
217
  @parameterized.parameters(
124
218
  'recipes/default_a8w8_recipe.json',
@@ -133,7 +227,7 @@ class QuantizerTest(parameterized.TestCase):
133
227
  updated_calibration_result = self._quantizer.calibrate(
134
228
  calib_data, previous_calibration_result=calibration_result
135
229
  )
136
- self.assertLen(updated_calibration_result, 13)
230
+ self.assertLen(updated_calibration_result, 7)
137
231
  self.assertNotEqual(
138
232
  calibration_result['StatefulPartitionedCall:0'],
139
233
  updated_calibration_result['StatefulPartitionedCall:0'],
@@ -215,6 +309,44 @@ class QuantizerTest(parameterized.TestCase):
215
309
  saved_recipe = json.load(json_file)
216
310
  self.assertEqual(saved_recipe, self._test_recipe)
217
311
 
312
+ def test_saved_legacy_recipe_lacks_block_size(self):
313
+ model_name = 'test_model'
314
+ legacy_recipe_path = os.path.join(
315
+ TEST_DATA_PREFIX_PATH,
316
+ 'recipes/dynamic_legacy_wi8_afp32_recipe.json',
317
+ )
318
+ self._quantizer.load_quantization_recipe(legacy_recipe_path)
319
+ result = self._quantizer.quantize()
320
+ result.save(self._tmp_save_path, model_name)
321
+ saved_recipe_path = os.path.join(
322
+ self._tmp_save_path, model_name + '_recipe.json'
323
+ )
324
+ with open(saved_recipe_path) as json_file:
325
+ saved_recipe = json.load(json_file)
326
+ with open(legacy_recipe_path) as json_file:
327
+ legacy_recipe = json.load(json_file)
328
+
329
+ self.assertNotEqual(saved_recipe, legacy_recipe)
330
+
331
+ # Verify that the default test recipe contains 'block_size'.
332
+ has_block_size = False
333
+ for config in legacy_recipe:
334
+ op_config = config.get('op_config')
335
+ if op_config:
336
+ weight_config = op_config.get('weight_tensor_config')
337
+ if weight_config and 'block_size' in weight_config:
338
+ has_block_size = True
339
+ break
340
+ self.assertTrue(has_block_size)
341
+
342
+ # Verify that the saved recipe does not have 'block_size'.
343
+ for config in saved_recipe:
344
+ op_config = config.get('op_config')
345
+ if op_config:
346
+ weight_config = op_config.get('weight_tensor_config')
347
+ if weight_config:
348
+ self.assertNotIn('block_size', weight_config)
349
+
218
350
  def test_save_no_quantize_raise_error(self):
219
351
  error_message = 'No quantized model to save.'
220
352
  with self.assertRaisesWithPredicateMatch(
@@ -243,6 +375,34 @@ class QuantizerTest(parameterized.TestCase):
243
375
  'sequential/dense_1/MatMul', validation_result.intermediate_tensors
244
376
  )
245
377
 
378
+ def test_validate_output_tensors_only_succeeds(self):
379
+ self._quantizer.quantize()
380
+ validation_result = self._quantizer.validate(
381
+ validate_output_tensors_only=True
382
+ )
383
+ validation_result = validation_result.get_signature_comparison_result()
384
+ self.assertIsNotNone(validation_result)
385
+ self.assertEmpty(validation_result.input_tensors)
386
+ self.assertEmpty(validation_result.constant_tensors)
387
+ self.assertEmpty(validation_result.intermediate_tensors)
388
+ self.assertNotEmpty(validation_result.output_tensors)
389
+ self.assertIn('StatefulPartitionedCall:0', validation_result.output_tensors)
390
+
391
+ def test_validate_with_quantized_model_arg_succeeds(self):
392
+ self._quantizer.quantize()
393
+ quantized_model = self._quantizer._result.quantized_model
394
+ self.assertIsNotNone(quantized_model)
395
+
396
+ new_quantizer = quantizer.Quantizer(
397
+ self._test_model_path, previous_quantized_model=quantized_model
398
+ )
399
+ validation_result = new_quantizer.validate()
400
+ validation_result = validation_result.get_signature_comparison_result()
401
+ self.assertIsNotNone(validation_result)
402
+ self.assertIn(
403
+ 'sequential/dense_1/MatMul', validation_result.intermediate_tensors
404
+ )
405
+
246
406
  def test_load_custom_policies_succeeds(self):
247
407
 
248
408
  test_op_config = qtyping.OpQuantizationConfig(
@@ -284,6 +444,33 @@ class QuantizerTest(parameterized.TestCase):
284
444
  op_config=test_op_config,
285
445
  )
286
446
 
447
+ def test_two_pass_quantization_with_conv_and_fc_succeeds(self):
448
+ float_model_path = self._test_model_path
449
+
450
+ drq_recipe_path = os.path.join(
451
+ TEST_DATA_PREFIX_PATH, 'recipes/dynamic_wi8_afp32_hadamard_recipe.json'
452
+ )
453
+ drq_quantizer = quantizer.Quantizer(float_model_path)
454
+ drq_quantizer.load_quantization_recipe(drq_recipe_path)
455
+ drq_result = drq_quantizer.quantize()
456
+ drq_model_path = os.path.join(self._tmp_save_path, 'drq_model.tflite')
457
+ drq_result.export_model(drq_model_path)
458
+
459
+ srq_recipe_path = os.path.join(
460
+ TEST_DATA_PREFIX_PATH, 'recipes/default_a8w8_recipe.json'
461
+ )
462
+ srq_quantizer = quantizer.Quantizer(drq_model_path)
463
+ srq_quantizer.load_quantization_recipe(srq_recipe_path)
464
+ representative_dataset = (
465
+ tfl_interpreter_utils.create_random_normal_input_data(
466
+ drq_model_path, num_samples=1
467
+ )
468
+ )
469
+ calibration_result = srq_quantizer.calibrate(representative_dataset)
470
+ srq_result = srq_quantizer.quantize(calibration_result)
471
+ srq_model_path = os.path.join(self._tmp_save_path, 'srq_model.tflite')
472
+ srq_result.export_model(srq_model_path)
473
+
287
474
 
288
475
  class QuantizerBytearrayInputs(googletest.TestCase):
289
476
 
@@ -426,14 +613,12 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
426
613
  'symmetric': False,
427
614
  'granularity': 'TENSORWISE',
428
615
  'dtype': 'INT',
429
- 'block_size': 0,
430
616
  },
431
617
  'weight_tensor_config': {
432
618
  'num_bits': 8,
433
619
  'symmetric': True,
434
620
  'granularity': 'CHANNELWISE',
435
621
  'dtype': 'INT',
436
- 'block_size': 0,
437
622
  },
438
623
  'compute_precision': 'INTEGER',
439
624
  'explicit_dequantize': False,
@@ -454,8 +639,7 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
454
639
 
455
640
  # Quantize and expect an error about missing signature in calibration data.
456
641
  error_message = (
457
- 'Missing QSVs (min/max) for tensor multiply_x:0 in Signature'
458
- " 'multiply'."
642
+ 'MUL(index: 0) not found in tensor_name_to_qsv'
459
643
  )
460
644
  with self.assertRaisesWithPredicateMatch(
461
645
  ValueError, lambda err: error_message in str(err)
@@ -477,21 +661,21 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
477
661
  'signature_1': [{
478
662
  'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
479
663
  'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
480
- 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
481
- np.int32
664
+ 'positions': (
665
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
482
666
  ),
483
- 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
484
- np.int32
667
+ 'tokens': (
668
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
485
669
  ),
486
670
  }],
487
671
  'signature_2': [{
488
672
  'cache_0': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
489
673
  'cache_1': _RNG.random(size=(1, 100, 4, 4), dtype=np.float32),
490
- 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
491
- np.int32
674
+ 'positions': (
675
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
492
676
  ),
493
- 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype(
494
- np.int32
677
+ 'tokens': (
678
+ _RNG.integers(low=0, high=10, size=(1, 100)).astype(np.int32)
495
679
  ),
496
680
  }],
497
681
  }
@@ -508,8 +692,8 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
508
692
  )
509
693
 
510
694
  self._quantizer.update_quantization_recipe(
511
- regex='StatefulPartitionedCall',
512
- operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
695
+ regex='.*',
696
+ operation_name=qtyping.TFLOperationName.OUTPUT,
513
697
  algorithm_key=_AlgorithmName.NO_QUANTIZE,
514
698
  )
515
699
 
@@ -521,6 +705,90 @@ class QuantizerToyGemma2Test(parameterized.TestCase):
521
705
  self._quantizer.quantize(calib_result)
522
706
  self.assertIsNotNone(self._quantizer._result.quantized_model)
523
707
 
708
+ def test_toy_gemma2_update_signature_defs_succeeds(self):
709
+
710
+ self.assertTrue(
711
+ _is_all_signature_defs_outputs_float(
712
+ open(self._test_model_path, 'rb').read()
713
+ )
714
+ )
715
+ calib_result = self._quantizer.calibrate(
716
+ self._toy_gemma2_calibration_dataset
717
+ )
718
+ self.assertIsNotNone(calib_result)
719
+ self._quantizer.quantize(calib_result)
720
+ self.assertIsNotNone(self._quantizer._result.quantized_model)
721
+ self.assertTrue(
722
+ _is_all_signature_defs_outputs_float(
723
+ self._quantizer._result.quantized_model
724
+ )
725
+ )
726
+
727
+
728
+ class QuantizerFullyConnectedTest(parameterized.TestCase):
729
+
730
+ def setUp(self):
731
+ super().setUp()
732
+ self._tmp_save_path = self.create_tempdir().full_path
733
+ self._test_model_path = os.path.join(
734
+ TEST_DATA_PREFIX_PATH,
735
+ 'tests/models/single_fc.tflite',
736
+ )
737
+
738
+ self._test_recipe_path = os.path.join(
739
+ TEST_DATA_PREFIX_PATH,
740
+ 'recipes/default_a8w8_recipe.json',
741
+ )
742
+ with open(self._test_recipe_path) as json_file:
743
+ self._test_recipe = json.load(json_file)
744
+
745
+ self._quantizer = quantizer.Quantizer(
746
+ self._test_model_path, self._test_recipe_path
747
+ )
748
+
749
+ self._quantizer.update_quantization_recipe(
750
+ regex='.*',
751
+ operation_name=qtyping.TFLOperationName.INPUT,
752
+ algorithm_key=_AlgorithmName.NO_QUANTIZE,
753
+ )
754
+ self._quantizer.update_quantization_recipe(
755
+ regex='.*',
756
+ operation_name=qtyping.TFLOperationName.OUTPUT,
757
+ algorithm_key=_AlgorithmName.NO_QUANTIZE,
758
+ )
759
+
760
+ def test_fully_connected_quantization_succeeds(self):
761
+ calib_result = self._quantizer.calibrate(
762
+ tfl_interpreter_utils.create_random_normal_input_data(
763
+ self._test_model_path, num_samples=4
764
+ )
765
+ )
766
+ self.assertIsNotNone(calib_result)
767
+ self._quantizer.quantize(calib_result)
768
+ self.assertIsNotNone(self._quantizer._result.quantized_model)
769
+
770
+ def test_fully_connected_quantization_update_signature_defs_succeeds(self):
771
+
772
+ model_content = open(self._test_model_path, 'rb').read()
773
+ self.assertTrue(_is_all_signature_defs_inputs_float(model_content))
774
+ self.assertTrue(_is_all_signature_defs_outputs_float(model_content))
775
+
776
+ calib_result = self._quantizer.calibrate(
777
+ tfl_interpreter_utils.create_random_normal_input_data(
778
+ self._test_model_path, num_samples=4
779
+ )
780
+ )
781
+ self.assertIsNotNone(calib_result)
782
+ quant_result = self._quantizer.quantize(calib_result)
783
+ self.assertIsNotNone(quant_result.quantized_model)
784
+
785
+ self.assertTrue(
786
+ _is_all_signature_defs_inputs_float(quant_result.quantized_model)
787
+ )
788
+ self.assertTrue(
789
+ _is_all_signature_defs_outputs_float(quant_result.quantized_model)
790
+ )
791
+
524
792
 
525
793
  if __name__ == '__main__':
526
794
  googletest.main()
@@ -15,51 +15,163 @@
15
15
 
16
16
  """Quantization recipe module."""
17
17
 
18
+ from ai_edge_quantizer import algorithm_manager
19
+ from ai_edge_quantizer import qtyping
20
+ from ai_edge_quantizer import recipe_manager
18
21
 
19
- def dynamic_wi8_afp32():
20
- """Returns a dynamic quantization recipe with int8 weights and float32 activation."""
21
- return [
22
- dict({
23
- 'regex': '.*',
24
- 'operation': '*',
25
- 'algorithm_key': 'min_max_uniform_quantize',
26
- 'op_config': {
27
- 'weight_tensor_config': {
28
- 'num_bits': 8,
29
- 'symmetric': True,
30
- 'granularity': 'CHANNELWISE',
31
- 'dtype': 'INT',
32
- 'block_size': 0,
33
- },
34
- 'compute_precision': 'INTEGER',
35
- 'explicit_dequantize': False,
36
- 'skip_checks': False,
37
- },
38
- })
39
- ]
22
+ AlgorithmName = algorithm_manager.AlgorithmName
40
23
 
41
24
 
42
- def dynamic_wi4_afp32():
43
- """Returns a dynamic quantization recipe with int4 weights and float32 activation."""
44
- return [
45
- dict({
46
- 'regex': '.*',
47
- 'operation': '*',
48
- 'algorithm_key': 'min_max_uniform_quantize',
49
- 'op_config': {
50
- 'weight_tensor_config': {
51
- 'num_bits': 4,
52
- 'symmetric': True,
53
- 'granularity': 'CHANNELWISE',
54
- 'dtype': 'INT',
55
- 'block_size': 0,
56
- },
57
- 'compute_precision': 'INTEGER',
58
- 'explicit_dequantize': False,
59
- 'skip_checks': False,
60
- },
61
- })
62
- ]
25
+ def dynamic_wi8_afp32(
26
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
27
+ ):
28
+ """Returns a dynamic quantization recipe with int8 weights and float32 activation.
29
+
30
+ All supported ops will be quantized with int8 weights and float32 activations,
31
+ which will be dynamically quantized to int8 during inference to enable int8
32
+ compute. The model quality may suffer due to the on-the-fly quantization. If
33
+ quality is a concern, consider using weight-only quantization.
34
+
35
+ Args:
36
+ algorithm_key: The algorithm to use for quantization.
37
+
38
+ Returns:
39
+ A dynamic quantization recipe.
40
+ """
41
+ rp_manager = recipe_manager.RecipeManager()
42
+ rp_manager.add_dynamic_config(
43
+ regex='.*',
44
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
45
+ num_bits=8,
46
+ algorithm_key=algorithm_key,
47
+ )
48
+ return rp_manager.get_quantization_recipe()
49
+
50
+
51
+ def dynamic_wi4_afp32(
52
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
53
+ ):
54
+ """Returns a dynamic quantization recipe with int4 weights and float32 activation.
55
+
56
+ All supported ops will be quantized with int4 weights and float32 activations,
57
+ which will be dynamically quantized to int4 during inference to enable int4
58
+ compute.
59
+
60
+ Args:
61
+ algorithm_key: The algorithm to use for quantization.
62
+
63
+ Returns:
64
+ A dynamic quantization recipe.
65
+ """
66
+ rp_manager = recipe_manager.RecipeManager()
67
+ rp_manager.add_dynamic_config(
68
+ regex='.*',
69
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
70
+ num_bits=4,
71
+ algorithm_key=algorithm_key,
72
+ )
73
+ return rp_manager.get_quantization_recipe()
74
+
75
+
76
+ def weight_only_wi8_afp32(
77
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
78
+ ):
79
+ """Returns a weight-only quantization recipe with int8 weights and float32 activation.
80
+
81
+ All supported ops will be quantized with int8 weights and float32 activations.
82
+ The weights will be explicitly dequantized before being fed into the op to
83
+ enable float compute thus retain model quality. If latency is a concern,
84
+ consider using dynamic range quantization.
85
+
86
+ Args:
87
+ algorithm_key: The algorithm to use for quantization.
88
+
89
+ Returns:
90
+ A weight-only quantization recipe.
91
+ """
92
+ rp_manager = recipe_manager.RecipeManager()
93
+ rp_manager.add_weight_only_config(
94
+ regex='.*',
95
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
96
+ num_bits=8,
97
+ algorithm_key=algorithm_key,
98
+ )
99
+ return rp_manager.get_quantization_recipe()
100
+
101
+
102
+ def weight_only_wi4_afp32(
103
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
104
+ ):
105
+ """Returns a weight-only quantization recipe with int4 weights and float32 activation.
106
+
107
+ All supported ops will be quantized with int4 weights and float32 activations.
108
+ The weights will be explicitly dequantized before being fed into the op to
109
+ enable float compute thus retain model quality.
110
+
111
+ Args:
112
+ algorithm_key: The algorithm to use for quantization.
113
+
114
+ Returns:
115
+ A weight-only quantization recipe.
116
+ """
117
+ rp_manager = recipe_manager.RecipeManager()
118
+ rp_manager.add_weight_only_config(
119
+ regex='.*',
120
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
121
+ num_bits=4,
122
+ algorithm_key=algorithm_key,
123
+ )
124
+ return rp_manager.get_quantization_recipe()
125
+
126
+
127
+ def static_wi8_ai8(
128
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
129
+ ):
130
+ """Returns a static quantization recipe with int8 weights and int8 activations.
131
+
132
+ All supported ops will be quantized with int8 weights and int8 activations.
133
+ Calibration is needed to use this recipe.
134
+
135
+ Args:
136
+ algorithm_key: The algorithm to use for quantization.
137
+
138
+ Returns:
139
+ A static quantization recipe.
140
+ """
141
+ rp_manager = recipe_manager.RecipeManager()
142
+ rp_manager.add_static_config(
143
+ regex='.*',
144
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
145
+ activation_num_bits=8,
146
+ weight_num_bits=8,
147
+ algorithm_key=algorithm_key,
148
+ )
149
+ return rp_manager.get_quantization_recipe()
150
+
151
+
152
+ def static_wi8_ai16(
153
+ algorithm_key: AlgorithmName = AlgorithmName.MIN_MAX_UNIFORM_QUANT,
154
+ ):
155
+ """Returns a static quantization recipe with int8 weights and int16 activations.
156
+
157
+ All supported ops will be quantized with int8 weights and int16 activations.
158
+ Calibration is needed to use this recipe.
159
+
160
+ Args:
161
+ algorithm_key: The algorithm to use for quantization.
162
+
163
+ Returns:
164
+ A static quantization recipe.
165
+ """
166
+ rp_manager = recipe_manager.RecipeManager()
167
+ rp_manager.add_static_config(
168
+ regex='.*',
169
+ operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
170
+ activation_num_bits=16,
171
+ weight_num_bits=8,
172
+ algorithm_key=algorithm_key,
173
+ )
174
+ return rp_manager.get_quantization_recipe()
63
175
 
64
176
 
65
177
  def dynamic_legacy_wi8_afp32():