ai-edge-quantizer-nightly 0.1.0.dev20250407__py3-none-any.whl → 0.1.0.dev20250410__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,7 @@
16
16
  """Performs naive min/max uniform quantization."""
17
17
 
18
18
  from typing import Any, Optional
19
+ import ml_dtypes
19
20
  import numpy as np
20
21
  from ai_edge_quantizer import qtyping
21
22
  from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
@@ -73,12 +74,25 @@ def get_tensor_quant_params(
73
74
  " parameters. Check if the correct calibration results are passed into"
74
75
  " the ParamsGenerator."
75
76
  )
77
+ clipping_values = None
78
+ # Blockwise quantization uses float16 scale, with 7 bit mantissa,
79
+ # so the maximum representable value is 65280.
80
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
81
+ clipping_values = np.broadcast_to(
82
+ np.array(65280), tensor_min_max["min"].shape
83
+ )
76
84
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
77
85
  tensor_min_max["min"],
78
86
  tensor_min_max["max"],
79
87
  tensor_quant_config.num_bits,
80
88
  tensor_quant_config.symmetric,
89
+ clipping_values,
81
90
  )
91
+ # Round the scale values to 7 bit mantissa.
92
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
93
+ scale = (
94
+ scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
95
+ )
82
96
  quantized_dim = None
83
97
  if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
84
98
  quantized_dim = common_utils.get_weight_quantized_dim(
@@ -22,7 +22,6 @@ import numpy as np
22
22
  from tensorflow.python.platform import googletest
23
23
  from ai_edge_quantizer import qtyping
24
24
  from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
25
- from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
26
25
  from ai_edge_quantizer.utils import test_utils
27
26
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
27
 
@@ -185,14 +184,13 @@ class NaiveMinMaxQuantizeTest(parameterized.TestCase):
185
184
  )
186
185
  scale = quant_params.scale
187
186
  zp = quant_params.zero_point
188
- expected_zp, expected_scale = (
189
- uniform_quantize_tensor.tensor_zp_scale_from_min_max(
190
- min_value=np.array([[-7], [-4], [-4], [7]]),
191
- max_value=np.array([[7], [4], [4], [7]]),
192
- num_bits=4,
193
- symmetric=True,
194
- )
195
- )
187
+ expected_scale = np.array([
188
+ [1],
189
+ [0.5703125],
190
+ [0.5703125],
191
+ [1],
192
+ ])
193
+ expected_zp = np.zeros([4, 1])
196
194
  self.assertTrue(np.array_equal(zp, expected_zp))
197
195
  self.assertTrue(np.array_equal(scale, expected_scale))
198
196
  self.assertIsNotNone(quant_params.quantized_data)
@@ -344,22 +344,58 @@ def _compatible_tensor_transformation_params(
344
344
  params2: qtyping.TensorTransformationParams,
345
345
  ) -> bool:
346
346
  """Check if two tensor transformation params are compatible."""
347
+ return (
348
+ _are_tensor_consumer_params_compatible(params1)
349
+ and _are_tensor_consumer_params_compatible(params2)
350
+ and _are_self_compatible_tensors_compatible_to_each_other(
351
+ params1, params2
352
+ )
353
+ )
354
+
355
+
356
+ def _are_tensor_consumer_params_compatible(
357
+ params: qtyping.TensorTransformationParams,
358
+ ) -> bool:
359
+ """Check if all tensor's consumers have the same quantization parameters."""
360
+ if params.consumers is None or len(params.consumers) < 2:
361
+ return True
362
+ consumer_1 = params.consumers[0]
363
+ for consumer in params.consumers[1:]:
364
+ if not _compatible_tensor_params(consumer, consumer_1):
365
+ return False
366
+ return True
367
+
368
+
369
+ def _are_self_compatible_tensors_compatible_to_each_other(
370
+ params1: qtyping.TensorTransformationParams,
371
+ params2: qtyping.TensorTransformationParams,
372
+ ) -> bool:
373
+ """Check if two self compatible tensors are compatible to each other.
374
+
375
+ Self compatible means that all tensor's consumers have the same quantization
376
+ parameters.
377
+
378
+ Args:
379
+ params1: The first tensor transformation params.
380
+ params2: The second tensor transformation params.
381
+
382
+ Returns:
383
+ Whether the two tensors are compatible to each other.
384
+ """
385
+ # Check the producer.
347
386
  if params1.producer is None or params2.producer is None:
