ai-edge-quantizer-nightly 0.4.0.dev20250828__py3-none-any.whl → 0.4.0.dev20250830__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/default_policy.py +5 -1
- ai_edge_quantizer/transformation_instruction_generator.py +102 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +163 -0
- ai_edge_quantizer/utils/calibration_utils.py +6 -85
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +9 -0
- {ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info}/RECORD +12 -10
- {ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.4.0.dev20250828.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info}/top_level.txt +0 -0
@@ -266,7 +266,11 @@ DEFAULT_JSON_POLICY = """
|
|
266
266
|
}
|
267
267
|
}
|
268
268
|
"""
|
269
|
-
QUANTIZABLE_COMPOSITES = [
|
269
|
+
QUANTIZABLE_COMPOSITES = [
|
270
|
+
"od" + "ml.npu_call",
|
271
|
+
"od" + "ml.rms_norm",
|
272
|
+
"od" + "ml.l2_norm",
|
273
|
+
]
|
270
274
|
|
271
275
|
|
272
276
|
def _unroll_json_config(
|
@@ -23,10 +23,13 @@ from collections.abc import Iterator
|
|
23
23
|
import dataclasses
|
24
24
|
from typing import Optional
|
25
25
|
from ai_edge_quantizer import qtyping
|
26
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
27
|
+
from ai_edge_quantizer.utils import constrained_ops_utils
|
26
28
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
27
29
|
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
28
30
|
|
29
31
|
|
32
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
30
33
|
_QuantTransformation = qtyping.QuantTransformation
|
31
34
|
|
32
35
|
|
@@ -165,6 +168,16 @@ class TransformationInstructionsGenerator:
|
|
165
168
|
else:
|
166
169
|
self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
167
170
|
self._create_tensor_name_to_graph_info_map()
|
171
|
+
self._same_as_input_scale_ops = (
|
172
|
+
constrained_ops_utils.get_constrained_op_list(
|
173
|
+
_OpQuantConstraint.SAME_AS_INPUT_SCALE
|
174
|
+
)
|
175
|
+
)
|
176
|
+
self._same_as_output_scale_ops = (
|
177
|
+
constrained_ops_utils.get_constrained_op_list(
|
178
|
+
_OpQuantConstraint.SAME_AS_OUTPUT_SCALE
|
179
|
+
)
|
180
|
+
)
|
168
181
|
|
169
182
|
@dataclasses.dataclass(frozen=True)
|
170
183
|
class TensorGraphInfo:
|
@@ -506,6 +519,89 @@ class TransformationInstructionsGenerator:
|
|
506
519
|
):
|
507
520
|
instructions.pop(i)
|
508
521
|
|
522
|
+
def _is_valid_quantize_requantize_pair(
|
523
|
+
self,
|
524
|
+
instr_0: qtyping.TransformationInst,
|
525
|
+
instr_1: qtyping.TransformationInst,
|
526
|
+
) -> bool:
|
527
|
+
"""Checks if the two instructions form a valid quantize and requantize pair."""
|
528
|
+
return (
|
529
|
+
instr_0.transformation == _QuantTransformation.QUANTIZE_TENSOR
|
530
|
+
and instr_1.transformation == _QuantTransformation.ADD_QUANTIZE
|
531
|
+
and instr_0.consumers == instr_1.consumers
|
532
|
+
)
|
533
|
+
|
534
|
+
def _is_op_constrained(
|
535
|
+
self, subgraph_id: int, op_index: int
|
536
|
+
) -> bool:
|
537
|
+
"""Checks if the op has same as input or output scale constraints."""
|
538
|
+
op_name = tfl_flatbuffer_utils.get_op_name_by_index(
|
539
|
+
self.flatbuffer_model, subgraph_id, op_index
|
540
|
+
)
|
541
|
+
return (
|
542
|
+
op_name in self._same_as_input_scale_ops
|
543
|
+
or op_name in self._same_as_output_scale_ops
|
544
|
+
)
|
545
|
+
|
546
|
+
def _are_quant_params_compatible(
|
547
|
+
self,
|
548
|
+
params_0: qtyping.UniformQuantParams,
|
549
|
+
params_1: qtyping.UniformQuantParams,
|
550
|
+
) -> bool:
|
551
|
+
"""Checks if quant params are the same except for the scale and zero point."""
|
552
|
+
ignore_set = {"scale", "zero_point"}
|
553
|
+
for field_info in dataclasses.fields(qtyping.UniformQuantParams):
|
554
|
+
field_name = field_info.name
|
555
|
+
if field_name in ignore_set:
|
556
|
+
continue
|
557
|
+
if getattr(params_0, field_name) != getattr(params_1, field_name):
|
558
|
+
return False
|
559
|
+
return True
|
560
|
+
|
561
|
+
def _eliminate_requantization_for_nonconstrained_provider(
|
562
|
+
self, tensor_trans_insts: qtyping.TensorTransformationInsts
|
563
|
+
) -> None:
|
564
|
+
"""Removes requantization for tensors with a non-constrained provider.
|
565
|
+
|
566
|
+
Fuses [QUANTIZE_TENSOR, ADD_QUANTIZE] instructions when a tensor has a
|
567
|
+
provider op without same as input/ouput scale constrains. Quant params from
|
568
|
+
the second instruction are copied to the first one and ADD_QUANTIZE is
|
569
|
+
removed.
|
570
|
+
|
571
|
+
Args:
|
572
|
+
tensor_trans_insts: Transformation instructions for a tensor.
|
573
|
+
"""
|
574
|
+
instructions = tensor_trans_insts.instructions
|
575
|
+
if instructions is None or len(instructions) != 2:
|
576
|
+
return
|
577
|
+
|
578
|
+
instr_0, instr_1 = instructions
|
579
|
+
params_0 = instr_0.parameters
|
580
|
+
params_1 = instr_1.parameters
|
581
|
+
producer_op_index = instr_0.producer
|
582
|
+
if (
|
583
|
+
not isinstance(params_0, qtyping.UniformQuantParams)
|
584
|
+
or not isinstance(params_1, qtyping.UniformQuantParams)
|
585
|
+
or not self._is_valid_quantize_requantize_pair(instr_0, instr_1)
|
586
|
+
or not self._are_quant_params_compatible(params_0, params_1)
|
587
|
+
# To avoid fusion when subgraph inputs connected to the main subgraph
|
588
|
+
# (e.g. while_body), we skip all tensors with no producer.
|
589
|
+
or producer_op_index == -1
|
590
|
+
# Can't apply fusion to tensors with a constrained producer since that
|
591
|
+
# will break the constraint.
|
592
|
+
or self._is_op_constrained(
|
593
|
+
tensor_trans_insts.subgraph_id, producer_op_index
|
594
|
+
)
|
595
|
+
):
|
596
|
+
return
|
597
|
+
|
598
|
+
# Fuse the quantize and requantize.
|
599
|
+
instr_0.parameters = dataclasses.replace(
|
600
|
+
params_0, scale=params_1.scale, zero_point=params_1.zero_point
|
601
|
+
)
|
602
|
+
# Remove the requantize instruction.
|
603
|
+
instructions.pop(1)
|
604
|
+
|
509
605
|
def _quant_params_to_transformation_insts(
|
510
606
|
self,
|
511
607
|
param: qtyping.TensorTransformationParams,
|
@@ -578,6 +674,12 @@ class TransformationInstructionsGenerator:
|
|
578
674
|
# will raise an error if the instructions are not valid.
|
579
675
|
self._check_tensor_transformation_instructions_valid(tensor_trans_insts)
|
580
676
|
|
677
|
+
# Remove unnecessary [QUANTIZE_TENSOR, ADD_QUANTIZE] pairs for tensors with
|
678
|
+
# providers without same as input/output scale constraints.
|
679
|
+
self._eliminate_requantization_for_nonconstrained_provider(
|
680
|
+
tensor_trans_insts
|
681
|
+
)
|
682
|
+
|
581
683
|
return tensor_trans_insts
|
582
684
|
|
583
685
|
def _split_instructions_by_tensor_duplication(
|
@@ -15,7 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Tests for instruction_generator."""
|
17
17
|
|
18
|
+
from collections.abc import Sequence
|
18
19
|
import os
|
20
|
+
from typing import Optional
|
19
21
|
|
20
22
|
import numpy as np
|
21
23
|
|
@@ -1337,5 +1339,166 @@ class InstructionGeneratorTest(parameterized.TestCase):
|
|
1337
1339
|
)
|
1338
1340
|
|
1339
1341
|
|
1342
|
+
class EliminateUnnecessaryRequantizationTest(parameterized.TestCase):
|
1343
|
+
|
1344
|
+
def setUp(self):
|
1345
|
+
super().setUp()
|
1346
|
+
self.ins_gen = instruction_generator.TransformationInstructionsGenerator(
|
1347
|
+
os.path.join(TEST_DATA_PREFIX_PATH, "tests/models/conv_fc_mnist.tflite")
|
1348
|
+
)
|
1349
|
+
|
1350
|
+
def _get_test_instruction(
|
1351
|
+
self,
|
1352
|
+
transformation: qtyping.QuantTransformation,
|
1353
|
+
producer: int = -1,
|
1354
|
+
consumers: Optional[Sequence[int]] = None,
|
1355
|
+
qparams: Optional[qtyping.UniformQuantParams] = None,
|
1356
|
+
) -> qtyping.TransformationInst:
|
1357
|
+
if consumers is None:
|
1358
|
+
consumers = []
|
1359
|
+
if qparams is None:
|
1360
|
+
qparams = qtyping.UniformQuantParams(
|
1361
|
+
num_bits=8,
|
1362
|
+
quantized_dimension=None,
|
1363
|
+
scale=np.array([1]),
|
1364
|
+
zero_point=np.array([0]),
|
1365
|
+
)
|
1366
|
+
return qtyping.TransformationInst(
|
1367
|
+
transformation=transformation,
|
1368
|
+
producer=producer,
|
1369
|
+
consumers=consumers,
|
1370
|
+
parameters=qparams,
|
1371
|
+
# Dummy values below.
|
1372
|
+
tensor_id=0,
|
1373
|
+
)
|
1374
|
+
|
1375
|
+
def _create_test_insts(
|
1376
|
+
self, instructions: list[qtyping.TransformationInst]
|
1377
|
+
) -> qtyping.TensorTransformationInsts:
|
1378
|
+
return qtyping.TensorTransformationInsts(
|
1379
|
+
tensor_name="test_tensor", subgraph_id=0, instructions=instructions
|
1380
|
+
)
|
1381
|
+
|
1382
|
+
def test_no_fusion_when_too_few_instructions(self):
|
1383
|
+
tensor_insts = self._create_test_insts([
|
1384
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
|
1385
|
+
])
|
1386
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
1387
|
+
tensor_insts
|
1388
|
+
)
|
1389
|
+
self.assertLen(tensor_insts.instructions, 1)
|
1390
|
+
|
1391
|
+
def test_no_fusion_when_too_many_instructions(self):
|
1392
|
+
tensor_insts = self._create_test_insts([
|
1393
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR),
|
1394
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE),
|
1395
|
+
self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
|
1396
|
+
])
|
1397
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
1398
|
+
tensor_insts
|
1399
|
+
)
|
1400
|
+
self.assertLen(tensor_insts.instructions, 3)
|
1401
|
+
|
1402
|
+
def test_no_fusion_when_invalid_transformation_pair(self):
|
1403
|
+
tensor_insts = self._create_test_insts([
|
1404
|
+
self._get_test_instruction(_QTransf.ADD_DEQUANTIZE),
|
1405
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE),
|
1406
|
+
])
|
1407
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
1408
|
+
tensor_insts
|
1409
|
+
)
|
1410
|
+
self.assertLen(tensor_insts.instructions, 2)
|
1411
|
+
|
1412
|
+
def test_no_fusion_when_consumers_mismatch(self):
|
1413
|
+
tensor_insts = self._create_test_insts([
|
1414
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[0]),
|
1415
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1]),
|
1416
|
+
])
|
1417
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
1418
|
+
tensor_insts
|
1419
|
+
)
|
1420
|
+
self.assertLen(tensor_insts.instructions, 2)
|
1421
|
+
|
1422
|
+
def test_no_fusion_when_no_producer(self):
|
1423
|
+
producer = -1
|
1424
|
+
tensor_insts = self._create_test_insts([
|
1425
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer),
|
1426
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer),
|
1427
|
+
])
|
1428
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
1429
|
+
tensor_insts
|
1430
|
+
)
|
1431
|
+
self.assertLen(tensor_insts.instructions, 2)
|
1432
|
+
|
1433
|
+
def test_no_fusion_when_quant_params_are_incompatible(self):
|
1434
|
+
params_8_bits = qtyping.UniformQuantParams(
|
1435
|
+
8, None, np.array([1]), np.array([0])
|
1436
|
+
)
|
1437
|
+
params_16_bits = qtyping.UniformQuantParams(
|
1438
|
+
16, None, np.array([1]), np.array([0])
|
1439
|
+
)
|
1440
|
+
tensor_insts = self._create_test_insts([
|
1441
|
+
self._get_test_instruction(
|
1442
|
+
_QTransf.QUANTIZE_TENSOR, qparams=params_8_bits
|
1443
|
+
),
|
1444
|
+
self._get_test_instruction(
|
1445
|
+
_QTransf.ADD_QUANTIZE, qparams=params_16_bits
|
1446
|
+
),
|
1447
|
+
])
|
1448
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
1449
|
+
tensor_insts
|
1450
|
+
)
|
1451
|
+
self.assertLen(tensor_insts.instructions, 2)
|
1452
|
+
|
1453
|
+
def test_no_fusion_when_producer_constrained(self):
|
1454
|
+
# Reshape op (op index 2) has same as input scale constraint.
|
1455
|
+
tensor_insts = self._create_test_insts([
|
1456
|
+
self._get_test_instruction(_QTransf.QUANTIZE_TENSOR, producer=2),
|
1457
|
+
self._get_test_instruction(_QTransf.ADD_QUANTIZE, producer=2),
|
1458
|
+
])
|
1459
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
1460
|
+
tensor_insts
|
1461
|
+
)
|
1462
|
+
self.assertLen(tensor_insts.instructions, 2)
|
1463
|
+
|
1464
|
+
def test_fusion_succeeds(self):
|
1465
|
+
producer = 0
|
1466
|
+
consumers = [1]
|
1467
|
+
params_0 = qtyping.UniformQuantParams(
|
1468
|
+
num_bits=8,
|
1469
|
+
quantized_dimension=None,
|
1470
|
+
scale=np.array([1]),
|
1471
|
+
zero_point=np.array([0]),
|
1472
|
+
)
|
1473
|
+
params_1 = qtyping.UniformQuantParams(
|
1474
|
+
num_bits=8,
|
1475
|
+
quantized_dimension=None,
|
1476
|
+
scale=np.array([2]),
|
1477
|
+
zero_point=np.array([1]),
|
1478
|
+
)
|
1479
|
+
inst_0 = self._get_test_instruction(
|
1480
|
+
_QTransf.QUANTIZE_TENSOR, producer, consumers, params_0
|
1481
|
+
)
|
1482
|
+
inst_1 = self._get_test_instruction(
|
1483
|
+
_QTransf.ADD_QUANTIZE, producer, consumers, params_1
|
1484
|
+
)
|
1485
|
+
tensor_insts = self._create_test_insts([inst_0, inst_1])
|
1486
|
+
self.ins_gen._eliminate_requantization_for_nonconstrained_provider(
|
1487
|
+
tensor_insts
|
1488
|
+
)
|
1489
|
+
|
1490
|
+
self.assertLen(tensor_insts.instructions, 1)
|
1491
|
+
result_inst = tensor_insts.instructions[0]
|
1492
|
+
self.assertEqual(result_inst.transformation, _QTransf.QUANTIZE_TENSOR)
|
1493
|
+
|
1494
|
+
result_params = result_inst.parameters
|
1495
|
+
# Explicitly narrow the type for pytype.
|
1496
|
+
if not isinstance(result_params, qtyping.UniformQuantParams):
|
1497
|
+
self.fail("Fused instruction parameters are not UniformQuantParams")
|
1498
|
+
|
1499
|
+
self.assertEqual(result_params.scale, params_1.scale)
|
1500
|
+
self.assertEqual(result_params.zero_point, params_1.zero_point)
|
1501
|
+
|
1502
|
+
|
1340
1503
|
if __name__ == "__main__":
|
1341
1504
|
googletest.main()
|
@@ -20,13 +20,11 @@ from typing import Any, Union
|
|
20
20
|
|
21
21
|
import numpy as np
|
22
22
|
|
23
|
-
from ai_edge_quantizer import algorithm_manager
|
24
23
|
from ai_edge_quantizer import qtyping
|
25
|
-
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
|
26
24
|
from ai_edge_quantizer.algorithms.utils import common_utils
|
25
|
+
from ai_edge_quantizer.utils import constrained_ops_utils
|
27
26
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
28
27
|
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
29
|
-
from ai_edge_litert import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
|
30
28
|
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
|
31
29
|
|
32
30
|
|
@@ -133,7 +131,11 @@ class CalibrationQsvAlignmentUtils:
|
|
133
131
|
"""
|
134
132
|
|
135
133
|
def __init__(self, model_path: str):
|
136
|
-
self._same_as_input_scale_ops =
|
134
|
+
self._same_as_input_scale_ops = (
|
135
|
+
constrained_ops_utils.get_constrained_op_list(
|
136
|
+
_OpQuantConstraint.SAME_AS_INPUT_SCALE
|
137
|
+
)
|
138
|
+
)
|
137
139
|
|
138
140
|
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_path)
|
139
141
|
self._flatbuffer_object = tfl_flatbuffer_utils.read_model(model_path)
|
@@ -146,87 +148,6 @@ class CalibrationQsvAlignmentUtils:
|
|
146
148
|
signature_runner = tfl_interpreter.get_signature_runner(signature_key)
|
147
149
|
self._signature_runners[signature_key] = signature_runner
|
148
150
|
|
149
|
-
# Make a list of `SAME_AS_INPUT_SCALE` operators. This is used to identify
|
150
|
-
# the operators that need to be constrained to the same scale as the input.
|
151
|
-
self._build_same_as_input_scale_op_list()
|
152
|
-
|
153
|
-
def _build_same_as_input_scale_op_list(self, verbose: bool = False):
|
154
|
-
"""Constructs a list of SAME_AS_INPUT_SCALE operators.
|
155
|
-
|
156
|
-
This is achieved by invoking all materialization functions and extracting
|
157
|
-
the constraint argument, using monkey patching to redirect logic to wrapper
|
158
|
-
functions.
|
159
|
-
|
160
|
-
Args:
|
161
|
-
verbose: Flag to enable verbose output.
|
162
|
-
"""
|
163
|
-
|
164
|
-
def materialize_standard_op_wrapper(
|
165
|
-
op_info: qtyping.OpInfo,
|
166
|
-
*_args,
|
167
|
-
constraint: _OpQuantConstraint = _OpQuantConstraint.NO_CONSTRAIN,
|
168
|
-
**_kwargs,
|
169
|
-
) -> list[qtyping.TensorTransformationParams]:
|
170
|
-
if constraint == _OpQuantConstraint.SAME_AS_INPUT_SCALE:
|
171
|
-
self._same_as_input_scale_ops.append(op_info.op_name)
|
172
|
-
# Return dummy values to avoid exceptions.
|
173
|
-
dummy_value = [qtyping.TensorTransformationParams("")] * 2
|
174
|
-
return dummy_value
|
175
|
-
|
176
|
-
# Dummy implementation of the `_are_weights_too_small` function to support
|
177
|
-
# `materialize_standard_op_wrapper` above.
|
178
|
-
def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
|
179
|
-
return False
|
180
|
-
|
181
|
-
# Dummy implementation of the `_materialize_bias_for_conv_ops` function to
|
182
|
-
# support `materialize_standard_op_wrapper` above.
|
183
|
-
def materialize_bias_for_conv_ops_wrapper(*_args, **_kwargs):
|
184
|
-
return
|
185
|
-
|
186
|
-
# Do monkey patch to intercept the `materialize_standard_op` function to
|
187
|
-
# support `materialize_standard_op_wrapper` above.
|
188
|
-
original_materialize_standard_op = common_utils.materialize_standard_op
|
189
|
-
original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
|
190
|
-
original_materialize_bias_for_conv_ops = (
|
191
|
-
common_quantize._materialize_bias_for_conv_ops # pylint: disable=protected-access
|
192
|
-
)
|
193
|
-
common_utils.materialize_standard_op = materialize_standard_op_wrapper
|
194
|
-
common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
|
195
|
-
common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
|
196
|
-
materialize_bias_for_conv_ops_wrapper
|
197
|
-
)
|
198
|
-
minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
|
199
|
-
|
200
|
-
# Loop over all available materialization functions to build up a list of
|
201
|
-
# `SAME_AS_INPUT_SCALE` constrained ops.
|
202
|
-
for op, materialize_fn in minmax_func_dict.items():
|
203
|
-
# Create a dummy op info to trigger the materialization.
|
204
|
-
mock_op = schema_fb.OperatorT()
|
205
|
-
mock_op.inputs = [0]
|
206
|
-
mock_op.outputs = [0]
|
207
|
-
op_info = qtyping.OpInfo(
|
208
|
-
op=mock_op,
|
209
|
-
op_name=op,
|
210
|
-
subgraph_op_index=0,
|
211
|
-
op_quant_config=qtyping.OpQuantizationConfig(),
|
212
|
-
)
|
213
|
-
materialize_fn(
|
214
|
-
get_tensor_quant_params_fn=None,
|
215
|
-
op_info=op_info,
|
216
|
-
graph_info=None,
|
217
|
-
tensor_name_to_qsv=None,
|
218
|
-
)
|
219
|
-
|
220
|
-
if verbose:
|
221
|
-
print(f" Constrained op list: {self._same_as_input_scale_ops}")
|
222
|
-
|
223
|
-
# Restore the original functions.
|
224
|
-
common_utils.materialize_standard_op = original_materialize_standard_op
|
225
|
-
common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
|
226
|
-
common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
|
227
|
-
original_materialize_bias_for_conv_ops
|
228
|
-
)
|
229
|
-
|
230
151
|
def _search_tensor_by_signature_name(
|
231
152
|
self, signature_key: str, signature_input_output_name: str, verbose=False
|
232
153
|
) -> list[str]:
|
@@ -0,0 +1,111 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utils for handling operators with quantization constraints."""
|
17
|
+
|
18
|
+
from ai_edge_quantizer import algorithm_manager
|
19
|
+
from ai_edge_quantizer import qtyping
|
20
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
|
21
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
22
|
+
from ai_edge_litert import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
|
23
|
+
|
24
|
+
|
25
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
26
|
+
|
27
|
+
|
28
|
+
def get_constrained_op_list(
|
29
|
+
quant_constraint: _OpQuantConstraint, verbose: bool = False
|
30
|
+
) -> list[str]:
|
31
|
+
"""Constructs and returns a list of constrained operators.
|
32
|
+
|
33
|
+
This is achieved by invoking all materialization functions and extracting
|
34
|
+
the constraint argument, using monkey patching to redirect logic to wrapper
|
35
|
+
functions.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
quant_constraint: The quantization constraint to filter operators by.
|
39
|
+
verbose: Flag to enable verbose output.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
A list containing operators with the specified constraint.
|
43
|
+
"""
|
44
|
+
constrained_ops = []
|
45
|
+
|
46
|
+
def materialize_standard_op_wrapper(
|
47
|
+
op_info: qtyping.OpInfo,
|
48
|
+
*_args,
|
49
|
+
constraint: _OpQuantConstraint = _OpQuantConstraint.NO_CONSTRAIN,
|
50
|
+
**_kwargs,
|
51
|
+
) -> list[qtyping.TensorTransformationParams]:
|
52
|
+
if constraint == quant_constraint:
|
53
|
+
constrained_ops.append(op_info.op_name)
|
54
|
+
# Return dummy values to avoid exceptions.
|
55
|
+
dummy_value = [qtyping.TensorTransformationParams("")] * 2
|
56
|
+
return dummy_value
|
57
|
+
|
58
|
+
# Dummy implementation of the `_are_weights_too_small` function to support
|
59
|
+
# `materialize_standard_op_wrapper` above.
|
60
|
+
def are_weights_too_small_wrapper(*_args, **_kwargs) -> bool:
|
61
|
+
return False
|
62
|
+
|
63
|
+
# Dummy implementation of the `_materialize_bias_for_conv_ops` function to
|
64
|
+
# support `materialize_standard_op_wrapper` above.
|
65
|
+
def materialize_bias_for_conv_ops_wrapper(*_args, **_kwargs):
|
66
|
+
return
|
67
|
+
|
68
|
+
# Do monkey patch to intercept the `materialize_standard_op` function to
|
69
|
+
# support `materialize_standard_op_wrapper` above.
|
70
|
+
original_materialize_standard_op = common_utils.materialize_standard_op
|
71
|
+
original_are_weights_too_small = common_quantize._are_weights_too_small # pylint: disable=protected-access
|
72
|
+
original_materialize_bias_for_conv_ops = (
|
73
|
+
common_quantize._materialize_bias_for_conv_ops # pylint: disable=protected-access
|
74
|
+
)
|
75
|
+
common_utils.materialize_standard_op = materialize_standard_op_wrapper
|
76
|
+
common_quantize._are_weights_too_small = are_weights_too_small_wrapper # pylint: disable=protected-access
|
77
|
+
common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
|
78
|
+
materialize_bias_for_conv_ops_wrapper
|
79
|
+
)
|
80
|
+
minmax_func_dict = algorithm_manager.MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT
|
81
|
+
|
82
|
+
# Loop over all available materialization functions to build up a list of
|
83
|
+
# ops with the given constraint.
|
84
|
+
for op, materialize_fn in minmax_func_dict.items():
|
85
|
+
# Create a dummy op info to trigger the materialization.
|
86
|
+
mock_op = schema_fb.OperatorT()
|
87
|
+
mock_op.inputs = [0]
|
88
|
+
mock_op.outputs = [0]
|
89
|
+
op_info = qtyping.OpInfo(
|
90
|
+
op=mock_op,
|
91
|
+
op_name=op,
|
92
|
+
subgraph_op_index=0,
|
93
|
+
op_quant_config=qtyping.OpQuantizationConfig(),
|
94
|
+
)
|
95
|
+
materialize_fn(
|
96
|
+
get_tensor_quant_params_fn=None,
|
97
|
+
op_info=op_info,
|
98
|
+
graph_info=None,
|
99
|
+
tensor_name_to_qsv=None,
|
100
|
+
)
|
101
|
+
|
102
|
+
if verbose:
|
103
|
+
print(f" {quant_constraint} op list: {constrained_ops}")
|
104
|
+
|
105
|
+
# Restore the original functions.
|
106
|
+
common_utils.materialize_standard_op = original_materialize_standard_op
|
107
|
+
common_quantize._are_weights_too_small = original_are_weights_too_small # pylint: disable=protected-access
|
108
|
+
common_quantize._materialize_bias_for_conv_ops = ( # pylint: disable=protected-access
|
109
|
+
original_materialize_bias_for_conv_ops
|
110
|
+
)
|
111
|
+
return constrained_ops
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from tensorflow.python.platform import googletest
|
17
|
+
from absl.testing import parameterized
|
18
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
19
|
+
from ai_edge_quantizer.utils import constrained_ops_utils
|
20
|
+
|
21
|
+
|
22
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
23
|
+
|
24
|
+
|
25
|
+
class ConstrainedOpsUtilsTest(parameterized.TestCase):
|
26
|
+
|
27
|
+
@parameterized.named_parameters(
|
28
|
+
dict(
|
29
|
+
testcase_name="same_as_input_scale",
|
30
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
31
|
+
expected_num_ops=14,
|
32
|
+
),
|
33
|
+
dict(
|
34
|
+
testcase_name="same_as_output_scale",
|
35
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
36
|
+
expected_num_ops=6,
|
37
|
+
),
|
38
|
+
dict(
|
39
|
+
testcase_name="no_constrain",
|
40
|
+
constraint=_OpQuantConstraint.NO_CONSTRAIN,
|
41
|
+
expected_num_ops=22,
|
42
|
+
),
|
43
|
+
)
|
44
|
+
def test_get_constrained_op_list(self, constraint, expected_num_ops):
|
45
|
+
constrained_ops = constrained_ops_utils.get_constrained_op_list(constraint)
|
46
|
+
self.assertLen(constrained_ops, expected_num_ops)
|
47
|
+
|
48
|
+
|
49
|
+
if __name__ == "__main__":
|
50
|
+
googletest.main()
|
@@ -342,3 +342,12 @@ def get_op_side_effect_subgraphs(
|
|
342
342
|
return [opts.decompositionSubgraphIndex]
|
343
343
|
# Can add other nested ops here (control flow ops, etc).
|
344
344
|
return []
|
345
|
+
|
346
|
+
|
347
|
+
def get_op_name_by_index(
|
348
|
+
flatbuffer_model: Any, subgraph_id: int, op_index: int
|
349
|
+
) -> str:
|
350
|
+
"""Get the op name from the flatbuffer model."""
|
351
|
+
op = flatbuffer_model.subgraphs[subgraph_id].operators[op_index]
|
352
|
+
builtin_code = flatbuffer_model.operatorCodes[op.opcodeIndex].builtinCode
|
353
|
+
return TFL_OP_CODE_TO_NAME[builtin_code]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-quantizer-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250830
|
4
4
|
Summary: A quantizer for advanced developers to quantize converted AI Edge models.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
|
@@ -5,7 +5,7 @@ ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gND
|
|
5
5
|
ai_edge_quantizer/calibrator.py,sha256=Sms7_AIHPH9G5xFaz5Ef3a5gPhxuIWQI8d2LUM8C96I,12071
|
6
6
|
ai_edge_quantizer/calibrator_test.py,sha256=ZLzIMWB2FSFU4TOatDioYuwp_kLh8iSCefZ5_Q9FU7s,11900
|
7
7
|
ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
|
8
|
-
ai_edge_quantizer/default_policy.py,sha256=
|
8
|
+
ai_edge_quantizer/default_policy.py,sha256=G_JZtZaQAnrWyfCusDWXwO27iLysk27RS91GlS61m_Q,11592
|
9
9
|
ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
|
10
10
|
ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
|
11
11
|
ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
|
@@ -19,8 +19,8 @@ ai_edge_quantizer/recipe.py,sha256=MEkfQ2Sg3KAE9LAORHWcbjYNPg06EUbwc1d-VspQA2U,6
|
|
19
19
|
ai_edge_quantizer/recipe_manager.py,sha256=6dgbE-IZfEetzXH3p3Qm_9eQutNDOpZnMpiaLTbP-ZQ,14744
|
20
20
|
ai_edge_quantizer/recipe_manager_test.py,sha256=H-B75vwPN5ND-nUa3pOXizeHTv4mufPiC5cL_OlDIYU,34040
|
21
21
|
ai_edge_quantizer/recipe_test.py,sha256=GKuo6N65wKLS2xwSpjd-BWWeVRpF1zc7Yt7phSMYSxA,5905
|
22
|
-
ai_edge_quantizer/transformation_instruction_generator.py,sha256=
|
23
|
-
ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=
|
22
|
+
ai_edge_quantizer/transformation_instruction_generator.py,sha256=O0U2aZcB8aXQgOV8r9g1rGNzDUiuI5Ta53XnxZbVffE,31576
|
23
|
+
ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=KW5-WoTTo9IqLEVnWxVC8ut8eWLi_91xfKgGqVQ9QDk,54635
|
24
24
|
ai_edge_quantizer/transformation_performer.py,sha256=o4J6OUbI0dLoobVYjkOFw5Po3yH0gZJXrfuTIYais4o,13029
|
25
25
|
ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
|
26
26
|
ai_edge_quantizer/algorithms/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
@@ -59,17 +59,19 @@ ai_edge_quantizer/transformations/quantize_tensor_test.py,sha256=mHLO3_MRt36A8-Z
|
|
59
59
|
ai_edge_quantizer/transformations/transformation_utils.py,sha256=efJdAkA24wlg6Vj5NFO7_7MDuvQLSNn-l11Vs_JPktI,7123
|
60
60
|
ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rvxRQIfi4ny9IoODFCTcbpjnIwoCL40zDKk,8698
|
61
61
|
ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
62
|
-
ai_edge_quantizer/utils/calibration_utils.py,sha256=
|
62
|
+
ai_edge_quantizer/utils/calibration_utils.py,sha256=iMf_bSCf-O86MzDt5D9hLKqbTydqLwirluaC6BJ9yHo,11553
|
63
63
|
ai_edge_quantizer/utils/calibration_utils_test.py,sha256=4BlksXl7b4yptL8xPR67hmJCnjhN9V10a2PunzfHrUE,9372
|
64
|
+
ai_edge_quantizer/utils/constrained_ops_utils.py,sha256=EAITCf7Ku_PFZcw3K-wd-8hGbyuRd5W5UtNdGvalwAE,4478
|
65
|
+
ai_edge_quantizer/utils/constrained_ops_utils_test.py,sha256=6k_AqfB-NmiLkW5WwEV5NSuswFWky2sL0xBGmV6Fdwk,1756
|
64
66
|
ai_edge_quantizer/utils/test_utils.py,sha256=a4Nk-wbeB09dFjTDZiA0K67d26j5DD0UDH_GIVmVG_4,8685
|
65
|
-
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=
|
67
|
+
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=aNtL4dpWH5uGGGlaygnMDkh5llTstbgs5ZxO0JkH5VQ,11718
|
66
68
|
ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
|
67
69
|
ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EoVjI_hplX_Rml3hfRsGmQOihexmizeJqt4SQcET9aA,14925
|
68
70
|
ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
|
69
71
|
ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
|
70
72
|
ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
|
71
|
-
ai_edge_quantizer_nightly-0.4.0.
|
72
|
-
ai_edge_quantizer_nightly-0.4.0.
|
73
|
-
ai_edge_quantizer_nightly-0.4.0.
|
74
|
-
ai_edge_quantizer_nightly-0.4.0.
|
75
|
-
ai_edge_quantizer_nightly-0.4.0.
|
73
|
+
ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
74
|
+
ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/METADATA,sha256=B1PQw7561EH7WCLLBTwYIoXP8WH-0vJeTyPuQPAiX_M,1535
|
75
|
+
ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
76
|
+
ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
|
77
|
+
ai_edge_quantizer_nightly-0.4.0.dev20250830.dist-info/RECORD,,
|
File without changes
|
File without changes
|