ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Tests for params_generator."""
17
-
18
16
  from collections.abc import Generator
19
17
  import os
20
18
  from typing import Any
@@ -38,8 +36,11 @@ _TensorQuantConfig = qtyping.TensorQuantizationConfig
38
36
  _QuantTransformation = qtyping.QuantTransformation
39
37
  _AlgorithmName = recipe_manager.AlgorithmName
40
38
  _QuantGranularity = qtyping.QuantGranularity
39
+ _QTransf = qtyping.QuantTransformation
40
+
41
41
 
42
42
  TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
43
+ _PARAMS_8BIT = qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0]))
43
44
 
44
45
 
45
46
  def _single_fc_model_representative_dataset_gen(num_samples=5):
@@ -64,6 +65,20 @@ def _get_calibration_data(
64
65
  return calibration_data
65
66
 
66
67
 
68
+ def _get_test_consumers(
69
+ transformations_per_consumer: list[list[_QTransf]],
70
+ params_per_consumer: list[qtyping.OpToTensorParams],
71
+ ) -> list[qtyping.OpToTensorParams]:
72
+ return [
73
+ qtyping.OpToTensorParams(
74
+ subgraph_op_id=i + 1,
75
+ transformations=transformations_per_consumer[i],
76
+ parameters=params_per_consumer[i],
77
+ )
78
+ for i in range(len(transformations_per_consumer))
79
+ ]
80
+
81
+
67
82
  class ParamsGeneratorTest(parameterized.TestCase):
68
83
 
69
84
  def setUp(self):
@@ -570,9 +585,27 @@ class ParamsGeneratorTest(parameterized.TestCase):
570
585
  )
571
586
  self.assertLen(quant_params, 6)
572
587
 
573
- @parameterized.parameters('no_quant', 'execution_mode', 'num_bits')
574
- def test_generate_params_buffer_sharing_graphs_fails(
575
- self, the_other_fc_difference
588
+ @parameterized.named_parameters(
589
+ dict(
590
+ testcase_name='different_quant_config_fc2_no_quant',
591
+ fc_2_num_bits=None,
592
+ expected_tensor_with_buffer_duplication='BatchMatMulV3',
593
+ ),
594
+ dict(
595
+ testcase_name='different_quant_config_fc2_int4',
596
+ fc_2_num_bits=4,
597
+ expected_tensor_with_buffer_duplication='BatchMatMulV3',
598
+ ),
599
+ dict(
600
+ testcase_name='same_quant_config',
601
+ fc_2_num_bits=8,
602
+ expected_tensor_with_buffer_duplication=None,
603
+ ),
604
+ )
605
+ def test_generate_params_marks_correct_buffers_for_duplication_when_distinct_tensors_share_constant_buffer(
606
+ self,
607
+ fc_2_num_bits,
608
+ expected_tensor_with_buffer_duplication,
576
609
  ):
577
610
  model_path = os.path.join(
578
611
  TEST_DATA_PREFIX_PATH, 'tests/models/weight_sharing_fcs.tflite'
@@ -580,33 +613,204 @@ class ParamsGeneratorTest(parameterized.TestCase):
580
613
  # Setup the quantization config for the first FC.
581
614
  self._recipe_manager.add_quantization_config(
582
615
  regex='PartitionedCall:0',
583
- operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
616
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
584
617
  op_config=qtyping.OpQuantizationConfig(
585
- weight_tensor_config=_TensorQuantConfig(num_bits=8),
618
+ weight_tensor_config=_TensorQuantConfig(
619
+ num_bits=8, granularity=qtyping.QuantGranularity.CHANNELWISE
620
+ ),
586
621
  compute_precision=_ComputePrecision.INTEGER,
587
622
  ),
588
623
  )
589
624
  # Setup the quantization config for the second FC (weight shared with the
590
625
  # first FC).
591
- if the_other_fc_difference == 'no_quant':
592
- pass
593
- elif the_other_fc_difference == 'num_bits':
626
+ if fc_2_num_bits is not None:
594
627
  self._recipe_manager.add_quantization_config(
595
628
  regex='PartitionedCall_1:0',
596
- operation_name=qtyping.TFLOperationName.ALL_SUPPORTED,
629
+ operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
597
630
  op_config=qtyping.OpQuantizationConfig(
598
- weight_tensor_config=_TensorQuantConfig(num_bits=4),
631
+ weight_tensor_config=_TensorQuantConfig(
632
+ num_bits=fc_2_num_bits,
633
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
634
+ ),
599
635
  compute_precision=_ComputePrecision.INTEGER,
600
636
  ),
601
637
  )
602
638
  pg = params_generator.ParamsGenerator(model_path)
603
- error_message = 'do not have the same quantization parameters'
604
- with self.assertRaisesWithPredicateMatch(
605
- RuntimeError, lambda err: error_message in str(err)
606
- ):
607
- pg.generate_quantization_parameters(
608
- self._recipe_manager,
639
+ quant_params = pg.generate_quantization_parameters(
640
+ self._recipe_manager,
641
+ )
642
+ self.assertLen(quant_params, 6)
643
+
644
+ # Check that the expected tensor has buffer duplication transformation as
645
+ # the first one to be applied. And no other tensor has buffer duplication
646
+ # transformation at all.
647
+ for tensor_name in quant_params:
648
+ if tensor_name == expected_tensor_with_buffer_duplication:
649
+ self.assertIsNotNone(quant_params[tensor_name].consumers)
650
+ for consumer in quant_params[tensor_name].consumers:
651
+ self.assertNotEmpty(consumer.transformations)
652
+ self.assertEqual(
653
+ consumer.transformations[0],
654
+ _QTransf.DUPLICATE_BUFFER,
655
+ )
656
+ self.assertNotIn(
657
+ _QTransf.DUPLICATE_BUFFER, consumer.transformations[1:]
658
+ )
659
+ elif quant_params[tensor_name].consumers is not None:
660
+ for consumer in quant_params[tensor_name].consumers:
661
+ self.assertNotIn(_QTransf.DUPLICATE_BUFFER, consumer.transformations)
662
+
663
+ def _get_fc_recipe_entry(self, regex: str, num_bits: int):
664
+ return {
665
+ 'regex': regex,
666
+ 'operation': 'FULLY_CONNECTED',
667
+ 'algorithm_key': 'min_max_uniform_quantize',
668
+ 'op_config': {
669
+ 'weight_tensor_config': {
670
+ 'num_bits': num_bits,
671
+ 'symmetric': True,
672
+ 'granularity': 'CHANNELWISE',
673
+ 'dtype': 'INT',
674
+ 'block_size': 0,
675
+ },
676
+ 'compute_precision': 'INTEGER',
677
+ 'explicit_dequantize': False,
678
+ 'skip_checks': False,
679
+ 'min_weight_elements': 0,
680
+ },
681
+ }
682
+
683
+ @parameterized.named_parameters(
684
+ dict(
685
+ testcase_name='fc1_quant_fc2_no_quant',
686
+ fc1_num_bits=8,
687
+ fc2_num_bits=None,
688
+ ),
689
+ dict(
690
+ testcase_name='fc1_no_quant_fc2_quant',
691
+ fc1_num_bits=None,
692
+ fc2_num_bits=8,
693
+ ),
694
+ dict(
695
+ testcase_name='fc1_quant_fc2_quant_different_params',
696
+ fc1_num_bits=8,
697
+ fc2_num_bits=4,
698
+ ),
699
+ )
700
+ def test_generate_params_marks_correct_buffers_tensors_for_duplication(
701
+ self,
702
+ fc1_num_bits,
703
+ fc2_num_bits,
704
+ ):
705
+ model_path = os.path.join(
706
+ TEST_DATA_PREFIX_PATH,
707
+ 'tests/models/constant_tensor_and_buffer_only_sharing_weight_fcs.tflite',
708
+ )
709
+ sig1_fc1_regex = 'BatchMatMulV3;'
710
+ sig1_fc2_regex = 'PartitionedCall:0;'
711
+ recipe = []
712
+ if fc1_num_bits is not None:
713
+ recipe.append(self._get_fc_recipe_entry(sig1_fc1_regex, fc1_num_bits))
714
+ if fc2_num_bits is not None:
715
+ recipe.append(self._get_fc_recipe_entry(sig1_fc2_regex, fc2_num_bits))
716
+ self._recipe_manager.load_quantization_recipe(recipe)
717
+ pg = params_generator.ParamsGenerator(model_path)
718
+ quant_params = pg.generate_quantization_parameters(self._recipe_manager)
719
+
720
+ expected_tensor = 'arith.constant'
721
+ consumers = quant_params[expected_tensor].consumers
722
+ self.assertLen(consumers, 2)
723
+
724
+ # Check FC1 transformations.
725
+ if fc1_num_bits is None:
726
+ fc1_quant_transformation = _QTransf.NO_QUANTIZE
727
+ else:
728
+ fc1_quant_transformation = _QTransf.QUANTIZE_TENSOR
729
+ self.assertEqual(
730
+ consumers[0].transformations,
731
+ [
732
+ _QTransf.DUPLICATE_TENSOR,
733
+ _QTransf.DUPLICATE_BUFFER,
734
+ fc1_quant_transformation,
735
+ ],
736
+ )
737
+ # Check FC2 transformations.
738
+ if fc2_num_bits is None:
739
+ fc2_quant_transformation = _QTransf.NO_QUANTIZE
740
+ else:
741
+ fc2_quant_transformation = _QTransf.QUANTIZE_TENSOR
742
+ self.assertEqual(
743
+ consumers[1].transformations,
744
+ [
745
+ _QTransf.DUPLICATE_TENSOR,
746
+ _QTransf.DUPLICATE_BUFFER,
747
+ fc2_quant_transformation,
748
+ ],
749
+ )
750
+ # Check that no other tensor has tensor or buffer duplication
751
+ # transformations.
752
+ for tensor_name, params in quant_params.items():
753
+ if tensor_name == expected_tensor:
754
+ continue
755
+ for consumer in params.consumers:
756
+ self.assertNotIn(_QTransf.DUPLICATE_TENSOR, consumer.transformations)
757
+ self.assertNotIn(_QTransf.DUPLICATE_BUFFER, consumer.transformations)
758
+
759
+ def test_generate_params_returns_valid_results_when_multiple_tensor_duplication_for_one_buffer(
760
+ self,
761
+ ):
762
+ model_path = os.path.join(
763
+ TEST_DATA_PREFIX_PATH,
764
+ 'tests/models/constant_tensor_and_buffer_only_sharing_weight_fcs.tflite',
765
+ )
766
+ sig1_fc1_regex = 'BatchMatMulV3;'
767
+ sig1_fc2_regex = 'PartitionedCall:0;'
768
+ sig2_fc1_regex = 'BatchMatMulV31;'
769
+ sig2_fc2_regex = 'PartitionedCall_1:0;'
770
+ recipe = [
771
+ self._get_fc_recipe_entry(sig1_fc1_regex, num_bits=8),
772
+ self._get_fc_recipe_entry(sig1_fc2_regex, num_bits=4),
773
+ self._get_fc_recipe_entry(sig2_fc1_regex, num_bits=8),
774
+ self._get_fc_recipe_entry(sig2_fc2_regex, num_bits=4),
775
+ ]
776
+ self._recipe_manager.load_quantization_recipe(recipe)
777
+ pg = params_generator.ParamsGenerator(model_path)
778
+ quant_params = pg.generate_quantization_parameters(self._recipe_manager)
779
+ # Check transformations for sig1.
780
+ sig1_expected_tensor = 'arith.constant'
781
+ sig1_consumers = quant_params[sig1_expected_tensor].consumers
782
+ self.assertLen(sig1_consumers, 2)
783
+ sig1_expected_transformations = [
784
+ _QTransf.DUPLICATE_TENSOR,
785
+ _QTransf.DUPLICATE_BUFFER,
786
+ _QTransf.QUANTIZE_TENSOR,
787
+ ]
788
+ for sig1_consumer in sig1_consumers:
789
+ self.assertEqual(
790
+ sig1_consumer.transformations,
791
+ sig1_expected_transformations,
792
+ )
793
+ # Check transformations for sig2.
794
+ sig2_expected_tensor = 'arith.constant1'
795
+ sig2_consumers = quant_params[sig2_expected_tensor].consumers
796
+ self.assertLen(sig2_consumers, 2)
797
+ sig2_expected_transformations = [
798
+ _QTransf.DUPLICATE_TENSOR,
799
+ _QTransf.QUANTIZE_TENSOR,
800
+ ]
801
+ for sig2_consumer in sig2_consumers:
802
+ self.assertEqual(
803
+ sig2_consumer.transformations,
804
+ sig2_expected_transformations,
609
805
  )
806
+ # Check that no other tensor has tensor or buffer duplication
807
+ # transformations.
808
+ for tensor_name, params in quant_params.items():
809
+ if tensor_name in [sig1_expected_tensor, sig2_expected_tensor]:
810
+ continue
811
+ for consumer in params.consumers:
812
+ self.assertNotIn(_QTransf.DUPLICATE_TENSOR, consumer.transformations)
813
+ self.assertNotIn(_QTransf.DUPLICATE_BUFFER, consumer.transformations)
610
814
 
611
815
  @parameterized.named_parameters(
612
816
  dict(
@@ -615,279 +819,185 @@ class ParamsGeneratorTest(parameterized.TestCase):
615
819
  tensor_name='tfl.quantize',
616
820
  producer=qtyping.OpToTensorParams(
617
821
  subgraph_op_id=0,
618
- transformations=[qtyping.QuantTransformation.ADD_DEQUANTIZE],
619
- parameters=qtyping.UniformQuantParams(
620
- 8, None, np.array([1]), np.array([0])
621
- ),
822
+ transformations=[_QTransf.ADD_DEQUANTIZE],
823
+ parameters=_PARAMS_8BIT,
824
+ ),
825
+ consumers=_get_test_consumers(
826
+ transformations_per_consumer=[
827
+ [_QTransf.ADD_QUANTIZE],
828
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
829
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
830
+ [_QTransf.NO_QUANTIZE],
831
+ ],
832
+ params_per_consumer=[_PARAMS_8BIT] * 4,
622
833
  ),
623
- consumers=[
624
- qtyping.OpToTensorParams(
625
- subgraph_op_id=1,
626
- transformations=[
627
- qtyping.QuantTransformation.ADD_QUANTIZE
628
- ],
629
- parameters=qtyping.UniformQuantParams(
630
- 8, None, np.array([1]), np.array([0])
631
- ),
632
- ),
633
- qtyping.OpToTensorParams(
634
- subgraph_op_id=2,
635
- transformations=[
636
- qtyping.QuantTransformation.ADD_QUANTIZE,
637
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
638
- ],
639
- parameters=qtyping.UniformQuantParams(
640
- 8, None, np.array([1]), np.array([0])
641
- ),
642
- ),
643
- qtyping.OpToTensorParams(
644
- subgraph_op_id=3,
645
- transformations=[
646
- qtyping.QuantTransformation.ADD_QUANTIZE,
647
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
648
- ],
649
- parameters=qtyping.UniformQuantParams(
650
- 8, None, np.array([1]), np.array([0])
651
- ),
652
- ),
653
- qtyping.OpToTensorParams(
654
- subgraph_op_id=4,
655
- transformations=[
656
- qtyping.QuantTransformation.NO_QUANTIZE,
657
- ],
658
- parameters=qtyping.UniformQuantParams(
659
- 8, None, np.array([1]), np.array([0])
660
- ),
661
- ),
662
- ],
663
834
  ),
664
835
  param2=qtyping.TensorTransformationParams(
665
- 'tfl.other_quantize',
666
- qtyping.OpToTensorParams(
836
+ tensor_name='tfl.other_quantize',
837
+ producer=qtyping.OpToTensorParams(
667
838
  subgraph_op_id=0,
668
- transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
669
- parameters=qtyping.UniformQuantParams(
670
- 8, None, np.array([1]), np.array([0])
671
- ),
839
+ transformations=[_QTransf.NO_QUANTIZE],
840
+ parameters=_PARAMS_8BIT,
841
+ ),
842
+ consumers=_get_test_consumers(
843
+ transformations_per_consumer=[
844
+ [_QTransf.ADD_QUANTIZE],
845
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
846
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
847
+ ],
848
+ params_per_consumer=[_PARAMS_8BIT] * 4,
672
849
  ),
673
- [
674
- qtyping.OpToTensorParams(
675
- subgraph_op_id=1,
676
- transformations=[
677
- qtyping.QuantTransformation.ADD_QUANTIZE
678
- ],
679
- parameters=qtyping.UniformQuantParams(
680
- 8, None, np.array([1]), np.array([0])
681
- ),
682
- ),
683
- qtyping.OpToTensorParams(
684
- subgraph_op_id=2,
685
- transformations=[
686
- qtyping.QuantTransformation.ADD_QUANTIZE,
687
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
688
- ],
689
- parameters=qtyping.UniformQuantParams(
690
- 8, None, np.array([1]), np.array([0])
691
- ),
692
- ),
693
- qtyping.OpToTensorParams(
694
- subgraph_op_id=3,
695
- transformations=[
696
- qtyping.QuantTransformation.ADD_QUANTIZE,
697
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
698
- ],
699
- parameters=qtyping.UniformQuantParams(
700
- 8, None, np.array([1]), np.array([0])
701
- ),
702
- ),
703
- ],
704
850
  ),
705
851
  expected=False,
706
852
  ),
707
853
  dict(
708
- testcase_name='param2_consumer_incompatible',
854
+ testcase_name='compatible',
709
855
  param1=qtyping.TensorTransformationParams(
710
856
  tensor_name='tfl.quantize',
711
- producer=qtyping.OpToTensorParams(
712
- subgraph_op_id=0,
713
- transformations=[qtyping.QuantTransformation.ADD_QUANTIZE],
714
- parameters=qtyping.UniformQuantParams(
715
- 8, None, np.array([1]), np.array([0])
716
- ),
857
+ producer=None,
858
+ consumers=_get_test_consumers(
859
+ transformations_per_consumer=[
860
+ [_QTransf.ADD_QUANTIZE],
861
+ [_QTransf.NO_QUANTIZE, _QTransf.ADD_QUANTIZE],
862
+ [_QTransf.NO_QUANTIZE],
863
+ ],
864
+ params_per_consumer=[_PARAMS_8BIT] * 4,
717
865
  ),
718
- consumers=[
719
- qtyping.OpToTensorParams(
720
- subgraph_op_id=1,
721
- transformations=[
722
- qtyping.QuantTransformation.ADD_QUANTIZE
723
- ],
724
- parameters=qtyping.UniformQuantParams(
725
- 8, None, np.array([1]), np.array([0])
726
- ),
727
- ),
728
- qtyping.OpToTensorParams(
729
- subgraph_op_id=2,
730
- transformations=[
731
- qtyping.QuantTransformation.ADD_QUANTIZE,
732
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
733
- ],
734
- parameters=qtyping.UniformQuantParams(
735
- 8, None, np.array([1]), np.array([0])
736
- ),
737
- ),
738
- qtyping.OpToTensorParams(
739
- subgraph_op_id=3,
740
- transformations=[
741
- qtyping.QuantTransformation.ADD_QUANTIZE,
742
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
743
- ],
744
- parameters=qtyping.UniformQuantParams(
745
- 8, None, np.array([1]), np.array([0])
746
- ),
747
- ),
748
- ],
749
866
  ),
750
867
  param2=qtyping.TensorTransformationParams(
751
- 'tfl.other_quantize',
752
- qtyping.OpToTensorParams(
753
- subgraph_op_id=0,
754
- transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
755
- parameters=qtyping.UniformQuantParams(
756
- 8, None, np.array([1]), np.array([0])
757
- ),
868
+ tensor_name='tfl.other_quantize',
869
+ producer=None,
870
+ consumers=_get_test_consumers(
871
+ transformations_per_consumer=[
872
+ [_QTransf.ADD_QUANTIZE],
873
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
874
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
875
+ [_QTransf.ADD_QUANTIZE],
876
+ ],
877
+ params_per_consumer=[_PARAMS_8BIT] * 4,
758
878
  ),
759
- [
760
- qtyping.OpToTensorParams(
761
- subgraph_op_id=1,
762
- transformations=[
763
- qtyping.QuantTransformation.ADD_QUANTIZE
764
- ],
765
- parameters=qtyping.UniformQuantParams(
766
- 8, None, np.array([1]), np.array([0])
767
- ),
768
- ),
769
- qtyping.OpToTensorParams(
770
- subgraph_op_id=2,
771
- transformations=[
772
- qtyping.QuantTransformation.ADD_QUANTIZE,
773
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
774
- ],
775
- parameters=qtyping.UniformQuantParams(
776
- 8, None, np.array([1]), np.array([0])
777
- ),
778
- ),
779
- qtyping.OpToTensorParams(
780
- subgraph_op_id=3,
781
- transformations=[
782
- qtyping.QuantTransformation.ADD_QUANTIZE,
783
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
784
- ],
785
- parameters=qtyping.UniformQuantParams(
786
- 8, None, np.array([1]), np.array([0])
787
- ),
788
- ),
789
- qtyping.OpToTensorParams(
790
- subgraph_op_id=4,
791
- transformations=[
792
- qtyping.QuantTransformation.QUANTIZE_TENSOR,
793
- ],
794
- parameters=qtyping.UniformQuantParams(
795
- 8, None, np.array([1]), np.array([0])
796
- ),
797
- ),
798
- ],
799
879
  ),
800
- expected=False,
880
+ expected=True,
801
881
  ),
802
882
  dict(
803
- testcase_name='compatible',
883
+ testcase_name='compatible_no_numeric_check',
804
884
  param1=qtyping.TensorTransformationParams(
805
885
  tensor_name='tfl.quantize',
806
886
  producer=None,
807
- consumers=[
808
- qtyping.OpToTensorParams(
809
- subgraph_op_id=2,
810
- transformations=[
811
- qtyping.QuantTransformation.ADD_QUANTIZE,
812
- ],
813
- parameters=qtyping.UniformQuantParams(
814
- 8, None, np.array([1]), np.array([0])
887
+ consumers=_get_test_consumers(
888
+ transformations_per_consumer=[
889
+ [_QTransf.ADD_QUANTIZE],
890
+ [_QTransf.ADD_QUANTIZE],
891
+ ],
892
+ params_per_consumer=[
893
+ qtyping.UniformQuantParams(
894
+ 8, None, np.array([0.00028806]), np.array([0])
815
895
  ),
816
- ),
817
- qtyping.OpToTensorParams(
818
- subgraph_op_id=3,
819
- transformations=[
820
- qtyping.QuantTransformation.NO_QUANTIZE,
821
- qtyping.QuantTransformation.ADD_QUANTIZE,
822
- ],
823
- parameters=qtyping.UniformQuantParams(
824
- 8, None, np.array([1]), np.array([0])
896
+ qtyping.UniformQuantParams(
897
+ 8, None, np.array([0.00027501]), np.array([0])
825
898
  ),
826
- ),
827
- qtyping.OpToTensorParams(
828
- subgraph_op_id=4,
829
- transformations=[
830
- qtyping.QuantTransformation.NO_QUANTIZE,
831
- ],
832
- ),
833
- ],
899
+ ],
900
+ ),
834
901
  ),
835
902
  param2=qtyping.TensorTransformationParams(
836
- 'tfl.other_quantize',
837
- None,
838
- [
839
- qtyping.OpToTensorParams(
840
- subgraph_op_id=1,
841
- transformations=[
842
- qtyping.QuantTransformation.ADD_QUANTIZE
843
- ],
844
- parameters=qtyping.UniformQuantParams(
845
- 8, None, np.array([1]), np.array([0])
846
- ),
847
- ),
848
- qtyping.OpToTensorParams(
849
- subgraph_op_id=2,
850
- transformations=[
851
- qtyping.QuantTransformation.ADD_QUANTIZE,
852
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
853
- ],
854
- parameters=qtyping.UniformQuantParams(
855
- 8, None, np.array([1]), np.array([0])
856
- ),
857
- ),
858
- qtyping.OpToTensorParams(
859
- subgraph_op_id=3,
860
- transformations=[
861
- qtyping.QuantTransformation.ADD_QUANTIZE,
862
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
863
- ],
864
- parameters=qtyping.UniformQuantParams(
865
- 8, None, np.array([1]), np.array([0])
903
+ tensor_name='tfl.quantize',
904
+ producer=None,
905
+ consumers=_get_test_consumers(
906
+ transformations_per_consumer=[
907
+ [_QTransf.ADD_QUANTIZE],
908
+ [_QTransf.ADD_QUANTIZE],
909
+ ],
910
+ params_per_consumer=[
911
+ qtyping.UniformQuantParams(
912
+ 8, None, np.array([0.00028806]), np.array([0])
866
913
  ),
867
- ),
868
- qtyping.OpToTensorParams(
869
- subgraph_op_id=4,
870
- transformations=[
871
- qtyping.QuantTransformation.ADD_QUANTIZE,
872
- ],
873
- parameters=qtyping.UniformQuantParams(
874
- 8, None, np.array([1]), np.array([0])
914
+ qtyping.UniformQuantParams(
915
+ 8, None, np.array([0.00027501]), np.array([0])
875
916
  ),
876
- ),
877
- ],
917
+ ],
918
+ ),
878
919
  ),
879
920
  expected=True,
880
921
  ),
881
922
  )
882
- def test_params_compatible(self, param1, param2, expected):
883
- # adding a test to make production coverage happy.
923
+ def test__are_self_compatible_tensors_compatible_to_each_other(
924
+ self, param1, param2, expected
925
+ ):
884
926
  self.assertEqual(
885
- params_generator._compatible_tensor_transformation_params(
927
+ params_generator._are_self_compatible_tensors_compatible_to_each_other(
886
928
  param1, param2
887
929
  ),
888
930
  expected,
889
931
  )
890
932
 
933
+ @parameterized.named_parameters(
934
+ dict(
935
+ testcase_name='consumer_incompatible',
936
+ params=qtyping.TensorTransformationParams(
937
+ tensor_name='tfl.quantize',
938
+ producer=qtyping.OpToTensorParams(
939
+ subgraph_op_id=0,
940
+ transformations=[_QTransf.NO_QUANTIZE],
941
+ parameters=_PARAMS_8BIT,
942
+ ),
943
+ consumers=_get_test_consumers(
944
+ transformations_per_consumer=[
945
+ [_QTransf.ADD_QUANTIZE],
946
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
947
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
948
+ [_QTransf.QUANTIZE_TENSOR],
949
+ ],
950
+ params_per_consumer=[_PARAMS_8BIT] * 4,
951
+ ),
952
+ ),
953
+ expected=False,
954
+ ),
955
+ dict(
956
+ testcase_name='compatible',
957
+ params=qtyping.TensorTransformationParams(
958
+ tensor_name='tfl.quantize',
959
+ producer=None,
960
+ consumers=_get_test_consumers(
961
+ transformations_per_consumer=[
962
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
963
+ [_QTransf.ADD_QUANTIZE],
964
+ [_QTransf.NO_QUANTIZE, _QTransf.ADD_QUANTIZE],
965
+ [_QTransf.NO_QUANTIZE],
966
+ ],
967
+ params_per_consumer=[_PARAMS_8BIT] * 4,
968
+ ),
969
+ ),
970
+ expected=True,
971
+ ),
972
+ dict(
973
+ testcase_name='compatible_no_numeric_check',
974
+ params=qtyping.TensorTransformationParams(
975
+ tensor_name='tfl.quantize',
976
+ producer=None,
977
+ consumers=_get_test_consumers(
978
+ transformations_per_consumer=[
979
+ [_QTransf.ADD_QUANTIZE],
980
+ [_QTransf.ADD_QUANTIZE],
981
+ ],
982
+ params_per_consumer=[
983
+ qtyping.UniformQuantParams(
984
+ 8, None, np.array([0.00028806]), np.array([0])
985
+ ),
986
+ qtyping.UniformQuantParams(
987
+ 8, None, np.array([0.00027501]), np.array([0])
988
+ ),
989
+ ],
990
+ ),
991
+ ),
992
+ expected=True,
993
+ ),
994
+ )
995
+ def test__are_tensor_consumer_params_compatible(self, params, expected):
996
+ self.assertEqual(
997
+ params_generator._are_tensor_consumer_params_compatible(params),
998
+ expected,
999
+ )
1000
+
891
1001
  def test_model_with_duplicated_tensor_names_fails(self):
892
1002
  model_path = os.path.join(
893
1003
  TEST_DATA_PREFIX_PATH, 'tests/models/duplicated_tensor_names.tflite'
@@ -1025,16 +1135,11 @@ class ParamsGeneratorAlreadyQuantizedModelTest(googletest.TestCase):
1025
1135
  )
1026
1136
  _ = params_generator.ParamsGenerator(test_model_path)
1027
1137
 
1028
- def test_check_is_float_model_raises_error_when_model_is_quantized(self):
1138
+ def test_check_is_quantized_model_succeeds_when_model_is_quantized(self):
1029
1139
  test_model_path = os.path.join(
1030
1140
  TEST_DATA_PREFIX_PATH, 'tests/models/mnist_quantized.tflite'
1031
1141
  )
1032
- with self.assertRaisesRegex(
1033
- ValueError,
1034
- 'The input model for quantization parameters generation is not a float'
1035
- ' model.',
1036
- ):
1037
- _ = params_generator.ParamsGenerator(test_model_path)
1142
+ _ = params_generator.ParamsGenerator(test_model_path)
1038
1143
 
1039
1144
 
1040
1145
  if __name__ == '__main__':