ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +158 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
- ai_edge_quantizer/calibrator.py +11 -60
- ai_edge_quantizer/calibrator_test.py +4 -73
- ai_edge_quantizer/default_policy.py +61 -26
- ai_edge_quantizer/model_modifier.py +97 -7
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +31 -8
- ai_edge_quantizer/params_generator.py +17 -10
- ai_edge_quantizer/params_generator_test.py +2 -7
- ai_edge_quantizer/qtyping.py +86 -6
- ai_edge_quantizer/quantizer.py +166 -21
- ai_edge_quantizer/quantizer_test.py +284 -16
- ai_edge_quantizer/recipe.py +154 -42
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +118 -13
- ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
- ai_edge_quantizer/transformation_performer.py +55 -25
- ai_edge_quantizer/transformation_performer_test.py +127 -5
- ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
- ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
- ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
- ai_edge_quantizer/transformations/transformation_utils.py +129 -6
- ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +75 -2
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -23,10 +23,13 @@ from collections.abc import Iterator
|
|
|
23
23
|
import dataclasses
|
|
24
24
|
from typing import Optional
|
|
25
25
|
from ai_edge_quantizer import qtyping
|
|
26
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
|
27
|
+
from ai_edge_quantizer.utils import constrained_ops_utils
|
|
26
28
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
27
29
|
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
|
28
30
|
|
|
29
31
|
|
|
32
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
|
30
33
|
_QuantTransformation = qtyping.QuantTransformation
|
|
31
34
|
|
|
32
35
|
|
|
@@ -51,6 +54,15 @@ def check_horizontal_optimization(
|
|
|
51
54
|
Returns:
|
|
52
55
|
True if the two transformations can be merged, False otherwise
|
|
53
56
|
"""
|
|
57
|
+
if (
|
|
58
|
+
isinstance(param1.parameters, qtyping.UniformQuantParams)
|
|
59
|
+
and param1.parameters.hadamard is not None
|
|
60
|
+
):
|
|
61
|
+
if (
|
|
62
|
+
isinstance(param2.parameters, qtyping.UniformQuantParams)
|
|
63
|
+
and param2.parameters.hadamard is not None
|
|
64
|
+
):
|
|
65
|
+
return True
|
|
54
66
|
return (
|
|
55
67
|
param1.parameters == param2.parameters
|
|
56
68
|
and len(param1.transformations) > index
|
|
@@ -165,6 +177,16 @@ class TransformationInstructionsGenerator:
|
|
|
165
177
|
else:
|
|
166
178
|
self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
|
167
179
|
self._create_tensor_name_to_graph_info_map()
|
|
180
|
+
self._same_as_input_scale_ops = (
|
|
181
|
+
constrained_ops_utils.get_constrained_op_list(
|
|
182
|
+
_OpQuantConstraint.SAME_AS_INPUT_SCALE
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
self._same_as_output_scale_ops = (
|
|
186
|
+
constrained_ops_utils.get_constrained_op_list(
|
|
187
|
+
_OpQuantConstraint.SAME_AS_OUTPUT_SCALE
|
|
188
|
+
)
|
|
189
|
+
)
|
|
168
190
|
|
|
169
191
|
@dataclasses.dataclass(frozen=True)
|
|
170
192
|
class TensorGraphInfo:
|
|
@@ -186,11 +208,13 @@ class TransformationInstructionsGenerator:
|
|
|
186
208
|
A tuple of tensor_name and TensorGraphInfo.
|
|
187
209
|
"""
|
|
188
210
|
for tensor_id, tensor in enumerate(subgraph.tensors):
|
|
189
|
-
consumers = [
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
211
|
+
consumers = []
|
|
212
|
+
for op_id, op in enumerate(subgraph.operators):
|
|
213
|
+
# Some ops may use the same input tensor multiple times,
|
|
214
|
+
# and we should handle each time independently.
|
|
215
|
+
for op_input in op.inputs:
|
|
216
|
+
if op_input == tensor_id:
|
|
217
|
+
consumers.append(op_id)
|
|
194
218
|
producer = -1
|
|
195
219
|
for op_id, op in enumerate(subgraph.operators):
|
|
196
220
|
if tensor_id in op.outputs:
|
|
@@ -504,6 +528,89 @@ class TransformationInstructionsGenerator:
|
|
|
504
528
|
):
|
|
505
529
|
instructions.pop(i)
|
|
506
530
|
|
|
531
|
+
def _is_valid_quantize_requantize_pair(
|
|
532
|
+
self,
|
|
533
|
+
instr_0: qtyping.TransformationInst,
|
|
534
|
+
instr_1: qtyping.TransformationInst,
|
|
535
|
+
) -> bool:
|
|
536
|
+
"""Checks if the two instructions form a valid quantize and requantize pair."""
|
|
537
|
+
return (
|
|
538
|
+
instr_0.transformation == _QuantTransformation.QUANTIZE_TENSOR
|
|
539
|
+
and instr_1.transformation == _QuantTransformation.ADD_QUANTIZE
|
|
540
|
+
and instr_0.consumers == instr_1.consumers
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
def _is_op_constrained(
|
|
544
|
+
self, subgraph_id: int, op_index: int
|
|
545
|
+
) -> bool:
|
|
546
|
+
"""Checks if the op has same as input or output scale constraints."""
|
|
547
|
+
op_name = tfl_flatbuffer_utils.get_op_name_by_index(
|
|
548
|
+
self.flatbuffer_model, subgraph_id, op_index
|
|
549
|
+
)
|
|
550
|
+
return (
|
|
551
|
+
op_name in self._same_as_input_scale_ops
|
|
552
|
+
or op_name in self._same_as_output_scale_ops
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
def _are_quant_params_compatible(
|
|
556
|
+
self,
|
|
557
|
+
params_0: qtyping.UniformQuantParams,
|
|
558
|
+
params_1: qtyping.UniformQuantParams,
|
|
559
|
+
) -> bool:
|
|
560
|
+
"""Checks if quant params are the same except for the scale and zero point."""
|
|
561
|
+
ignore_set = {"scale", "zero_point"}
|
|
562
|
+
for field_info in dataclasses.fields(qtyping.UniformQuantParams):
|
|
563
|
+
field_name = field_info.name
|
|
564
|
+
if field_name in ignore_set:
|
|
565
|
+
continue
|
|
566
|
+
if getattr(params_0, field_name) != getattr(params_1, field_name):
|
|
567
|
+
return False
|
|
568
|
+
return True
|
|
569
|
+
|
|
570
|
+
def _eliminate_requantization_for_nonconstrained_provider(
|
|
571
|
+
self, tensor_trans_insts: qtyping.TensorTransformationInsts
|
|
572
|
+
) -> None:
|
|
573
|
+
"""Removes requantization for tensors with a non-constrained provider.
|
|
574
|
+
|
|
575
|
+
Fuses [QUANTIZE_TENSOR, ADD_QUANTIZE] instructions when a tensor has a
|
|
576
|
+
provider op without same as input/ouput scale constrains. Quant params from
|
|
577
|
+
the second instruction are copied to the first one and ADD_QUANTIZE is
|
|
578
|
+
removed.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
tensor_trans_insts: Transformation instructions for a tensor.
|
|
582
|
+
"""
|
|
583
|
+
instructions = tensor_trans_insts.instructions
|
|
584
|
+
if instructions is None or len(instructions) != 2:
|
|
585
|
+
return
|
|
586
|
+
|
|
587
|
+
instr_0, instr_1 = instructions
|
|
588
|
+
params_0 = instr_0.parameters
|
|
589
|
+
params_1 = instr_1.parameters
|
|
590
|
+
producer_op_index = instr_0.producer
|
|
591
|
+
if (
|
|
592
|
+
not isinstance(params_0, qtyping.UniformQuantParams)
|
|
593
|
+
or not isinstance(params_1, qtyping.UniformQuantParams)
|
|
594
|
+
or not self._is_valid_quantize_requantize_pair(instr_0, instr_1)
|
|
595
|
+
or not self._are_quant_params_compatible(params_0, params_1)
|
|
596
|
+
# To avoid fusion when subgraph inputs connected to the main subgraph
|
|
597
|
+
# (e.g. while_body), we skip all tensors with no producer.
|
|
598
|
+
or producer_op_index == -1
|
|
599
|
+
# Can't apply fusion to tensors with a constrained producer since that
|
|
600
|
+
# will break the constraint.
|
|
601
|
+
or self._is_op_constrained(
|
|
602
|
+
tensor_trans_insts.subgraph_id, producer_op_index
|
|
603
|
+
)
|
|
604
|
+
):
|
|
605
|
+
return
|
|
606
|
+
|
|
607
|
+
# Fuse the quantize and requantize.
|
|
608
|
+
instr_0.parameters = dataclasses.replace(
|
|
609
|
+
params_0, scale=params_1.scale, zero_point=params_1.zero_point
|
|
610
|
+
)
|
|
611
|
+
# Remove the requantize instruction.
|
|
612
|
+
instructions.pop(1)
|
|
613
|
+
|
|
507
614
|
def _quant_params_to_transformation_insts(
|
|
508
615
|
self,
|
|
509
616
|
param: qtyping.TensorTransformationParams,
|
|
@@ -576,6 +683,12 @@ class TransformationInstructionsGenerator:
|
|
|
576
683
|
# will raise an error if the instructions are not valid.
|
|
577
684
|
self._check_tensor_transformation_instructions_valid(tensor_trans_insts)
|
|
578
685
|
|
|
686
|
+
# Remove unnecessary [QUANTIZE_TENSOR, ADD_QUANTIZE] pairs for tensors with
|
|
687
|
+
# providers without same as input/output scale constraints.
|
|
688
|
+
self._eliminate_requantization_for_nonconstrained_provider(
|
|
689
|
+
tensor_trans_insts
|
|
690
|
+
)
|
|
691
|
+
|
|
579
692
|
return tensor_trans_insts
|
|
580
693
|
|
|
581
694
|
def _split_instructions_by_tensor_duplication(
|
|
@@ -671,7 +784,6 @@ class TransformationInstructionsGenerator:
|
|
|
671
784
|
"""
|
|
672
785
|
is_tensor_unquantized = False
|
|
673
786
|
is_tensor_quantized = False
|
|
674
|
-
is_operator_emulated = False
|
|
675
787
|
for instruction in instructions:
|
|
676
788
|
transform_type = instruction.transformation
|
|
677
789
|
if transform_type == qtyping.QuantTransformation.NO_QUANTIZE:
|
|
@@ -681,17 +793,10 @@ class TransformationInstructionsGenerator:
|
|
|
681
793
|
or transform_type == qtyping.QuantTransformation.ADD_DEQUANTIZE
|
|
682
794
|
):
|
|
683
795
|
is_tensor_quantized = True
|
|
684
|
-
elif transform_type == qtyping.QuantTransformation.EMULATED_SUBCHANNEL:
|
|
685
|
-
is_operator_emulated = True
|
|
686
796
|
if is_tensor_unquantized and is_tensor_quantized:
|
|
687
797
|
raise ValueError(
|
|
688
798
|
"Tensor %s can not be both quantized and unquantized" % tensor_name
|
|
689
799
|
)
|
|
690
|
-
if is_operator_emulated and len(instructions) > 1:
|
|
691
|
-
raise ValueError(
|
|
692
|
-
"Tensor %s : op replacement transformation can not be combined with"
|
|
693
|
-
" other transformations." % tensor_name
|
|
694
|
-
)
|
|
695
800
|
|
|
696
801
|
def _check_tensor_transformation_instructions_valid(
|
|
697
802
|
self,
|
|
@@ -15,7 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
"""Tests for instruction_generator."""
|
|
17
17
|
|
|
18
|
+
from collections.abc import Sequence
|
|
18
19
|
import os
|
|
20
|
+
from typing import Optional
|
|
19
21
|
|
|
20
22
|
import numpy as np
|
|
21
23
|
|
|
@@ -953,33 +955,6 @@ class InstructionGeneratorTest(parameterized.TestCase):
|
|
|
953
955
|
instructions["StatefulPartitionedCall:0"], output_transformation
|
|
954
956
|
)
|
|
955
957
|
|
|
956
|
-
def test_raise_error_on_op_replacement_transformation_is_not_unique(self):
|
|
957
|
-
test_model_path = os.path.join(
|
|
958
|
-
TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
|
|
959
|
-
)
|
|
960
|
-
quant_parameters = {}
|
|
961
|
-
quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
|
|
962
|
-
"tfl.quantize",
|
|
963
|
-
qtyping.OpToTensorParams(
|
|
964
|
-
subgraph_op_id=0,
|
|
965
|
-
transformations=[
|
|
966
|
-
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
|
967
|
-
qtyping.QuantTransformation.EMULATED_SUBCHANNEL,
|
|
968
|
-
],
|
|
969
|
-
parameters=qtyping.UniformQuantParams(
|
|
970
|
-
8, None, np.array([1]), np.array([0])
|
|
971
|
-
),
|
|
972
|
-
),
|
|
973
|
-
[],
|
|
974
|
-
)
|
|
975
|
-
ins_gen = instruction_generator.TransformationInstructionsGenerator(
|
|
976
|
-
test_model_path
|
|
977
|
-
)
|
|
978
|
-
with self.assertRaisesRegex(
|
|
979
|
-
ValueError, "op replacement transformation can not be combined"
|
|
980
|
-
):
|
|
981
|
-
ins_gen.quant_params_to_transformation_insts(quant_parameters)
|
|
982
|
-
|
|
983
958
|
def test_raise_error_on_no_quant_conflict(self):
|
|
984
959
|
test_model_path = os.path.join(
|
|
985
960
|
TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
|
|
@@ -1364,5 +1339,166 @@ class InstructionGeneratorTest(parameterized.TestCase):
|
|
|
1364
1339
|
)
|
|
1365
1340
|
|
|
1366
1341
|
|
|
1342
|
+
class EliminateUnnecessaryRequantizationTest(parameterized.TestCase):
|
|
1343
|
+
|
|
1344
|
+
def setUp(self):
|
|
1345
|
+
super().setUp()
|
|
1346
|
+
self.ins_gen = instruction_generator.TransformationInstructionsGenerator(
|
|
1347
|
+
os.path.join(TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite")
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
def _get_test_instruction(
|
|
1351
|
+
self,
|
|
1352
|
+
transformation: qtyping.QuantTransformation,
|
|
1353
|
+
producer: int = -1,
|
|
1354
|
+
consumers: Optional[Sequence[int]] = None,
|
|
1355
|
+
qparams: Optional[qtyping.UniformQuantParams] = None,
|
|
1356
|
+
) -> qtyping.TransformationInst:
|
|
1357
|
+
if consumers is None:
|
|
1358
|
+
consumers = []
|
|
1359
|
+
if qparams is None:
|
|
1360
|
+
qparams = qtyping.UniformQuantParams(
|
|
1361
|
+
num_bits=8,
|
|
1362
|
+
quantized_dimension=None,
|
|
1363
|
+
scale=np.array([1]),
|
|
1364
|
+
zero_point=np.array([0]),
|
|
1365
|
+
)
|
|
1366
|
+
return qtyping.TransformationInst(
|
|
1367
|
+
transformation=transformation,
|
|
1368
|
+
producer=producer,
|
|
1369
|
+
consumers=consumers,
|
|
1370
|
+
parameters=qparams,
|
|
1371
|
+
# Dummy values below.
|
|
1372
|
+
tensor_id=0,
|
|
1373
|
+
)
|
|
1374
|
+
|
|
1375
|
+
def _create_test_insts(
|
|
1376
|
+
self, instructions: list[qtyping.TransformationInst]
|
|
1377
|
+
) -> qtyping.TensorTransformationInsts:
|
|
1378
|
+
return qtyping.TensorTransformationInsts(
|
|
1379
|
+
tensor_name="test_tensor", subgraph_id=0, instructions=instructions
|
|
1380
|
+
)
|
|
1381
|
+
|
|
1382
|
+
def test_no_fusion_when_too_few_instructions(self):
|
|
1383
|
+
tensor_insts = self._create_test_insts([
|
|
1384
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
|
|
1385
|
+
])
|
|
1386
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1387
|
+
tensor_insts
|
|
1388
|
+
)
|
|
1389
|
+
self.assertLen(tensor_insts.instructions, 1)
|
|
1390
|
+
|
|
1391
|
+
def test_no_fusion_when_too_many_instructions(self):
|
|
1392
|
+
tensor_insts = self._create_test_insts([
|
|
1393
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
|
|
1394
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE),
|
|
1395
|
+
self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
|
|
1396
|
+
])
|
|
1397
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1398
|
+
tensor_insts
|
|
1399
|
+
)
|
|
1400
|
+
self.assertLen(tensor_insts.instructions, 3)
|
|
1401
|
+
|
|
1402
|
+
def test_no_fusion_when_invalid_transformation_pair(self):
|
|
1403
|
+
tensor_insts = self._create_test_insts([
|
|
1404
|
+
self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
|
|
1405
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE),
|
|
1406
|
+
])
|
|
1407
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1408
|
+
tensor_insts
|
|
1409
|
+
)
|
|
1410
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1411
|
+
|
|
1412
|
+
def test_no_fusion_when_consumers_mismatch(self):
|
|
1413
|
+
tensor_insts = self._create_test_insts([
|
|
1414
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[0]),
|
|
1415
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1]),
|
|
1416
|
+
])
|
|
1417
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1418
|
+
tensor_insts
|
|
1419
|
+
)
|
|
1420
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1421
|
+
|
|
1422
|
+
def test_no_fusion_when_no_producer(self):
|
|
1423
|
+
producer = -1
|
|
1424
|
+
tensor_insts = self._create_test_insts([
|
|
1425
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer),
|
|
1426
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer),
|
|
1427
|
+
])
|
|
1428
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1429
|
+
tensor_insts
|
|
1430
|
+
)
|
|
1431
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1432
|
+
|
|
1433
|
+
def test_no_fusion_when_quant_params_are_incompatible(self):
|
|
1434
|
+
params_8_bits = qtyping.UniformQuantParams(
|
|
1435
|
+
8, None, np.array([1]), np.array([0])
|
|
1436
|
+
)
|
|
1437
|
+
params_16_bits = qtyping.UniformQuantParams(
|
|
1438
|
+
16, None, np.array([1]), np.array([0])
|
|
1439
|
+
)
|
|
1440
|
+
tensor_insts = self._create_test_insts([
|
|
1441
|
+
self._get_test_instruction(
|
|
1442
|
+
_QTransf.QUANTIZE_TENSOR, qparams=params_8_bits
|
|
1443
|
+
),
|
|
1444
|
+
self._get_test_instruction(
|
|
1445
|
+
_QTransf.ADD_QUANTIZE, qparams=params_16_bits
|
|
1446
|
+
),
|
|
1447
|
+
])
|
|
1448
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1449
|
+
tensor_insts
|
|
1450
|
+
)
|
|
1451
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1452
|
+
|
|
1453
|
+
def test_no_fusion_when_producer_constrained(self):
|
|
1454
|
+
# Reshape op (op index 2) has same as input scale constraint.
|
|
1455
|
+
tensor_insts = self._create_test_insts([
|
|
1456
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer=2),
|
|
1457
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer=2),
|
|
1458
|
+
])
|
|
1459
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1460
|
+
tensor_insts
|
|
1461
|
+
)
|
|
1462
|
+
self.assertLen(tensor_insts.instructions, 2)
|
|
1463
|
+
|
|
1464
|
+
def test_fusion_succeeds(self):
|
|
1465
|
+
producer = 0
|
|
1466
|
+
consumers = [1]
|
|
1467
|
+
params_0 = qtyping.UniformQuantParams(
|
|
1468
|
+
num_bits=8,
|
|
1469
|
+
quantized_dimension=None,
|
|
1470
|
+
scale=np.array([1]),
|
|
1471
|
+
zero_point=np.array([0]),
|
|
1472
|
+
)
|
|
1473
|
+
params_1 = qtyping.UniformQuantParams(
|
|
1474
|
+
num_bits=8,
|
|
1475
|
+
quantized_dimension=None,
|
|
1476
|
+
scale=np.array([2]),
|
|
1477
|
+
zero_point=np.array([1]),
|
|
1478
|
+
)
|
|
1479
|
+
inst_0 = self._get_test_instruction(
|
|
1480
|
+
_QTransf.QUANTIZE_TENSOR, producer, consumers, params_0
|
|
1481
|
+
)
|
|
1482
|
+
inst_1 = self._get_test_instruction(
|
|
1483
|
+
_QTransf.ADD_QUANTIZE, producer, consumers, params_1
|
|
1484
|
+
)
|
|
1485
|
+
tensor_insts = self._create_test_insts([inst_0, inst_1])
|
|
1486
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
|
1487
|
+
tensor_insts
|
|
1488
|
+
)
|
|
1489
|
+
|
|
1490
|
+
self.assertLen(tensor_insts.instructions, 1)
|
|
1491
|
+
result_inst = tensor_insts.instructions[0]
|
|
1492
|
+
self.assertEqual(result_inst.transformation, _QTransf.QUANTIZE_TENSOR)
|
|
1493
|
+
|
|
1494
|
+
result_params = result_inst.parameters
|
|
1495
|
+
# Explicitly narrow the type for pytype.
|
|
1496
|
+
if not isinstance(result_params, qtyping.UniformQuantParams):
|
|
1497
|
+
self.fail("Fused instruction parameters are not UniformQuantParams")
|
|
1498
|
+
|
|
1499
|
+
self.assertEqual(result_params.scale, params_1.scale)
|
|
1500
|
+
self.assertEqual(result_params.zero_point, params_1.zero_point)
|
|
1501
|
+
|
|
1502
|
+
|
|
1367
1503
|
if __name__ == "__main__":
|
|
1368
1504
|
googletest.main()
|
|
@@ -24,7 +24,8 @@ from ai_edge_quantizer import qtyping
|
|
|
24
24
|
from ai_edge_quantizer.transformations import dequant_insert
|
|
25
25
|
from ai_edge_quantizer.transformations import duplicate_buffer
|
|
26
26
|
from ai_edge_quantizer.transformations import duplicate_tensor
|
|
27
|
-
from ai_edge_quantizer.transformations import
|
|
27
|
+
from ai_edge_quantizer.transformations import insert_decomposed_hadamard_rotation
|
|
28
|
+
from ai_edge_quantizer.transformations import insert_hadamard_rotation
|
|
28
29
|
from ai_edge_quantizer.transformations import quant_insert
|
|
29
30
|
from ai_edge_quantizer.transformations import quantize_tensor
|
|
30
31
|
from ai_edge_quantizer.transformations import transformation_utils
|
|
@@ -71,7 +72,7 @@ class TransformationPerformer:
|
|
|
71
72
|
quantize_tensor.quantize_tensor
|
|
72
73
|
),
|
|
73
74
|
qtyping.QuantTransformation.EMULATED_SUBCHANNEL: (
|
|
74
|
-
|
|
75
|
+
transformation_utils.raise_deprecated_error
|
|
75
76
|
),
|
|
76
77
|
qtyping.QuantTransformation.ADD_QUANTIZE: quant_insert.insert_quant,
|
|
77
78
|
qtyping.QuantTransformation.DUPLICATE_BUFFER: (
|
|
@@ -80,6 +81,12 @@ class TransformationPerformer:
|
|
|
80
81
|
qtyping.QuantTransformation.DUPLICATE_TENSOR: (
|
|
81
82
|
duplicate_tensor.duplicate_tensor
|
|
82
83
|
),
|
|
84
|
+
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
|
|
85
|
+
insert_hadamard_rotation.insert_hadamard_rotation
|
|
86
|
+
),
|
|
87
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION: (
|
|
88
|
+
insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation
|
|
89
|
+
),
|
|
83
90
|
}
|
|
84
91
|
# transformations are seprated in two categories:
|
|
85
92
|
# op_insertion_transformations are transformations that only insert ops
|
|
@@ -91,6 +98,8 @@ class TransformationPerformer:
|
|
|
91
98
|
qtyping.QuantTransformation.ADD_QUANTIZE,
|
|
92
99
|
qtyping.QuantTransformation.DUPLICATE_BUFFER,
|
|
93
100
|
qtyping.QuantTransformation.DUPLICATE_TENSOR,
|
|
101
|
+
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
|
102
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION,
|
|
94
103
|
])
|
|
95
104
|
self._op_replacement_transformations = set(
|
|
96
105
|
[qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
|
|
@@ -180,6 +189,38 @@ class TransformationPerformer:
|
|
|
180
189
|
)
|
|
181
190
|
transformation.tensor_id = trans_info.output_tensor_id
|
|
182
191
|
|
|
192
|
+
def _get_updated_producer_id(
|
|
193
|
+
self, original_producer_id: int, subgraph_id: int
|
|
194
|
+
) -> int:
|
|
195
|
+
"""Update the producer of a transformation instruction."""
|
|
196
|
+
if original_producer_id is None or original_producer_id < 0:
|
|
197
|
+
producer = -1
|
|
198
|
+
elif original_producer_id < len(self._original_op_id_map[subgraph_id]):
|
|
199
|
+
producer = self._original_op_id_map[subgraph_id][original_producer_id]
|
|
200
|
+
else:
|
|
201
|
+
# If the producer id is not in the original op map, it's an added op,
|
|
202
|
+
# go the added op map to find the producer.
|
|
203
|
+
producer = self._added_op_id_map[subgraph_id][
|
|
204
|
+
original_producer_id - len(self._original_op_id_map[subgraph_id])
|
|
205
|
+
]
|
|
206
|
+
return producer
|
|
207
|
+
|
|
208
|
+
def _get_updated_consumer_ids(
|
|
209
|
+
self,
|
|
210
|
+
original_consumer_ids: list[int],
|
|
211
|
+
subgraph_id: int,
|
|
212
|
+
) -> list[int]:
|
|
213
|
+
"""Update the consumers of a transformation instruction."""
|
|
214
|
+
consumers = []
|
|
215
|
+
for original_op_id in original_consumer_ids:
|
|
216
|
+
new_consumer_id = (
|
|
217
|
+
-1
|
|
218
|
+
if original_op_id == -1
|
|
219
|
+
else self._original_op_id_map[subgraph_id][original_op_id]
|
|
220
|
+
)
|
|
221
|
+
consumers.append(new_consumer_id)
|
|
222
|
+
return consumers
|
|
223
|
+
|
|
183
224
|
def _apply_single_transformation(
|
|
184
225
|
self,
|
|
185
226
|
transformation_inst: qtyping.TensorTransformationInsts,
|
|
@@ -198,28 +239,12 @@ class TransformationPerformer:
|
|
|
198
239
|
None, update the transformation_inst & tflite_model in place
|
|
199
240
|
"""
|
|
200
241
|
instruction = transformation_inst.instructions[transformation_index]
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
instruction.producer
|
|
208
|
-
]
|
|
209
|
-
else:
|
|
210
|
-
# if the producer id is not in the original op map, it's an added op,
|
|
211
|
-
# go the corresponding new maps
|
|
212
|
-
producer = self._added_op_id_map[transformation_inst.subgraph_id][
|
|
213
|
-
instruction.producer
|
|
214
|
-
- len(self._original_op_id_map[transformation_inst.subgraph_id])
|
|
215
|
-
]
|
|
216
|
-
consumers = []
|
|
217
|
-
for original_op_id in instruction.consumers:
|
|
218
|
-
consumers.append(
|
|
219
|
-
self._original_op_id_map[transformation_inst.subgraph_id][
|
|
220
|
-
original_op_id
|
|
221
|
-
]
|
|
222
|
-
)
|
|
242
|
+
producer = self._get_updated_producer_id(
|
|
243
|
+
instruction.producer, transformation_inst.subgraph_id
|
|
244
|
+
)
|
|
245
|
+
consumers = self._get_updated_consumer_ids(
|
|
246
|
+
instruction.consumers, transformation_inst.subgraph_id
|
|
247
|
+
)
|
|
223
248
|
trans_info = self._transformation_registration[instruction.transformation](
|
|
224
249
|
transformation_utils.TransformationInput(
|
|
225
250
|
instruction.tensor_id,
|
|
@@ -239,7 +264,12 @@ class TransformationPerformer:
|
|
|
239
264
|
)
|
|
240
265
|
self._update_op_id_map(
|
|
241
266
|
transformation_inst.subgraph_id,
|
|
242
|
-
|
|
267
|
+
# The added op must be right before the most immediate consumer, unless
|
|
268
|
+
# the consumer is the graph output (id=-1), then use the producer's
|
|
269
|
+
# index instead.
|
|
270
|
+
min(instruction.consumers)
|
|
271
|
+
if min(instruction.consumers) >= 0
|
|
272
|
+
else instruction.producer + 1,
|
|
243
273
|
trans_info.num_ops_added,
|
|
244
274
|
)
|
|
245
275
|
|