348
387
  if params1.producer != params2.producer:
349
388
  return False
350
389
  elif not _compatible_tensor_params(params1.producer, params2.producer):
351
390
  return False
391
+
392
+ # Check the consumers.
352
393
  if params1.consumers is None or params2.consumers is None:
353
394
  if params1.consumers != params2.consumers:
354
395
  return False
355
396
  else:
356
- # Check all consumers within each params are compatible.
357
- for params1_consumer in params1.consumers:
358
- if not _compatible_tensor_params(params1_consumer, params1.consumers[0]):
359
- return False
360
- for params2_consumer in params2.consumers:
361
- if not _compatible_tensor_params(params2_consumer, params2.consumers[0]):
362
- return False
397
+ # Since all consumer params within each tensor are the same, it's enough to
398
+ # check only the first consumers.
363
399
  if not _compatible_tensor_params(
364
400
  params1.consumers[0], params2.consumers[0]
365
401
  ):
@@ -37,8 +37,11 @@ _TensorQuantConfig = qtyping.TensorQuantizationConfig
37
37
  _QuantTransformation = qtyping.QuantTransformation
38
38
  _AlgorithmName = recipe_manager.AlgorithmName
39
39
  _QuantGranularity = qtyping.QuantGranularity
40
+ _QTransf = qtyping.QuantTransformation
41
+
40
42
 
41
43
  TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
44
+ _PARAMS_8BIT = qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0]))
42
45
 
43
46
 
44
47
  def _single_fc_model_representative_dataset_gen(num_samples=5):
@@ -63,6 +66,20 @@ def _get_calibration_data(
63
66
  return calibration_data
64
67
 
65
68
 
69
+ def _get_test_consumers(
70
+ transformations_per_consumer: list[list[_QTransf]],
71
+ params_per_consumer: list[qtyping.OpToTensorParams],
72
+ ) -> list[qtyping.OpToTensorParams]:
73
+ return [
74
+ qtyping.OpToTensorParams(
75
+ subgraph_op_id=i + 1,
76
+ transformations=transformations_per_consumer[i],
77
+ parameters=params_per_consumer[i],
78
+ )
79
+ for i in range(len(transformations_per_consumer))
80
+ ]
81
+
82
+
66
83
  class ParamsGeneratorTest(parameterized.TestCase):
67
84
 
68
85
  def setUp(self):
@@ -635,12 +652,12 @@ class ParamsGeneratorTest(parameterized.TestCase):
635
652
  self.assertNotEmpty(consumer.transformations)
636
653
  self.assertEqual(
637
654
  consumer.transformations[0],
638
- qtyping.QuantTransformation.DUPLICATE_BUFFER,
655
+ _QTransf.DUPLICATE_BUFFER,
639
656
  )
640
657
  elif quant_params[tensor_name].consumers is not None:
641
658
  for consumer in quant_params[tensor_name].consumers:
642
659
  self.assertNotIn(
643
- qtyping.QuantTransformation.DUPLICATE_BUFFER,
660
+ _QTransf.DUPLICATE_BUFFER,
644
661
  consumer.transformations,
645
662
  )
646
663
 
@@ -651,328 +668,182 @@ class ParamsGeneratorTest(parameterized.TestCase):
651
668
  tensor_name='tfl.quantize',
652
669
  producer=qtyping.OpToTensorParams(
653
670
  subgraph_op_id=0,
654
- transformations=[qtyping.QuantTransformation.ADD_DEQUANTIZE],
655
- parameters=qtyping.UniformQuantParams(
656
- 8, None, np.array([1]), np.array([0])
657
- ),
671
+ transformations=[_QTransf.ADD_DEQUANTIZE],
672
+ parameters=_PARAMS_8BIT,
673
+ ),
674
+ consumers=_get_test_consumers(
675
+ transformations_per_consumer=[
676
+ [_QTransf.ADD_QUANTIZE],
677
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
678
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
679
+ [_QTransf.NO_QUANTIZE],
680
+ ],
681
+ params_per_consumer=[_PARAMS_8BIT] * 4,
658
682
  ),
659
- consumers=[
660
- qtyping.OpToTensorParams(
661
- subgraph_op_id=1,
662
- transformations=[
663
- qtyping.QuantTransformation.ADD_QUANTIZE
664
- ],
665
- parameters=qtyping.UniformQuantParams(
666
- 8, None, np.array([1]), np.array([0])
667
- ),
668
- ),
669
- qtyping.OpToTensorParams(
670
- subgraph_op_id=2,
671
- transformations=[
672
- qtyping.QuantTransformation.ADD_QUANTIZE,
673
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
674
- ],
675
- parameters=qtyping.UniformQuantParams(
676
- 8, None, np.array([1]), np.array([0])
677
- ),
678
- ),
679
- qtyping.OpToTensorParams(
680
- subgraph_op_id=3,
681
- transformations=[
682
- qtyping.QuantTransformation.ADD_QUANTIZE,
683
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
684
- ],
685
- parameters=qtyping.UniformQuantParams(
686
- 8, None, np.array([1]), np.array([0])
687
- ),
688
- ),
689
- qtyping.OpToTensorParams(
690
- subgraph_op_id=4,
691
- transformations=[
692
- qtyping.QuantTransformation.NO_QUANTIZE,
693
- ],
694
- parameters=qtyping.UniformQuantParams(
695
- 8, None, np.array([1]), np.array([0])
696
- ),
697
- ),
698
- ],
699
683
  ),
