ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -36,6 +36,13 @@ _OpQuantConstraint = common_utils.OpQuantConstraint
36
36
  _ComputePrecision = qtyping.ComputePrecision
37
37
 
38
38
 
39
+ def check_if_quantized(tensor: Any) -> bool:
40
+ """Checks if the tensor is quantized."""
41
+ return (
42
+ tensor.quantization is not None and tensor.quantization.scale is not None
43
+ )
44
+
45
+
39
46
  def check_op_quantization_config(
40
47
  op_name: _TFLOpName,
41
48
  op_quant_config: qtyping.OpQuantizationConfig,
@@ -271,7 +278,7 @@ def materialize_average_pool_2d(
271
278
  )
272
279
 
273
280
 
274
- def _materialize_bias_for_conv_ops(
281
+ def _materialize_bias_for_fc_conv_ops(
275
282
  op_info: qtyping.OpInfo,
276
283
  graph_info: qtyping.GraphInfo,
277
284
  op_tensor_params: list[qtyping.TensorTransformationParams],
@@ -290,14 +297,16 @@ def _materialize_bias_for_conv_ops(
290
297
  op_weight_index: Index for the weight tensor in the op.
291
298
  op_bias_index: Index for the bias tensor in the op.
292
299
  """
293
- _, _, bias_tensor, _ = tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
294
- op_info.op,
295
- graph_info.subgraph_tensors,
296
- op_input_index,
297
- op_weight_index,
298
- op_bias_index,
299
- )
300
- if bias_tensor is not None:
300
+ _, weight_tensor, bias_tensor, _ = (
301
+ tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
302
+ op_info.op,
303
+ graph_info.subgraph_tensors,
304
+ op_input_index,
305
+ op_weight_index,
306
+ op_bias_index,
307
+ )
308
+ )
309
+ if bias_tensor is not None and not check_if_quantized(bias_tensor):
301
310
  bias_quant_params = None
302
311
  # Fused bias needs to be quantized for SRQ.
303
312
  # Check if SRQ.
@@ -309,13 +318,41 @@ def _materialize_bias_for_conv_ops(
309
318
  bias_tensor,
310
319
  graph_info.buffers,
311
320
  )
312
- bias_quant_params = (
313
- uniform_quantize_tensor.symmetric_quantize_bias_tensor(
314
- bias_content,
315
- op_tensor_params[op_input_index].consumers[0].parameters,
316
- op_tensor_params[op_weight_index].consumers[0].parameters,
317
- )
321
+ input_consumer_params = (
322
+ op_tensor_params[op_input_index].consumers[0].parameters
323
+ )
324
+ weight_consumer_params = (
325
+ op_tensor_params[op_weight_index].consumers[0].parameters
318
326
  )
327
+ if weight_consumer_params is None and check_if_quantized(weight_tensor):
328
+ quant_params = weight_tensor.quantization
329
+ if op_info.op_quant_config.weight_tensor_config is None:
330
+ raise ValueError(
331
+ "weight_tensor_config cannot be None when weight tensor is"
332
+ " quantized."
333
+ )
334
+ weight_consumer_params = qtyping.UniformQuantParams(
335
+ num_bits=op_info.op_quant_config.weight_tensor_config.num_bits,
336
+ scale=quant_params.scale,
337
+ zero_point=quant_params.zeroPoint,
338
+ quantized_dimension=quant_params.quantizedDimension,
339
+ )
340
+ try:
341
+ # Bias quantization is using fixed quantization scale:
342
+ # input_scale * weight_scale. To avoid hidden numerics error, we check
343
+ # the quantization error in bias quantization.
344
+ bias_quant_params = (
345
+ uniform_quantize_tensor.symmetric_quantize_bias_tensor(
346
+ bias_content,
347
+ input_consumer_params,
348
+ weight_consumer_params,
349
+ )
350
+ )
351
+ except ValueError as e:
352
+ raise ValueError(
353
+ f"Failed to quantize bias tensor for op {op_info.op_name} with op"
354
+ f" id {op_info.subgraph_op_index}."
355
+ ) from e
319
356
  # We only quantize bias under SRQ. Setting is_constant=True for SRQ only
320
357
  # to avoid quantize bias for DRQ and weight-only cases.
321
358
  is_constant = (
@@ -371,6 +408,25 @@ def materialize_slice(
371
408
  )
372
409
 
373
410
 
411
+ def materialize_select(
412
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
413
+ op_info: qtyping.OpInfo,
414
+ graph_info: qtyping.GraphInfo,
415
+ tensor_name_to_qsv: dict[str, Any],
416
+ ) -> list[qtyping.TensorTransformationParams]:
417
+ """Materialize tensors in tfl.select."""
418
+ return common_utils.materialize_standard_op(
419
+ op_info,
420
+ graph_info,
421
+ tensor_name_to_qsv,
422
+ get_tensor_quant_params_fn,
423
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
424
+ inputs_to_ignore=[
425
+ 0,
426
+ ], # Condition tensor does not need to be quantized.
427
+ )
428
+
429
+
374
430
  def materialize_select_v2(
375
431
  get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
376
432
  op_info: qtyping.OpInfo,
@@ -416,12 +472,21 @@ def materialize_sum(
416
472
  tensor_name_to_qsv: dict[str, Any],
417
473
  ) -> list[qtyping.TensorTransformationParams]:
418
474
  """Materialize tensors in tfl.sum."""
475
+ # For 8 bits the reference kernel calls a function without input/output
476
+ # constraints. For all others it calls a function that enforces input/output
477
+ # scale/zero point checks. See:
478
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/reduce.cc#L909
479
+ activation_config = op_info.op_quant_config.activation_tensor_config
480
+ if activation_config is not None and activation_config.num_bits == 8:
481
+ constraint = _OpQuantConstraint.NO_CONSTRAIN
482
+ else:
483
+ constraint = _OpQuantConstraint.SAME_AS_INPUT_SCALE
419
484
  return common_utils.materialize_standard_op(
420
485
  op_info,
421
486
  graph_info,
422
487
  tensor_name_to_qsv,
423
488
  get_tensor_quant_params_fn,
424
- constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
489
+ constraint=constraint,
425
490
  inputs_to_ignore=[1], # Axis index does not need to be quantized.
426
491
  )
427
492
 
@@ -452,7 +517,13 @@ def materialize_fc_conv(
452
517
  weights, bias).
453
518
  """
454
519
  ignored_inputs = [bias_index] # Bias tensor is quantized separately.
455
- if _are_weights_too_small(op_info, graph_info, weight_index):
520
+ should_ignore_weight = False
521
+ if graph_info:
522
+ w_tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]]
523
+ should_ignore_weight = check_if_quantized(w_tensor)
524
+ if should_ignore_weight or _are_weights_too_small(
525
+ op_info, graph_info, weight_index
526
+ ):
456
527
  ignored_inputs.append(weight_index)
457
528
 
458
529
  op_tensor_params = common_utils.materialize_standard_op(
@@ -463,7 +534,7 @@ def materialize_fc_conv(
463
534
  inputs_to_ignore=ignored_inputs,
464
535
  )
465
536
 
466
- _materialize_bias_for_conv_ops(
537
+ _materialize_bias_for_fc_conv_ops(
467
538
  op_info,
468
539
  graph_info,
469
540
  op_tensor_params,
@@ -518,7 +589,7 @@ def materialize_conv2d_transpose(
518
589
  "Materialize standard op should return at least two tensors for"
519
590
  " conv2d_transpose."
520
591
  )
521
- _materialize_bias_for_conv_ops(
592
+ _materialize_bias_for_fc_conv_ops(
522
593
  op_info,
523
594
  graph_info,
524
595
  op_tensor_params,
@@ -671,6 +742,366 @@ def materialize_split(
671
742
  )
672
743
 
673
744
 
745
+ def materialize_pad(
746
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
747
+ op_info: qtyping.OpInfo,
748
+ graph_info: qtyping.GraphInfo,
749
+ tensor_name_to_qsv: dict[str, Any],
750
+ ) -> list[qtyping.TensorTransformationParams]:
751
+ """Materialize tensors in tfl.pad."""
752
+ return common_utils.materialize_standard_op(
753
+ op_info,
754
+ graph_info,
755
+ tensor_name_to_qsv,
756
+ get_tensor_quant_params_fn,
757
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
758
+ inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
759
+ )
760
+
761
+
762
+ def materialize_padv2(
763
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
764
+ op_info: qtyping.OpInfo,
765
+ graph_info: qtyping.GraphInfo,
766
+ tensor_name_to_qsv: dict[str, Any],
767
+ ) -> list[qtyping.TensorTransformationParams]:
768
+ """Materialize tensors in tfl.padv2."""
769
+ return common_utils.materialize_standard_op(
770
+ op_info,
771
+ graph_info,
772
+ tensor_name_to_qsv,
773
+ get_tensor_quant_params_fn,
774
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
775
+ inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
776
+ )
777
+
778
+
779
+ def materialize_mirror_pad(
780
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
781
+ op_info: qtyping.OpInfo,
782
+ graph_info: qtyping.GraphInfo,
783
+ tensor_name_to_qsv: dict[str, Any],
784
+ ) -> list[qtyping.TensorTransformationParams]:
785
+ """Materialize tensors in tfl.mirror_pad.
786
+
787
+ Args:
788
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
789
+ tensor.
790
+ op_info: Aggregated information about the op (e.g., quantization config).
791
+ graph_info: Graph information needed to perform quantization for the op.
792
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
793
+
794
+ Returns:
795
+ A list of `qtyping.TensorTransformationParams` for the tensors in the op.
796
+ """
797
+ return common_utils.materialize_standard_op(
798
+ op_info,
799
+ graph_info,
800
+ tensor_name_to_qsv,
801
+ get_tensor_quant_params_fn,
802
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
803
+ inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
804
+ )
805
+
806
+
807
+ def materialize_space_to_depth(
808
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
809
+ op_info: qtyping.OpInfo,
810
+ graph_info: qtyping.GraphInfo,
811
+ tensor_name_to_qsv: dict[str, Any],
812
+ ) -> list[qtyping.TensorTransformationParams]:
813
+ """Materialize tensors in tfl.space_to_depth.
814
+
815
+ Args:
816
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
817
+ tensor.
818
+ op_info: Aggregated information about the op (e.g., quantization config).
819
+ graph_info: Graph information needed to perform quantization for the op.
820
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
821
+
822
+ Returns:
823
+ A list of `qtyping.TensorTransformationParams` for the tensors in the op.
824
+ """
825
+ return common_utils.materialize_standard_op(
826
+ op_info,
827
+ graph_info,
828
+ tensor_name_to_qsv,
829
+ get_tensor_quant_params_fn,
830
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
831
+ )
832
+
833
+
834
+ def materialize_squared_difference(
835
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
836
+ op_info: qtyping.OpInfo,
837
+ graph_info: qtyping.GraphInfo,
838
+ tensor_name_to_qsv: dict[str, Any],
839
+ ) -> list[qtyping.TensorTransformationParams]:
840
+ """Materialize tensors in tfl.squared_difference."""
841
+ return common_utils.materialize_standard_op(
842
+ op_info,
843
+ graph_info,
844
+ tensor_name_to_qsv,
845
+ get_tensor_quant_params_fn,
846
+ )
847
+
848
+
849
+ def materialize_max_pool_2d(
850
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
851
+ op_info: qtyping.OpInfo,
852
+ graph_info: qtyping.GraphInfo,
853
+ tensor_name_to_qsv: dict[str, Any],
854
+ ) -> list[qtyping.TensorTransformationParams]:
855
+ """Materialize tensors in tfl.max_pool_2d."""
856
+ return common_utils.materialize_standard_op(
857
+ op_info,
858
+ graph_info,
859
+ tensor_name_to_qsv,
860
+ get_tensor_quant_params_fn,
861
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
862
+ )
863
+
864
+
865
+ def materialize_resize_bilinear(
866
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
867
+ op_info: qtyping.OpInfo,
868
+ graph_info: qtyping.GraphInfo,
869
+ tensor_name_to_qsv: dict[str, Any],
870
+ ) -> list[qtyping.TensorTransformationParams]:
871
+ """Materialize tensors in tfl.resize_bilinear."""
872
+ return common_utils.materialize_standard_op(
873
+ op_info,
874
+ graph_info,
875
+ tensor_name_to_qsv,
876
+ get_tensor_quant_params_fn,
877
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
878
+ inputs_to_ignore=[1], # Resize size does not need to be quantized.
879
+ )
880
+
881
+
882
+ def materialize_resize_nearest_neighbor(
883
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
884
+ op_info: qtyping.OpInfo,
885
+ graph_info: qtyping.GraphInfo,
886
+ tensor_name_to_qsv: dict[str, Any],
887
+ ) -> list[qtyping.TensorTransformationParams]:
888
+ """Materialize tensors in tfl.resize_nearest_neighbor."""
889
+ return common_utils.materialize_standard_op(
890
+ op_info,
891
+ graph_info,
892
+ tensor_name_to_qsv,
893
+ get_tensor_quant_params_fn,
894
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
895
+ inputs_to_ignore=[1], # Resize size does not need to be quantized.
896
+ )
897
+
898
+
899
+ def materialize_gather_nd(
900
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
901
+ op_info: qtyping.OpInfo,
902
+ graph_info: qtyping.GraphInfo,
903
+ tensor_name_to_qsv: dict[str, Any],
904
+ ) -> list[qtyping.TensorTransformationParams]:
905
+ """Materialize tensors in tfl.gather_nd."""
906
+ return common_utils.materialize_standard_op(
907
+ op_info,
908
+ graph_info,
909
+ tensor_name_to_qsv,
910
+ get_tensor_quant_params_fn,
911
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
912
+ inputs_to_ignore=[1], # Gather indices do not need to be quantized.
913
+ )
914
+
915
+
916
+ def materialize_maximum(
917
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
918
+ op_info: qtyping.OpInfo,
919
+ graph_info: qtyping.GraphInfo,
920
+ tensor_name_to_qsv: dict[str, Any],
921
+ ) -> list[qtyping.TensorTransformationParams]:
922
+ """Materialize tensors in tfl.maximum."""
923
+ return common_utils.materialize_standard_op(
924
+ op_info,
925
+ graph_info,
926
+ tensor_name_to_qsv,
927
+ get_tensor_quant_params_fn,
928
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
929
+ )
930
+
931
+
932
+ def materialize_pack(
933
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
934
+ op_info: qtyping.OpInfo,
935
+ graph_info: qtyping.GraphInfo,
936
+ tensor_name_to_qsv: dict[str, Any],
937
+ ) -> list[qtyping.TensorTransformationParams]:
938
+ """Materialize tensors in tfl.pack."""
939
+ return common_utils.materialize_standard_op(
940
+ op_info,
941
+ graph_info,
942
+ tensor_name_to_qsv,
943
+ get_tensor_quant_params_fn,
944
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
945
+ )
946
+
947
+
948
+ def materialize_unpack(
949
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
950
+ op_info: qtyping.OpInfo,
951
+ graph_info: qtyping.GraphInfo,
952
+ tensor_name_to_qsv: dict[str, Any],
953
+ ) -> list[qtyping.TensorTransformationParams]:
954
+ """Materialize tensors in tfl.unpack."""
955
+ return common_utils.materialize_standard_op(
956
+ op_info,
957
+ graph_info,
958
+ tensor_name_to_qsv,
959
+ get_tensor_quant_params_fn,
960
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
961
+ )
962
+
963
+
964
+ def materialize_div(
965
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
966
+ op_info: qtyping.OpInfo,
967
+ graph_info: qtyping.GraphInfo,
968
+ tensor_name_to_qsv: dict[str, Any],
969
+ ) -> list[qtyping.TensorTransformationParams]:
970
+ """Materialize tensors in tfl.div."""
971
+ return common_utils.materialize_standard_op(
972
+ op_info,
973
+ graph_info,
974
+ tensor_name_to_qsv,
975
+ get_tensor_quant_params_fn,
976
+ )
977
+
978
+
979
+ def materialize_broadcast_to(
980
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
981
+ op_info: qtyping.OpInfo,
982
+ graph_info: qtyping.GraphInfo,
983
+ tensor_name_to_qsv: dict[str, Any],
984
+ ) -> list[qtyping.TensorTransformationParams]:
985
+ """Materialize tensors in tfl.broadcast_to."""
986
+ return common_utils.materialize_standard_op(
987
+ op_info,
988
+ graph_info,
989
+ tensor_name_to_qsv,
990
+ get_tensor_quant_params_fn,
991
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
992
+ inputs_to_ignore=[1], # Shape tensor does not need to be quantized.
993
+ )
994
+
995
+
996
+ def materialize_sqrt(
997
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
998
+ op_info: qtyping.OpInfo,
999
+ graph_info: qtyping.GraphInfo,
1000
+ tensor_name_to_qsv: dict[str, Any],
1001
+ ) -> list[qtyping.TensorTransformationParams]:
1002
+ """Materialize tensors in tfl.sqrt."""
1003
+ return common_utils.materialize_standard_op(
1004
+ op_info,
1005
+ graph_info,
1006
+ tensor_name_to_qsv,
1007
+ get_tensor_quant_params_fn,
1008
+ )
1009
+
1010
+
1011
+ def materialize_hard_swish(
1012
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1013
+ op_info: qtyping.OpInfo,
1014
+ graph_info: qtyping.GraphInfo,
1015
+ tensor_name_to_qsv: dict[str, Any],
1016
+ ) -> list[qtyping.TensorTransformationParams]:
1017
+ """Materialize tensors in tfl.hard_swish."""
1018
+ return common_utils.materialize_standard_op(
1019
+ op_info,
1020
+ graph_info,
1021
+ tensor_name_to_qsv,
1022
+ get_tensor_quant_params_fn,
1023
+ )
1024
+
1025
+
1026
+ def materialize_gather(
1027
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1028
+ op_info: qtyping.OpInfo,
1029
+ graph_info: qtyping.GraphInfo,
1030
+ tensor_name_to_qsv: dict[str, Any],
1031
+ ) -> list[qtyping.TensorTransformationParams]:
1032
+ """Materialize tensors in tfl.gather."""
1033
+ return common_utils.materialize_standard_op(
1034
+ op_info,
1035
+ graph_info,
1036
+ tensor_name_to_qsv,
1037
+ get_tensor_quant_params_fn,
1038
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
1039
+ inputs_to_ignore=[1], # Indices do not need to be quantized.
1040
+ )
1041
+
1042
+
1043
+ def materialize_reduce_min(
1044
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1045
+ op_info: qtyping.OpInfo,
1046
+ graph_info: qtyping.GraphInfo,
1047
+ tensor_name_to_qsv: dict[str, Any],
1048
+ ) -> list[qtyping.TensorTransformationParams]:
1049
+ """Materialize tensors in tfl.reduce_min."""
1050
+ return common_utils.materialize_standard_op(
1051
+ op_info,
1052
+ graph_info,
1053
+ tensor_name_to_qsv,
1054
+ get_tensor_quant_params_fn,
1055
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
1056
+ inputs_to_ignore=[1], # Axis index does not need to be quantized.
1057
+ )
1058
+
1059
+
1060
+ def materialize_equal(
1061
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1062
+ op_info: qtyping.OpInfo,
1063
+ graph_info: qtyping.GraphInfo,
1064
+ tensor_name_to_qsv: dict[str, Any],
1065
+ ) -> list[qtyping.TensorTransformationParams]:
1066
+ """Materialize tensors in tfl.equal."""
1067
+ return common_utils.materialize_standard_op(
1068
+ op_info,
1069
+ graph_info,
1070
+ tensor_name_to_qsv,
1071
+ get_tensor_quant_params_fn,
1072
+ )
1073
+
1074
+
1075
+ def materialize_not_equal(
1076
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1077
+ op_info: qtyping.OpInfo,
1078
+ graph_info: qtyping.GraphInfo,
1079
+ tensor_name_to_qsv: dict[str, Any],
1080
+ ) -> list[qtyping.TensorTransformationParams]:
1081
+ """Materialize tensors in tfl.not_equal."""
1082
+ return common_utils.materialize_standard_op(
1083
+ op_info,
1084
+ graph_info,
1085
+ tensor_name_to_qsv,
1086
+ get_tensor_quant_params_fn,
1087
+ )
1088
+
1089
+
1090
+ def materialize_relu(
1091
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1092
+ op_info: qtyping.OpInfo,
1093
+ graph_info: qtyping.GraphInfo,
1094
+ tensor_name_to_qsv: dict[str, Any],
1095
+ ) -> list[qtyping.TensorTransformationParams]:
1096
+ """Materialize tensors in tfl.relu."""
1097
+ return common_utils.materialize_standard_op(
1098
+ op_info,
1099
+ graph_info,
1100
+ tensor_name_to_qsv,
1101
+ get_tensor_quant_params_fn,
1102
+ )
1103
+
1104
+
674
1105
  def _get_tensor_shape_for_blockwise(
675
1106
  tensor_shape: Sequence[int], quantized_dim: int, block_size: int
676
1107
  ) -> list[int]:
@@ -700,18 +1131,29 @@ def _get_tensor_shape_for_blockwise(
700
1131
 
701
1132
 
702
1133
  def _reshape_data_for_blockwise(
703
- tensor_data: np.ndarray, quantized_dim: int, block_size: int
1134
+ tensor_data: np.ndarray,
1135
+ quantized_dim: int,
1136
+ block_size: int,
704
1137
  ) -> tuple[np.ndarray, int]:
705
1138
  """Reshapes data for blockwise quantization.
706
1139
 
707
1140
  Args:
708
1141
  tensor_data: The original tensor data.
709
1142
  quantized_dim: The dimension to be quantized blockwise.
710
- block_size: The size of the block.
1143
+ block_size: The size of the block. `block_size must be a multiple of 32. `
1144
+ `The tensor quantized dimension shape must be divisible by block_size.
711
1145
 
712
1146
  Returns:
713
1147
  A tuple containing the reshaped tensor data and the new reduce dimension.
714
1148
  """
1149
+
1150
+ # TODO: b/417508018 - create AEQ specific error class instead of
1151
+ # using generic ValueError.
1152
+ if tensor_data.shape[quantized_dim] % block_size != 0:
1153
+ raise ValueError(
1154
+ "Tensor quantization dimension must be divisible by block size for"
1155
+ " blockwise quantization."
1156
+ )
715
1157
  new_shape = _get_tensor_shape_for_blockwise(
716
1158
  tensor_data.shape, quantized_dim, block_size
717
1159
  )
@@ -783,42 +1225,36 @@ def init_tensor_min_max(
783
1225
  A dictionary containing the min/max values for the tensor, or an empty
784
1226
  dictionary if the tensor data is None.
785
1227
  """
786
- if tensor_data is None:
1228
+ weight_tensor_config = op_info.op_quant_config.weight_tensor_config
1229
+ if tensor_data is None or weight_tensor_config is None:
787
1230
  return {}
788
1231
  else:
789
- weight_tensor_config = op_info.op_quant_config.weight_tensor_config
790
- quantized_dim = None
791
- if weight_tensor_config is not None and (
792
- weight_tensor_config.granularity == qtyping.QuantGranularity.CHANNELWISE
793
- ):
1232
+ # Get reduce dimension for min/max calculation based on quantization
1233
+ # granularity.
1234
+ granularity = weight_tensor_config.granularity
1235
+ if granularity == qtyping.QuantGranularity.TENSORWISE:
1236
+ reduce_dims = None
1237
+ keep_dims = True
1238
+ elif granularity == qtyping.QuantGranularity.CHANNELWISE:
794
1239
  quantized_dim = common_utils.get_weight_quantized_dim(
795
- op_info, tensor_data
796
- )
797
- if (
798
- weight_tensor_config is not None
799
- and weight_tensor_config.granularity
800
- == qtyping.QuantGranularity.BLOCKWISE
801
- ):
802
- quantized_dim = (
803
- tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
804
- op_info.op_name
805
- ]
806
- )
807
- reshaped_data, reduce_dims = _reshape_data_for_blockwise(
808
- tensor_data,
809
- quantized_dim,
810
- weight_tensor_config.block_size,
1240
+ op_info, tensor_data, weight_tensor_config.granularity
811
1241
  )
812
- return {
813
- "min": np.min(reshaped_data, axis=reduce_dims, keepdims=False),
814
- "max": np.max(reshaped_data, axis=reduce_dims, keepdims=False),
815
- }
816
-
817
- else:
818
1242
  reduce_dims = common_utils.get_reduce_dims(
819
1243
  quantized_dim, tensor_data.shape
820
1244
  )
821
- return {
822
- "min": np.min(tensor_data, axis=reduce_dims, keepdims=True),
823
- "max": np.max(tensor_data, axis=reduce_dims, keepdims=True),
824
- }
1245
+ keep_dims = True
1246
+ elif uniform_quantize_tensor.is_blockwise(granularity):
1247
+ tensor_data, reduce_dims = (
1248
+ uniform_quantize_tensor.reshape_data_for_blockwise(
1249
+ tensor_data,
1250
+ op_info.op_name,
1251
+ granularity,
1252
+ )
1253
+ )
1254
+ keep_dims = False
1255
+ else:
1256
+ raise ValueError(f"Unsupported granularity: {granularity}")
1257
+ return {
1258
+ "min": np.min(tensor_data, axis=reduce_dims, keepdims=keep_dims),
1259
+ "max": np.max(tensor_data, axis=reduce_dims, keepdims=keep_dims),
1260
+ }
@@ -31,8 +31,7 @@ _TensorQuantConfig = qtyping.TensorQuantizationConfig
31
31
 
32
32
 
33
33
  class CommonQuantizeTest(parameterized.TestCase):
34
- """Tests for general quantize functions.
35
- """
34
+ """Tests for general quantize functions."""
36
35
 
37
36
  def setUp(self):
38
37
  super().setUp()
@@ -69,6 +68,34 @@ class CommonQuantizeTest(parameterized.TestCase):
69
68
  default_policy.DEFAULT_CONFIG_CHECK_POLICY,
70
69
  )
71
70
 
71
+ def test_reshape_data_for_blockwise_raises_error_when_quantized_dim_not_divisible_by_block_size(
72
+ self,
73
+ ):
74
+ tensor_data = np.ones((24, 128), dtype=np.float32)
75
+ block_size = 256
76
+ quantized_dim = 1
77
+ with self.assertRaisesWithPredicateMatch(
78
+ ValueError,
79
+ lambda err: (
80
+ "Tensor quantization dimension must be divisible by block"
81
+ " size for blockwise quantization."
82
+ )
83
+ in str(err),
84
+ ):
85
+ common_quantize._reshape_data_for_blockwise(
86
+ tensor_data, quantized_dim, block_size
87
+ )
88
+
89
+ def test_reshape_data_for_blockwise_returns_correct_values(self):
90
+ tensor_data = np.ones((24, 128), dtype=np.float32)
91
+ block_size = 32
92
+ quantized_dim = 1
93
+ new_tensor_data, reduce_dim = common_quantize._reshape_data_for_blockwise(
94
+ tensor_data, quantized_dim, block_size
95
+ )
96
+ self.assertEqual(new_tensor_data.shape, (24, 4, 32))
97
+ self.assertEqual(reduce_dim, 2)
98
+
72
99
 
73
100
  if __name__ == "__main__":
74
101
  googletest.main()