ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ to implement the get_tensor_quant_params_fn with the
23
23
  qtyping.GetTensorQuantParamsFuncSignature signature.
24
24
  """
25
25
 
26
- from typing import Any
26
+ from typing import Any, Optional, Sequence
27
27
  import numpy as np
28
28
  from ai_edge_quantizer import qtyping
29
29
  from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
@@ -36,6 +36,13 @@ _OpQuantConstraint = common_utils.OpQuantConstraint
36
36
  _ComputePrecision = qtyping.ComputePrecision
37
37
 
38
38
 
39
+ def check_if_quantized(tensor: Any) -> bool:
40
+ """Checks if the tensor is quantized."""
41
+ return (
42
+ tensor.quantization is not None and tensor.quantization.scale is not None
43
+ )
44
+
45
+
39
46
  def check_op_quantization_config(
40
47
  op_name: _TFLOpName,
41
48
  op_quant_config: qtyping.OpQuantizationConfig,
@@ -110,6 +117,21 @@ def materialize_output(
110
117
  )
111
118
 
112
119
 
120
+ def materialize_composite(
121
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
122
+ op_info: qtyping.OpInfo,
123
+ graph_info: qtyping.GraphInfo,
124
+ tensor_name_to_qsv: dict[str, Any],
125
+ ) -> list[qtyping.TensorTransformationParams]:
126
+ """Materialize tensors in the virtual output op."""
127
+ return common_utils.materialize_standard_op(
128
+ op_info,
129
+ graph_info,
130
+ tensor_name_to_qsv,
131
+ get_tensor_quant_params_fn,
132
+ )
133
+
134
+
113
135
  def materialize_add(
114
136
  get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
115
137
  op_info: qtyping.OpInfo,
@@ -256,7 +278,7 @@ def materialize_average_pool_2d(
256
278
  )
257
279
 
258
280
 
259
- def _materialize_bias_for_conv_ops(
281
+ def _materialize_bias_for_fc_conv_ops(
260
282
  op_info: qtyping.OpInfo,
261
283
  graph_info: qtyping.GraphInfo,
262
284
  op_tensor_params: list[qtyping.TensorTransformationParams],
@@ -275,14 +297,16 @@ def _materialize_bias_for_conv_ops(
275
297
  op_weight_index: Index for the weight tensor in the op.
276
298
  op_bias_index: Index for the bias tensor in the op.
277
299
  """
278
- _, _, bias_tensor, _ = tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
279
- op_info.op,
280
- graph_info.subgraph_tensors,
281
- op_input_index,
282
- op_weight_index,
283
- op_bias_index,
284
- )
285
- if bias_tensor is not None:
300
+ _, weight_tensor, bias_tensor, _ = (
301
+ tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
302
+ op_info.op,
303
+ graph_info.subgraph_tensors,
304
+ op_input_index,
305
+ op_weight_index,
306
+ op_bias_index,
307
+ )
308
+ )
309
+ if bias_tensor is not None and not check_if_quantized(bias_tensor):
286
310
  bias_quant_params = None
287
311
  # Fused bias needs to be quantized for SRQ.
288
312
  # Check if SRQ.
@@ -294,13 +318,41 @@ def _materialize_bias_for_conv_ops(
294
318
  bias_tensor,
295
319
  graph_info.buffers,
296
320
  )