700
684
  param2=qtyping.TensorTransformationParams(
701
- 'tfl.other_quantize',
702
- qtyping.OpToTensorParams(
685
+ tensor_name='tfl.other_quantize',
686
+ producer=qtyping.OpToTensorParams(
703
687
  subgraph_op_id=0,
704
- transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
705
- parameters=qtyping.UniformQuantParams(
706
- 8, None, np.array([1]), np.array([0])
707
- ),
688
+ transformations=[_QTransf.NO_QUANTIZE],
689
+ parameters=_PARAMS_8BIT,
690
+ ),
691
+ consumers=_get_test_consumers(
692
+ transformations_per_consumer=[
693
+ [_QTransf.ADD_QUANTIZE],
694
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
695
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
696
+ ],
697
+ params_per_consumer=[_PARAMS_8BIT] * 4,
708
698
  ),
709
- [
710
- qtyping.OpToTensorParams(
711
- subgraph_op_id=1,
712
- transformations=[
713
- qtyping.QuantTransformation.ADD_QUANTIZE
714
- ],
715
- parameters=qtyping.UniformQuantParams(
716
- 8, None, np.array([1]), np.array([0])
717
- ),
718
- ),
719
- qtyping.OpToTensorParams(
720
- subgraph_op_id=2,
721
- transformations=[
722
- qtyping.QuantTransformation.ADD_QUANTIZE,
723
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
724
- ],
725
- parameters=qtyping.UniformQuantParams(
726
- 8, None, np.array([1]), np.array([0])
727
- ),
728
- ),
729
- qtyping.OpToTensorParams(
730
- subgraph_op_id=3,
731
- transformations=[
732
- qtyping.QuantTransformation.ADD_QUANTIZE,
733
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
734
- ],
735
- parameters=qtyping.UniformQuantParams(
736
- 8, None, np.array([1]), np.array([0])
737
- ),
738
- ),
739
- ],
740
699
  ),
741
700
  expected=False,
742
701
  ),
