ai-edge-quantizer-nightly 0.4.0.dev20250828__py3-none-any.whl → 0.4.0.dev20250830__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -266,7 +266,11 @@ DEFAULT_JSON_POLICY = """
266
266
  }
267
267
  }
268
268
  """
269
- QUANTIZABLE_COMPOSITES = ["od" + "ml.npu_call", "od" + "ml.rms_norm"]
269
+ QUANTIZABLE_COMPOSITES = [
270
+ "od" + "ml.npu_call",
271
+ "od" + "ml.rms_norm",
272
+ "od" + "ml.l2_norm",
273
+ ]
270
274
 
271
275
 
272
276
  def _unroll_json_config(
@@ -23,10 +23,13 @@ from collections.abc import Iterator
23
23
  import dataclasses
24
24
  from typing import Optional
25
25
  from ai_edge_quantizer import qtyping
26
+ from ai_edge_quantizer.algorithms.utils import common_utils
27
+ from ai_edge_quantizer.utils import constrained_ops_utils
26
28
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
27
29
  from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
28
30
 
29
31
 
32
+ _OpQuantConstraint = common_utils.OpQuantConstraint
30
33
  _QuantTransformation = qtyping.QuantTransformation
31
34
 
32
35
 
@@ -165,6 +168,16 @@ class TransformationInstructionsGenerator:
165
168
  else:
166
169
  self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
167
170
  self._create_tensor_name_to_graph_info_map()
171
+ self._same_as_input_scale_ops = (
172
+ constrained_ops_utils.get_constrained_op_list(
173
+ _OpQuantConstraint.SAME_AS_INPUT_SCALE
174
+ )
175
+ )
176
+ self._same_as_output_scale_ops = (
177
+ constrained_ops_utils.get_constrained_op_list(
178
+ _OpQuantConstraint.SAME_AS_OUTPUT_SCALE
179
+ )
180
+ )
168
181
 
169
182
  @dataclasses.dataclass(frozen=True)
170
183
  class TensorGraphInfo:
@@ -506,6 +519,89 @@ class TransformationInstructionsGenerator:
506
519
  ):
507
520
  instructions.pop(i)
508
521
 
