ai-edge-quantizer-nightly 0.3.0.dev20250609__py3-none-any.whl → 0.3.0.dev20250611__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +8 -0
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +31 -0
- ai_edge_quantizer/default_policy.py +5 -2
- ai_edge_quantizer/qtyping.py +2 -0
- ai_edge_quantizer/utils/test_utils.py +31 -1
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +2 -0
- {ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info}/RECORD +11 -11
- {ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info → ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info}/top_level.txt +0 -0
@@ -108,6 +108,10 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
|
|
108
108
|
),
|
109
109
|
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
|
110
110
|
_TFLOpName.PAD: common_quantize.materialize_pad,
|
111
|
+
_TFLOpName.SQUARED_DIFFERENCE: (
|
112
|
+
common_quantize.materialize_squared_difference
|
113
|
+
),
|
114
|
+
_TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
|
111
115
|
}
|
112
116
|
for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
113
117
|
register_quantized_op(
|
@@ -242,6 +246,10 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
242
246
|
),
|
243
247
|
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
|
244
248
|
_TFLOpName.PAD: common_quantize.materialize_pad,
|
249
|
+
_TFLOpName.SQUARED_DIFFERENCE: (
|
250
|
+
common_quantize.materialize_squared_difference
|
251
|
+
),
|
252
|
+
_TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
|
245
253
|
})
|
246
254
|
|
247
255
|
for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
@@ -697,6 +697,37 @@ def materialize_pad(
|
|
697
697
|
)
|
698
698
|
|
699
699
|
|
700
|
+
def materialize_squared_difference(
|
701
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
702
|
+
op_info: qtyping.OpInfo,
|
703
|
+
graph_info: qtyping.GraphInfo,
|
704
|
+
tensor_name_to_qsv: dict[str, Any],
|
705
|
+
) -> list[qtyping.TensorTransformationParams]:
|
706
|
+
"""Materialize tensors in tfl.squared_difference."""
|
707
|
+
return common_utils.materialize_standard_op(
|
708
|
+
op_info,
|
709
|
+
graph_info,
|
710
|
+
tensor_name_to_qsv,
|
711
|
+
get_tensor_quant_params_fn,
|
712
|
+
)
|
713
|
+
|
714
|
+
|
715
|
+
def materialize_max_pool_2d(
|
716
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
717
|
+
op_info: qtyping.OpInfo,
|
718
|
+
graph_info: qtyping.GraphInfo,
|
719
|
+
tensor_name_to_qsv: dict[str, Any],
|
720
|
+
) -> list[qtyping.TensorTransformationParams]:
|
721
|
+
"""Materialize tensors in tfl.max_pool_2d."""
|
722
|
+
return common_utils.materialize_standard_op(
|
723
|
+
op_info,
|
724
|
+
graph_info,
|
725
|
+
tensor_name_to_qsv,
|
726
|
+
get_tensor_quant_params_fn,
|
727
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
728
|
+
)
|
729
|
+
|
730
|
+
|
700
731
|
def _get_tensor_shape_for_blockwise(
|
701
732
|
tensor_shape: Sequence[int], quantized_dim: int, block_size: int
|
702
733
|
) -> list[int]:
|
@@ -184,7 +184,8 @@ DEFAULT_JSON_POLICY = """
|
|
184
184
|
"DYNAMIC_UPDATE_SLICE",
|
185
185
|
"SELECT_V2",
|
186
186
|
"STABLEHLO_COMPOSITE",
|
187
|
-
"PAD"
|
187
|
+
"PAD",
|
188
|
+
"MAX_POOL_2D"
|
188
189
|
],
|
189
190
|
"static_wi8_ai8": [
|
190
191
|
"ADD",
|
@@ -216,7 +217,9 @@ DEFAULT_JSON_POLICY = """
|
|
216
217
|
"DYNAMIC_UPDATE_SLICE",
|
217
218
|
"SELECT_V2",
|
218
219
|
"STABLEHLO_COMPOSITE",
|
219
|
-
"PAD"
|
220
|
+
"PAD",
|
221
|
+
"SQUARED_DIFFERENCE",
|
222
|
+
"MAX_POOL_2D"
|
220
223
|
],
|
221
224
|
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
222
225
|
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
ai_edge_quantizer/qtyping.py
CHANGED
@@ -63,6 +63,8 @@ class TFLOperationName(str, enum.Enum):
|
|
63
63
|
DYNAMIC_UPDATE_SLICE = 'DYNAMIC_UPDATE_SLICE'
|
64
64
|
STABLEHLO_COMPOSITE = 'STABLEHLO_COMPOSITE'
|
65
65
|
PAD = 'PAD'
|
66
|
+
SQUARED_DIFFERENCE = 'SQUARED_DIFFERENCE'
|
67
|
+
MAX_POOL_2D = 'MAX_POOL_2D'
|
66
68
|
|
67
69
|
|
68
70
|
class QuantizeMode(enum.Enum):
|
@@ -18,6 +18,7 @@
|
|
18
18
|
import inspect as _inspect
|
19
19
|
import os.path as _os_path
|
20
20
|
import sys as _sys
|
21
|
+
from typing import Union
|
21
22
|
|
22
23
|
from absl.testing import parameterized
|
23
24
|
|
@@ -97,6 +98,7 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
97
98
|
op_name: _OpName,
|
98
99
|
op_config: _OpQuantConfig,
|
99
100
|
num_validation_samples: int = 4,
|
101
|
+
num_calibration_samples: Union[int, None] = None,
|
100
102
|
error_metric: str = 'mse',
|
101
103
|
) -> model_validator.ComparisonResult:
|
102
104
|
"""Quantizes and validates the given model with the given configurations.
|
@@ -107,6 +109,8 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
107
109
|
op_name: The name of the operation to be quantized.
|
108
110
|
op_config: The configuration for the operation to be quantized.
|
109
111
|
num_validation_samples: The number of samples to use for validation.
|
112
|
+
num_calibration_samples: The number of samples to use for calibration. If
|
113
|
+
None then it will be set to num_validation_samples * 8.
|
110
114
|
error_metric: The error error_metric to use for validation.
|
111
115
|
|
112
116
|
Returns:
|
@@ -120,8 +124,11 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
120
124
|
op_config=op_config,
|
121
125
|
)
|
122
126
|
if quantizer_instance.need_calibration:
|
127
|
+
if num_calibration_samples is None:
|
128
|
+
num_calibration_samples = num_validation_samples * 8
|
123
129
|
calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
|
124
|
-
quantizer_instance.float_model,
|
130
|
+
quantizer_instance.float_model,
|
131
|
+
num_samples=num_calibration_samples,
|
125
132
|
)
|
126
133
|
calibration_result = quantizer_instance.calibrate(calibration_data)
|
127
134
|
quantization_result = quantizer_instance.quantize(calibration_result)
|
@@ -198,3 +205,26 @@ class BaseOpTestCase(parameterized.TestCase):
|
|
198
205
|
self.assert_output_errors_below_tolerance(
|
199
206
|
validation_result, output_tolerance
|
200
207
|
)
|
208
|
+
|
209
|
+
def assert_quantization_accuracy(
|
210
|
+
self,
|
211
|
+
algorithm_key: _AlgorithmName,
|
212
|
+
model_path: str,
|
213
|
+
op_name: _OpName,
|
214
|
+
op_config: _OpQuantConfig,
|
215
|
+
num_validation_samples: int = 4,
|
216
|
+
num_calibration_samples: Union[int, None] = None,
|
217
|
+
output_tolerance: float = 1e-4,
|
218
|
+
):
|
219
|
+
"""Check if the output errors after quantization are within the tolerance."""
|
220
|
+
validation_result = self.quantize_and_validate(
|
221
|
+
model_path=model_path,
|
222
|
+
algorithm_key=algorithm_key,
|
223
|
+
num_validation_samples=num_validation_samples,
|
224
|
+
num_calibration_samples=num_calibration_samples,
|
225
|
+
op_name=op_name,
|
226
|
+
op_config=op_config,
|
227
|
+
)
|
228
|
+
self.assert_output_errors_below_tolerance(
|
229
|
+
validation_result, output_tolerance
|
230
|
+
)
|
@@ -57,6 +57,8 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
|
|
57
57
|
schema.BuiltinOperator.DYNAMIC_UPDATE_SLICE
|
58
58
|
),
|
59
59
|
_TFLOpName.PAD: schema.BuiltinOperator.PAD,
|
60
|
+
_TFLOpName.SQUARED_DIFFERENCE: schema.BuiltinOperator.SQUARED_DIFFERENCE,
|
61
|
+
_TFLOpName.MAX_POOL_2D: schema.BuiltinOperator.MAX_POOL_2D,
|
60
62
|
})
|
61
63
|
|
62
64
|
TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-quantizer-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250611
|
4
4
|
Summary: A quantizer for advanced developers to quantize converted AI Edge models.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
|
@@ -1,18 +1,18 @@
|
|
1
1
|
ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas,793
|
2
|
-
ai_edge_quantizer/algorithm_manager.py,sha256=
|
2
|
+
ai_edge_quantizer/algorithm_manager.py,sha256=lfCazb2b0Q4L3of0cTWkF5lMr3AD6LWW1ekmFoEGB_4,12062
|
3
3
|
ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
|
4
4
|
ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
|
5
5
|
ai_edge_quantizer/calibrator.py,sha256=-_jX_KkfIepkQAwxxDrZjvPO1JsoSjHXVy1DPc1iFjM,12068
|
6
6
|
ai_edge_quantizer/calibrator_test.py,sha256=C_oWOaRugPKYX74jF-eRFH-k6nGOdA8I9_uPiocaOuE,11900
|
7
7
|
ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
|
8
|
-
ai_edge_quantizer/default_policy.py,sha256=
|
8
|
+
ai_edge_quantizer/default_policy.py,sha256=nKtghUjTQ8QS9CgLRwQb3iB2eZOyQv0FqyISlcgzSH4,11195
|
9
9
|
ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
|
10
10
|
ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
|
11
11
|
ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
|
12
12
|
ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
|
13
13
|
ai_edge_quantizer/params_generator.py,sha256=j1BV2cGFLlQmUY6aoW5uglYqf77b9ytN8oZ1gh6o0mM,20096
|
14
14
|
ai_edge_quantizer/params_generator_test.py,sha256=RDYoRZDJfEZRtjlTAU2kZ_4t3JHOqEHxfJX9V4ETAhg,40597
|
15
|
-
ai_edge_quantizer/qtyping.py,sha256=
|
15
|
+
ai_edge_quantizer/qtyping.py,sha256=0Dwz6LHQG8LhZMhVAo_h6ieZ_gcfkJl2yJcsGf17YYs,16527
|
16
16
|
ai_edge_quantizer/quantizer.py,sha256=g3DMqFMrMpt9jQttCE0WcdNbMtk0JZnmN5MmCHrNdyM,13202
|
17
17
|
ai_edge_quantizer/quantizer_test.py,sha256=K_HBA56JkFI3HL8VLWCqGEfC0ISh5ldMKoNyBdGRAJg,20368
|
18
18
|
ai_edge_quantizer/recipe.py,sha256=FR0uJceumZrnle2VRSOQZ1uXup4S1cTYKRH-N53mWRo,2919
|
@@ -28,7 +28,7 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCP
|
|
28
28
|
ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
|
29
29
|
ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
|
30
30
|
ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
31
|
-
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=
|
31
|
+
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=ofDoiZhOKjF7Tm-v0a4xsLSvytjfvMALXLDcuwcKNK0,29634
|
32
32
|
ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
|
33
33
|
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=BDdn_uBZakfHyzdMJPKadsOqxqyC-s6W2ZzFH99L4fE,8652
|
34
34
|
ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
|
@@ -63,15 +63,15 @@ ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rv
|
|
63
63
|
ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
|
64
64
|
ai_edge_quantizer/utils/calibration_utils.py,sha256=1Fj9MIO6aLZIRgyd4axvZN4S_O64nB_-Miu1WP664js,2536
|
65
65
|
ai_edge_quantizer/utils/calibration_utils_test.py,sha256=Z-AcdTieesWFKyKBb08ZXm4Mgu6cvJ4bg2-MJ7hLD10,2856
|
66
|
-
ai_edge_quantizer/utils/test_utils.py,sha256=
|
67
|
-
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=
|
66
|
+
ai_edge_quantizer/utils/test_utils.py,sha256=Y2pdMvn1k4gmqDo3noJfzx3fJcDHX_1hcsP6oiIz65Y,8240
|
67
|
+
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=Yy1u53FzRBFx-fr1TqoycWMZwAlAl0b2IB4MmGV1xJA,10758
|
68
68
|
ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
|
69
69
|
ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EtOv6cpKM_F0uv2bWuSXylYmTeXT6zUc182pw4sdYSI,13889
|
70
70
|
ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
|
71
71
|
ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
|
72
72
|
ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
|
73
|
-
ai_edge_quantizer_nightly-0.3.0.
|
74
|
-
ai_edge_quantizer_nightly-0.3.0.
|
75
|
-
ai_edge_quantizer_nightly-0.3.0.
|
76
|
-
ai_edge_quantizer_nightly-0.3.0.
|
77
|
-
ai_edge_quantizer_nightly-0.3.0.
|
73
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
74
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/METADATA,sha256=FPK-WqVTMEz-w5yycBejT4oRBxMY4fiYH-AAL6Pf4-w,1528
|
75
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
76
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
|
77
|
+
ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/RECORD,,
|
File without changes
|
File without changes
|