743
702
  dict(
744
- testcase_name='param2_consumer_incompatible',
703
+ testcase_name='compatible',
745
704
  param1=qtyping.TensorTransformationParams(
746
705
  tensor_name='tfl.quantize',
747
- producer=qtyping.OpToTensorParams(
748
- subgraph_op_id=0,
749
- transformations=[qtyping.QuantTransformation.ADD_QUANTIZE],
750
- parameters=qtyping.UniformQuantParams(
751
- 8, None, np.array([1]), np.array([0])
752
- ),
706
+ producer=None,
707
+ consumers=_get_test_consumers(
708
+ transformations_per_consumer=[
709
+ [_QTransf.ADD_QUANTIZE],
710
+ [_QTransf.NO_QUANTIZE, _QTransf.ADD_QUANTIZE],
711
+ [_QTransf.NO_QUANTIZE],
712
+ ],
713
+ params_per_consumer=[_PARAMS_8BIT] * 4,
753
714
  ),
754
- consumers=[
755
- qtyping.OpToTensorParams(
756
- subgraph_op_id=1,
757
- transformations=[
758
- qtyping.QuantTransformation.ADD_QUANTIZE
759
- ],
760
- parameters=qtyping.UniformQuantParams(
761
- 8, None, np.array([1]), np.array([0])
762
- ),
763
- ),
764
- qtyping.OpToTensorParams(
765
- subgraph_op_id=2,
766
- transformations=[
767
- qtyping.QuantTransformation.ADD_QUANTIZE,
768
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
769
- ],
770
- parameters=qtyping.UniformQuantParams(
771
- 8, None, np.array([1]), np.array([0])
772
- ),
773
- ),
774
- qtyping.OpToTensorParams(
775
- subgraph_op_id=3,
776
- transformations=[
777
- qtyping.QuantTransformation.ADD_QUANTIZE,
778
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
779
- ],
780
- parameters=qtyping.UniformQuantParams(
781
- 8, None, np.array([1]), np.array([0])
782
- ),
783
- ),
784
- ],
785
715
  ),
786
716
  param2=qtyping.TensorTransformationParams(
787
- 'tfl.other_quantize',
788
- qtyping.OpToTensorParams(
789
- subgraph_op_id=0,
790
- transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
791
- parameters=qtyping.UniformQuantParams(
792
- 8, None, np.array([1]), np.array([0])
793
- ),
717
+ tensor_name='tfl.other_quantize',
718
+ producer=None,
719
+ consumers=_get_test_consumers(
720
+ transformations_per_consumer=[
721
+ [_QTransf.ADD_QUANTIZE],
722
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
723
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
724
+ [_QTransf.ADD_QUANTIZE],
725
+ ],
726
+ params_per_consumer=[_PARAMS_8BIT] * 4,
794
727
  ),
795
- [
796
- qtyping.OpToTensorParams(
797
- subgraph_op_id=1,
798
- transformations=[
799
- qtyping.QuantTransformation.ADD_QUANTIZE
800
- ],
801
- parameters=qtyping.UniformQuantParams(
802
- 8, None, np.array([1]), np.array([0])
803
- ),
804
- ),
805
- qtyping.OpToTensorParams(
806
- subgraph_op_id=2,
807
- transformations=[
808
- qtyping.QuantTransformation.ADD_QUANTIZE,
809
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
810
- ],
811
- parameters=qtyping.UniformQuantParams(
812
- 8, None, np.array([1]), np.array([0])
813
- ),
814
- ),
815
- qtyping.OpToTensorParams(
816
- subgraph_op_id=3,
817
- transformations=[
818
- qtyping.QuantTransformation.ADD_QUANTIZE,
819
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
820
- ],
821
- parameters=qtyping.UniformQuantParams(
822
- 8, None, np.array([1]), np.array([0])
823
- ),
824
- ),
825
- qtyping.OpToTensorParams(
826
- subgraph_op_id=4,
827
- transformations=[
828
- qtyping.QuantTransformation.QUANTIZE_TENSOR,
829
- ],
830
- parameters=qtyping.UniformQuantParams(
831
- 8, None, np.array([1]), np.array([0])
832
- ),
833
- ),
834
- ],
835
728
  ),
836
- expected=False,
729
+ expected=True,
837
730
  ),
838
731
  dict(
839
- testcase_name='compatible',
732
+ testcase_name='compatible_no_numeric_check',
840
733
  param1=qtyping.TensorTransformationParams(
841
734
  tensor_name='tfl.quantize',
842
735
  producer=None,
843
- consumers=[
844
- qtyping.OpToTensorParams(
845
- subgraph_op_id=2,
846
- transformations=[
847
- qtyping.QuantTransformation.ADD_QUANTIZE,
848
- ],
849
- parameters=qtyping.UniformQuantParams(
850
- 8, None, np.array([1]), np.array([0])
736
+ consumers=_get_test_consumers(
737
+ transformations_per_consumer=[
738
+ [_QTransf.ADD_QUANTIZE],
739
+ [_QTransf.ADD_QUANTIZE],
740
+ ],
741
+ params_per_consumer=[
742
+ qtyping.UniformQuantParams(
743
+ 8, None, np.array([0.00028806]), np.array([0])
851
744
  ),
852
- ),
853
- qtyping.OpToTensorParams(
854
- subgraph_op_id=3,
855
- transformations=[
856
- qtyping.QuantTransformation.NO_QUANTIZE,
857
- qtyping.QuantTransformation.ADD_QUANTIZE,
858
- ],
859
- parameters=qtyping.UniformQuantParams(
860
- 8, None, np.array([1]), np.array([0])
745
+ qtyping.UniformQuantParams(
746
+ 8, None, np.array([0.00027501]), np.array([0])
861
747
  ),
862
- ),
863
- qtyping.OpToTensorParams(
864
- subgraph_op_id=4,
865
- transformations=[
866
- qtyping.QuantTransformation.NO_QUANTIZE,
867
- ],
868
- ),
869
- ],
748
+ ],
749
+ ),
870
750
  ),
871
751
  param2=qtyping.TensorTransformationParams(
872
- 'tfl.other_quantize',
873
- None,
874
- [
875
- qtyping.OpToTensorParams(
876
- subgraph_op_id=1,
877
- transformations=[
878
- qtyping.QuantTransformation.ADD_QUANTIZE
879
- ],
880
- parameters=qtyping.UniformQuantParams(
881
- 8, None, np.array([1]), np.array([0])
882
- ),
883
- ),
884
- qtyping.OpToTensorParams(
885
- subgraph_op_id=2,
886
- transformations=[
887
- qtyping.QuantTransformation.ADD_QUANTIZE,
888
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
889
- ],
890
- parameters=qtyping.UniformQuantParams(
891
- 8, None, np.array([1]), np.array([0])
892
- ),
893
- ),
894
- qtyping.OpToTensorParams(
895
- subgraph_op_id=3,
896
- transformations=[
897
- qtyping.QuantTransformation.ADD_QUANTIZE,
898
- qtyping.QuantTransformation.ADD_DEQUANTIZE,
899
- ],
900
- parameters=qtyping.UniformQuantParams(
901
- 8, None, np.array([1]), np.array([0])
752
+ tensor_name='tfl.quantize',
753
+ producer=None,
754
+ consumers=_get_test_consumers(
755
+ transformations_per_consumer=[
756
+ [_QTransf.ADD_QUANTIZE],
757
+ [_QTransf.ADD_QUANTIZE],
758
+ ],
759
+ params_per_consumer=[
760
+ qtyping.UniformQuantParams(
761
+ 8, None, np.array([0.00028806]), np.array([0])
902
762
  ),
903
- ),
904
- qtyping.OpToTensorParams(
905
- subgraph_op_id=4,
906
- transformations=[
907
- qtyping.QuantTransformation.ADD_QUANTIZE,
908
- ],
909
- parameters=qtyping.UniformQuantParams(
910
- 8, None, np.array([1]), np.array([0])
763
+ qtyping.UniformQuantParams(
764
+ 8, None, np.array([0.00027501]), np.array([0])
911
765
  ),
912
- ),
913
- ],
766
+ ],
767
+ ),
914
768
  ),
915
769
  expected=True,
916
770
  ),
771
+ )
772
+ def test__are_self_compatible_tensors_compatible_to_each_other(
773
+ self, param1, param2, expected
774
+ ):
775
+ self.assertEqual(
776
+ params_generator._are_self_compatible_tensors_compatible_to_each_other(
777
+ param1, param2
778
+ ),
779
+ expected,
780
+ )
781
+
782
+ @parameterized.named_parameters(
917
783
  dict(
918
- testcase_name='compatible_no_numeric_check',
919
- param1=qtyping.TensorTransformationParams(
784
+ testcase_name='consumer_incompatible',
785
+ params=qtyping.TensorTransformationParams(
786
+ tensor_name='tfl.quantize',
787
+ producer=qtyping.OpToTensorParams(
788
+ subgraph_op_id=0,
789
+ transformations=[_QTransf.NO_QUANTIZE],
790
+ parameters=_PARAMS_8BIT,
791
+ ),
792
+ consumers=_get_test_consumers(
793
+ transformations_per_consumer=[
794
+ [_QTransf.ADD_QUANTIZE],
795
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
796
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
797
+ [_QTransf.QUANTIZE_TENSOR],
798
+ ],
799
+ params_per_consumer=[_PARAMS_8BIT] * 4,
800
+ ),
801
+ ),
802
+ expected=False,
803
+ ),
804
+ dict(
805
+ testcase_name='compatible',
806
+ params=qtyping.TensorTransformationParams(
920
807
  tensor_name='tfl.quantize',
921
808
  producer=None,
922
- consumers=[
923
- qtyping.OpToTensorParams(
924
- subgraph_op_id=4,
925
- transformations=[
926
- qtyping.QuantTransformation.ADD_QUANTIZE,
927
- ],
928
- parameters=qtyping.UniformQuantParams(
929
- 8, None, np.array([0.00028806]), np.array([0])
930
- ),
931
- ),
932
- qtyping.OpToTensorParams(
933
- subgraph_op_id=5,
934
- transformations=[
935
- qtyping.QuantTransformation.ADD_QUANTIZE,
936
- ],
937
- parameters=qtyping.UniformQuantParams(
938
- 8, None, np.array([0.00027501]), np.array([0])
939
- ),
940
- ),
941
- ],
809
+ consumers=_get_test_consumers(
810
+ transformations_per_consumer=[
811
+ [_QTransf.ADD_QUANTIZE, _QTransf.ADD_DEQUANTIZE],
812
+ [_QTransf.ADD_QUANTIZE],
813
+ [_QTransf.NO_QUANTIZE, _QTransf.ADD_QUANTIZE],
814
+ [_QTransf.NO_QUANTIZE],
815
+ ],
816
+ params_per_consumer=[_PARAMS_8BIT] * 4,
817
+ ),
942
818
  ),
943
- param2=qtyping.TensorTransformationParams(
819
+ expected=True,
820
+ ),
821
+ dict(
822
+ testcase_name='compatible_no_numeric_check',
823
+ params=qtyping.TensorTransformationParams(
944
824
  tensor_name='tfl.quantize',
945
825
  producer=None,
946
- consumers=[
947
- qtyping.OpToTensorParams(
948
- subgraph_op_id=4,
949
- transformations=[
950
- qtyping.QuantTransformation.ADD_QUANTIZE,
951
- ],
952
- parameters=qtyping.UniformQuantParams(
826
+ consumers=_get_test_consumers(
827
+ transformations_per_consumer=[
828
+ [_QTransf.ADD_QUANTIZE],
829
+ [_QTransf.ADD_QUANTIZE],
830
+ ],
831
+ params_per_consumer=[
832
+ qtyping.UniformQuantParams(
953
833
  8, None, np.array([0.00028806]), np.array([0])
954
834
  ),
955
- ),
956
- qtyping.OpToTensorParams(
957
- subgraph_op_id=5,
958
- transformations=[
959
- qtyping.QuantTransformation.ADD_QUANTIZE,
960
- ],
961
- parameters=qtyping.UniformQuantParams(
835
+ qtyping.UniformQuantParams(
962
836
  8, None, np.array([0.00027501]), np.array([0])
963
837
  ),
964
- ),
965
- ],
838
+ ],
839
+ ),
966
840
  ),
967
841
  expected=True,
968
842
  ),
969
843
  )
970
- def test_params_compatible(self, param1, param2, expected):
971
- # adding a test to make production coverage happy.
844
+ def test__are_tensor_consumer_params_compatible(self, params, expected):
972
845
  self.assertEqual(
973
- params_generator._compatible_tensor_transformation_params(
974
- param1, param2
975
- ),
846
+ params_generator._are_tensor_consumer_params_compatible(params),
976
847
  expected,
977
848
  )
978
849
 
@@ -16,6 +16,7 @@
16
16
  """quantize a given tensor."""
17
17
 
18
18
  from typing import Optional, cast
19
+ import ml_dtypes
19
20
  import numpy as np
20
21
  from ai_edge_quantizer import qtyping
21
22
  from ai_edge_quantizer.transformations import transformation_utils
@@ -121,26 +122,6 @@ def _perform_channelwise_quantization(
121
122
  return flatbuffer_quantization
122
123
 
123
124
 
124
- def _downcast_and_truncate_scale(input_scale: np.ndarray) -> np.ndarray:
125
- """Given a fp32 scale, downcast it to fp16 and truncate mantissa to 7 bits.
126
-
127
- CPU kernel can only utilize 7 bits of mantissa for fp16, so we want to produce
128
- scale this way to unify behaviours across different platforms.
129
-
130
- Args:
131
- input_scale: The input scale in fp32.
132
-
133
- Returns:
134
- The downcasted & truncated scale in fp16.
135
- """
136
-
137
- # A regular fp16 has 10 bits of mantissa, so we need to zero out the 3 least
138
- # significant bits.
139
- return (
140
- input_scale.astype(np.float16).view(dtype=np.uint16) & np.uint16(0xFFF8)
141
- ).view(dtype=np.float16)
142
-
143
-
144
125
  def _perform_blockwise_quantization(
145
126
  transformation_input: transformation_utils.TransformationInput,
146
127
  ) -> schema_py_generated.QuantizationParametersT():
@@ -162,13 +143,12 @@ def _perform_blockwise_quantization(
162
143
  )
163
144
  tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
164
145
  blockwise_details = schema_py_generated.BlockwiseQuantizationT()
165
- # Downcast and truncate the scale to fp16.
166
- downcasted_scale = _downcast_and_truncate_scale(
167
- transformation_input.quant_params.scale
168
- )
146
+ # Downcast and round the scale to fp16 with 7 bit mantissa.
169
147
  scale_tensor_id = transformation_utils.add_new_constant_tensor(
170
148
  tensor.name + b"_scales",
171
- downcasted_scale,
149
+ transformation_input.quant_params.scale.astype(ml_dtypes.bfloat16).astype(
150
+ np.float16
151
+ ),
172
152
  schema_py_generated.TensorType.FLOAT16,
173
153
  transformation_input.subgraph,
174
154
  transformation_input.buffers,
@@ -231,7 +231,8 @@ def get_tensor_data(tensor: Any, buffers: list[Any]) -> Optional[np.ndarray]:
231
231
  data = np.frombuffer(
232
232
  buffer_data, dtype=TENSOR_CODE_TO_TYPE[tensor.type].lower()
233
233
  )
234
- data = np.reshape(data, tensor.shape)
234
+ if tensor.shape is not None:
235
+ data = np.reshape(data, tensor.shape)
235
236
  return data
236
237
 
237
238
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.1.0.dev20250407
3
+ Version: 0.1.0.dev20250410
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -10,8 +10,8 @@ ai_edge_quantizer/model_modifier.py,sha256=SPt9X-xBzRvcd4xIS24zLHt3aUS2QwsNDqweF
10
10
  ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
11
11
  ai_edge_quantizer/model_validator.py,sha256=fRNz0jO54cthPTibsCuViUXUuFRHl_fbvEiCukIVy20,13030
12
12
  ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
13
- ai_edge_quantizer/params_generator.py,sha256=46XDjnP4R3m4xsoXNp7brv0sNQPdQMg217_CbEl-Wgg,15780
14
- ai_edge_quantizer/params_generator_test.py,sha256=9WTUl87XqbM4NruX5ypLuVRtuhcw-CmxndsMOUzZ92Q,43171
13
+ ai_edge_quantizer/params_generator.py,sha256=PeIwoNYg4kJq0cMPucTvyxXTqD0I1Sr8vm5xHZCQ518,16774
14
+ ai_edge_quantizer/params_generator_test.py,sha256=DhULRWs1-UuO55zuuxocMjWDClcjcaKaue6mOcoHq9E,37186
15
15
  ai_edge_quantizer/qtyping.py,sha256=FqelZu7j0fGBRSCv_VVsuf3VmbfVlYJGgsjvdMXGgaw,15284
16
16
  ai_edge_quantizer/quantizer.py,sha256=g3DMqFMrMpt9jQttCE0WcdNbMtk0JZnmN5MmCHrNdyM,13202
17
17
  ai_edge_quantizer/quantizer_test.py,sha256=K_HBA56JkFI3HL8VLWCqGEfC0ISh5ldMKoNyBdGRAJg,20368
@@ -32,8 +32,8 @@ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=SVu1RSX5
32
32
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=qMmKbWqxrCoVKbLKHn9WuCrGKPfHkEyU0Nmhokh8Qeo,2597
33
33
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=Fk3s9Qy2A_hjUepFOUmTwIZ_wKYVPbdDX4eoP-eoAQU,8726
34
34
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
35
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=cbyyYAoQnEraOYSV00wZ557ElBndHduVGeHikYUEFCE,7995
36
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=B30SEISYZ9DPs3suKeG2elgXylR98pCEMWSEGgZo20o,7648
35
+ ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=Divlsn3NjNGtH0vlvE91wxL-VHb4q1nUE0JTDGiEtYc,8572
36
+ ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=zoF_EHjYqsKkuev8wfuutIITEmp_maa70IpJI_Df3ck,7431
37
37
  ai_edge_quantizer/algorithms/uniform_quantize/octav.py,sha256=e5wYtki-vl739gSVAZHAKcs2hA87GvFUjVoSUPlnkyM,6433
38
38
  ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py,sha256=IcTOaJ1pxtqsitqxOEP9LROVEP_19VFutHalqNied4I,6940
39
39
  ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py,sha256=WmZzKQlzfu9gFr9SbUDoPY3rFqTl363om8-0rTLwotw,11629
@@ -52,7 +52,7 @@ ai_edge_quantizer/transformations/emulated_subchannel.py,sha256=HVaRxoC8PCAvy3xe
52
52
  ai_edge_quantizer/transformations/emulated_subchannel_test.py,sha256=gZP6u9NdPXl7s19qB_Un8evou9ZZV6I9Gy0E1rdobHM,7722
53
53
  ai_edge_quantizer/transformations/quant_insert.py,sha256=jn6HsJaV-sqBiFPY-Aqbd64t8zgcYVkEkZI375x_FWY,3958
54
54
  ai_edge_quantizer/transformations/quant_insert_test.py,sha256=X9ptPDvJCFkR5tejKnD1SlHFGPazQTW-wNNMV9MEAuw,10107
55
- ai_edge_quantizer/transformations/quantize_tensor.py,sha256=y6As38mTzhva50YvNQ7p0SFpuWet3LPqFwE3qIO0gEQ,8231
55
+ ai_edge_quantizer/transformations/quantize_tensor.py,sha256=kjaNrw9mnrn0t8u0vey9S_uPz3iVUicwy4rluxVqV3E,7617
56
56
  ai_edge_quantizer/transformations/quantize_tensor_test.py,sha256=mHLO3_MRt36A8-ZN8ADn5tBBJlqjTWa7ZUN8Mmu5Rcw,9116
57
57
  ai_edge_quantizer/transformations/transformation_utils.py,sha256=5w0fG6TP362elTHs-JZokl24fuK4Gv6DGyIpybQYb3g,4885
58
58
  ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=xH64SF3UHDh84vYbt-WvmXNjM-Jg-mefES1ACO1tkqw,6269
@@ -60,14 +60,14 @@ ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V
60
60
  ai_edge_quantizer/utils/calibration_utils.py,sha256=1Fj9MIO6aLZIRgyd4axvZN4S_O64nB_-Miu1WP664js,2536
61
61
  ai_edge_quantizer/utils/calibration_utils_test.py,sha256=Z-AcdTieesWFKyKBb08ZXm4Mgu6cvJ4bg2-MJ7hLD10,2856
62
62
  ai_edge_quantizer/utils/test_utils.py,sha256=HwZCIpO9fJRAhuN6t6voXKOYQtcioFtt_tpkAlDsAYk,6205
63
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=irrGbbOt14PLFcS4538II0dB4Q7YJMgGvpBERVHevXM,10535
63
+ ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=NKtw60BJAjIE6Yww8B1vJpxXwp4MSERmpKajXJWm5rI,10568
64
64
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
65
65
  ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=x2xA2CFPpe_2trcV8v5xGaBETvVCfwAcJuq6yieGJ0Y,12687
66
66
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=Op3JxtOqlrjzmYF18jnnstL1k9xiY9kKJ8S2vklKGkc,11327
67
67
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
68
68
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
69
- ai_edge_quantizer_nightly-0.1.0.dev20250407.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
70
- ai_edge_quantizer_nightly-0.1.0.dev20250407.dist-info/METADATA,sha256=u28A5AptBWTj0kkQD0QfUKvfwYUGbzTmi1NzqiItobc,1527
71
- ai_edge_quantizer_nightly-0.1.0.dev20250407.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
72
- ai_edge_quantizer_nightly-0.1.0.dev20250407.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
73
- ai_edge_quantizer_nightly-0.1.0.dev20250407.dist-info/RECORD,,
69
+ ai_edge_quantizer_nightly-0.1.0.dev20250410.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
70
+ ai_edge_quantizer_nightly-0.1.0.dev20250410.dist-info/METADATA,sha256=PeTpfQxSLhSk9QyphqCPCHKAytqdjDAf9FzubQofMmM,1527
71
+ ai_edge_quantizer_nightly-0.1.0.dev20250410.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
72
+ ai_edge_quantizer_nightly-0.1.0.dev20250410.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
73
+ ai_edge_quantizer_nightly-0.1.0.dev20250410.dist-info/RECORD,,