522
+ def _is_valid_quantize_requantize_pair(
523
+ self,
524
+ instr_0: qtyping.TransformationInst,
525
+ instr_1: qtyping.TransformationInst,
526
+ ) -> bool:
527
+ """Checks if the two instructions form a valid quantize and requantize pair."""
528
+ return (
529
+ instr_0.transformation == _QuantTransformation.QUANTIZE_TENSOR
530
+ and instr_1.transformation == _QuantTransformation.ADD_QUANTIZE
531
+ and instr_0.consumers == instr_1.consumers
532
+ )
533
+
534
+ def _is_op_constrained(
535
+ self, subgraph_id: int, op_index: int
536
+ ) -> bool:
537
+ """Checks if the op has same as input or output scale constraints."""
538
+ op_name = tfl_flatbuffer_utils.get_op_name_by_index(
539
+ self.flatbuffer_model, subgraph_id, op_index
540
+ )
541
+ return (
542
+ op_name in self._same_as_input_scale_ops
543
+ or op_name in self._same_as_output_scale_ops
544
+ )
545
+
546
+ def _are_quant_params_compatible(
547
+ self,
548
+ params_0: qtyping.UniformQuantParams,
549
+ params_1: qtyping.UniformQuantParams,
550
+ ) -> bool:
551
+ """Checks if quant params are the same except for the scale and zero point."""
552
+ ignore_set = {"scale", "zero_point"}
553
+ for field_info in dataclasses.fields(qtyping.UniformQuantParams):
554
+ field_name = field_info.name
555
+ if field_name in ignore_set:
556
+ continue
557
+ if getattr(params_0, field_name) != getattr(params_1, field_name):
558
+ return False
559
+ return True
560
+
561
+ def _eliminate_requantization_for_nonconstrained_provider(
562
+ self, tensor_trans_insts: qtyping.TensorTransformationInsts
563
+ ) -> None:
564
+ """Removes requantization for tensors with a non-constrained provider.
565
+
566
+ Fuses [QUANTIZE_TENSOR, ADD_QUANTIZE] instructions when a tensor has a
567
+ provider op without same as input/ouput scale constrains. Quant params from
568
+ the second instruction are copied to the first one and ADD_QUANTIZE is
569
+ removed.
570
+
571
+ Args:
572
+ tensor_trans_insts: Transformation instructions for a tensor.
573
+ """
574
+ instructions = tensor_trans_insts.instructions
575
+ if instructions is None or len(instructions) != 2:
576
+ return
577
+
578
+ instr_0, instr_1 = instructions
579
+ params_0 = instr_0.parameters
580
+ params_1 = instr_1.parameters
581
+ producer_op_index = instr_0.producer
582
+ if (
583
+ not isinstance(params_0, qtyping.UniformQuantParams)
584
+ or not isinstance(params_1, qtyping.UniformQuantParams)
585
+ or not self._is_valid_quantize_requantize_pair(instr_0, instr_1)
586
+ or not self._are_quant_params_compatible(params_0, params_1)
587
+ # To avoid fusion when subgraph inputs connected to the main subgraph
588
+ # (e.g. while_body), we skip all tensors with no producer.
589
+ or producer_op_index == -1
590
+ # Can't apply fusion to tensors with a constrained producer since that
591
+ # will break the constraint.
592
+ or self._is_op_constrained(
593
+ tensor_trans_insts.subgraph_id, producer_op_index
594
+ )
595
+ ):
596
+ return
597
+
598
+ # Fuse the quantize and requantize.
599
+ instr_0.parameters = dataclasses.replace(
600
+ params_0, scale=params_1.scale, zero_point=params_1.zero_point
601
+ )
602
+ # Remove the requantize instruction.
603
+ instructions.pop(1)
604
+
509
605
  def _quant_params_to_transformation_insts(
510
606
  self,
511
607
  param: qtyping.TensorTransformationParams,
@@ -578,6 +674,12 @@ class TransformationInstructionsGenerator:
578
674
  # will raise an error if the instructions are not valid.
579
675
  self._check_tensor_transformation_instructions_valid(tensor_trans_insts)
580
676
 
677
+ # Remove unnecessary [QUANTIZE_TENSOR, ADD_QUANTIZE] pairs for tensors with
678
+ # providers without same as input/output scale constraints.
679
+ self._eliminate_requantization_for_nonconstrained_provider(
680
+ tensor_trans_insts
681
+ )
682
+
581
683
  return tensor_trans_insts
582
684
 
583
685
  def _split_instructions_by_tensor_duplication(
@@ -15,7 +15,9 @@
15
15
 
16
16
  """Tests for instruction_generator."""
17
17
 
18
+ from collections.abc import Sequence
18
19
  import os
20
+ from typing import Optional
19
21
 
20
22
  import numpy as np
21
23
 
@@ -1337,5 +1339,166 @@ class InstructionGeneratorTest(parameterized.TestCase):
1337
1339
  )
1338
1340
 
1339
1341
 
1342
+ class EliminateUnnecessaryRequantizationTest(parameterized.TestCase):
1343
+
1344
+ def setUp(self):
1345
+ super().setUp()
1346
+ self.ins_gen = instruction_generator.TransformationInstructionsGenerator(
1347
+ os.path.join(TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite")
1348
+ )
1349
+
1350
+ def _get_test_instruction(
1351
+ self,
1352
+ transformation: qtyping.QuantTransformation,
1353
+ producer: int = -1,
1354
+ consumers: Optional[Sequence[int]] = None,
1355
+ qparams: Optional[qtyping.UniformQuantParams] = None,
1356
+ ) -> qtyping.TransformationInst:
1357
+ if consumers is None:
1358
+ consumers = []
1359
+ if qparams is None:
1360
+ qparams = qtyping.UniformQuantParams(
1361
+ num_bits=8,
1362
+ quantized_dimension=None,
1363
+ scale=np.array([1]),
1364
+ zero_point=np.array([0]),
1365
+ )
1366
+ return qtyping.TransformationInst(
1367
+ transformation=transformation,
1368
+ producer=producer,
1369
+ consumers=consumers,
1370
+ parameters=qparams,
1371
+ # Dummy values below.
1372
+ tensor_id=0,
1373
+ )
1374
+
1375
+ def _create_test_insts(
1376
+ self, instructions: list[qtyping.TransformationInst]
1377
+ ) -> qtyping.TensorTransformationInsts:
1378
+ return qtyping.TensorTransformationInsts(
1379
+ tensor_name="test_tensor", subgraph_id=0, instructions=instructions
1380
+ )
1381
+
1382
+ def test_no_fusion_when_too_few_instructions(self):
1383
+ tensor_insts = self._create_test_insts([
1384
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
1385
+ ])
1386
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1387
+ tensor_insts
1388
+ )
1389
+ self.assertLen(tensor_insts.instructions, 1)
1390
+
1391
+ def test_no_fusion_when_too_many_instructions(self):
1392
+ tensor_insts = self._create_test_insts([
1393
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
1394
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1395
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1396
+ ])
1397
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1398
+ tensor_insts
1399
+ )
1400
+ self.assertLen(tensor_insts.instructions, 3)
1401
+
1402
+ def test_no_fusion_when_invalid_transformation_pair(self):
1403
+ tensor_insts = self._create_test_insts([
1404
+ self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
1405
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE),
1406
+ ])
1407
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1408
+ tensor_insts
1409
+ )
1410
+ self.assertLen(tensor_insts.instructions, 2)
1411
+
1412
+ def test_no_fusion_when_consumers_mismatch(self):
1413
+ tensor_insts = self._create_test_insts([
1414
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[0]),
1415
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1]),
1416
+ ])
1417
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1418
+ tensor_insts
1419
+ )
1420
+ self.assertLen(tensor_insts.instructions, 2)
1421
+
1422
+ def test_no_fusion_when_no_producer(self):
1423
+ producer = -1
1424
+ tensor_insts = self._create_test_insts([
1425
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer),
1426
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer),
1427
+ ])
1428
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1429
+ tensor_insts
1430
+ )
1431
+ self.assertLen(tensor_insts.instructions, 2)
1432
+
1433
+ def test_no_fusion_when_quant_params_are_incompatible(self):
1434
+ params_8_bits = qtyping.UniformQuantParams(
1435
+ 8, None, np.array([1]), np.array([0])
1436
+ )
1437
+ params_16_bits = qtyping.UniformQuantParams(
1438
+ 16, None, np.array([1]), np.array([0])
1439
+ )
1440
+ tensor_insts = self._create_test_insts([
1441
+ self._get_test_instruction(
1442
+ _QTransf.QUANTIZE_TENSOR, qparams=params_8_bits
1443
+ ),
1444
+ self._get_test_instruction(
1445
+ _QTransf.ADD_QUANTIZE, qparams=params_16_bits
1446
+ ),
1447
+ ])
1448
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1449
+ tensor_insts
1450
+ )
1451
+ self.assertLen(tensor_insts.instructions, 2)
1452
+
1453
+ def test_no_fusion_when_producer_constrained(self):
1454
+ # Reshape op (op index 2) has same as input scale constraint.
1455
+ tensor_insts = self._create_test_insts([
1456
+ self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer=2),
1457
+ self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer=2),
1458
+ ])
1459
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1460
+ tensor_insts
1461
+ )
1462
+ self.assertLen(tensor_insts.instructions, 2)
1463
+
1464
+ def test_fusion_succeeds(self):
1465
+ producer = 0
1466
+ consumers = [1]
1467
+ params_0 = qtyping.UniformQuantParams(
1468
+ num_bits=8,
1469
+ quantized_dimension=None,
1470
+ scale=np.array([1]),
1471
+ zero_point=np.array([0]),
1472
+ )
1473
+ params_1 = qtyping.UniformQuantParams(
1474
+ num_bits=8,
1475
+ quantized_dimension=None,
1476
+ scale=np.array([2]),
1477
+ zero_point=np.array([1]),
1478
+ )
1479
+ inst_0 = self._get_test_instruction(
1480
+ _QTransf.QUANTIZE_TENSOR, producer, consumers, params_0
1481
+ )
1482
+ inst_1 = self._get_test_instruction(
1483
+ _QTransf.ADD_QUANTIZE, producer, consumers, params_1
1484
+ )
1485
+ tensor_insts = self._create_test_insts([inst_0, inst_1])
1486
+ self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
1487
+ tensor_insts
1488
+ )
1489
+
1490
+ self.assertLen(tensor_insts.instructions, 1)
1491
+ result_inst = tensor_insts.instructions[0]
1492
+ self.assertEqual(result_inst.transformation, _QTransf.QUANTIZE_TENSOR)
1493
+
1494
+ result_params = result_inst.parameters
1495
+ # Explicitly narrow the type for pytype.
1496
+ if not isinstance(result_params, qtyping.UniformQuantParams):
1497
+ self.fail("Fused instruction parameters are not UniformQuantParams")
1498
+
1499
+ self.assertEqual(result_params.scale, params_1.scale)
1500
+ self.assertEqual(result_params.zero_point, params_1.zero_point)
1501
+
1502
+
1340
1503
  if __name__ == "__main__":