297
- bias_quant_params = (
298
- uniform_quantize_tensor.symmetric_quantize_bias_tensor(
299
- bias_content,
300
- op_tensor_params[op_input_index].consumers[0].parameters,
301
- op_tensor_params[op_weight_index].consumers[0].parameters,
302
- )
321
+ input_consumer_params = (
322
+ op_tensor_params[op_input_index].consumers[0].parameters
323
+ )
324
+ weight_consumer_params = (
325
+ op_tensor_params[op_weight_index].consumers[0].parameters
303
326
  )
327
+ if weight_consumer_params is None and check_if_quantized(weight_tensor):
328
+ quant_params = weight_tensor.quantization
329
+ if op_info.op_quant_config.weight_tensor_config is None:
330
+ raise ValueError(
331
+ "weight_tensor_config cannot be None when weight tensor is"
332
+ " quantized."
333
+ )
334
+ weight_consumer_params = qtyping.UniformQuantParams(
335
+ num_bits=op_info.op_quant_config.weight_tensor_config.num_bits,
336
+ scale=quant_params.scale,
337
+ zero_point=quant_params.zeroPoint,
338
+ quantized_dimension=quant_params.quantizedDimension,
339
+ )
340
+ try:
341
+ # Bias quantization is using fixed quantization scale:
342
+ # input_scale * weight_scale. To avoid hidden numerics error, we check
343
+ # the quantization error in bias quantization.
344
+ bias_quant_params = (
345
+ uniform_quantize_tensor.symmetric_quantize_bias_tensor(
346
+ bias_content,
347
+ input_consumer_params,
348
+ weight_consumer_params,
349
+ )
350
+ )
351
+ except ValueError as e:
352
+ raise ValueError(
353
+ f"Failed to quantize bias tensor for op {op_info.op_name} with op"
354
+ f" id {op_info.subgraph_op_index}."
355
+ ) from e
304
356
  # We only quantize bias under SRQ. Setting is_constant=True for SRQ only
305
357
  # to avoid quantize bias for DRQ and weight-only cases.
306
358
  is_constant = (
@@ -356,6 +408,25 @@ def materialize_slice(
356
408
  )
357
409
 
358
410
 
411
+ def materialize_select(
412
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
413
+ op_info: qtyping.OpInfo,
414
+ graph_info: qtyping.GraphInfo,
415
+ tensor_name_to_qsv: dict[str, Any],
416
+ ) -> list[qtyping.TensorTransformationParams]:
417
+ """Materialize tensors in tfl.select."""
418
+ return common_utils.materialize_standard_op(
419
+ op_info,
420
+ graph_info,
421
+ tensor_name_to_qsv,
422
+ get_tensor_quant_params_fn,
423
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
424
+ inputs_to_ignore=[
425
+ 0,
426
+ ], # Condition tensor does not need to be quantized.
427
+ )
428
+
429
+
359
430
  def materialize_select_v2(
360
431
  get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
361
432
  op_info: qtyping.OpInfo,
@@ -375,6 +446,25 @@ def materialize_select_v2(
375
446
  )
376
447
 
377
448
 
449
+ def materialize_dynamic_update_slice(
450
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
451
+ op_info: qtyping.OpInfo,
452
+ graph_info: qtyping.GraphInfo,
453
+ tensor_name_to_qsv: dict[str, Any],
454
+ ) -> list[qtyping.TensorTransformationParams]:
455
+ """Materialize tensors in tfl.dynamic_update_slice."""
456
+ return common_utils.materialize_standard_op(
457
+ op_info,
458
+ graph_info,
459
+ tensor_name_to_qsv,
460
+ get_tensor_quant_params_fn,
461
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
462
+ inputs_to_ignore=[
463
+ 2,
464
+ ], # start_indices do not need to be quantized.
465
+ )
466
+
467
+
378
468
  def materialize_sum(
379
469
  get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
380
470
  op_info: qtyping.OpInfo,
@@ -382,12 +472,21 @@ def materialize_sum(
382
472
  tensor_name_to_qsv: dict[str, Any],
383
473
  ) -> list[qtyping.TensorTransformationParams]:
384
474
  """Materialize tensors in tfl.sum."""
475
+ # For 8 bits the reference kernel calls a function without input/output
476
+ # constraints. For all others it calls a function that enforces input/output
477
+ # scale/zero point checks. See:
478
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/reduce.cc#L909
479
+ activation_config = op_info.op_quant_config.activation_tensor_config
480
+ if activation_config is not None and activation_config.num_bits == 8:
481
+ constraint = _OpQuantConstraint.NO_CONSTRAIN
482
+ else:
483
+ constraint = _OpQuantConstraint.SAME_AS_INPUT_SCALE
385
484
  return common_utils.materialize_standard_op(
386
485
  op_info,
387
486
  graph_info,
388
487
  tensor_name_to_qsv,
389
488
  get_tensor_quant_params_fn,
390
- constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
489
+ constraint=constraint,
391
490
  inputs_to_ignore=[1], # Axis index does not need to be quantized.
392
491
  )
393
492
 
@@ -418,7 +517,13 @@ def materialize_fc_conv(
418
517
  weights, bias).
419
518
  """
420
519
  ignored_inputs = [bias_index] # Bias tensor is quantized separately.
421
- if _are_weights_too_small(op_info, graph_info, weight_index):
520
+ should_ignore_weight = False
521
+ if graph_info:
522
+ w_tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]]
523
+ should_ignore_weight = check_if_quantized(w_tensor)
524
+ if should_ignore_weight or _are_weights_too_small(
525
+ op_info, graph_info, weight_index
526
+ ):
422
527
  ignored_inputs.append(weight_index)
423
528
 
424
529
  op_tensor_params = common_utils.materialize_standard_op(
@@ -429,7 +534,7 @@ def materialize_fc_conv(
429
534
  inputs_to_ignore=ignored_inputs,
430
535
  )
431
536
 
432
- _materialize_bias_for_conv_ops(
537
+ _materialize_bias_for_fc_conv_ops(
433
538
  op_info,
434
539
  graph_info,
435
540
  op_tensor_params,
@@ -484,7 +589,7 @@ def materialize_conv2d_transpose(
484
589
  "Materialize standard op should return at least two tensors for"
485
590
  " conv2d_transpose."
486
591
  )
487
- _materialize_bias_for_conv_ops(
592
+ _materialize_bias_for_fc_conv_ops(
488
593
  op_info,
489
594
  graph_info,
490
595
  op_tensor_params,
@@ -635,3 +740,521 @@ def materialize_split(
635
740
  constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
636
741
  inputs_to_ignore=[0], # Split dimension does not need to be quantized.
637
742
  )
743
+
744
+
745
+ def materialize_pad(
746
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
747
+ op_info: qtyping.OpInfo,
748
+ graph_info: qtyping.GraphInfo,
749
+ tensor_name_to_qsv: dict[str, Any],
750
+ ) -> list[qtyping.TensorTransformationParams]:
751
+ """Materialize tensors in tfl.pad."""
752
+ return common_utils.materialize_standard_op(
753
+ op_info,
754
+ graph_info,
755
+ tensor_name_to_qsv,
756
+ get_tensor_quant_params_fn,
757
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
758
+ inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
759
+ )
760
+
761
+
762
+ def materialize_padv2(
763
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
764
+ op_info: qtyping.OpInfo,
765
+ graph_info: qtyping.GraphInfo,
766
+ tensor_name_to_qsv: dict[str, Any],
767
+ ) -> list[qtyping.TensorTransformationParams]:
768
+ """Materialize tensors in tfl.padv2."""
769
+ return common_utils.materialize_standard_op(
770
+ op_info,
771
+ graph_info,
772
+ tensor_name_to_qsv,
773
+ get_tensor_quant_params_fn,
774
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
775
+ inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
776
+ )
777
+
778
+
779
+ def materialize_mirror_pad(
780
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
781
+ op_info: qtyping.OpInfo,
782
+ graph_info: qtyping.GraphInfo,
783
+ tensor_name_to_qsv: dict[str, Any],
784
+ ) -> list[qtyping.TensorTransformationParams]:
785
+ """Materialize tensors in tfl.mirror_pad.
786
+
787
+ Args:
788
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
789
+ tensor.
790
+ op_info: Aggregated information about the op (e.g., quantization config).
791
+ graph_info: Graph information needed to perform quantization for the op.
792
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
793
+
794
+ Returns:
795
+ A list of `qtyping.TensorTransformationParams` for the tensors in the op.
796
+ """
797
+ return common_utils.materialize_standard_op(
798
+ op_info,
799
+ graph_info,
800
+ tensor_name_to_qsv,
801
+ get_tensor_quant_params_fn,
802
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
803
+ inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
804
+ )
805
+
806
+
807
+ def materialize_space_to_depth(
808
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
809
+ op_info: qtyping.OpInfo,
810
+ graph_info: qtyping.GraphInfo,
811
+ tensor_name_to_qsv: dict[str, Any],
812
+ ) -> list[qtyping.TensorTransformationParams]:
813
+ """Materialize tensors in tfl.space_to_depth.
814
+
815
+ Args:
816
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
817
+ tensor.
818
+ op_info: Aggregated information about the op (e.g., quantization config).
819
+ graph_info: Graph information needed to perform quantization for the op.
820
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
821
+
822
+ Returns:
823
+ A list of `qtyping.TensorTransformationParams` for the tensors in the op.
824
+ """
825
+ return common_utils.materialize_standard_op(
826
+ op_info,
827
+ graph_info,
828
+ tensor_name_to_qsv,
829
+ get_tensor_quant_params_fn,
830
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
831
+ )
832
+
833
+
834
+ def materialize_squared_difference(
835
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
836
+ op_info: qtyping.OpInfo,
837
+ graph_info: qtyping.GraphInfo,
838
+ tensor_name_to_qsv: dict[str, Any],
839
+ ) -> list[qtyping.TensorTransformationParams]:
840
+ """Materialize tensors in tfl.squared_difference."""
841
+ return common_utils.materialize_standard_op(
842
+ op_info,
843
+ graph_info,
844
+ tensor_name_to_qsv,
845
+ get_tensor_quant_params_fn,
846
+ )
847
+
848
+
849
+ def materialize_max_pool_2d(
850
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
851
+ op_info: qtyping.OpInfo,
852
+ graph_info: qtyping.GraphInfo,
853
+ tensor_name_to_qsv: dict[str, Any],
854
+ ) -> list[qtyping.TensorTransformationParams]:
855
+ """Materialize tensors in tfl.max_pool_2d."""
856
+ return common_utils.materialize_standard_op(
857
+ op_info,
858
+ graph_info,
859
+ tensor_name_to_qsv,
860
+ get_tensor_quant_params_fn,
861
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
862
+ )
863
+
864
+
865
+ def materialize_resize_bilinear(
866
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
867
+ op_info: qtyping.OpInfo,
868
+ graph_info: qtyping.GraphInfo,
869
+ tensor_name_to_qsv: dict[str, Any],
870
+ ) -> list[qtyping.TensorTransformationParams]:
871
+ """Materialize tensors in tfl.resize_bilinear."""
872
+ return common_utils.materialize_standard_op(
873
+ op_info,
874
+ graph_info,
875
+ tensor_name_to_qsv,
876
+ get_tensor_quant_params_fn,
877
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
878
+ inputs_to_ignore=[1], # Resize size does not need to be quantized.
879
+ )
880
+
881
+
882
+ def materialize_resize_nearest_neighbor(
883
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
884
+ op_info: qtyping.OpInfo,
885
+ graph_info: qtyping.GraphInfo,
886
+ tensor_name_to_qsv: dict[str, Any],
887
+ ) -> list[qtyping.TensorTransformationParams]:
888
+ """Materialize tensors in tfl.resize_nearest_neighbor."""
889
+ return common_utils.materialize_standard_op(
890
+ op_info,
891
+ graph_info,
892
+ tensor_name_to_qsv,
893
+ get_tensor_quant_params_fn,
894
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
895
+ inputs_to_ignore=[1], # Resize size does not need to be quantized.
896
+ )
897
+
898
+
899
+ def materialize_gather_nd(
900
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
901
+ op_info: qtyping.OpInfo,
902
+ graph_info: qtyping.GraphInfo,
903
+ tensor_name_to_qsv: dict[str, Any],
904
+ ) -> list[qtyping.TensorTransformationParams]:
905
+ """Materialize tensors in tfl.gather_nd."""
906
+ return common_utils.materialize_standard_op(
907
+ op_info,
908
+ graph_info,
909
+ tensor_name_to_qsv,
910
+ get_tensor_quant_params_fn,
911
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
912
+ inputs_to_ignore=[1], # Gather indices do not need to be quantized.
913
+ )
914
+
915
+
916
+ def materialize_maximum(
917
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
918
+ op_info: qtyping.OpInfo,
919
+ graph_info: qtyping.GraphInfo,
920
+ tensor_name_to_qsv: dict[str, Any],
921
+ ) -> list[qtyping.TensorTransformationParams]:
922
+ """Materialize tensors in tfl.maximum."""
923
+ return common_utils.materialize_standard_op(
924
+ op_info,
925
+ graph_info,
926
+ tensor_name_to_qsv,
927
+ get_tensor_quant_params_fn,
928
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
929
+ )
930
+
931
+
932
+ def materialize_pack(
933
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
934
+ op_info: qtyping.OpInfo,
935
+ graph_info: qtyping.GraphInfo,
936
+ tensor_name_to_qsv: dict[str, Any],
937
+ ) -> list[qtyping.TensorTransformationParams]:
938
+ """Materialize tensors in tfl.pack."""
939
+ return common_utils.materialize_standard_op(
940
+ op_info,
941
+ graph_info,
942
+ tensor_name_to_qsv,
943
+ get_tensor_quant_params_fn,
944
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
945
+ )
946
+
947
+
948
+ def materialize_unpack(
949
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
950
+ op_info: qtyping.OpInfo,
951
+ graph_info: qtyping.GraphInfo,
952
+ tensor_name_to_qsv: dict[str, Any],
953
+ ) -> list[qtyping.TensorTransformationParams]:
954
+ """Materialize tensors in tfl.unpack."""
955
+ return common_utils.materialize_standard_op(
956
+ op_info,
957
+ graph_info,
958
+ tensor_name_to_qsv,
959
+ get_tensor_quant_params_fn,
960
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
961
+ )
962
+
963
+
964
+ def materialize_div(
965
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
966
+ op_info: qtyping.OpInfo,
967
+ graph_info: qtyping.GraphInfo,
968
+ tensor_name_to_qsv: dict[str, Any],
969
+ ) -> list[qtyping.TensorTransformationParams]:
970
+ """Materialize tensors in tfl.div."""
971
+ return common_utils.materialize_standard_op(
972
+ op_info,
973
+ graph_info,
974
+ tensor_name_to_qsv,
975
+ get_tensor_quant_params_fn,
976
+ )
977
+
978
+
979
+ def materialize_broadcast_to(
980
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
981
+ op_info: qtyping.OpInfo,
982
+ graph_info: qtyping.GraphInfo,
983
+ tensor_name_to_qsv: dict[str, Any],
984
+ ) -> list[qtyping.TensorTransformationParams]:
985
+ """Materialize tensors in tfl.broadcast_to."""
986
+ return common_utils.materialize_standard_op(
987
+ op_info,
988
+ graph_info,
989
+ tensor_name_to_qsv,
990
+ get_tensor_quant_params_fn,
991
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
992
+ inputs_to_ignore=[1], # Shape tensor does not need to be quantized.
993
+ )
994
+
995
+
996
+ def materialize_sqrt(
997
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
998
+ op_info: qtyping.OpInfo,
999
+ graph_info: qtyping.GraphInfo,
1000
+ tensor_name_to_qsv: dict[str, Any],
1001
+ ) -> list[qtyping.TensorTransformationParams]:
1002
+ """Materialize tensors in tfl.sqrt."""
1003
+ return common_utils.materialize_standard_op(
1004
+ op_info,
1005
+ graph_info,
1006
+ tensor_name_to_qsv,
1007
+ get_tensor_quant_params_fn,
1008
+ )
1009
+
1010
+
1011
+ def materialize_hard_swish(
1012
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1013
+ op_info: qtyping.OpInfo,
1014
+ graph_info: qtyping.GraphInfo,
1015
+ tensor_name_to_qsv: dict[str, Any],
1016
+ ) -> list[qtyping.TensorTransformationParams]:
1017
+ """Materialize tensors in tfl.hard_swish."""
1018
+ return common_utils.materialize_standard_op(
1019
+ op_info,
1020
+ graph_info,
1021
+ tensor_name_to_qsv,
1022
+ get_tensor_quant_params_fn,
1023
+ )
1024
+
1025
+
1026
+ def materialize_gather(
1027
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1028
+ op_info: qtyping.OpInfo,
1029
+ graph_info: qtyping.GraphInfo,
1030
+ tensor_name_to_qsv: dict[str, Any],
1031
+ ) -> list[qtyping.TensorTransformationParams]:
1032
+ """Materialize tensors in tfl.gather."""
1033
+ return common_utils.materialize_standard_op(
1034
+ op_info,
1035
+ graph_info,
1036
+ tensor_name_to_qsv,
1037
+ get_tensor_quant_params_fn,
1038
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
1039
+ inputs_to_ignore=[1], # Indices do not need to be quantized.
1040
+ )
1041
+
1042
+
1043
+ def materialize_reduce_min(
1044
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1045
+ op_info: qtyping.OpInfo,
1046
+ graph_info: qtyping.GraphInfo,
1047
+ tensor_name_to_qsv: dict[str, Any],
1048
+ ) -> list[qtyping.TensorTransformationParams]:
1049
+ """Materialize tensors in tfl.reduce_min."""
1050
+ return common_utils.materialize_standard_op(
1051
+ op_info,
1052
+ graph_info,
1053
+ tensor_name_to_qsv,
1054
+ get_tensor_quant_params_fn,
1055
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
1056
+ inputs_to_ignore=[1], # Axis index does not need to be quantized.
1057
+ )
1058
+
1059
+
1060
+ def materialize_equal(
1061
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1062
+ op_info: qtyping.OpInfo,
1063
+ graph_info: qtyping.GraphInfo,
1064
+ tensor_name_to_qsv: dict[str, Any],
1065
+ ) -> list[qtyping.TensorTransformationParams]:
1066
+ """Materialize tensors in tfl.equal."""
1067
+ return common_utils.materialize_standard_op(
1068
+ op_info,
1069
+ graph_info,
1070
+ tensor_name_to_qsv,
1071
+ get_tensor_quant_params_fn,
1072
+ )
1073
+
1074
+
1075
+ def materialize_not_equal(
1076
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1077
+ op_info: qtyping.OpInfo,
1078
+ graph_info: qtyping.GraphInfo,
1079
+ tensor_name_to_qsv: dict[str, Any],
1080
+ ) -> list[qtyping.TensorTransformationParams]:
1081
+ """Materialize tensors in tfl.not_equal."""
1082
+ return common_utils.materialize_standard_op(
1083
+ op_info,
1084
+ graph_info,
1085
+ tensor_name_to_qsv,
1086
+ get_tensor_quant_params_fn,
1087
+ )
1088
+
1089
+
1090
+ def materialize_relu(
1091
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
1092
+ op_info: qtyping.OpInfo,
1093
+ graph_info: qtyping.GraphInfo,
1094
+ tensor_name_to_qsv: dict[str, Any],
1095
+ ) -> list[qtyping.TensorTransformationParams]:
1096
+ """Materialize tensors in tfl.relu."""
1097
+ return common_utils.materialize_standard_op(
1098
+ op_info,
1099
+ graph_info,
1100
+ tensor_name_to_qsv,
1101
+ get_tensor_quant_params_fn,
1102
+ )
1103
+
1104
+
1105
+ def _get_tensor_shape_for_blockwise(
1106
+ tensor_shape: Sequence[int], quantized_dim: int, block_size: int
1107
+ ) -> list[int]:
1108
+ """Get the tensor shape for blockwise quantization.
1109
+
1110
+ This function splits the quantize dimension of the tensor into blocks and the
1111
+ dim/blocks. Hence, min/max of the tensor can be calculated for each block
1112
+ using existing functions.
1113
+
1114
+ Args:
1115
+ tensor_shape: The original shape of the tensor.
1116
+ quantized_dim: The dimension to be quantized blockwise.
1117
+ block_size: The size of the block.
1118
+
1119
+ Returns:
1120
+ The new tensor shape for calculating scale and zp for blockwise
1121
+ quantization.
1122
+ """
1123
+ new_shape = []
1124
+ for index, val in enumerate(tensor_shape):
1125
+ if index == quantized_dim:
1126
+ new_shape.append(int(val / block_size))
1127
+ new_shape.append(block_size)
1128
+ else:
1129
+ new_shape.append(val)
1130
+ return new_shape
1131
+
1132
+
1133
+ def _reshape_data_for_blockwise(
1134
+ tensor_data: np.ndarray,
1135
+ quantized_dim: int,
1136
+ block_size: int,
1137
+ ) -> tuple[np.ndarray, int]:
1138
+ """Reshapes data for blockwise quantization.
1139
+
1140
+ Args:
1141
+ tensor_data: The original tensor data.
1142
+ quantized_dim: The dimension to be quantized blockwise.
1143
+ block_size: The size of the block. `block_size must be a multiple of 32. `
1144
+ `The tensor quantized dimension shape must be divisible by block_size.
1145
+
1146
+ Returns:
1147
+ A tuple containing the reshaped tensor data and the new reduce dimension.
1148
+ """
1149
+
1150
+ # TODO: b/417508018 - create AEQ specific error class instead of
1151
+ # using generic ValueError.
1152
+ if tensor_data.shape[quantized_dim] % block_size != 0:
1153
+ raise ValueError(
1154
+ "Tensor quantization dimension must be divisible by block size for"
1155
+ " blockwise quantization."
1156
+ )
1157
+ new_shape = _get_tensor_shape_for_blockwise(
1158
+ tensor_data.shape, quantized_dim, block_size
1159
+ )
1160
+ reshaped_data = tensor_data.reshape(new_shape)
1161
+ return reshaped_data, quantized_dim + 1
1162
+
1163
+
1164
+ def broadcast_scale_zp_for_blockwise(
1165
+ tensor_content: np.ndarray,
1166
+ quant_params: qtyping.UniformQuantParams,
1167
+ ) -> qtyping.UniformQuantParams:
1168
+ """Broadcasts scale and zp for blockwise quantization.
1169
+
1170
+ Args:
1171
+ tensor_content: The original tensor data.
1172
+ quant_params: The quantization parameters.
1173
+ `quant_params.quantized_dimension` must be specified.
1174
+ `quant_params.block_size` must be specified and positive.
1175
+
1176
+ Returns:
1177
+ The updated quantization parameters with broadcasted scale and zp for
1178
+ correct constant quantization.
1179
+ """
1180
+ if quant_params.quantized_dimension is None:
1181
+ raise ValueError("Quantized dimension must be specified.")
1182
+ if quant_params.block_size is None or quant_params.block_size <= 0:
1183
+ raise ValueError("Block size must be specified and positive.")
1184
+ quantized_dim = quant_params.quantized_dimension
1185
+ expanded_tensor_shape = _get_tensor_shape_for_blockwise(
1186
+ tensor_content.shape, quantized_dim, quant_params.block_size
1187
+ )
1188
+ expanded_scale = np.reshape(
1189
+ np.broadcast_to(
1190
+ np.expand_dims(quant_params.scale, quantized_dim + 1),
1191
+ expanded_tensor_shape,
1192
+ ),
1193
+ tensor_content.shape,
1194
+ )
1195
+ expanded_zp = np.reshape(
1196
+ np.broadcast_to(
1197
+ np.expand_dims(quant_params.zero_point, quantized_dim + 1),
1198
+ expanded_tensor_shape,
1199
+ ),
1200
+ tensor_content.shape,
1201
+ )
1202
+ return qtyping.UniformQuantParams(
1203
+ scale=expanded_scale,
1204
+ zero_point=expanded_zp,
1205
+ num_bits=quant_params.num_bits,
1206
+ symmetric=quant_params.symmetric,
1207
+ quantized_dimension=quantized_dim,
1208
+ block_size=quant_params.block_size,
1209
+ )
1210
+
1211
+
1212
+ def init_tensor_min_max(
1213
+ tensor_data: Optional[np.ndarray],
1214
+ op_info: qtyping.OpInfo,
1215
+ ) -> qtyping.QSV:
1216
+ """Initialize the min/max for a tensor.
1217
+
1218
+ This function initializes the min/max values for a tensor.
1219
+
1220
+ Args:
1221
+ tensor_data: The tensor data.
1222
+ op_info: Aggregated information about the op (e.g., quantization config).
1223
+
1224
+ Returns:
1225
+ A dictionary containing the min/max values for the tensor, or an empty
1226
+ dictionary if the tensor data is None.
1227
+ """
1228
+ weight_tensor_config = op_info.op_quant_config.weight_tensor_config
1229
+ if tensor_data is None or weight_tensor_config is None:
1230
+ return {}
1231
+ else:
1232
+ # Get reduce dimension for min/max calculation based on quantization
1233
+ # granularity.
1234
+ granularity = weight_tensor_config.granularity
1235
+ if granularity == qtyping.QuantGranularity.TENSORWISE:
1236
+ reduce_dims = None
1237
+ keep_dims = True
1238
+ elif granularity == qtyping.QuantGranularity.CHANNELWISE:
1239
+ quantized_dim = common_utils.get_weight_quantized_dim(
1240
+ op_info, tensor_data, weight_tensor_config.granularity
1241
+ )
1242
+ reduce_dims = common_utils.get_reduce_dims(
1243
+ quantized_dim, tensor_data.shape
1244
+ )
1245
+ keep_dims = True
1246
+ elif uniform_quantize_tensor.is_blockwise(granularity):
1247
+ tensor_data, reduce_dims = (
1248
+ uniform_quantize_tensor.reshape_data_for_blockwise(
1249
+ tensor_data,
1250
+ op_info.op_name,
1251
+ granularity,
1252
+ )
1253
+ )
1254
+ keep_dims = False
1255
+ else:
1256
+ raise ValueError(f"Unsupported granularity: {granularity}")
1257
+ return {
1258
+ "min": np.min(tensor_data, axis=reduce_dims, keepdims=keep_dims),
1259
+ "max": np.max(tensor_data, axis=reduce_dims, keepdims=keep_dims),
1260
+ }