ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +158 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
- ai_edge_quantizer/calibrator.py +11 -60
- ai_edge_quantizer/calibrator_test.py +4 -73
- ai_edge_quantizer/default_policy.py +61 -26
- ai_edge_quantizer/model_modifier.py +97 -7
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +31 -8
- ai_edge_quantizer/params_generator.py +17 -10
- ai_edge_quantizer/params_generator_test.py +2 -7
- ai_edge_quantizer/qtyping.py +86 -6
- ai_edge_quantizer/quantizer.py +166 -21
- ai_edge_quantizer/quantizer_test.py +284 -16
- ai_edge_quantizer/recipe.py +154 -42
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +118 -13
- ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
- ai_edge_quantizer/transformation_performer.py +55 -25
- ai_edge_quantizer/transformation_performer_test.py +127 -5
- ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
- ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
- ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
- ai_edge_quantizer/transformations/transformation_utils.py +129 -6
- ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +75 -2
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -36,6 +36,13 @@ _OpQuantConstraint = common_utils.OpQuantConstraint
|
|
|
36
36
|
_ComputePrecision = qtyping.ComputePrecision
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
def check_if_quantized(tensor: Any) -> bool:
|
|
40
|
+
"""Checks if the tensor is quantized."""
|
|
41
|
+
return (
|
|
42
|
+
tensor.quantization is not None and tensor.quantization.scale is not None
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
39
46
|
def check_op_quantization_config(
|
|
40
47
|
op_name: _TFLOpName,
|
|
41
48
|
op_quant_config: qtyping.OpQuantizationConfig,
|
|
@@ -271,7 +278,7 @@ def materialize_average_pool_2d(
|
|
|
271
278
|
)
|
|
272
279
|
|
|
273
280
|
|
|
274
|
-
def
|
|
281
|
+
def _materialize_bias_for_fc_conv_ops(
|
|
275
282
|
op_info: qtyping.OpInfo,
|
|
276
283
|
graph_info: qtyping.GraphInfo,
|
|
277
284
|
op_tensor_params: list[qtyping.TensorTransformationParams],
|
|
@@ -290,14 +297,16 @@ def _materialize_bias_for_conv_ops(
|
|
|
290
297
|
op_weight_index: Index for the weight tensor in the op.
|
|
291
298
|
op_bias_index: Index for the bias tensor in the op.
|
|
292
299
|
"""
|
|
293
|
-
_,
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
300
|
+
_, weight_tensor, bias_tensor, _ = (
|
|
301
|
+
tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
|
|
302
|
+
op_info.op,
|
|
303
|
+
graph_info.subgraph_tensors,
|
|
304
|
+
op_input_index,
|
|
305
|
+
op_weight_index,
|
|
306
|
+
op_bias_index,
|
|
307
|
+
)
|
|
308
|
+
)
|
|
309
|
+
if bias_tensor is not None and not check_if_quantized(bias_tensor):
|
|
301
310
|
bias_quant_params = None
|
|
302
311
|
# Fused bias needs to be quantized for SRQ.
|
|
303
312
|
# Check if SRQ.
|
|
@@ -309,13 +318,41 @@ def _materialize_bias_for_conv_ops(
|
|
|
309
318
|
bias_tensor,
|
|
310
319
|
graph_info.buffers,
|
|
311
320
|
)
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
)
|
|
321
|
+
input_consumer_params = (
|
|
322
|
+
op_tensor_params[op_input_index].consumers[0].parameters
|
|
323
|
+
)
|
|
324
|
+
weight_consumer_params = (
|
|
325
|
+
op_tensor_params[op_weight_index].consumers[0].parameters
|
|
318
326
|
)
|
|
327
|
+
if weight_consumer_params is None and check_if_quantized(weight_tensor):
|
|
328
|
+
quant_params = weight_tensor.quantization
|
|
329
|
+
if op_info.op_quant_config.weight_tensor_config is None:
|
|
330
|
+
raise ValueError(
|
|
331
|
+
"weight_tensor_config cannot be None when weight tensor is"
|
|
332
|
+
" quantized."
|
|
333
|
+
)
|
|
334
|
+
weight_consumer_params = qtyping.UniformQuantParams(
|
|
335
|
+
num_bits=op_info.op_quant_config.weight_tensor_config.num_bits,
|
|
336
|
+
scale=quant_params.scale,
|
|
337
|
+
zero_point=quant_params.zeroPoint,
|
|
338
|
+
quantized_dimension=quant_params.quantizedDimension,
|
|
339
|
+
)
|
|
340
|
+
try:
|
|
341
|
+
# Bias quantization is using fixed quantization scale:
|
|
342
|
+
# input_scale * weight_scale. To avoid hidden numerics error, we check
|
|
343
|
+
# the quantization error in bias quantization.
|
|
344
|
+
bias_quant_params = (
|
|
345
|
+
uniform_quantize_tensor.symmetric_quantize_bias_tensor(
|
|
346
|
+
bias_content,
|
|
347
|
+
input_consumer_params,
|
|
348
|
+
weight_consumer_params,
|
|
349
|
+
)
|
|
350
|
+
)
|
|
351
|
+
except ValueError as e:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
f"Failed to quantize bias tensor for op {op_info.op_name} with op"
|
|
354
|
+
f" id {op_info.subgraph_op_index}."
|
|
355
|
+
) from e
|
|
319
356
|
# We only quantize bias under SRQ. Setting is_constant=True for SRQ only
|
|
320
357
|
# to avoid quantize bias for DRQ and weight-only cases.
|
|
321
358
|
is_constant = (
|
|
@@ -371,6 +408,25 @@ def materialize_slice(
|
|
|
371
408
|
)
|
|
372
409
|
|
|
373
410
|
|
|
411
|
+
def materialize_select(
|
|
412
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
413
|
+
op_info: qtyping.OpInfo,
|
|
414
|
+
graph_info: qtyping.GraphInfo,
|
|
415
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
416
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
417
|
+
"""Materialize tensors in tfl.select."""
|
|
418
|
+
return common_utils.materialize_standard_op(
|
|
419
|
+
op_info,
|
|
420
|
+
graph_info,
|
|
421
|
+
tensor_name_to_qsv,
|
|
422
|
+
get_tensor_quant_params_fn,
|
|
423
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
424
|
+
inputs_to_ignore=[
|
|
425
|
+
0,
|
|
426
|
+
], # Condition tensor does not need to be quantized.
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
|
|
374
430
|
def materialize_select_v2(
|
|
375
431
|
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
376
432
|
op_info: qtyping.OpInfo,
|
|
@@ -416,12 +472,21 @@ def materialize_sum(
|
|
|
416
472
|
tensor_name_to_qsv: dict[str, Any],
|
|
417
473
|
) -> list[qtyping.TensorTransformationParams]:
|
|
418
474
|
"""Materialize tensors in tfl.sum."""
|
|
475
|
+
# For 8 bits the reference kernel calls a function without input/output
|
|
476
|
+
# constraints. For all others it calls a function that enforces input/output
|
|
477
|
+
# scale/zero point checks. See:
|
|
478
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/reduce.cc#L909
|
|
479
|
+
activation_config = op_info.op_quant_config.activation_tensor_config
|
|
480
|
+
if activation_config is not None and activation_config.num_bits == 8:
|
|
481
|
+
constraint = _OpQuantConstraint.NO_CONSTRAIN
|
|
482
|
+
else:
|
|
483
|
+
constraint = _OpQuantConstraint.SAME_AS_INPUT_SCALE
|
|
419
484
|
return common_utils.materialize_standard_op(
|
|
420
485
|
op_info,
|
|
421
486
|
graph_info,
|
|
422
487
|
tensor_name_to_qsv,
|
|
423
488
|
get_tensor_quant_params_fn,
|
|
424
|
-
constraint=
|
|
489
|
+
constraint=constraint,
|
|
425
490
|
inputs_to_ignore=[1], # Axis index does not need to be quantized.
|
|
426
491
|
)
|
|
427
492
|
|
|
@@ -452,7 +517,13 @@ def materialize_fc_conv(
|
|
|
452
517
|
weights, bias).
|
|
453
518
|
"""
|
|
454
519
|
ignored_inputs = [bias_index] # Bias tensor is quantized separately.
|
|
455
|
-
|
|
520
|
+
should_ignore_weight = False
|
|
521
|
+
if graph_info:
|
|
522
|
+
w_tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]]
|
|
523
|
+
should_ignore_weight = check_if_quantized(w_tensor)
|
|
524
|
+
if should_ignore_weight or _are_weights_too_small(
|
|
525
|
+
op_info, graph_info, weight_index
|
|
526
|
+
):
|
|
456
527
|
ignored_inputs.append(weight_index)
|
|
457
528
|
|
|
458
529
|
op_tensor_params = common_utils.materialize_standard_op(
|
|
@@ -463,7 +534,7 @@ def materialize_fc_conv(
|
|
|
463
534
|
inputs_to_ignore=ignored_inputs,
|
|
464
535
|
)
|
|
465
536
|
|
|
466
|
-
|
|
537
|
+
_materialize_bias_for_fc_conv_ops(
|
|
467
538
|
op_info,
|
|
468
539
|
graph_info,
|
|
469
540
|
op_tensor_params,
|
|
@@ -518,7 +589,7 @@ def materialize_conv2d_transpose(
|
|
|
518
589
|
"Materialize standard op should return at least two tensors for"
|
|
519
590
|
" conv2d_transpose."
|
|
520
591
|
)
|
|
521
|
-
|
|
592
|
+
_materialize_bias_for_fc_conv_ops(
|
|
522
593
|
op_info,
|
|
523
594
|
graph_info,
|
|
524
595
|
op_tensor_params,
|
|
@@ -671,6 +742,366 @@ def materialize_split(
|
|
|
671
742
|
)
|
|
672
743
|
|
|
673
744
|
|
|
745
|
+
def materialize_pad(
|
|
746
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
747
|
+
op_info: qtyping.OpInfo,
|
|
748
|
+
graph_info: qtyping.GraphInfo,
|
|
749
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
750
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
751
|
+
"""Materialize tensors in tfl.pad."""
|
|
752
|
+
return common_utils.materialize_standard_op(
|
|
753
|
+
op_info,
|
|
754
|
+
graph_info,
|
|
755
|
+
tensor_name_to_qsv,
|
|
756
|
+
get_tensor_quant_params_fn,
|
|
757
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
758
|
+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def materialize_padv2(
|
|
763
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
764
|
+
op_info: qtyping.OpInfo,
|
|
765
|
+
graph_info: qtyping.GraphInfo,
|
|
766
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
767
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
768
|
+
"""Materialize tensors in tfl.padv2."""
|
|
769
|
+
return common_utils.materialize_standard_op(
|
|
770
|
+
op_info,
|
|
771
|
+
graph_info,
|
|
772
|
+
tensor_name_to_qsv,
|
|
773
|
+
get_tensor_quant_params_fn,
|
|
774
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
775
|
+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def materialize_mirror_pad(
|
|
780
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
781
|
+
op_info: qtyping.OpInfo,
|
|
782
|
+
graph_info: qtyping.GraphInfo,
|
|
783
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
784
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
785
|
+
"""Materialize tensors in tfl.mirror_pad.
|
|
786
|
+
|
|
787
|
+
Args:
|
|
788
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
|
789
|
+
tensor.
|
|
790
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
791
|
+
graph_info: Graph information needed to perform quantization for the op.
|
|
792
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
793
|
+
|
|
794
|
+
Returns:
|
|
795
|
+
A list of `qtyping.TensorTransformationParams` for the tensors in the op.
|
|
796
|
+
"""
|
|
797
|
+
return common_utils.materialize_standard_op(
|
|
798
|
+
op_info,
|
|
799
|
+
graph_info,
|
|
800
|
+
tensor_name_to_qsv,
|
|
801
|
+
get_tensor_quant_params_fn,
|
|
802
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
803
|
+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def materialize_space_to_depth(
|
|
808
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
809
|
+
op_info: qtyping.OpInfo,
|
|
810
|
+
graph_info: qtyping.GraphInfo,
|
|
811
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
812
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
813
|
+
"""Materialize tensors in tfl.space_to_depth.
|
|
814
|
+
|
|
815
|
+
Args:
|
|
816
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
|
817
|
+
tensor.
|
|
818
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
819
|
+
graph_info: Graph information needed to perform quantization for the op.
|
|
820
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
821
|
+
|
|
822
|
+
Returns:
|
|
823
|
+
A list of `qtyping.TensorTransformationParams` for the tensors in the op.
|
|
824
|
+
"""
|
|
825
|
+
return common_utils.materialize_standard_op(
|
|
826
|
+
op_info,
|
|
827
|
+
graph_info,
|
|
828
|
+
tensor_name_to_qsv,
|
|
829
|
+
get_tensor_quant_params_fn,
|
|
830
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
def materialize_squared_difference(
|
|
835
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
836
|
+
op_info: qtyping.OpInfo,
|
|
837
|
+
graph_info: qtyping.GraphInfo,
|
|
838
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
839
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
840
|
+
"""Materialize tensors in tfl.squared_difference."""
|
|
841
|
+
return common_utils.materialize_standard_op(
|
|
842
|
+
op_info,
|
|
843
|
+
graph_info,
|
|
844
|
+
tensor_name_to_qsv,
|
|
845
|
+
get_tensor_quant_params_fn,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def materialize_max_pool_2d(
|
|
850
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
851
|
+
op_info: qtyping.OpInfo,
|
|
852
|
+
graph_info: qtyping.GraphInfo,
|
|
853
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
854
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
855
|
+
"""Materialize tensors in tfl.max_pool_2d."""
|
|
856
|
+
return common_utils.materialize_standard_op(
|
|
857
|
+
op_info,
|
|
858
|
+
graph_info,
|
|
859
|
+
tensor_name_to_qsv,
|
|
860
|
+
get_tensor_quant_params_fn,
|
|
861
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def materialize_resize_bilinear(
|
|
866
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
867
|
+
op_info: qtyping.OpInfo,
|
|
868
|
+
graph_info: qtyping.GraphInfo,
|
|
869
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
870
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
871
|
+
"""Materialize tensors in tfl.resize_bilinear."""
|
|
872
|
+
return common_utils.materialize_standard_op(
|
|
873
|
+
op_info,
|
|
874
|
+
graph_info,
|
|
875
|
+
tensor_name_to_qsv,
|
|
876
|
+
get_tensor_quant_params_fn,
|
|
877
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
878
|
+
inputs_to_ignore=[1], # Resize size does not need to be quantized.
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def materialize_resize_nearest_neighbor(
|
|
883
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
884
|
+
op_info: qtyping.OpInfo,
|
|
885
|
+
graph_info: qtyping.GraphInfo,
|
|
886
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
887
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
888
|
+
"""Materialize tensors in tfl.resize_nearest_neighbor."""
|
|
889
|
+
return common_utils.materialize_standard_op(
|
|
890
|
+
op_info,
|
|
891
|
+
graph_info,
|
|
892
|
+
tensor_name_to_qsv,
|
|
893
|
+
get_tensor_quant_params_fn,
|
|
894
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
895
|
+
inputs_to_ignore=[1], # Resize size does not need to be quantized.
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def materialize_gather_nd(
|
|
900
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
901
|
+
op_info: qtyping.OpInfo,
|
|
902
|
+
graph_info: qtyping.GraphInfo,
|
|
903
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
904
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
905
|
+
"""Materialize tensors in tfl.gather_nd."""
|
|
906
|
+
return common_utils.materialize_standard_op(
|
|
907
|
+
op_info,
|
|
908
|
+
graph_info,
|
|
909
|
+
tensor_name_to_qsv,
|
|
910
|
+
get_tensor_quant_params_fn,
|
|
911
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
912
|
+
inputs_to_ignore=[1], # Gather indices do not need to be quantized.
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
def materialize_maximum(
|
|
917
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
918
|
+
op_info: qtyping.OpInfo,
|
|
919
|
+
graph_info: qtyping.GraphInfo,
|
|
920
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
921
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
922
|
+
"""Materialize tensors in tfl.maximum."""
|
|
923
|
+
return common_utils.materialize_standard_op(
|
|
924
|
+
op_info,
|
|
925
|
+
graph_info,
|
|
926
|
+
tensor_name_to_qsv,
|
|
927
|
+
get_tensor_quant_params_fn,
|
|
928
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def materialize_pack(
|
|
933
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
934
|
+
op_info: qtyping.OpInfo,
|
|
935
|
+
graph_info: qtyping.GraphInfo,
|
|
936
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
937
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
938
|
+
"""Materialize tensors in tfl.pack."""
|
|
939
|
+
return common_utils.materialize_standard_op(
|
|
940
|
+
op_info,
|
|
941
|
+
graph_info,
|
|
942
|
+
tensor_name_to_qsv,
|
|
943
|
+
get_tensor_quant_params_fn,
|
|
944
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
def materialize_unpack(
|
|
949
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
950
|
+
op_info: qtyping.OpInfo,
|
|
951
|
+
graph_info: qtyping.GraphInfo,
|
|
952
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
953
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
954
|
+
"""Materialize tensors in tfl.unpack."""
|
|
955
|
+
return common_utils.materialize_standard_op(
|
|
956
|
+
op_info,
|
|
957
|
+
graph_info,
|
|
958
|
+
tensor_name_to_qsv,
|
|
959
|
+
get_tensor_quant_params_fn,
|
|
960
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
def materialize_div(
|
|
965
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
966
|
+
op_info: qtyping.OpInfo,
|
|
967
|
+
graph_info: qtyping.GraphInfo,
|
|
968
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
969
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
970
|
+
"""Materialize tensors in tfl.div."""
|
|
971
|
+
return common_utils.materialize_standard_op(
|
|
972
|
+
op_info,
|
|
973
|
+
graph_info,
|
|
974
|
+
tensor_name_to_qsv,
|
|
975
|
+
get_tensor_quant_params_fn,
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
def materialize_broadcast_to(
|
|
980
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
981
|
+
op_info: qtyping.OpInfo,
|
|
982
|
+
graph_info: qtyping.GraphInfo,
|
|
983
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
984
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
985
|
+
"""Materialize tensors in tfl.broadcast_to."""
|
|
986
|
+
return common_utils.materialize_standard_op(
|
|
987
|
+
op_info,
|
|
988
|
+
graph_info,
|
|
989
|
+
tensor_name_to_qsv,
|
|
990
|
+
get_tensor_quant_params_fn,
|
|
991
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
992
|
+
inputs_to_ignore=[1], # Shape tensor does not need to be quantized.
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
def materialize_sqrt(
|
|
997
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
998
|
+
op_info: qtyping.OpInfo,
|
|
999
|
+
graph_info: qtyping.GraphInfo,
|
|
1000
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1001
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1002
|
+
"""Materialize tensors in tfl.sqrt."""
|
|
1003
|
+
return common_utils.materialize_standard_op(
|
|
1004
|
+
op_info,
|
|
1005
|
+
graph_info,
|
|
1006
|
+
tensor_name_to_qsv,
|
|
1007
|
+
get_tensor_quant_params_fn,
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
def materialize_hard_swish(
|
|
1012
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1013
|
+
op_info: qtyping.OpInfo,
|
|
1014
|
+
graph_info: qtyping.GraphInfo,
|
|
1015
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1016
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1017
|
+
"""Materialize tensors in tfl.hard_swish."""
|
|
1018
|
+
return common_utils.materialize_standard_op(
|
|
1019
|
+
op_info,
|
|
1020
|
+
graph_info,
|
|
1021
|
+
tensor_name_to_qsv,
|
|
1022
|
+
get_tensor_quant_params_fn,
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
def materialize_gather(
|
|
1027
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1028
|
+
op_info: qtyping.OpInfo,
|
|
1029
|
+
graph_info: qtyping.GraphInfo,
|
|
1030
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1031
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1032
|
+
"""Materialize tensors in tfl.gather."""
|
|
1033
|
+
return common_utils.materialize_standard_op(
|
|
1034
|
+
op_info,
|
|
1035
|
+
graph_info,
|
|
1036
|
+
tensor_name_to_qsv,
|
|
1037
|
+
get_tensor_quant_params_fn,
|
|
1038
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
1039
|
+
inputs_to_ignore=[1], # Indices do not need to be quantized.
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
def materialize_reduce_min(
|
|
1044
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1045
|
+
op_info: qtyping.OpInfo,
|
|
1046
|
+
graph_info: qtyping.GraphInfo,
|
|
1047
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1048
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1049
|
+
"""Materialize tensors in tfl.reduce_min."""
|
|
1050
|
+
return common_utils.materialize_standard_op(
|
|
1051
|
+
op_info,
|
|
1052
|
+
graph_info,
|
|
1053
|
+
tensor_name_to_qsv,
|
|
1054
|
+
get_tensor_quant_params_fn,
|
|
1055
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
1056
|
+
inputs_to_ignore=[1], # Axis index does not need to be quantized.
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
|
|
1060
|
+
def materialize_equal(
|
|
1061
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1062
|
+
op_info: qtyping.OpInfo,
|
|
1063
|
+
graph_info: qtyping.GraphInfo,
|
|
1064
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1065
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1066
|
+
"""Materialize tensors in tfl.equal."""
|
|
1067
|
+
return common_utils.materialize_standard_op(
|
|
1068
|
+
op_info,
|
|
1069
|
+
graph_info,
|
|
1070
|
+
tensor_name_to_qsv,
|
|
1071
|
+
get_tensor_quant_params_fn,
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def materialize_not_equal(
|
|
1076
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1077
|
+
op_info: qtyping.OpInfo,
|
|
1078
|
+
graph_info: qtyping.GraphInfo,
|
|
1079
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1080
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1081
|
+
"""Materialize tensors in tfl.not_equal."""
|
|
1082
|
+
return common_utils.materialize_standard_op(
|
|
1083
|
+
op_info,
|
|
1084
|
+
graph_info,
|
|
1085
|
+
tensor_name_to_qsv,
|
|
1086
|
+
get_tensor_quant_params_fn,
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
def materialize_relu(
|
|
1091
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1092
|
+
op_info: qtyping.OpInfo,
|
|
1093
|
+
graph_info: qtyping.GraphInfo,
|
|
1094
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1095
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1096
|
+
"""Materialize tensors in tfl.relu."""
|
|
1097
|
+
return common_utils.materialize_standard_op(
|
|
1098
|
+
op_info,
|
|
1099
|
+
graph_info,
|
|
1100
|
+
tensor_name_to_qsv,
|
|
1101
|
+
get_tensor_quant_params_fn,
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
|
|
674
1105
|
def _get_tensor_shape_for_blockwise(
|
|
675
1106
|
tensor_shape: Sequence[int], quantized_dim: int, block_size: int
|
|
676
1107
|
) -> list[int]:
|
|
@@ -700,18 +1131,29 @@ def _get_tensor_shape_for_blockwise(
|
|
|
700
1131
|
|
|
701
1132
|
|
|
702
1133
|
def _reshape_data_for_blockwise(
|
|
703
|
-
tensor_data: np.ndarray,
|
|
1134
|
+
tensor_data: np.ndarray,
|
|
1135
|
+
quantized_dim: int,
|
|
1136
|
+
block_size: int,
|
|
704
1137
|
) -> tuple[np.ndarray, int]:
|
|
705
1138
|
"""Reshapes data for blockwise quantization.
|
|
706
1139
|
|
|
707
1140
|
Args:
|
|
708
1141
|
tensor_data: The original tensor data.
|
|
709
1142
|
quantized_dim: The dimension to be quantized blockwise.
|
|
710
|
-
block_size: The size of the block.
|
|
1143
|
+
block_size: The size of the block. `block_size must be a multiple of 32. `
|
|
1144
|
+
`The tensor quantized dimension shape must be divisible by block_size.
|
|
711
1145
|
|
|
712
1146
|
Returns:
|
|
713
1147
|
A tuple containing the reshaped tensor data and the new reduce dimension.
|
|
714
1148
|
"""
|
|
1149
|
+
|
|
1150
|
+
# TODO: b/417508018 - create AEQ specific error class instead of
|
|
1151
|
+
# using generic ValueError.
|
|
1152
|
+
if tensor_data.shape[quantized_dim] % block_size != 0:
|
|
1153
|
+
raise ValueError(
|
|
1154
|
+
"Tensor quantization dimension must be divisible by block size for"
|
|
1155
|
+
" blockwise quantization."
|
|
1156
|
+
)
|
|
715
1157
|
new_shape = _get_tensor_shape_for_blockwise(
|
|
716
1158
|
tensor_data.shape, quantized_dim, block_size
|
|
717
1159
|
)
|
|
@@ -783,42 +1225,36 @@ def init_tensor_min_max(
|
|
|
783
1225
|
A dictionary containing the min/max values for the tensor, or an empty
|
|
784
1226
|
dictionary if the tensor data is None.
|
|
785
1227
|
"""
|
|
786
|
-
|
|
1228
|
+
weight_tensor_config = op_info.op_quant_config.weight_tensor_config
|
|
1229
|
+
if tensor_data is None or weight_tensor_config is None:
|
|
787
1230
|
return {}
|
|
788
1231
|
else:
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
1232
|
+
# Get reduce dimension for min/max calculation based on quantization
|
|
1233
|
+
# granularity.
|
|
1234
|
+
granularity = weight_tensor_config.granularity
|
|
1235
|
+
if granularity == qtyping.QuantGranularity.TENSORWISE:
|
|
1236
|
+
reduce_dims = None
|
|
1237
|
+
keep_dims = True
|
|
1238
|
+
elif granularity == qtyping.QuantGranularity.CHANNELWISE:
|
|
794
1239
|
quantized_dim = common_utils.get_weight_quantized_dim(
|
|
795
|
-
op_info, tensor_data
|
|
796
|
-
)
|
|
797
|
-
if (
|
|
798
|
-
weight_tensor_config is not None
|
|
799
|
-
and weight_tensor_config.granularity
|
|
800
|
-
== qtyping.QuantGranularity.BLOCKWISE
|
|
801
|
-
):
|
|
802
|
-
quantized_dim = (
|
|
803
|
-
tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
|
|
804
|
-
op_info.op_name
|
|
805
|
-
]
|
|
806
|
-
)
|
|
807
|
-
reshaped_data, reduce_dims = _reshape_data_for_blockwise(
|
|
808
|
-
tensor_data,
|
|
809
|
-
quantized_dim,
|
|
810
|
-
weight_tensor_config.block_size,
|
|
1240
|
+
op_info, tensor_data, weight_tensor_config.granularity
|
|
811
1241
|
)
|
|
812
|
-
return {
|
|
813
|
-
"min": np.min(reshaped_data, axis=reduce_dims, keepdims=False),
|
|
814
|
-
"max": np.max(reshaped_data, axis=reduce_dims, keepdims=False),
|
|
815
|
-
}
|
|
816
|
-
|
|
817
|
-
else:
|
|
818
1242
|
reduce_dims = common_utils.get_reduce_dims(
|
|
819
1243
|
quantized_dim, tensor_data.shape
|
|
820
1244
|
)
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
1245
|
+
keep_dims = True
|
|
1246
|
+
elif uniform_quantize_tensor.is_blockwise(granularity):
|
|
1247
|
+
tensor_data, reduce_dims = (
|
|
1248
|
+
uniform_quantize_tensor.reshape_data_for_blockwise(
|
|
1249
|
+
tensor_data,
|
|
1250
|
+
op_info.op_name,
|
|
1251
|
+
granularity,
|
|
1252
|
+
)
|
|
1253
|
+
)
|
|
1254
|
+
keep_dims = False
|
|
1255
|
+
else:
|
|
1256
|
+
raise ValueError(f"Unsupported granularity: {granularity}")
|
|
1257
|
+
return {
|
|
1258
|
+
"min": np.min(tensor_data, axis=reduce_dims, keepdims=keep_dims),
|
|
1259
|
+
"max": np.max(tensor_data, axis=reduce_dims, keepdims=keep_dims),
|
|
1260
|
+
}
|
|
@@ -31,8 +31,7 @@ _TensorQuantConfig = qtyping.TensorQuantizationConfig
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class CommonQuantizeTest(parameterized.TestCase):
|
|
34
|
-
"""Tests for general quantize functions.
|
|
35
|
-
"""
|
|
34
|
+
"""Tests for general quantize functions."""
|
|
36
35
|
|
|
37
36
|
def setUp(self):
|
|
38
37
|
super().setUp()
|
|
@@ -69,6 +68,34 @@ class CommonQuantizeTest(parameterized.TestCase):
|
|
|
69
68
|
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
|
70
69
|
)
|
|
71
70
|
|
|
71
|
+
def test_reshape_data_for_blockwise_raises_error_when_quantized_dim_not_divisible_by_block_size(
|
|
72
|
+
self,
|
|
73
|
+
):
|
|
74
|
+
tensor_data = np.ones((24, 128), dtype=np.float32)
|
|
75
|
+
block_size = 256
|
|
76
|
+
quantized_dim = 1
|
|
77
|
+
with self.assertRaisesWithPredicateMatch(
|
|
78
|
+
ValueError,
|
|
79
|
+
lambda err: (
|
|
80
|
+
"Tensor quantization dimension must be divisible by block"
|
|
81
|
+
" size for blockwise quantization."
|
|
82
|
+
)
|
|
83
|
+
in str(err),
|
|
84
|
+
):
|
|
85
|
+
common_quantize._reshape_data_for_blockwise(
|
|
86
|
+
tensor_data, quantized_dim, block_size
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def test_reshape_data_for_blockwise_returns_correct_values(self):
|
|
90
|
+
tensor_data = np.ones((24, 128), dtype=np.float32)
|
|
91
|
+
block_size = 32
|
|
92
|
+
quantized_dim = 1
|
|
93
|
+
new_tensor_data, reduce_dim = common_quantize._reshape_data_for_blockwise(
|
|
94
|
+
tensor_data, quantized_dim, block_size
|
|
95
|
+
)
|
|
96
|
+
self.assertEqual(new_tensor_data.shape, (24, 4, 32))
|
|
97
|
+
self.assertEqual(reduce_dim, 2)
|
|
98
|
+
|
|
72
99
|
|
|
73
100
|
if __name__ == "__main__":
|
|
74
101
|
googletest.main()
|