1341
1504
  googletest.main()
@@ -20,13 +20,11 @@ from typing import Any, Union
20
20
 
21
21
  import numpy as np
22
22
 
23
- from ai_edge_quantizer import algorithm_manager
24
23
  from ai_edge_quantizer import qtyping
25
- from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
26
24
  from ai_edge_quantizer.algorithms.utils import common_utils
25
+ from ai_edge_quantizer.utils import constrained_ops_utils
27
26
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
27
  from ai_edge_quantizer.utils import tfl_interpreter_utils
29
- from ai_edge_litert import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
30
28
  from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
31
29
 
32
30
 
@@ -133,7 +131,11 @@ class CalibrationQsvAlignmentUtils:
133
131
  """
134
132
 
135
133
  def __init__(self, model_path: str):
136
- self._same_as_input_scale_ops = []
134
+ self._same_as_input_scale_ops = (
135
+ constrained_ops_utils.get_constrained_op_list(
136
+ _OpQuantConstraint.SAME_AS_INPUT_SCALE
137
+ )
138
+ )
137
139
 
138
140
  tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_path)
139
141
  self._flatbuffer_object = tfl_flatbuffer_utils.read_model(model_path)
@@ -146,87 +148,6 @@ class CalibrationQsvAlignmentUtils:
146
148
  signature_runner = tfl_interpreter.get_signature_runner(signature_key)
147
149
  self._signature_runners[signature_key] = signature_runner
148
150
 
149
- # Make a list of `SAME_AS_INPUT_SCALE` operators. This is used to identify
150
- # the operators that need to be constrained to the same scale as the input.
151
- self._build_same_as_input_scale_op_list()
152
-
153
- def _build_same_as_input_scale_op_list(self, verbose: bool = False):
154
- """Constructs a list of SAME_AS_INPUT_SCALE operators.
155
-
156
- This is achieved by invoking all materialization functions and extracting
157
- the constraint argument, using monkey patching to redirect logic to wrapper
158
- functions.
159
-
160
- Args:
161
- verbose: Flag to enable verbose output.
162
- """
163
-
164
- def materialize_standard_op_wrapper(
165
- op_info: qtyping.OpInfo,
166
- *_args,
167
- constraint: _OpQuantConstraint = _OpQuantConstraint.NO_CONSTRAIN,
168
- **_kwargs,
169
- ) -> list[qtyping.TensorTransformationParams]:
170
- if constraint == _OpQuantConstraint.SAME_AS_INPUT_SCALE:
171
- self._same_as_input_scale_ops.append(op_info.op_name)
172
- # Return dummy values to avoid exceptions.
173
- dummy_value = [qtyping.TensorTransformationParams("")] * 2
174
- return dummy_value
175
-
176
- # Dummy implementation of the `_are_weights_too_small` function to support
177
- # `materialize_standard_op_wrapper` above.
178
- def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
179
- return False
180
-
181
- # Dummy implementation of the `_materialize_bias_for_conv_ops` function to
182
- # support `materialize_standard_op_wrapper` above.
183
- def materialize_bias_for_conv_ops_wrapper(*_args, **_kwargs):
184
- return
185
-
186
- # Do monkey patch to intercept the `materialize_standard_op` function to
187
- # support `materialize_standard_op_wrapper` above.
188
- original_materialize_standard_op = common_utils.materialize_standard_op
189
- original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
190
- original_materialize_bias_for_conv_ops = (
191
- common_quantize._materialize_bias_for_conv_ops # pylint: disable=protected-access
192
- )
193
- common_utils.materialize_standard_op = materialize_standard_op_wrapper
194
- common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
195
- common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
196
- materialize_bias_for_conv_ops_wrapper
197
- )
198
- minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
199
-
200
- # Loop over all available materialization functions to build up a list of
201
- # `SAME_AS_INPUT_SCALE` constrained ops.
202
- for op, materialize_fn in minmax_func_dict.items():
203
- # Create a dummy op info to trigger the materialization.
204
- mock_op = schema_fb.OperatorT()
205
- mock_op.inputs = [0]
206
- mock_op.outputs = [0]
207
- op_info = qtyping.OpInfo(
208
- op=mock_op,
209
- op_name=op,
210
- subgraph_op_index=0,
211
- op_quant_config=qtyping.OpQuantizationConfig(),
212
- )
213
- materialize_fn(
214
- get_tensor_quant_params_fn=None,
215
- op_info=op_info,
216
- graph_info=None,
217
- tensor_name_to_qsv=None,
218
- )
219
-
220
- if verbose:
221
- print(f" Constrained op list: {self._same_as_input_scale_ops}")
222
-
223
- # Restore the original functions.
224
- common_utils.materialize_standard_op = original_materialize_standard_op
225
- common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
226
- common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
227
- original_materialize_bias_for_conv_ops
228
- )
229
-
230
151
  def _search_tensor_by_signature_name(
231
152
  self, signature_key: str, signature_input_output_name: str, verbose=False
232
153
  ) -> list[str]:
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utils for handling operators with quantization constraints."""
17
+
18
+ from ai_edge_quantizer import algorithm_manager
19
+ from ai_edge_quantizer import qtyping
20
+ from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
21
+ from ai_edge_quantizer.algorithms.utils import common_utils
22
+ from ai_edge_litert import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ _OpQuantConstraint = common_utils.OpQuantConstraint
26
+
27
+
28
+ def get_constrained_op_list(
29
+ quant_constraint: _OpQuantConstraint, verbose: bool = False
30
+ ) -> list[str]:
31
+ """Constructs and returns a list of constrained operators.
32
+
33
+ This is achieved by invoking all materialization functions and extracting
34
+ the constraint argument, using monkey patching to redirect logic to wrapper
35
+ functions.
36
+
37
+ Args:
38
+ quant_constraint: The quantization constraint to filter operators by.
39
+ verbose: Flag to enable verbose output.
40
+
41
+ Returns:
42
+ A list containing operators with the specified constraint.
43
+ """
44
+ constrained_ops = []
45
+
46
+ def materialize_standard_op_wrapper(
47
+ op_info: qtyping.OpInfo,
48
+ *_args,
49
+ constraint: _OpQuantConstraint = _OpQuantConstraint.NO_CONSTRAIN,
50
+ **_kwargs,
51
+ ) -> list[qtyping.TensorTransformationParams]:
52
+ if constraint == quant_constraint:
53
+ constrained_ops.append(op_info.op_name)
54
+ # Return dummy values to avoid exceptions.
55
+ dummy_value = [qtyping.TensorTransformationParams("")] * 2
56
+ return dummy_value
57
+
58
+ # Dummy implementation of the `_are_weights_too_small` function to support
59
+ # `materialize_standard_op_wrapper` above.
60
+ def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
61
+ return False
62
+
63
+ # Dummy implementation of the `_materialize_bias_for_conv_ops` function to
64
+ # support `materialize_standard_op_wrapper` above.
65
+ def materialize_bias_for_conv_ops_wrapper(*_args, **_kwargs):
66
+ return
67
+
68
+ # Do monkey patch to intercept the `materialize_standard_op` function to
69
+ # support `materialize_standard_op_wrapper` above.
70
+ original_materialize_standard_op = common_utils.materialize_standard_op
71
+ original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
72
+ original_materialize_bias_for_conv_ops = (
73
+ common_quantize._materialize_bias_for_conv_ops # pylint: disable=protected-access
74
+ )
75
+ common_utils.materialize_standard_op = materialize_standard_op_wrapper
76
+ common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
77
+ common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
78
+ materialize_bias_for_conv_ops_wrapper
79
+ )
80
+ minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
81
+
82
+ # Loop over all available materialization functions to build up a list of
83
+ # ops with the given constraint.
84
+ for op, materialize_fn in minmax_func_dict.items():
85
+ # Create a dummy op info to trigger the materialization.
86
+ mock_op = schema_fb.OperatorT()
87
+ mock_op.inputs = [0]
88
+ mock_op.outputs = [0]
89
+ op_info = qtyping.OpInfo(
90
+ op=mock_op,
91
+ op_name=op,
92
+ subgraph_op_index=0,
93
+ op_quant_config=qtyping.OpQuantizationConfig(),
94
+ )
95
+ materialize_fn(
96
+ get_tensor_quant_params_fn=None,
97
+ op_info=op_info,
98
+ graph_info=None,
99
+ tensor_name_to_qsv=None,
100
+ )
101
+
102
+ if verbose:
103
+ print(f" {quant_constraint} op list: {constrained_ops}")
104
+
105
+ # Restore the original functions.
106
+ common_utils.materialize_standard_op = original_materialize_standard_op
107
+ common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
108
+ common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
109
+ original_materialize_bias_for_conv_ops
110
+ )
111
+ return constrained_ops
@@ -0,0 +1,50 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from tensorflow.python.platform import googletest
17
+ from absl.testing import parameterized
18
+ from ai_edge_quantizer.algorithms.utils import common_utils
19
+ from ai_edge_quantizer.utils import constrained_ops_utils
20
+
21
+
22
+ _OpQuantConstraint = common_utils.OpQuantConstraint
23
+
24
+
25
+ class ConstrainedOpsUtilsTest(parameterized.TestCase):
26
+
27
+ @parameterized.named_parameters(
28
+ dict(
29
+ testcase_name="same_as_input_scale",
30
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
31
+ expected_num_ops=14,
32
+ ),
33
+ dict(
34
+ testcase_name="same_as_output_scale",
35
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
36
+ expected_num_ops=6,
37
+ ),
38
+ dict(
39
+ testcase_name="no_constrain",
40
+ constraint=_OpQuantConstraint.NO_CONSTRAIN,
41
+ expected_num_ops=22,
42
+ ),
43
+ )
44
+ def test_get_constrained_op_list(self, constraint, expected_num_ops):
45
+ constrained_ops = constrained_ops_utils.get_constrained_op_list(constraint)
46
+ self.assertLen(constrained_ops, expected_num_ops)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ googletest.main()
@@ -342,3 +342,12 @@ def get_op_side_effect_subgraphs(
342
342
  return [opts.decompositionSubgraphIndex]
343
343
  # Can add other nested ops here (control flow ops, etc).
344
344
  return []
345
+
346
+
347
+ def get_op_name_by_index(
348
+ flatbuffer_model: Any, subgraph_id: int, op_index: int
349
+ ) -> str:
350
+ """Get the op name from the flatbuffer model."""
351
+ op = flatbuffer_model.subgraphs[subgraph_id].operators[op_index]
352
+ builtin_code = flatbuffer_model.operatorCodes[op.opcodeIndex].builtinCode
353
+ return TFL_OP_CODE_TO_NAME[builtin_code]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.4.0.dev20250828
3
+ Version: 0.4.0.dev20250830
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -5,7 +5,7 @@ ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gND
5
5
  ai_edge_quantizer/calibrator.py,sha256=Sms7_AIHPH9G5xFaz5Ef3a5gPhxuIWQI8d2LUM8C96I,12071
6
6
  ai_edge_quantizer/calibrator_test.py,sha256=ZLzIMWB2FSFU4TOatDioYuwp_kLh8iSCefZ5_Q9FU7s,11900
7
7
  ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
8
- ai_edge_quantizer/default_policy.py,sha256=LXEdwdr0SiCfWo6ZwbHQ8ykoqA40GV6fGAT1aofry3o,11556
8
+ ai_edge_quantizer/default_policy.py,sha256=G_JZtZaQAnrWyfCusDWXwO27iLysk27RS91GlS61m_Q,11592
9
9
  ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
10
10
  ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
11
11
  ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
@@ -19,8 +19,8 @@ ai_edge_quantizer/recipe.py,sha256=MEkfQ2Sg3KAE9LAORHWcbjYNPg06EUbwc1d-VspQA2U,6
19
19
  ai_edge_quantizer/recipe_manager.py,sha256=6dgbE-IZfEetzXH3p3Qm_9eQutNDOpZnMpiaLTbP-ZQ,14744
20
20
  ai_edge_quantizer/recipe_manager_test.py,sha256=H-B75vwPN5ND-nUa3pOXizeHTv4mufPiC5cL_OlDIYU,34040
21
21
  ai_edge_quantizer/recipe_test.py,sha256=GKuo6N65wKLS2xwSpjd-BWWeVRpF1zc7Yt7phSMYSxA,5905
22
- ai_edge_quantizer/transformation_instruction_generator.py,sha256=iMGXy7_ufqgQRzu4drAfO31VGdze35peEFh1BMZlVHk,27714
23
- ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=Zw3EOSnvzjuB4NWeo129eJZxK_EHno9oF9OtEQ-0dnM,48905
22
+ ai_edge_quantizer/transformation_instruction_generator.py,sha256=O0U2aZcB8aXQgOV8r9g1rGNzDUiuI5Ta53XnxZbVffE,31576
23
+ ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=KW5-WoTTo9IqLEVnWxVC8ut8eWLi_91xfKgGqVQ9QDk,54635
24
24
  ai_edge_quantizer/transformation_performer.py,sha256=o4J6OUbI0dLoobVYjkOFw5Po3yH0gZJXrfuTIYais4o,13029
25
25
  ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
26
26
  ai_edge_quantizer/algorithms/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
@@ -59,17 +59,19 @@ ai_edge_quantizer/transformations/quantize_tensor_test.py,sha256=mHLO3_MRt36A8-Z
59
59
  ai_edge_quantizer/transformations/transformation_utils.py,sha256=efJdAkA24wlg6Vj5NFO7_7MDuvQLSNn-l11Vs_JPktI,7123
60
60
  ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rvxRQIfi4ny9IoODFCTcbpjnIwoCL40zDKk,8698
61
61
  ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
62
- ai_edge_quantizer/utils/calibration_utils.py,sha256=e3dG7Nm94Ix0hkTWTWPUhEG6a8QR_cAM3PSwblfJV5g,15106
62
+ ai_edge_quantizer/utils/calibration_utils.py,sha256=iMf_bSCf-O86MzDt5D9hLKqbTydqLwirluaC6BJ9yHo,11553
63
63
  ai_edge_quantizer/utils/calibration_utils_test.py,sha256=4BlksXl7b4yptL8xPR67hmJCnjhN9V10a2PunzfHrUE,9372
64
+ ai_edge_quantizer/utils/constrained_ops_utils.py,sha256=EAITCf7Ku_PFZcw3K-wd-8hGbyuRd5W5UtNdGvalwAE,4478
65
+ ai_edge_quantizer/utils/constrained_ops_utils_test.py,sha256=6k_AqfB-NmiLkW5WwEV5NSuswFWky2sL0xBGmV6Fdwk,1756
64
66
  ai_edge_quantizer/utils/test_utils.py,sha256=a4Nk-wbeB09dFjTDZiA0K67d26j5DD0UDH_GIVmVG_4,8685
65
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=RL6oq6FzZj-xV0Zgh0UBn7-fOQaRXSxZ-PPG_LmtyUY,11384
67
+ ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=aNtL4dpWH5uGGGlaygnMDkh5llTstbgs5ZxO0JkH5VQ,11718
66
68
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
67
69
  ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EoVjI_hplX_Rml3hfRsGmQOihexmizeJqt4SQcET9aA,14925
68
70
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
69
71
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
70
72
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
71
- ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
72
- ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info/METADATA,sha256=08u7TZ16Y_SHk1eBL3q6pxP4U79rQEXyXbqUEBWuXFo,1535
73
- ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
74
- ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
75
- ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info/RECORD,,
73
+ ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
+ ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/METADATA,sha256=B1PQw7561EH7WCLLBTwYIoXP8WH-0vJeTyPuQPAiX_M,1535
75
+ ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
+ ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
+ ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/RECORD,,