ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -23,10 +23,16 @@ from collections.abc import Iterator
23
23
  import dataclasses
24
24
  from typing import Optional
25
25
  from ai_edge_quantizer import qtyping
26
+ from ai_edge_quantizer.algorithms.utils import common_utils
27
+ from ai_edge_quantizer.utils import constrained_ops_utils
26
28
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
27
29
  from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
28
30
 
29
31
 
32
+ _OpQuantConstraint = common_utils.OpQuantConstraint
33
+ _QuantTransformation = qtyping.QuantTransformation
34
+
35
+
30
36
  # When a tensor has no producer, we'll assign -1 to the producer field
31
37
  # When a tensor is a graph output, we'll also include a -1 in the consumer list
32
38
  def check_horizontal_optimization(
@@ -48,6 +54,15 @@ def check_horizontal_optimization(
48
54
  Returns:
49
55
  True if the two transformations can be merged, False otherwise
50
56
  """
57
+ if (
58
+ isinstance(param1.parameters, qtyping.UniformQuantParams)
59
+ and param1.parameters.hadamard is not None
60
+ ):
61
+ if (
62
+ isinstance(param2.parameters, qtyping.UniformQuantParams)
63
+ and param2.parameters.hadamard is not None
64
+ ):
65
+ return True
51
66
  return (
52
67
  param1.parameters == param2.parameters
53
68
  and len(param1.transformations) > index
@@ -162,6 +177,16 @@ class TransformationInstructionsGenerator:
162
177
  else:
163
178
  self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
164
179
  self._create_tensor_name_to_graph_info_map()
180
+ self._same_as_input_scale_ops = (
181
+ constrained_ops_utils.get_constrained_op_list(
182
+ _OpQuantConstraint.SAME_AS_INPUT_SCALE
183
+ )
184
+ )
185
+ self._same_as_output_scale_ops = (
186
+ constrained_ops_utils.get_constrained_op_list(
187
+ _OpQuantConstraint.SAME_AS_OUTPUT_SCALE
188
+ )
189
+ )
165
190
 
166
191
  @dataclasses.dataclass(frozen=True)
167
192
  class TensorGraphInfo:
@@ -183,11 +208,13 @@ class TransformationInstructionsGenerator:
183
208
  A tuple of tensor_name and TensorGraphInfo.
184
209
  """
185
210
  for tensor_id, tensor in enumerate(subgraph.tensors):
186
- consumers = [
187
- op_id
188
- for (op_id, op) in enumerate(subgraph.operators)
189
- if tensor_id in op.inputs
190
- ]
211
+ consumers = []
212
+ for op_id, op in enumerate(subgraph.operators):
213
+ # Some ops may use the same input tensor multiple times,
214
+ # and we should handle each time independently.
215
+ for op_input in op.inputs:
216
+ if op_input == tensor_id:
217
+ consumers.append(op_id)
191
218
  producer = -1
192
219
  for op_id, op in enumerate(subgraph.operators):
193
220
  if tensor_id in op.outputs:
@@ -454,53 +481,181 @@ class TransformationInstructionsGenerator:
454
481
  transformations.insert(0, producer_trans_rule)
455
482
  return transformations
456
483
 
484
+ def _remove_last_tensor_duplication(
485
+ self, tensor_trans_insts: qtyping.TensorTransformationInsts
486
+ ) -> None:
487
+ """Remove the last tensor duplication so the original tensor can be reused."""
488
+ instructions = tensor_trans_insts.instructions
489
+ if not instructions:
490
+ return
491
+ for i in range(len(instructions) - 1, -1, -1):
492
+ if (
493
+ instructions[i].transformation
494
+ == _QuantTransformation.DUPLICATE_TENSOR
495
+ ):
496
+ instructions.pop(i)
497
+ return
498
+
499
+ def _remove_unnecessary_buffer_duplication(
500
+ self, tensor_trans_insts: qtyping.TensorTransformationInsts
501
+ ) -> None:
502
+ """Remove buffer duplications that comes after a tensor duplication.
503
+
504
+ When a tensor is duplicated, a new buffer is created for it. Therefore,
505
+ buffer duplication transformation that comes after it is unnecessary.
506
+
507
+ Args:
508
+ tensor_trans_insts: Transformation instructions for a tensor.
509
+ """
510
+ instructions = tensor_trans_insts.instructions
511
+ if not instructions:
512
+ return
513
+
514
+ # Find all consumers that have a tensor duplication.
515
+ consumers_with_tensor_duplication = set()
516
+ for instr in instructions:
517
+ if instr.transformation == _QuantTransformation.DUPLICATE_TENSOR:
518
+ consumers_with_tensor_duplication.update(instr.consumers)
519
+ if not consumers_with_tensor_duplication:
520
+ return
521
+
522
+ # Remove a buffer duplication that comes with a tensor duplication.
523
+ for i in range(len(instructions) - 1, -1, -1):
524
+ instr = instructions[i]
525
+ if (
526
+ instr.transformation == _QuantTransformation.DUPLICATE_BUFFER
527
+ and consumers_with_tensor_duplication.issuperset(instr.consumers)
528
+ ):
529
+ instructions.pop(i)
530
+
531
+ def _is_valid_quantize_requantize_pair(
532
+ self,
533
+ instr_0: qtyping.TransformationInst,
534
+ instr_1: qtyping.TransformationInst,
535
+ ) -> bool:
536
+ """Checks if the two instructions form a valid quantize and requantize pair."""
537
+ return (
538
+ instr_0.transformation == _QuantTransformation.QUANTIZE_TENSOR
539
+ and instr_1.transformation == _QuantTransformation.ADD_QUANTIZE
540
+ and instr_0.consumers == instr_1.consumers
541
+ )
542
+
543
+ def _is_op_constrained(
544
+ self, subgraph_id: int, op_index: int
545
+ ) -> bool:
546
+ """Checks if the op has same as input or output scale constraints."""
547
+ op_name = tfl_flatbuffer_utils.get_op_name_by_index(
548
+ self.flatbuffer_model, subgraph_id, op_index
549
+ )
550
+ return (
551
+ op_name in self._same_as_input_scale_ops
552
+ or op_name in self._same_as_output_scale_ops
553
+ )
554
+
555
+ def _are_quant_params_compatible(
556
+ self,
557
+ params_0: qtyping.UniformQuantParams,
558
+ params_1: qtyping.UniformQuantParams,
559
+ ) -> bool:
560
+ """Checks if quant params are the same except for the scale and zero point."""
561
+ ignore_set = {"scale", "zero_point"}
562
+ for field_info in dataclasses.fields(qtyping.UniformQuantParams):
563
+ field_name = field_info.name
564
+ if field_name in ignore_set:
565
+ continue
566
+ if getattr(params_0, field_name) != getattr(params_1, field_name):
567
+ return False
568
+ return True
569
+
570
+ def _eliminate_requantization_for_nonconstrained_provider(
571
+ self, tensor_trans_insts: qtyping.TensorTransformationInsts
572
+ ) -> None:
573
+ """Removes requantization for tensors with a non-constrained provider.
574
+
575
+ Fuses [QUANTIZE_TENSOR, ADD_QUANTIZE] instructions when a tensor has a
576
+ provider op without same as input/ouput scale constrains. Quant params from
577
+ the second instruction are copied to the first one and ADD_QUANTIZE is
578
+ removed.
579
+
580
+ Args:
581
+ tensor_trans_insts: Transformation instructions for a tensor.
582
+ """
583
+ instructions = tensor_trans_insts.instructions
584
+ if instructions is None or len(instructions) != 2:
585
+ return
586
+
587
+ instr_0, instr_1 = instructions
588
+ params_0 = instr_0.parameters
589
+ params_1 = instr_1.parameters
590
+ producer_op_index = instr_0.producer
591
+ if (
592
+ not isinstance(params_0, qtyping.UniformQuantParams)
593
+ or not isinstance(params_1, qtyping.UniformQuantParams)
594
+ or not self._is_valid_quantize_requantize_pair(instr_0, instr_1)
595
+ or not self._are_quant_params_compatible(params_0, params_1)
596
+ # To avoid fusion when subgraph inputs connected to the main subgraph
597
+ # (e.g. while_body), we skip all tensors with no producer.
598
+ or producer_op_index == -1
599
+ # Can't apply fusion to tensors with a constrained producer since that
600
+ # will break the constraint.
601
+ or self._is_op_constrained(
602
+ tensor_trans_insts.subgraph_id, producer_op_index
603
+ )
604
+ ):
605
+ return
606
+
607
+ # Fuse the quantize and requantize.
608
+ instr_0.parameters = dataclasses.replace(
609
+ params_0, scale=params_1.scale, zero_point=params_1.zero_point
610
+ )
611
+ # Remove the requantize instruction.
612
+ instructions.pop(1)
613
+
457
614
  def _quant_params_to_transformation_insts(
458
615
  self,
459
616
  param: qtyping.TensorTransformationParams,
460
617
  ) -> qtyping.TensorTransformationInsts:
461
- """Converts a single quantization params to transformation instructions.
618
+ """Convert single tensor quant params to transformation instructions.
462
619
 
463
620
  Args:
464
- param: quantization parameter of a tensor in the graph
621
+ param: Quantization parameters of a tensor in the graph.
465
622
 
466
623
  Returns:
467
- a list of transformations to be applied to the same tensor
624
+ Transformations to be applied to the given tensor.
468
625
  """
469
- # setup the structure
626
+ # Setup the structure.
470
627
  tensor_info = self._tensor_name_to_graph_info[param.tensor_name]
471
628
  tensor_trans_insts = qtyping.TensorTransformationInsts(
472
629
  param.tensor_name, tensor_info.subgraph_id, []
473
630
  )
474
631
 
475
- # horizontal optimization
476
- consumer_group = self._group_consumer_transformations(param)
477
- # at this point, starting from index 1 of consumer_group, we're having sets
478
- # that represents transformations that can be grouped together
479
- transformations_available_for_vertical_optimization = (
480
- self._produce_transformation_for_vertical_opt(consumer_group, param)
481
- )
482
- other_consumer_transformations = (
483
- self._produce_consumer_transformations_unavailable_for_vertical_opt(
484
- consumer_group, param
485
- )
486
- )
487
-
632
+ # Add all producer rules.
488
633
  transformations = []
489
- # adding all producer rules
490
- producer_params = param.producer
491
- if producer_params:
492
- for transformation in producer_params.transformations:
634
+ if param.producer:
635
+ for transformation in param.producer.transformations:
493
636
  transformations.append(
494
637
  qtyping.TransformationInst(
495
638
  transformation,
496
639
  tensor_info.tensor_id,
497
640
  tensor_info.producer,
498
641
  tensor_info.consumers,
499
- producer_params.parameters,
642
+ param.producer.parameters,
500
643
  )
501
644
  )
502
645
 
503
- # apply vertical optimization
646
+ # Horizontal optimization.
647
+ consumer_group = self._group_consumer_transformations(param)
648
+ # At this point, starting from index 1 of consumer_group, we're having sets
649
+ # that represent transformations that can be grouped together.
650
+ transformations_available_for_vertical_optimization = (
651
+ self._produce_transformation_for_vertical_opt(consumer_group, param)
652
+ )
653
+ other_consumer_transformations = (
654
+ self._produce_consumer_transformations_unavailable_for_vertical_opt(
655
+ consumer_group, param
656
+ )
657
+ )
658
+ # Apply vertical optimization.
504
659
  last_producer_rule_idx = len(transformations) - 1
505
660
  if last_producer_rule_idx >= 0:
506
661
  transformations += self._apply_vertical_optimization(
@@ -509,30 +664,127 @@ class TransformationInstructionsGenerator:
509
664
  )
510
665
  else:
511
666
  transformations += transformations_available_for_vertical_optimization
512
- # Adding other consumers rules
667
+ # Adding other consumers rules.
513
668
  transformations += other_consumer_transformations
514
669
  tensor_trans_insts.instructions = transformations
670
+
671
+ # Now, when all optimizations are done, we can remove the last tensor
672
+ # duplication instruction, so the original tensor can be reused.
673
+ self._remove_last_tensor_duplication(tensor_trans_insts)
674
+ # With the tensor duplication instructions finalized, we can remove
675
+ # unnecessary buffer duplications applied to the same duplicated tensors.
676
+ # This is not a part of a vertical optimization because vertical
677
+ # optimization only works between producers & consumers, and this is between
678
+ # the consumer only. Also this can't be done during the params generation
679
+ # because removing last tensor duplication has to happen first.
680
+ self._remove_unnecessary_buffer_duplication(tensor_trans_insts)
681
+
515
682
  # Check the generated transformation instructions are valid, the function
516
- # will raise an error if the instructions are not valid
683
+ # will raise an error if the instructions are not valid.
517
684
  self._check_tensor_transformation_instructions_valid(tensor_trans_insts)
518
685
 
686
+ # Remove unnecessary [QUANTIZE_TENSOR, ADD_QUANTIZE] pairs for tensors with
687
+ # providers without same as input/output scale constraints.
688
+ self._eliminate_requantization_for_nonconstrained_provider(
689
+ tensor_trans_insts
690
+ )
691
+
519
692
  return tensor_trans_insts
520
693
 
521
- def _check_tensor_transformation_instructions_valid(
522
- self, instructions: qtyping.TensorTransformationInsts
523
- ):
524
- """Check if the tensor transformation instructions are valid.
694
+ def _split_instructions_by_tensor_duplication(
695
+ self,
696
+ instructions: qtyping.TensorTransformationInsts,
697
+ ) -> list[list[qtyping.TransformationInst]]:
698
+ """Split the instructions into subsets by tensor duplication.
699
+
700
+ Splits the instructions into subsets based on which tensor (original or one
701
+ of duplicated ones) they will be applied to.
702
+
703
+ The first subset is for the original tensor. The following subsets are for
704
+ the duplicated tensors. The order of instructions in each subset is
705
+ preserved.
706
+
707
+ Enforced constraints for each duplicated tensor's instructions subset:
708
+ 1. The first instruction must be a `DUPLICATE_TENSOR` one.
709
+ 2. No other `DUPLICATE_TENSOR` instructions can be present.
710
+
711
+ For the following instructions:
712
+ [
713
+ (transformation=DUPLICATE_TENSOR, consumers=[1, 2, 3]),
714
+ (transformation=DUPLICATE_TENSOR, consumers=[4]),
715
+ (transformation=T1, consumers=[1, 2]),
716
+ (transformation=T2, consumers=[3]),
717
+ (transformation=T3, consumers=[4]),
718
+ (transformation=T4, consumers=[5])
719
+ ]
720
+
721
+ `instruction_subsets` will be:
722
+ [
723
+ [(transformation=T4, consumers=[5])],
724
+ [
725
+ (transformation=DUPLICATE_TENSOR, consumers=[1, 2, 3]),
726
+ (transformation=T1, consumers=[1, 2]),
727
+ (transformation=T2, consumers=[3])
728
+ ],
729
+ [
730
+ (transformation=DUPLICATE_TENSOR, consumers=[4]),
731
+ (transformation=T3, consumers=[4])
732
+ ]
733
+ ],
525
734
 
526
735
  Args:
527
736
  instructions: Transformation instructions for a tensor.
528
737
 
738
+ Returns:
739
+ A list of subsets of transformation instructions, where the first subset
740
+ is for the original tensor, and the following subsets are for the
741
+ duplicated tensors.
742
+
529
743
  Raises:
530
- ValueError: If the instructions are not valid.
744
+ ValueError: If DUPLICATE_TENSOR is found and it's not the first
745
+ transformation for its consumers.
746
+ """
747
+ original_tensor_subset_idx = 0
748
+ instruction_subsets = [[]]
749
+ consumer_to_subset_idx = {}
750
+ for instruction in instructions.instructions:
751
+ if instruction.transformation == _QuantTransformation.DUPLICATE_TENSOR:
752
+ instruction_subsets.append([instruction])
753
+ subset_idx = len(instruction_subsets) - 1
754
+ for consumer in instruction.consumers:
755
+ if consumer in consumer_to_subset_idx:
756
+ raise ValueError(
757
+ f"Tensor {instructions.tensor_name} : duplicate tensor should"
758
+ " be the first instruction for its consumers."
759
+ )
760
+ else:
761
+ consumer_to_subset_idx[consumer] = subset_idx
762
+ else:
763
+ first_consumer = instruction.consumers[0]
764
+ if first_consumer not in consumer_to_subset_idx:
765
+ consumer_to_subset_idx[first_consumer] = original_tensor_subset_idx
766
+ subset_idx = consumer_to_subset_idx[first_consumer]
767
+ instruction_subsets[subset_idx].append(instruction)
768
+
769
+ return instruction_subsets
770
+
771
+ def _check_subset_of_tensor_transformation_instructions_valid(
772
+ self,
773
+ instructions: Optional[list[qtyping.TransformationInst]],
774
+ tensor_name: str,
775
+ ):
776
+ """Check if a subset of tensor transformation instructions is valid.
777
+
778
+ Args:
779
+ instructions: A subset of transformation instructions for a tensor.
780
+ tensor_name: The name of the tensor.
781
+
782
+ Raises:
783
+ ValueError: If the subset of instructions are not valid.
531
784
  """
532
785
  is_tensor_unquantized = False
533
786
  is_tensor_quantized = False
534
- is_operator_emulated = False
535
- for instruction in instructions.instructions:
787
+ for instruction in instructions:
536
788
  transform_type = instruction.transformation
537
789
  if transform_type == qtyping.QuantTransformation.NO_QUANTIZE:
538
790
  is_tensor_unquantized = True
@@ -541,18 +793,33 @@ class TransformationInstructionsGenerator:
541
793
  or transform_type == qtyping.QuantTransformation.ADD_DEQUANTIZE
542
794
  ):
543
795
  is_tensor_quantized = True
544
- elif transform_type == qtyping.QuantTransformation.EMULATED_SUBCHANNEL:
545
- is_operator_emulated = True
546
796
  if is_tensor_unquantized and is_tensor_quantized:
547
797
  raise ValueError(
548
- "Tensor %s can not be both quantized and unquantized"
549
- % instructions.tensor_name
798
+ "Tensor %s can not be both quantized and unquantized" % tensor_name
550
799
  )
551
- if is_operator_emulated and len(instructions.instructions) > 1:
552
- raise ValueError(
553
- "Tensor %s : op replacement transformation can not be combined with"
554
- " other transformations."
555
- % instructions.tensor_name
800
+
801
+ def _check_tensor_transformation_instructions_valid(
802
+ self,
803
+ instructions: qtyping.TensorTransformationInsts,
804
+ ):
805
+ """Check if the tensor transformation instructions are valid.
806
+
807
+ Args:
808
+ instructions: Transformation instructions for a tensor.
809
+
810
+ Raises:
811
+ ValueError: If the instructions are not valid.
812
+ """
813
+ # Split the instructions into subsets based on which tensor (original or one
814
+ # of duplicated ones) they will be applied to.
815
+ instruction_subsets = self._split_instructions_by_tensor_duplication(
816
+ instructions
817
+ )
818
+ # Check that each subset of instructions is valid.
819
+ for instruction_subset in instruction_subsets:
820
+ self._check_subset_of_tensor_transformation_instructions_valid(
821
+ instruction_subset,
822
+ instructions.tensor_name,
556
823
  )
557
824
 
558
825
  def quant_params_to_transformation_insts(