ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -23,7 +23,7 @@ to implement the get_tensor_quant_params_fn with the
|
|
|
23
23
|
qtyping.GetTensorQuantParamsFuncSignature signature.
|
|
24
24
|
"""
|
|
25
25
|
|
|
26
|
-
from typing import Any
|
|
26
|
+
from typing import Any, Optional, Sequence
|
|
27
27
|
import numpy as np
|
|
28
28
|
from ai_edge_quantizer import qtyping
|
|
29
29
|
from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
|
|
@@ -36,6 +36,13 @@ _OpQuantConstraint = common_utils.OpQuantConstraint
|
|
|
36
36
|
_ComputePrecision = qtyping.ComputePrecision
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
def check_if_quantized(tensor: Any) -> bool:
|
|
40
|
+
"""Checks if the tensor is quantized."""
|
|
41
|
+
return (
|
|
42
|
+
tensor.quantization is not None and tensor.quantization.scale is not None
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
39
46
|
def check_op_quantization_config(
|
|
40
47
|
op_name: _TFLOpName,
|
|
41
48
|
op_quant_config: qtyping.OpQuantizationConfig,
|
|
@@ -110,6 +117,21 @@ def materialize_output(
|
|
|
110
117
|
)
|
|
111
118
|
|
|
112
119
|
|
|
120
|
+
def materialize_composite(
|
|
121
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
122
|
+
op_info: qtyping.OpInfo,
|
|
123
|
+
graph_info: qtyping.GraphInfo,
|
|
124
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
125
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
126
|
+
"""Materialize tensors in the virtual output op."""
|
|
127
|
+
return common_utils.materialize_standard_op(
|
|
128
|
+
op_info,
|
|
129
|
+
graph_info,
|
|
130
|
+
tensor_name_to_qsv,
|
|
131
|
+
get_tensor_quant_params_fn,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
113
135
|
def materialize_add(
|
|
114
136
|
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
115
137
|
op_info: qtyping.OpInfo,
|
|
@@ -256,7 +278,7 @@ def materialize_average_pool_2d(
|
|
|
256
278
|
)
|
|
257
279
|
|
|
258
280
|
|
|
259
|
-
def
|
|
281
|
+
def _materialize_bias_for_fc_conv_ops(
|
|
260
282
|
op_info: qtyping.OpInfo,
|
|
261
283
|
graph_info: qtyping.GraphInfo,
|
|
262
284
|
op_tensor_params: list[qtyping.TensorTransformationParams],
|
|
@@ -275,14 +297,16 @@ def _materialize_bias_for_conv_ops(
|
|
|
275
297
|
op_weight_index: Index for the weight tensor in the op.
|
|
276
298
|
op_bias_index: Index for the bias tensor in the op.
|
|
277
299
|
"""
|
|
278
|
-
_,
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
300
|
+
_, weight_tensor, bias_tensor, _ = (
|
|
301
|
+
tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
|
|
302
|
+
op_info.op,
|
|
303
|
+
graph_info.subgraph_tensors,
|
|
304
|
+
op_input_index,
|
|
305
|
+
op_weight_index,
|
|
306
|
+
op_bias_index,
|
|
307
|
+
)
|
|
308
|
+
)
|
|
309
|
+
if bias_tensor is not None and not check_if_quantized(bias_tensor):
|
|
286
310
|
bias_quant_params = None
|
|
287
311
|
# Fused bias needs to be quantized for SRQ.
|
|
288
312
|
# Check if SRQ.
|
|
@@ -294,13 +318,41 @@ def _materialize_bias_for_conv_ops(
|
|
|
294
318
|
bias_tensor,
|
|
295
319
|
graph_info.buffers,
|
|
296
320
|
)
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
)
|
|
321
|
+
input_consumer_params = (
|
|
322
|
+
op_tensor_params[op_input_index].consumers[0].parameters
|
|
323
|
+
)
|
|
324
|
+
weight_consumer_params = (
|
|
325
|
+
op_tensor_params[op_weight_index].consumers[0].parameters
|
|
303
326
|
)
|
|
327
|
+
if weight_consumer_params is None and check_if_quantized(weight_tensor):
|
|
328
|
+
quant_params = weight_tensor.quantization
|
|
329
|
+
if op_info.op_quant_config.weight_tensor_config is None:
|
|
330
|
+
raise ValueError(
|
|
331
|
+
"weight_tensor_config cannot be None when weight tensor is"
|
|
332
|
+
" quantized."
|
|
333
|
+
)
|
|
334
|
+
weight_consumer_params = qtyping.UniformQuantParams(
|
|
335
|
+
num_bits=op_info.op_quant_config.weight_tensor_config.num_bits,
|
|
336
|
+
scale=quant_params.scale,
|
|
337
|
+
zero_point=quant_params.zeroPoint,
|
|
338
|
+
quantized_dimension=quant_params.quantizedDimension,
|
|
339
|
+
)
|
|
340
|
+
try:
|
|
341
|
+
# Bias quantization is using fixed quantization scale:
|
|
342
|
+
# input_scale * weight_scale. To avoid hidden numerics error, we check
|
|
343
|
+
# the quantization error in bias quantization.
|
|
344
|
+
bias_quant_params = (
|
|
345
|
+
uniform_quantize_tensor.symmetric_quantize_bias_tensor(
|
|
346
|
+
bias_content,
|
|
347
|
+
input_consumer_params,
|
|
348
|
+
weight_consumer_params,
|
|
349
|
+
)
|
|
350
|
+
)
|
|
351
|
+
except ValueError as e:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
f"Failed to quantize bias tensor for op {op_info.op_name} with op"
|
|
354
|
+
f" id {op_info.subgraph_op_index}."
|
|
355
|
+
) from e
|
|
304
356
|
# We only quantize bias under SRQ. Setting is_constant=True for SRQ only
|
|
305
357
|
# to avoid quantize bias for DRQ and weight-only cases.
|
|
306
358
|
is_constant = (
|
|
@@ -356,6 +408,25 @@ def materialize_slice(
|
|
|
356
408
|
)
|
|
357
409
|
|
|
358
410
|
|
|
411
|
+
def materialize_select(
|
|
412
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
413
|
+
op_info: qtyping.OpInfo,
|
|
414
|
+
graph_info: qtyping.GraphInfo,
|
|
415
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
416
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
417
|
+
"""Materialize tensors in tfl.select."""
|
|
418
|
+
return common_utils.materialize_standard_op(
|
|
419
|
+
op_info,
|
|
420
|
+
graph_info,
|
|
421
|
+
tensor_name_to_qsv,
|
|
422
|
+
get_tensor_quant_params_fn,
|
|
423
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
424
|
+
inputs_to_ignore=[
|
|
425
|
+
0,
|
|
426
|
+
], # Condition tensor does not need to be quantized.
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
|
|
359
430
|
def materialize_select_v2(
|
|
360
431
|
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
361
432
|
op_info: qtyping.OpInfo,
|
|
@@ -375,6 +446,25 @@ def materialize_select_v2(
|
|
|
375
446
|
)
|
|
376
447
|
|
|
377
448
|
|
|
449
|
+
def materialize_dynamic_update_slice(
|
|
450
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
451
|
+
op_info: qtyping.OpInfo,
|
|
452
|
+
graph_info: qtyping.GraphInfo,
|
|
453
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
454
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
455
|
+
"""Materialize tensors in tfl.dynamic_update_slice."""
|
|
456
|
+
return common_utils.materialize_standard_op(
|
|
457
|
+
op_info,
|
|
458
|
+
graph_info,
|
|
459
|
+
tensor_name_to_qsv,
|
|
460
|
+
get_tensor_quant_params_fn,
|
|
461
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
462
|
+
inputs_to_ignore=[
|
|
463
|
+
2,
|
|
464
|
+
], # start_indices do not need to be quantized.
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
|
|
378
468
|
def materialize_sum(
|
|
379
469
|
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
380
470
|
op_info: qtyping.OpInfo,
|
|
@@ -382,12 +472,21 @@ def materialize_sum(
|
|
|
382
472
|
tensor_name_to_qsv: dict[str, Any],
|
|
383
473
|
) -> list[qtyping.TensorTransformationParams]:
|
|
384
474
|
"""Materialize tensors in tfl.sum."""
|
|
475
|
+
# For 8 bits the reference kernel calls a function without input/output
|
|
476
|
+
# constraints. For all others it calls a function that enforces input/output
|
|
477
|
+
# scale/zero point checks. See:
|
|
478
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/reduce.cc#L909
|
|
479
|
+
activation_config = op_info.op_quant_config.activation_tensor_config
|
|
480
|
+
if activation_config is not None and activation_config.num_bits == 8:
|
|
481
|
+
constraint = _OpQuantConstraint.NO_CONSTRAIN
|
|
482
|
+
else:
|
|
483
|
+
constraint = _OpQuantConstraint.SAME_AS_INPUT_SCALE
|
|
385
484
|
return common_utils.materialize_standard_op(
|
|
386
485
|
op_info,
|
|
387
486
|
graph_info,
|
|
388
487
|
tensor_name_to_qsv,
|
|
389
488
|
get_tensor_quant_params_fn,
|
|
390
|
-
constraint=
|
|
489
|
+
constraint=constraint,
|
|
391
490
|
inputs_to_ignore=[1], # Axis index does not need to be quantized.
|
|
392
491
|
)
|
|
393
492
|
|
|
@@ -418,7 +517,13 @@ def materialize_fc_conv(
|
|
|
418
517
|
weights, bias).
|
|
419
518
|
"""
|
|
420
519
|
ignored_inputs = [bias_index] # Bias tensor is quantized separately.
|
|
421
|
-
|
|
520
|
+
should_ignore_weight = False
|
|
521
|
+
if graph_info:
|
|
522
|
+
w_tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]]
|
|
523
|
+
should_ignore_weight = check_if_quantized(w_tensor)
|
|
524
|
+
if should_ignore_weight or _are_weights_too_small(
|
|
525
|
+
op_info, graph_info, weight_index
|
|
526
|
+
):
|
|
422
527
|
ignored_inputs.append(weight_index)
|
|
423
528
|
|
|
424
529
|
op_tensor_params = common_utils.materialize_standard_op(
|
|
@@ -429,7 +534,7 @@ def materialize_fc_conv(
|
|
|
429
534
|
inputs_to_ignore=ignored_inputs,
|
|
430
535
|
)
|
|
431
536
|
|
|
432
|
-
|
|
537
|
+
_materialize_bias_for_fc_conv_ops(
|
|
433
538
|
op_info,
|
|
434
539
|
graph_info,
|
|
435
540
|
op_tensor_params,
|
|
@@ -484,7 +589,7 @@ def materialize_conv2d_transpose(
|
|
|
484
589
|
"Materialize standard op should return at least two tensors for"
|
|
485
590
|
" conv2d_transpose."
|
|
486
591
|
)
|
|
487
|
-
|
|
592
|
+
_materialize_bias_for_fc_conv_ops(
|
|
488
593
|
op_info,
|
|
489
594
|
graph_info,
|
|
490
595
|
op_tensor_params,
|
|
@@ -635,3 +740,521 @@ def materialize_split(
|
|
|
635
740
|
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
636
741
|
inputs_to_ignore=[0], # Split dimension does not need to be quantized.
|
|
637
742
|
)
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
def materialize_pad(
|
|
746
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
747
|
+
op_info: qtyping.OpInfo,
|
|
748
|
+
graph_info: qtyping.GraphInfo,
|
|
749
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
750
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
751
|
+
"""Materialize tensors in tfl.pad."""
|
|
752
|
+
return common_utils.materialize_standard_op(
|
|
753
|
+
op_info,
|
|
754
|
+
graph_info,
|
|
755
|
+
tensor_name_to_qsv,
|
|
756
|
+
get_tensor_quant_params_fn,
|
|
757
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
758
|
+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def materialize_padv2(
|
|
763
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
764
|
+
op_info: qtyping.OpInfo,
|
|
765
|
+
graph_info: qtyping.GraphInfo,
|
|
766
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
767
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
768
|
+
"""Materialize tensors in tfl.padv2."""
|
|
769
|
+
return common_utils.materialize_standard_op(
|
|
770
|
+
op_info,
|
|
771
|
+
graph_info,
|
|
772
|
+
tensor_name_to_qsv,
|
|
773
|
+
get_tensor_quant_params_fn,
|
|
774
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
775
|
+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def materialize_mirror_pad(
|
|
780
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
781
|
+
op_info: qtyping.OpInfo,
|
|
782
|
+
graph_info: qtyping.GraphInfo,
|
|
783
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
784
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
785
|
+
"""Materialize tensors in tfl.mirror_pad.
|
|
786
|
+
|
|
787
|
+
Args:
|
|
788
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
|
789
|
+
tensor.
|
|
790
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
791
|
+
graph_info: Graph information needed to perform quantization for the op.
|
|
792
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
793
|
+
|
|
794
|
+
Returns:
|
|
795
|
+
A list of `qtyping.TensorTransformationParams` for the tensors in the op.
|
|
796
|
+
"""
|
|
797
|
+
return common_utils.materialize_standard_op(
|
|
798
|
+
op_info,
|
|
799
|
+
graph_info,
|
|
800
|
+
tensor_name_to_qsv,
|
|
801
|
+
get_tensor_quant_params_fn,
|
|
802
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
803
|
+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def materialize_space_to_depth(
|
|
808
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
809
|
+
op_info: qtyping.OpInfo,
|
|
810
|
+
graph_info: qtyping.GraphInfo,
|
|
811
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
812
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
813
|
+
"""Materialize tensors in tfl.space_to_depth.
|
|
814
|
+
|
|
815
|
+
Args:
|
|
816
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
|
817
|
+
tensor.
|
|
818
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
819
|
+
graph_info: Graph information needed to perform quantization for the op.
|
|
820
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
821
|
+
|
|
822
|
+
Returns:
|
|
823
|
+
A list of `qtyping.TensorTransformationParams` for the tensors in the op.
|
|
824
|
+
"""
|
|
825
|
+
return common_utils.materialize_standard_op(
|
|
826
|
+
op_info,
|
|
827
|
+
graph_info,
|
|
828
|
+
tensor_name_to_qsv,
|
|
829
|
+
get_tensor_quant_params_fn,
|
|
830
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
def materialize_squared_difference(
|
|
835
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
836
|
+
op_info: qtyping.OpInfo,
|
|
837
|
+
graph_info: qtyping.GraphInfo,
|
|
838
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
839
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
840
|
+
"""Materialize tensors in tfl.squared_difference."""
|
|
841
|
+
return common_utils.materialize_standard_op(
|
|
842
|
+
op_info,
|
|
843
|
+
graph_info,
|
|
844
|
+
tensor_name_to_qsv,
|
|
845
|
+
get_tensor_quant_params_fn,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def materialize_max_pool_2d(
|
|
850
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
851
|
+
op_info: qtyping.OpInfo,
|
|
852
|
+
graph_info: qtyping.GraphInfo,
|
|
853
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
854
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
855
|
+
"""Materialize tensors in tfl.max_pool_2d."""
|
|
856
|
+
return common_utils.materialize_standard_op(
|
|
857
|
+
op_info,
|
|
858
|
+
graph_info,
|
|
859
|
+
tensor_name_to_qsv,
|
|
860
|
+
get_tensor_quant_params_fn,
|
|
861
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def materialize_resize_bilinear(
|
|
866
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
867
|
+
op_info: qtyping.OpInfo,
|
|
868
|
+
graph_info: qtyping.GraphInfo,
|
|
869
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
870
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
871
|
+
"""Materialize tensors in tfl.resize_bilinear."""
|
|
872
|
+
return common_utils.materialize_standard_op(
|
|
873
|
+
op_info,
|
|
874
|
+
graph_info,
|
|
875
|
+
tensor_name_to_qsv,
|
|
876
|
+
get_tensor_quant_params_fn,
|
|
877
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
878
|
+
inputs_to_ignore=[1], # Resize size does not need to be quantized.
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def materialize_resize_nearest_neighbor(
|
|
883
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
884
|
+
op_info: qtyping.OpInfo,
|
|
885
|
+
graph_info: qtyping.GraphInfo,
|
|
886
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
887
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
888
|
+
"""Materialize tensors in tfl.resize_nearest_neighbor."""
|
|
889
|
+
return common_utils.materialize_standard_op(
|
|
890
|
+
op_info,
|
|
891
|
+
graph_info,
|
|
892
|
+
tensor_name_to_qsv,
|
|
893
|
+
get_tensor_quant_params_fn,
|
|
894
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
895
|
+
inputs_to_ignore=[1], # Resize size does not need to be quantized.
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def materialize_gather_nd(
|
|
900
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
901
|
+
op_info: qtyping.OpInfo,
|
|
902
|
+
graph_info: qtyping.GraphInfo,
|
|
903
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
904
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
905
|
+
"""Materialize tensors in tfl.gather_nd."""
|
|
906
|
+
return common_utils.materialize_standard_op(
|
|
907
|
+
op_info,
|
|
908
|
+
graph_info,
|
|
909
|
+
tensor_name_to_qsv,
|
|
910
|
+
get_tensor_quant_params_fn,
|
|
911
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
912
|
+
inputs_to_ignore=[1], # Gather indices do not need to be quantized.
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
def materialize_maximum(
|
|
917
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
918
|
+
op_info: qtyping.OpInfo,
|
|
919
|
+
graph_info: qtyping.GraphInfo,
|
|
920
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
921
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
922
|
+
"""Materialize tensors in tfl.maximum."""
|
|
923
|
+
return common_utils.materialize_standard_op(
|
|
924
|
+
op_info,
|
|
925
|
+
graph_info,
|
|
926
|
+
tensor_name_to_qsv,
|
|
927
|
+
get_tensor_quant_params_fn,
|
|
928
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def materialize_pack(
|
|
933
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
934
|
+
op_info: qtyping.OpInfo,
|
|
935
|
+
graph_info: qtyping.GraphInfo,
|
|
936
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
937
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
938
|
+
"""Materialize tensors in tfl.pack."""
|
|
939
|
+
return common_utils.materialize_standard_op(
|
|
940
|
+
op_info,
|
|
941
|
+
graph_info,
|
|
942
|
+
tensor_name_to_qsv,
|
|
943
|
+
get_tensor_quant_params_fn,
|
|
944
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
def materialize_unpack(
|
|
949
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
950
|
+
op_info: qtyping.OpInfo,
|
|
951
|
+
graph_info: qtyping.GraphInfo,
|
|
952
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
953
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
954
|
+
"""Materialize tensors in tfl.unpack."""
|
|
955
|
+
return common_utils.materialize_standard_op(
|
|
956
|
+
op_info,
|
|
957
|
+
graph_info,
|
|
958
|
+
tensor_name_to_qsv,
|
|
959
|
+
get_tensor_quant_params_fn,
|
|
960
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
def materialize_div(
|
|
965
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
966
|
+
op_info: qtyping.OpInfo,
|
|
967
|
+
graph_info: qtyping.GraphInfo,
|
|
968
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
969
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
970
|
+
"""Materialize tensors in tfl.div."""
|
|
971
|
+
return common_utils.materialize_standard_op(
|
|
972
|
+
op_info,
|
|
973
|
+
graph_info,
|
|
974
|
+
tensor_name_to_qsv,
|
|
975
|
+
get_tensor_quant_params_fn,
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
def materialize_broadcast_to(
|
|
980
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
981
|
+
op_info: qtyping.OpInfo,
|
|
982
|
+
graph_info: qtyping.GraphInfo,
|
|
983
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
984
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
985
|
+
"""Materialize tensors in tfl.broadcast_to."""
|
|
986
|
+
return common_utils.materialize_standard_op(
|
|
987
|
+
op_info,
|
|
988
|
+
graph_info,
|
|
989
|
+
tensor_name_to_qsv,
|
|
990
|
+
get_tensor_quant_params_fn,
|
|
991
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
992
|
+
inputs_to_ignore=[1], # Shape tensor does not need to be quantized.
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
def materialize_sqrt(
|
|
997
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
998
|
+
op_info: qtyping.OpInfo,
|
|
999
|
+
graph_info: qtyping.GraphInfo,
|
|
1000
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1001
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1002
|
+
"""Materialize tensors in tfl.sqrt."""
|
|
1003
|
+
return common_utils.materialize_standard_op(
|
|
1004
|
+
op_info,
|
|
1005
|
+
graph_info,
|
|
1006
|
+
tensor_name_to_qsv,
|
|
1007
|
+
get_tensor_quant_params_fn,
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
def materialize_hard_swish(
|
|
1012
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1013
|
+
op_info: qtyping.OpInfo,
|
|
1014
|
+
graph_info: qtyping.GraphInfo,
|
|
1015
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1016
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1017
|
+
"""Materialize tensors in tfl.hard_swish."""
|
|
1018
|
+
return common_utils.materialize_standard_op(
|
|
1019
|
+
op_info,
|
|
1020
|
+
graph_info,
|
|
1021
|
+
tensor_name_to_qsv,
|
|
1022
|
+
get_tensor_quant_params_fn,
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
def materialize_gather(
|
|
1027
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1028
|
+
op_info: qtyping.OpInfo,
|
|
1029
|
+
graph_info: qtyping.GraphInfo,
|
|
1030
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1031
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1032
|
+
"""Materialize tensors in tfl.gather."""
|
|
1033
|
+
return common_utils.materialize_standard_op(
|
|
1034
|
+
op_info,
|
|
1035
|
+
graph_info,
|
|
1036
|
+
tensor_name_to_qsv,
|
|
1037
|
+
get_tensor_quant_params_fn,
|
|
1038
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
1039
|
+
inputs_to_ignore=[1], # Indices do not need to be quantized.
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
def materialize_reduce_min(
|
|
1044
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1045
|
+
op_info: qtyping.OpInfo,
|
|
1046
|
+
graph_info: qtyping.GraphInfo,
|
|
1047
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1048
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1049
|
+
"""Materialize tensors in tfl.reduce_min."""
|
|
1050
|
+
return common_utils.materialize_standard_op(
|
|
1051
|
+
op_info,
|
|
1052
|
+
graph_info,
|
|
1053
|
+
tensor_name_to_qsv,
|
|
1054
|
+
get_tensor_quant_params_fn,
|
|
1055
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
1056
|
+
inputs_to_ignore=[1], # Axis index does not need to be quantized.
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
|
|
1060
|
+
def materialize_equal(
|
|
1061
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1062
|
+
op_info: qtyping.OpInfo,
|
|
1063
|
+
graph_info: qtyping.GraphInfo,
|
|
1064
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1065
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1066
|
+
"""Materialize tensors in tfl.equal."""
|
|
1067
|
+
return common_utils.materialize_standard_op(
|
|
1068
|
+
op_info,
|
|
1069
|
+
graph_info,
|
|
1070
|
+
tensor_name_to_qsv,
|
|
1071
|
+
get_tensor_quant_params_fn,
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def materialize_not_equal(
|
|
1076
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1077
|
+
op_info: qtyping.OpInfo,
|
|
1078
|
+
graph_info: qtyping.GraphInfo,
|
|
1079
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1080
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1081
|
+
"""Materialize tensors in tfl.not_equal."""
|
|
1082
|
+
return common_utils.materialize_standard_op(
|
|
1083
|
+
op_info,
|
|
1084
|
+
graph_info,
|
|
1085
|
+
tensor_name_to_qsv,
|
|
1086
|
+
get_tensor_quant_params_fn,
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
def materialize_relu(
|
|
1091
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
1092
|
+
op_info: qtyping.OpInfo,
|
|
1093
|
+
graph_info: qtyping.GraphInfo,
|
|
1094
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
1095
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
1096
|
+
"""Materialize tensors in tfl.relu."""
|
|
1097
|
+
return common_utils.materialize_standard_op(
|
|
1098
|
+
op_info,
|
|
1099
|
+
graph_info,
|
|
1100
|
+
tensor_name_to_qsv,
|
|
1101
|
+
get_tensor_quant_params_fn,
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
|
|
1105
|
+
def _get_tensor_shape_for_blockwise(
|
|
1106
|
+
tensor_shape: Sequence[int], quantized_dim: int, block_size: int
|
|
1107
|
+
) -> list[int]:
|
|
1108
|
+
"""Get the tensor shape for blockwise quantization.
|
|
1109
|
+
|
|
1110
|
+
This function splits the quantize dimension of the tensor into blocks and the
|
|
1111
|
+
dim/blocks. Hence, min/max of the tensor can be calculated for each block
|
|
1112
|
+
using existing functions.
|
|
1113
|
+
|
|
1114
|
+
Args:
|
|
1115
|
+
tensor_shape: The original shape of the tensor.
|
|
1116
|
+
quantized_dim: The dimension to be quantized blockwise.
|
|
1117
|
+
block_size: The size of the block.
|
|
1118
|
+
|
|
1119
|
+
Returns:
|
|
1120
|
+
The new tensor shape for calculating scale and zp for blockwise
|
|
1121
|
+
quantization.
|
|
1122
|
+
"""
|
|
1123
|
+
new_shape = []
|
|
1124
|
+
for index, val in enumerate(tensor_shape):
|
|
1125
|
+
if index == quantized_dim:
|
|
1126
|
+
new_shape.append(int(val / block_size))
|
|
1127
|
+
new_shape.append(block_size)
|
|
1128
|
+
else:
|
|
1129
|
+
new_shape.append(val)
|
|
1130
|
+
return new_shape
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
def _reshape_data_for_blockwise(
|
|
1134
|
+
tensor_data: np.ndarray,
|
|
1135
|
+
quantized_dim: int,
|
|
1136
|
+
block_size: int,
|
|
1137
|
+
) -> tuple[np.ndarray, int]:
|
|
1138
|
+
"""Reshapes data for blockwise quantization.
|
|
1139
|
+
|
|
1140
|
+
Args:
|
|
1141
|
+
tensor_data: The original tensor data.
|
|
1142
|
+
quantized_dim: The dimension to be quantized blockwise.
|
|
1143
|
+
block_size: The size of the block. `block_size must be a multiple of 32. `
|
|
1144
|
+
`The tensor quantized dimension shape must be divisible by block_size.
|
|
1145
|
+
|
|
1146
|
+
Returns:
|
|
1147
|
+
A tuple containing the reshaped tensor data and the new reduce dimension.
|
|
1148
|
+
"""
|
|
1149
|
+
|
|
1150
|
+
# TODO: b/417508018 - create AEQ specific error class instead of
|
|
1151
|
+
# using generic ValueError.
|
|
1152
|
+
if tensor_data.shape[quantized_dim] % block_size != 0:
|
|
1153
|
+
raise ValueError(
|
|
1154
|
+
"Tensor quantization dimension must be divisible by block size for"
|
|
1155
|
+
" blockwise quantization."
|
|
1156
|
+
)
|
|
1157
|
+
new_shape = _get_tensor_shape_for_blockwise(
|
|
1158
|
+
tensor_data.shape, quantized_dim, block_size
|
|
1159
|
+
)
|
|
1160
|
+
reshaped_data = tensor_data.reshape(new_shape)
|
|
1161
|
+
return reshaped_data, quantized_dim + 1
|
|
1162
|
+
|
|
1163
|
+
|
|
1164
|
+
def broadcast_scale_zp_for_blockwise(
|
|
1165
|
+
tensor_content: np.ndarray,
|
|
1166
|
+
quant_params: qtyping.UniformQuantParams,
|
|
1167
|
+
) -> qtyping.UniformQuantParams:
|
|
1168
|
+
"""Broadcasts scale and zp for blockwise quantization.
|
|
1169
|
+
|
|
1170
|
+
Args:
|
|
1171
|
+
tensor_content: The original tensor data.
|
|
1172
|
+
quant_params: The quantization parameters.
|
|
1173
|
+
`quant_params.quantized_dimension` must be specified.
|
|
1174
|
+
`quant_params.block_size` must be specified and positive.
|
|
1175
|
+
|
|
1176
|
+
Returns:
|
|
1177
|
+
The updated quantization parameters with broadcasted scale and zp for
|
|
1178
|
+
correct constant quantization.
|
|
1179
|
+
"""
|
|
1180
|
+
if quant_params.quantized_dimension is None:
|
|
1181
|
+
raise ValueError("Quantized dimension must be specified.")
|
|
1182
|
+
if quant_params.block_size is None or quant_params.block_size <= 0:
|
|
1183
|
+
raise ValueError("Block size must be specified and positive.")
|
|
1184
|
+
quantized_dim = quant_params.quantized_dimension
|
|
1185
|
+
expanded_tensor_shape = _get_tensor_shape_for_blockwise(
|
|
1186
|
+
tensor_content.shape, quantized_dim, quant_params.block_size
|
|
1187
|
+
)
|
|
1188
|
+
expanded_scale = np.reshape(
|
|
1189
|
+
np.broadcast_to(
|
|
1190
|
+
np.expand_dims(quant_params.scale, quantized_dim + 1),
|
|
1191
|
+
expanded_tensor_shape,
|
|
1192
|
+
),
|
|
1193
|
+
tensor_content.shape,
|
|
1194
|
+
)
|
|
1195
|
+
expanded_zp = np.reshape(
|
|
1196
|
+
np.broadcast_to(
|
|
1197
|
+
np.expand_dims(quant_params.zero_point, quantized_dim + 1),
|
|
1198
|
+
expanded_tensor_shape,
|
|
1199
|
+
),
|
|
1200
|
+
tensor_content.shape,
|
|
1201
|
+
)
|
|
1202
|
+
return qtyping.UniformQuantParams(
|
|
1203
|
+
scale=expanded_scale,
|
|
1204
|
+
zero_point=expanded_zp,
|
|
1205
|
+
num_bits=quant_params.num_bits,
|
|
1206
|
+
symmetric=quant_params.symmetric,
|
|
1207
|
+
quantized_dimension=quantized_dim,
|
|
1208
|
+
block_size=quant_params.block_size,
|
|
1209
|
+
)
|
|
1210
|
+
|
|
1211
|
+
|
|
1212
|
+
def init_tensor_min_max(
|
|
1213
|
+
tensor_data: Optional[np.ndarray],
|
|
1214
|
+
op_info: qtyping.OpInfo,
|
|
1215
|
+
) -> qtyping.QSV:
|
|
1216
|
+
"""Initialize the min/max for a tensor.
|
|
1217
|
+
|
|
1218
|
+
This function initializes the min/max values for a tensor.
|
|
1219
|
+
|
|
1220
|
+
Args:
|
|
1221
|
+
tensor_data: The tensor data.
|
|
1222
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
1223
|
+
|
|
1224
|
+
Returns:
|
|
1225
|
+
A dictionary containing the min/max values for the tensor, or an empty
|
|
1226
|
+
dictionary if the tensor data is None.
|
|
1227
|
+
"""
|
|
1228
|
+
weight_tensor_config = op_info.op_quant_config.weight_tensor_config
|
|
1229
|
+
if tensor_data is None or weight_tensor_config is None:
|
|
1230
|
+
return {}
|
|
1231
|
+
else:
|
|
1232
|
+
# Get reduce dimension for min/max calculation based on quantization
|
|
1233
|
+
# granularity.
|
|
1234
|
+
granularity = weight_tensor_config.granularity
|
|
1235
|
+
if granularity == qtyping.QuantGranularity.TENSORWISE:
|
|
1236
|
+
reduce_dims = None
|
|
1237
|
+
keep_dims = True
|
|
1238
|
+
elif granularity == qtyping.QuantGranularity.CHANNELWISE:
|
|
1239
|
+
quantized_dim = common_utils.get_weight_quantized_dim(
|
|
1240
|
+
op_info, tensor_data, weight_tensor_config.granularity
|
|
1241
|
+
)
|
|
1242
|
+
reduce_dims = common_utils.get_reduce_dims(
|
|
1243
|
+
quantized_dim, tensor_data.shape
|
|
1244
|
+
)
|
|
1245
|
+
keep_dims = True
|
|
1246
|
+
elif uniform_quantize_tensor.is_blockwise(granularity):
|
|
1247
|
+
tensor_data, reduce_dims = (
|
|
1248
|
+
uniform_quantize_tensor.reshape_data_for_blockwise(
|
|
1249
|
+
tensor_data,
|
|
1250
|
+
op_info.op_name,
|
|
1251
|
+
granularity,
|
|
1252
|
+
)
|
|
1253
|
+
)
|
|
1254
|
+
keep_dims = False
|
|
1255
|
+
else:
|
|
1256
|
+
raise ValueError(f"Unsupported granularity: {granularity}")
|
|
1257
|
+
return {
|
|
1258
|
+
"min": np.min(tensor_data, axis=reduce_dims, keepdims=keep_dims),
|
|
1259
|
+
"max": np.max(tensor_data, axis=reduce_dims, keepdims=keep_dims),
|
|
1260
|
+
}
|