ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -23,10 +23,16 @@ from collections.abc import Iterator
|
|
|
23
23
|
import dataclasses
|
|
24
24
|
from typing import Optional
|
|
25
25
|
from ai_edge_quantizer import qtyping
|
|
26
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
|
27
|
+
from ai_edge_quantizer.utils import constrained_ops_utils
|
|
26
28
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
27
29
|
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
|
28
30
|
|
|
29
31
|
|
|
32
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
|
33
|
+
_QuantTransformation = qtyping.QuantTransformation
|
|
34
|
+
|
|
35
|
+
|
|
30
36
|
# When a tensor has no producer, we'll assign -1 to the producer field
|
|
31
37
|
# When a tensor is a graph output, we'll also include a -1 in the consumer list
|
|
32
38
|
def check_horizontal_optimization(
|
|
@@ -48,6 +54,15 @@ def check_horizontal_optimization(
|
|
|
48
54
|
Returns:
|
|
49
55
|
True if the two transformations can be merged, False otherwise
|
|
50
56
|
"""
|
|
57
|
+
if (
|
|
58
|
+
isinstance(param1.parameters, qtyping.UniformQuantParams)
|
|
59
|
+
and param1.parameters.hadamard is not None
|
|
60
|
+
):
|
|
61
|
+
if (
|
|
62
|
+
isinstance(param2.parameters, qtyping.UniformQuantParams)
|
|
63
|
+
and param2.parameters.hadamard is not None
|
|
64
|
+
):
|
|
65
|
+
return True
|
|
51
66
|
return (
|
|
52
67
|
param1.parameters == param2.parameters
|
|
53
68
|
and len(param1.transformations) > index
|
|
@@ -162,6 +177,16 @@ class TransformationInstructionsGenerator:
|
|
|
162
177
|
else:
|
|
163
178
|
self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
|
164
179
|
self._create_tensor_name_to_graph_info_map()
|
|
180
|
+
self._same_as_input_scale_ops = (
|
|
181
|
+
constrained_ops_utils.get_constrained_op_list(
|
|
182
|
+
_OpQuantConstraint.SAME_AS_INPUT_SCALE
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
self._same_as_output_scale_ops = (
|
|
186
|
+
constrained_ops_utils.get_constrained_op_list(
|
|
187
|
+
_OpQuantConstraint.SAME_AS_OUTPUT_SCALE
|
|
188
|
+
)
|
|
189
|
+
)
|
|
165
190
|
|
|
166
191
|
@dataclasses.dataclass(frozen=True)
|
|
167
192
|
class TensorGraphInfo:
|
|
@@ -183,11 +208,13 @@ class TransformationInstructionsGenerator:
|
|
|
183
208
|
A tuple of tensor_name and TensorGraphInfo.
|
|
184
209
|
"""
|
|
185
210
|
for tensor_id, tensor in enumerate(subgraph.tensors):
|
|
186
|
-
consumers = [
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
211
|
+
consumers = []
|
|
212
|
+
for op_id, op in enumerate(subgraph.operators):
|
|
213
|
+
# Some ops may use the same input tensor multiple times,
|
|
214
|
+
# and we should handle each time independently.
|
|
215
|
+
for op_input in op.inputs:
|
|
216
|
+
if op_input == tensor_id:
|
|
217
|
+
consumers.append(op_id)
|
|
191
218
|
producer = -1
|
|
192
219
|
for op_id, op in enumerate(subgraph.operators):
|
|
193
220
|
if tensor_id in op.outputs:
|
|
@@ -454,53 +481,181 @@ class TransformationInstructionsGenerator:
|
|
|
454
481
|
transformations.insert(0, producer_trans_rule)
|
|
455
482
|
return transformations
|
|
456
483
|
|
|
484
|
+
def _remove_last_tensor_duplication(
|
|
485
|
+
self, tensor_trans_insts: qtyping.TensorTransformationInsts
|
|
486
|
+
) -> None:
|
|
487
|
+
"""Remove the last tensor duplication so the original tensor can be reused."""
|
|
488
|
+
instructions = tensor_trans_insts.instructions
|
|
489
|
+
if not instructions:
|
|
490
|
+
return
|
|
491
|
+
for i in range(len(instructions) - 1, -1, -1):
|
|
492
|
+
if (
|
|
493
|
+
instructions[i].transformation
|
|
494
|
+
== _QuantTransformation.DUPLICATE_TENSOR
|
|
495
|
+
):
|
|
496
|
+
instructions.pop(i)
|
|
497
|
+
return
|
|
498
|
+
|
|
499
|
+
def _remove_unnecessary_buffer_duplication(
|
|
500
|
+
self, tensor_trans_insts: qtyping.TensorTransformationInsts
|
|
501
|
+
) -> None:
|
|
502
|
+
"""Remove buffer duplications that comes after a tensor duplication.
|
|
503
|
+
|
|
504
|
+
When a tensor is duplicated, a new buffer is created for it. Therefore,
|
|
505
|
+
buffer duplication transformation that comes after it is unnecessary.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
tensor_trans_insts: Transformation instructions for a tensor.
|
|
509
|
+
"""
|
|
510
|
+
instructions = tensor_trans_insts.instructions
|
|
511
|
+
if not instructions:
|
|
512
|
+
return
|
|
513
|
+
|
|
514
|
+
# Find all consumers that have a tensor duplication.
|
|
515
|
+
consumers_with_tensor_duplication = set()
|
|
516
|
+
for instr in instructions:
|
|
517
|
+
if instr.transformation == _QuantTransformation.DUPLICATE_TENSOR:
|
|
518
|
+
consumers_with_tensor_duplication.update(instr.consumers)
|
|
519
|
+
if not consumers_with_tensor_duplication:
|
|
520
|
+
return
|
|
521
|
+
|
|
522
|
+
# Remove a buffer duplication that comes with a tensor duplication.
|
|
523
|
+
for i in range(len(instructions) - 1, -1, -1):
|
|
524
|
+
instr = instructions[i]
|
|
525
|
+
if (
|
|
526
|
+
instr.transformation == _QuantTransformation.DUPLICATE_BUFFER
|
|
527
|
+
and consumers_with_tensor_duplication.issuperset(instr.consumers)
|
|
528
|
+
):
|
|
529
|
+
instructions.pop(i)
|
|
530
|
+
|
|
531
|
+
def _is_valid_quantize_requantize_pair(
|
|
532
|
+
self,
|
|
533
|
+
instr_0: qtyping.TransformationInst,
|
|
534
|
+
instr_1: qtyping.TransformationInst,
|
|
535
|
+
) -> bool:
|
|
536
|
+
"""Checks if the two instructions form a valid quantize and requantize pair."""
|
|
537
|
+
return (
|
|
538
|
+
instr_0.transformation == _QuantTransformation.QUANTIZE_TENSOR
|
|
539
|
+
and instr_1.transformation == _QuantTransformation.ADD_QUANTIZE
|
|
540
|
+
and instr_0.consumers == instr_1.consumers
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
def _is_op_constrained(
|
|
544
|
+
self, subgraph_id: int, op_index: int
|
|
545
|
+
) -> bool:
|
|
546
|
+
"""Checks if the op has same as input or output scale constraints."""
|
|
547
|
+
op_name = tfl_flatbuffer_utils.get_op_name_by_index(
|
|
548
|
+
self.flatbuffer_model, subgraph_id, op_index
|
|
549
|
+
)
|
|
550
|
+
return (
|
|
551
|
+
op_name in self._same_as_input_scale_ops
|
|
552
|
+
or op_name in self._same_as_output_scale_ops
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
def _are_quant_params_compatible(
|
|
556
|
+
self,
|
|
557
|
+
params_0: qtyping.UniformQuantParams,
|
|
558
|
+
params_1: qtyping.UniformQuantParams,
|
|
559
|
+
) -> bool:
|
|
560
|
+
"""Checks if quant params are the same except for the scale and zero point."""
|
|
561
|
+
ignore_set = {"scale", "zero_point"}
|
|
562
|
+
for field_info in dataclasses.fields(qtyping.UniformQuantParams):
|
|
563
|
+
field_name = field_info.name
|
|
564
|
+
if field_name in ignore_set:
|
|
565
|
+
continue
|
|
566
|
+
if getattr(params_0, field_name) != getattr(params_1, field_name):
|
|
567
|
+
return False
|
|
568
|
+
return True
|
|
569
|
+
|
|
570
|
+
def _eliminate_requantization_for_nonconstrained_provider(
|
|
571
|
+
self, tensor_trans_insts: qtyping.TensorTransformationInsts
|
|
572
|
+
) -> None:
|
|
573
|
+
"""Removes requantization for tensors with a non-constrained provider.
|
|
574
|
+
|
|
575
|
+
Fuses [QUANTIZE_TENSOR, ADD_QUANTIZE] instructions when a tensor has a
|
|
576
|
+
provider op without same as input/ouput scale constrains. Quant params from
|
|
577
|
+
the second instruction are copied to the first one and ADD_QUANTIZE is
|
|
578
|
+
removed.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
tensor_trans_insts: Transformation instructions for a tensor.
|
|
582
|
+
"""
|
|
583
|
+
instructions = tensor_trans_insts.instructions
|
|
584
|
+
if instructions is None or len(instructions) != 2:
|
|
585
|
+
return
|
|
586
|
+
|
|
587
|
+
instr_0, instr_1 = instructions
|
|
588
|
+
params_0 = instr_0.parameters
|
|
589
|
+
params_1 = instr_1.parameters
|
|
590
|
+
producer_op_index = instr_0.producer
|
|
591
|
+
if (
|
|
592
|
+
not isinstance(params_0, qtyping.UniformQuantParams)
|
|
593
|
+
or not isinstance(params_1, qtyping.UniformQuantParams)
|
|
594
|
+
or not self._is_valid_quantize_requantize_pair(instr_0, instr_1)
|
|
595
|
+
or not self._are_quant_params_compatible(params_0, params_1)
|
|
596
|
+
# To avoid fusion when subgraph inputs connected to the main subgraph
|
|
597
|
+
# (e.g. while_body), we skip all tensors with no producer.
|
|
598
|
+
or producer_op_index == -1
|
|
599
|
+
# Can't apply fusion to tensors with a constrained producer since that
|
|
600
|
+
# will break the constraint.
|
|
601
|
+
or self._is_op_constrained(
|
|
602
|
+
tensor_trans_insts.subgraph_id, producer_op_index
|
|
603
|
+
)
|
|
604
|
+
):
|
|
605
|
+
return
|
|
606
|
+
|
|
607
|
+
# Fuse the quantize and requantize.
|
|
608
|
+
instr_0.parameters = dataclasses.replace(
|
|
609
|
+
params_0, scale=params_1.scale, zero_point=params_1.zero_point
|
|
610
|
+
)
|
|
611
|
+
# Remove the requantize instruction.
|
|
612
|
+
instructions.pop(1)
|
|
613
|
+
|
|
457
614
|
def _quant_params_to_transformation_insts(
|
|
458
615
|
self,
|
|
459
616
|
param: qtyping.TensorTransformationParams,
|
|
460
617
|
) -> qtyping.TensorTransformationInsts:
|
|
461
|
-
"""
|
|
618
|
+
"""Convert single tensor quant params to transformation instructions.
|
|
462
619
|
|
|
463
620
|
Args:
|
|
464
|
-
param:
|
|
621
|
+
param: Quantization parameters of a tensor in the graph.
|
|
465
622
|
|
|
466
623
|
Returns:
|
|
467
|
-
|
|
624
|
+
Transformations to be applied to the given tensor.
|
|
468
625
|
"""
|
|
469
|
-
#
|
|
626
|
+
# Setup the structure.
|
|
470
627
|
tensor_info = self._tensor_name_to_graph_info[param.tensor_name]
|
|
471
628
|
tensor_trans_insts = qtyping.TensorTransformationInsts(
|
|
472
629
|
param.tensor_name, tensor_info.subgraph_id, []
|
|
473
630
|
)
|
|
474
631
|
|
|
475
|
-
#
|
|
476
|
-
consumer_group = self._group_consumer_transformations(param)
|
|
477
|
-
# at this point, starting from index 1 of consumer_group, we're having sets
|
|
478
|
-
# that represents transformations that can be grouped together
|
|
479
|
-
transformations_available_for_vertical_optimization = (
|
|
480
|
-
self._produce_transformation_for_vertical_opt(consumer_group, param)
|
|
481
|
-
)
|
|
482
|
-
other_consumer_transformations = (
|
|
483
|
-
self._produce_consumer_transformations_unavailable_for_vertical_opt(
|
|
484
|
-
consumer_group, param
|
|
485
|
-
)
|
|
486
|
-
)
|
|
487
|
-
|
|
632
|
+
# Add all producer rules.
|
|
488
633
|
transformations = []
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
if producer_params:
|
|
492
|
-
for transformation in producer_params.transformations:
|
|
634
|
+
if param.producer:
|
|
635
|
+
for transformation in param.producer.transformations:
|
|
493
636
|
transformations.append(
|
|
494
637
|
qtyping.TransformationInst(
|
|
495
638
|
transformation,
|
|
496
639
|
tensor_info.tensor_id,
|
|
497
640
|
tensor_info.producer,
|
|
498
641
|
tensor_info.consumers,
|
|
499
|
-
|
|
642
|
+
param.producer.parameters,
|
|
500
643
|
)
|
|
501
644
|
)
|
|
502
645
|
|
|
503
|
-
#
|
|
646
|
+
# Horizontal optimization.
|
|
647
|
+
consumer_group = self._group_consumer_transformations(param)
|
|
648
|
+
# At this point, starting from index 1 of consumer_group, we're having sets
|
|
649
|
+
# that represent transformations that can be grouped together.
|
|
650
|
+
transformations_available_for_vertical_optimization = (
|
|
651
|
+
self._produce_transformation_for_vertical_opt(consumer_group, param)
|
|
652
|
+
)
|
|
653
|
+
other_consumer_transformations = (
|
|
654
|
+
self._produce_consumer_transformations_unavailable_for_vertical_opt(
|
|
655
|
+
consumer_group, param
|
|
656
|
+
)
|
|
657
|
+
)
|
|
658
|
+
# Apply vertical optimization.
|
|
504
659
|
last_producer_rule_idx = len(transformations) - 1
|
|
505
660
|
if last_producer_rule_idx >= 0:
|
|
506
661
|
transformations += self._apply_vertical_optimization(
|
|
@@ -509,30 +664,127 @@ class TransformationInstructionsGenerator:
|
|
|
509
664
|
)
|
|
510
665
|
else:
|
|
511
666
|
transformations += transformations_available_for_vertical_optimization
|
|
512
|
-
# Adding other consumers rules
|
|
667
|
+
# Adding other consumers rules.
|
|
513
668
|
transformations += other_consumer_transformations
|
|
514
669
|
tensor_trans_insts.instructions = transformations
|
|
670
|
+
|
|
671
|
+
# Now, when all optimizations are done, we can remove the last tensor
|
|
672
|
+
# duplication instruction, so the original tensor can be reused.
|
|
673
|
+
self._remove_last_tensor_duplication(tensor_trans_insts)
|
|
674
|
+
# With the tensor duplication instructions finalized, we can remove
|
|
675
|
+
# unnecessary buffer duplications applied to the same duplicated tensors.
|
|
676
|
+
# This is not a part of a vertical optimization because vertical
|
|
677
|
+
# optimization only works between producers & consumers, and this is between
|
|
678
|
+
# the consumer only. Also this can't be done during the params generation
|
|
679
|
+
# because removing last tensor duplication has to happen first.
|
|
680
|
+
self._remove_unnecessary_buffer_duplication(tensor_trans_insts)
|
|
681
|
+
|
|
515
682
|
# Check the generated transformation instructions are valid, the function
|
|
516
|
-
# will raise an error if the instructions are not valid
|
|
683
|
+
# will raise an error if the instructions are not valid.
|
|
517
684
|
self._check_tensor_transformation_instructions_valid(tensor_trans_insts)
|
|
518
685
|
|
|
686
|
+
# Remove unnecessary [QUANTIZE_TENSOR, ADD_QUANTIZE] pairs for tensors with
|
|
687
|
+
# providers without same as input/output scale constraints.
|
|
688
|
+
self._eliminate_requantization_for_nonconstrained_provider(
|
|
689
|
+
tensor_trans_insts
|
|
690
|
+
)
|
|
691
|
+
|
|
519
692
|
return tensor_trans_insts
|
|
520
693
|
|
|
521
|
-
def
|
|
522
|
-
self,
|
|
523
|
-
|
|
524
|
-
|
|
694
|
+
def _split_instructions_by_tensor_duplication(
|
|
695
|
+
self,
|
|
696
|
+
instructions: qtyping.TensorTransformationInsts,
|
|
697
|
+
) -> list[list[qtyping.TransformationInst]]:
|
|
698
|
+
"""Split the instructions into subsets by tensor duplication.
|
|
699
|
+
|
|
700
|
+
Splits the instructions into subsets based on which tensor (original or one
|
|
701
|
+
of duplicated ones) they will be applied to.
|
|
702
|
+
|
|
703
|
+
The first subset is for the original tensor. The following subsets are for
|
|
704
|
+
the duplicated tensors. The order of instructions in each subset is
|
|
705
|
+
preserved.
|
|
706
|
+
|
|
707
|
+
Enforced constraints for each duplicated tensor's instructions subset:
|
|
708
|
+
1. The first instruction must be a `DUPLICATE_TENSOR` one.
|
|
709
|
+
2. No other `DUPLICATE_TENSOR` instructions can be present.
|
|
710
|
+
|
|
711
|
+
For the following instructions:
|
|
712
|
+
[
|
|
713
|
+
(transformation=DUPLICATE_TENSOR, consumers=[1, 2, 3]),
|
|
714
|
+
(transformation=DUPLICATE_TENSOR, consumers=[4]),
|
|
715
|
+
(transformation=T1, consumers=[1, 2]),
|
|
716
|
+
(transformation=T2, consumers=[3]),
|
|
717
|
+
(transformation=T3, consumers=[4]),
|
|
718
|
+
(transformation=T4, consumers=[5])
|
|
719
|
+
]
|
|
720
|
+
|
|
721
|
+
`instruction_subsets` will be:
|
|
722
|
+
[
|
|
723
|
+
[(transformation=T4, consumers=[5])],
|
|
724
|
+
[
|
|
725
|
+
(transformation=DUPLICATE_TENSOR, consumers=[1, 2, 3]),
|
|
726
|
+
(transformation=T1, consumers=[1, 2]),
|
|
727
|
+
(transformation=T2, consumers=[3])
|
|
728
|
+
],
|
|
729
|
+
[
|
|
730
|
+
(transformation=DUPLICATE_TENSOR, consumers=[4]),
|
|
731
|
+
(transformation=T3, consumers=[4])
|
|
732
|
+
]
|
|
733
|
+
],
|
|
525
734
|
|
|
526
735
|
Args:
|
|
527
736
|
instructions: Transformation instructions for a tensor.
|
|
528
737
|
|
|
738
|
+
Returns:
|
|
739
|
+
A list of subsets of transformation instructions, where the first subset
|
|
740
|
+
is for the original tensor, and the following subsets are for the
|
|
741
|
+
duplicated tensors.
|
|
742
|
+
|
|
529
743
|
Raises:
|
|
530
|
-
ValueError: If
|
|
744
|
+
ValueError: If DUPLICATE_TENSOR is found and it's not the first
|
|
745
|
+
transformation for its consumers.
|
|
746
|
+
"""
|
|
747
|
+
original_tensor_subset_idx = 0
|
|
748
|
+
instruction_subsets = [[]]
|
|
749
|
+
consumer_to_subset_idx = {}
|
|
750
|
+
for instruction in instructions.instructions:
|
|
751
|
+
if instruction.transformation == _QuantTransformation.DUPLICATE_TENSOR:
|
|
752
|
+
instruction_subsets.append([instruction])
|
|
753
|
+
subset_idx = len(instruction_subsets) - 1
|
|
754
|
+
for consumer in instruction.consumers:
|
|
755
|
+
if consumer in consumer_to_subset_idx:
|
|
756
|
+
raise ValueError(
|
|
757
|
+
f"Tensor {instructions.tensor_name} : duplicate tensor should"
|
|
758
|
+
" be the first instruction for its consumers."
|
|
759
|
+
)
|
|
760
|
+
else:
|
|
761
|
+
consumer_to_subset_idx[consumer] = subset_idx
|
|
762
|
+
else:
|
|
763
|
+
first_consumer = instruction.consumers[0]
|
|
764
|
+
if first_consumer not in consumer_to_subset_idx:
|
|
765
|
+
consumer_to_subset_idx[first_consumer] = original_tensor_subset_idx
|
|
766
|
+
subset_idx = consumer_to_subset_idx[first_consumer]
|
|
767
|
+
instruction_subsets[subset_idx].append(instruction)
|
|
768
|
+
|
|
769
|
+
return instruction_subsets
|
|
770
|
+
|
|
771
|
+
def _check_subset_of_tensor_transformation_instructions_valid(
|
|
772
|
+
self,
|
|
773
|
+
instructions: Optional[list[qtyping.TransformationInst]],
|
|
774
|
+
tensor_name: str,
|
|
775
|
+
):
|
|
776
|
+
"""Check if a subset of tensor transformation instructions is valid.
|
|
777
|
+
|
|
778
|
+
Args:
|
|
779
|
+
instructions: A subset of transformation instructions for a tensor.
|
|
780
|
+
tensor_name: The name of the tensor.
|
|
781
|
+
|
|
782
|
+
Raises:
|
|
783
|
+
ValueError: If the subset of instructions are not valid.
|
|
531
784
|
"""
|
|
532
785
|
is_tensor_unquantized = False
|
|
533
786
|
is_tensor_quantized = False
|
|
534
|
-
|
|
535
|
-
for instruction in instructions.instructions:
|
|
787
|
+
for instruction in instructions:
|
|
536
788
|
transform_type = instruction.transformation
|
|
537
789
|
if transform_type == qtyping.QuantTransformation.NO_QUANTIZE:
|
|
538
790
|
is_tensor_unquantized = True
|
|
@@ -541,18 +793,33 @@ class TransformationInstructionsGenerator:
|
|
|
541
793
|
or transform_type == qtyping.QuantTransformation.ADD_DEQUANTIZE
|
|
542
794
|
):
|
|
543
795
|
is_tensor_quantized = True
|
|
544
|
-
elif transform_type == qtyping.QuantTransformation.EMULATED_SUBCHANNEL:
|
|
545
|
-
is_operator_emulated = True
|
|
546
796
|
if is_tensor_unquantized and is_tensor_quantized:
|
|
547
797
|
raise ValueError(
|
|
548
|
-
"Tensor %s can not be both quantized and unquantized"
|
|
549
|
-
% instructions.tensor_name
|
|
798
|
+
"Tensor %s can not be both quantized and unquantized" % tensor_name
|
|
550
799
|
)
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
800
|
+
|
|
801
|
+
def _check_tensor_transformation_instructions_valid(
|
|
802
|
+
self,
|
|
803
|
+
instructions: qtyping.TensorTransformationInsts,
|
|
804
|
+
):
|
|
805
|
+
"""Check if the tensor transformation instructions are valid.
|
|
806
|
+
|
|
807
|
+
Args:
|
|
808
|
+
instructions: Transformation instructions for a tensor.
|
|
809
|
+
|
|
810
|
+
Raises:
|
|
811
|
+
ValueError: If the instructions are not valid.
|
|
812
|
+
"""
|
|
813
|
+
# Split the instructions into subsets based on which tensor (original or one
|
|
814
|
+
# of duplicated ones) they will be applied to.
|
|
815
|
+
instruction_subsets = self._split_instructions_by_tensor_duplication(
|
|
816
|
+
instructions
|
|
817
|
+
)
|
|
818
|
+
# Check that each subset of instructions is valid.
|
|
819
|
+
for instruction_subset in instruction_subsets:
|
|
820
|
+
self._check_subset_of_tensor_transformation_instructions_valid(
|
|
821
|
+
instruction_subset,
|
|
822
|
+
instructions.tensor_name,
|
|
556
823
|
)
|
|
557
824
|
|
|
558
825
|
def quant_params_to_transformation_insts(
|