ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,19 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Init file for the AI Edge quantizer package."""
|
17
|
+
|
18
|
+
# pylint: disable=unused-import
|
19
|
+
from .quantizer import Quantizer
|
@@ -0,0 +1,167 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Quantizer Algorithm Manager Interface."""
|
17
|
+
|
18
|
+
import enum
|
19
|
+
from ai_edge_quantizer import algorithm_manager_api
|
20
|
+
from ai_edge_quantizer import default_policy
|
21
|
+
from ai_edge_quantizer import qtyping
|
22
|
+
from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
|
23
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
|
24
|
+
|
25
|
+
_TFLOpName = qtyping.TFLOperationName
|
26
|
+
|
27
|
+
_alg_manager_instance = algorithm_manager_api.AlgorithmManagerApi()
|
28
|
+
|
29
|
+
# Expose instance functions.
|
30
|
+
get_quantization_func = _alg_manager_instance.get_quantization_func
|
31
|
+
get_supported_ops = _alg_manager_instance.get_supported_ops
|
32
|
+
get_init_qsv_func = _alg_manager_instance.get_init_qsv_func
|
33
|
+
register_op_quant_config_validation_func = (
|
34
|
+
_alg_manager_instance.register_op_quant_config_validation_func
|
35
|
+
)
|
36
|
+
register_config_check_policy_func = (
|
37
|
+
_alg_manager_instance.register_config_check_policy
|
38
|
+
)
|
39
|
+
register_quantized_op = _alg_manager_instance.register_quantized_op
|
40
|
+
is_op_registered = _alg_manager_instance.is_op_registered
|
41
|
+
is_algorithm_registered = _alg_manager_instance.is_algorithm_registered
|
42
|
+
check_op_quantization_config = (
|
43
|
+
_alg_manager_instance.check_op_quantization_config
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
# Quantization algorithms.
|
48
|
+
class AlgorithmName(str, enum.Enum):
|
49
|
+
NO_QUANTIZE = "no_quantize"
|
50
|
+
MIN_MAX_UNIFORM_QUANT = naive_min_max_quantize.ALGORITHM_KEY
|
51
|
+
FLOAT_CASTING = float_casting.ALGORITHM_KEY
|
52
|
+
|
53
|
+
|
54
|
+
# Register MIN_MAX_UNIFORM_QUANT algorithm.
|
55
|
+
register_op_quant_config_validation_func(
|
56
|
+
AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
57
|
+
naive_min_max_quantize.check_op_quantization_config,
|
58
|
+
)
|
59
|
+
|
60
|
+
# Register a config check policy for MIN_MAX_UNIFORM_QUANT algorithm.
|
61
|
+
register_config_check_policy_func(
|
62
|
+
AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
63
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
for op_name, materialize_func in zip(
|
68
|
+
(
|
69
|
+
_TFLOpName.INPUT,
|
70
|
+
_TFLOpName.OUTPUT,
|
71
|
+
_TFLOpName.FULLY_CONNECTED,
|
72
|
+
_TFLOpName.BATCH_MATMUL,
|
73
|
+
_TFLOpName.CONV_2D,
|
74
|
+
_TFLOpName.DEPTHWISE_CONV_2D,
|
75
|
+
_TFLOpName.CONV_2D_TRANSPOSE,
|
76
|
+
_TFLOpName.RESHAPE,
|
77
|
+
_TFLOpName.AVERAGE_POOL_2D,
|
78
|
+
_TFLOpName.EMBEDDING_LOOKUP,
|
79
|
+
_TFLOpName.SOFTMAX,
|
80
|
+
_TFLOpName.TANH,
|
81
|
+
_TFLOpName.TRANSPOSE,
|
82
|
+
_TFLOpName.GELU,
|
83
|
+
_TFLOpName.ADD,
|
84
|
+
_TFLOpName.SUB,
|
85
|
+
_TFLOpName.MUL,
|
86
|
+
_TFLOpName.MEAN,
|
87
|
+
_TFLOpName.RSQRT,
|
88
|
+
_TFLOpName.CONCATENATION,
|
89
|
+
_TFLOpName.STRIDED_SLICE,
|
90
|
+
_TFLOpName.SPLIT,
|
91
|
+
_TFLOpName.LOGISTIC, # Sigmoid
|
92
|
+
_TFLOpName.SLICE,
|
93
|
+
_TFLOpName.SUM,
|
94
|
+
_TFLOpName.SELECT_V2,
|
95
|
+
),
|
96
|
+
(
|
97
|
+
naive_min_max_quantize.materialize_input,
|
98
|
+
naive_min_max_quantize.materialize_output,
|
99
|
+
naive_min_max_quantize.materialize_fc_conv,
|
100
|
+
naive_min_max_quantize.materialize_batch_matmul,
|
101
|
+
naive_min_max_quantize.materialize_fc_conv,
|
102
|
+
naive_min_max_quantize.materialize_fc_conv,
|
103
|
+
naive_min_max_quantize.materialize_conv2d_transpose,
|
104
|
+
naive_min_max_quantize.materialize_reshape,
|
105
|
+
naive_min_max_quantize.materialize_average_pool_2d,
|
106
|
+
naive_min_max_quantize.materialize_embedding_lookup,
|
107
|
+
naive_min_max_quantize.materialize_softmax_and_logistic,
|
108
|
+
naive_min_max_quantize.materialize_tanh,
|
109
|
+
naive_min_max_quantize.materialize_transpose,
|
110
|
+
naive_min_max_quantize.materialize_gelu,
|
111
|
+
naive_min_max_quantize.materialize_add,
|
112
|
+
naive_min_max_quantize.materialize_sub,
|
113
|
+
naive_min_max_quantize.materialize_mul,
|
114
|
+
naive_min_max_quantize.materialize_mean,
|
115
|
+
naive_min_max_quantize.materialize_rsqrt,
|
116
|
+
naive_min_max_quantize.materialize_concatenation,
|
117
|
+
naive_min_max_quantize.materialize_strided_slice,
|
118
|
+
naive_min_max_quantize.materialize_split,
|
119
|
+
naive_min_max_quantize.materialize_softmax_and_logistic,
|
120
|
+
naive_min_max_quantize.materialize_slice,
|
121
|
+
naive_min_max_quantize.materialize_sum,
|
122
|
+
naive_min_max_quantize.materialize_select_v2,
|
123
|
+
),
|
124
|
+
):
|
125
|
+
register_quantized_op(
|
126
|
+
AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
127
|
+
op_name,
|
128
|
+
naive_min_max_quantize.init_qsvs,
|
129
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
130
|
+
materialize_func=materialize_func,
|
131
|
+
)
|
132
|
+
|
133
|
+
# Register FLOAT_CASTING algorithm.
|
134
|
+
register_op_quant_config_validation_func(
|
135
|
+
AlgorithmName.FLOAT_CASTING,
|
136
|
+
float_casting.check_op_quantization_config,
|
137
|
+
)
|
138
|
+
|
139
|
+
# Register a config check policy for FLOAT_CASTING algorithm.
|
140
|
+
# TODO: b/353780772 - Replace an empty policy for FLOAT_CASTING algorithm.
|
141
|
+
register_config_check_policy_func(
|
142
|
+
AlgorithmName.FLOAT_CASTING, qtyping.ConfigCheckPolicyDict()
|
143
|
+
)
|
144
|
+
|
145
|
+
for op_name, materialize_func in zip(
|
146
|
+
(
|
147
|
+
_TFLOpName.FULLY_CONNECTED,
|
148
|
+
_TFLOpName.CONV_2D,
|
149
|
+
_TFLOpName.DEPTHWISE_CONV_2D,
|
150
|
+
_TFLOpName.CONV_2D_TRANSPOSE,
|
151
|
+
_TFLOpName.EMBEDDING_LOOKUP,
|
152
|
+
),
|
153
|
+
(
|
154
|
+
float_casting.materialize_fc_conv,
|
155
|
+
float_casting.materialize_fc_conv,
|
156
|
+
float_casting.materialize_fc_conv,
|
157
|
+
float_casting.materialize_conv2d_transpose,
|
158
|
+
float_casting.materialize_embedding_lookup,
|
159
|
+
),
|
160
|
+
):
|
161
|
+
register_quantized_op(
|
162
|
+
AlgorithmName.FLOAT_CASTING,
|
163
|
+
op_name,
|
164
|
+
float_casting.init_qsvs,
|
165
|
+
calibration_func=float_casting.calibrate,
|
166
|
+
materialize_func=materialize_func,
|
167
|
+
)
|
@@ -0,0 +1,271 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""The Python API for Algorithm Manager of Quantizer."""
|
17
|
+
|
18
|
+
from collections.abc import Callable
|
19
|
+
import dataclasses
|
20
|
+
import functools
|
21
|
+
from typing import Any, Optional
|
22
|
+
from ai_edge_quantizer import qtyping
|
23
|
+
|
24
|
+
|
25
|
+
@dataclasses.dataclass
|
26
|
+
class QuantizedOperationInfo:
|
27
|
+
"""Stores all quantization functions for a given op."""
|
28
|
+
|
29
|
+
tfl_op_key: qtyping.TFLOperationName
|
30
|
+
init_qsv_func: Callable[..., Any]
|
31
|
+
calibration_func: Callable[..., Any]
|
32
|
+
materialize_func: Callable[..., Any]
|
33
|
+
|
34
|
+
|
35
|
+
@dataclasses.dataclass
|
36
|
+
class QuantizationAlgorithmInfo:
|
37
|
+
quantization_algorithm: str
|
38
|
+
quantized_ops: dict[qtyping.TFLOperationName, QuantizedOperationInfo]
|
39
|
+
|
40
|
+
|
41
|
+
class AlgorithmManagerApi:
|
42
|
+
"""Quantizer API client to manage quantization configs and functions."""
|
43
|
+
|
44
|
+
def __init__(self):
|
45
|
+
self._algorithm_registry = dict()
|
46
|
+
# Check if an op quantization config is supported for a given algorithm.
|
47
|
+
self._config_check_registry = dict()
|
48
|
+
# Policy to check if an op quantization config is supported for a given
|
49
|
+
# algorithm.
|
50
|
+
self._config_check_policy_registry = dict()
|
51
|
+
|
52
|
+
def register_op_quant_config_validation_func(
|
53
|
+
self,
|
54
|
+
algorithm_key: str,
|
55
|
+
config_check_func: Callable[..., Any],
|
56
|
+
):
|
57
|
+
"""Register functions to check if an op quantization config is supported."""
|
58
|
+
self._config_check_registry[algorithm_key] = config_check_func
|
59
|
+
|
60
|
+
def register_quantized_op(
|
61
|
+
self,
|
62
|
+
algorithm_key: str,
|
63
|
+
tfl_op_name: qtyping.TFLOperationName,
|
64
|
+
init_qsv_func: Callable[..., Any],
|
65
|
+
calibration_func: Callable[..., Any],
|
66
|
+
materialize_func: Callable[..., Any],
|
67
|
+
):
|
68
|
+
"""Register functions to support a quantization operation.
|
69
|
+
|
70
|
+
This function registers the relevant information to support the quantized
|
71
|
+
version of given tfl_operation, for the algorithm specified by
|
72
|
+
quantization_algorithm.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
algorithm_key: Quantization algorithm keyword for which the quantized
|
76
|
+
operation is for.
|
77
|
+
tfl_op_name: TFLite op name.
|
78
|
+
init_qsv_func: QSV init function to be called.
|
79
|
+
calibration_func: Quantized operation to be called during calibration.
|
80
|
+
materialize_func: Quantized operation to be called during materialization.
|
81
|
+
"""
|
82
|
+
quantized_algorithm_info = self._algorithm_registry.setdefault(
|
83
|
+
algorithm_key, QuantizationAlgorithmInfo(algorithm_key, dict())
|
84
|
+
)
|
85
|
+
|
86
|
+
quantized_algorithm_info.quantized_ops[tfl_op_name] = (
|
87
|
+
QuantizedOperationInfo(
|
88
|
+
tfl_op_name,
|
89
|
+
init_qsv_func,
|
90
|
+
calibration_func,
|
91
|
+
materialize_func,
|
92
|
+
)
|
93
|
+
)
|
94
|
+
|
95
|
+
def is_op_registered(
|
96
|
+
self,
|
97
|
+
quantization_algorithm: str,
|
98
|
+
tfl_op_name: qtyping.TFLOperationName,
|
99
|
+
) -> bool:
|
100
|
+
"""Check if the given key for quantization is valid.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
quantization_algorithm: Target quantization algorithm.
|
104
|
+
tfl_op_name: TFL operation name.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
True if the given op is registered for the given algorithm, false
|
108
|
+
otherwise.
|
109
|
+
"""
|
110
|
+
if not self.is_algorithm_registered(quantization_algorithm):
|
111
|
+
return False
|
112
|
+
|
113
|
+
return (
|
114
|
+
tfl_op_name
|
115
|
+
in self._algorithm_registry[quantization_algorithm].quantized_ops
|
116
|
+
)
|
117
|
+
|
118
|
+
def is_algorithm_registered(self, quantization_algorithm: str) -> bool:
|
119
|
+
"""Check if the given algorithm is registered.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
quantization_algorithm: Target quantization algorithm.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
True if the given algorithm is registered, false otherwise.
|
126
|
+
"""
|
127
|
+
return quantization_algorithm in self._algorithm_registry
|
128
|
+
|
129
|
+
def check_op_quantization_config(
|
130
|
+
self,
|
131
|
+
quantization_algorithm: str,
|
132
|
+
tfl_op_name: qtyping.TFLOperationName,
|
133
|
+
op_quantization_config: qtyping.OpQuantizationConfig,
|
134
|
+
) -> None:
|
135
|
+
"""Checks if the given op quantization config is valid.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
quantization_algorithm: Target quantization algorithm.
|
139
|
+
tfl_op_name: TFL operation name.
|
140
|
+
op_quantization_config: Op quantization config to be checked.
|
141
|
+
|
142
|
+
Raises:
|
143
|
+
ValueError if the given op is not registered for the given algorithm, or
|
144
|
+
the given algorithm is not registered.
|
145
|
+
"""
|
146
|
+
if op_quantization_config.skip_checks:
|
147
|
+
return
|
148
|
+
if not self.is_op_registered(quantization_algorithm, tfl_op_name):
|
149
|
+
raise ValueError(
|
150
|
+
f"Unsupported operation {tfl_op_name} for Algorithm:"
|
151
|
+
f" {quantization_algorithm}."
|
152
|
+
)
|
153
|
+
if quantization_algorithm not in self._config_check_registry:
|
154
|
+
raise ValueError(
|
155
|
+
f"Config checking function for algorithm {quantization_algorithm} is"
|
156
|
+
" not registered. Please use"
|
157
|
+
" `register_op_quant_config_validation_func` to register the"
|
158
|
+
" validation function."
|
159
|
+
)
|
160
|
+
self._config_check_registry[quantization_algorithm](
|
161
|
+
tfl_op_name,
|
162
|
+
op_quantization_config,
|
163
|
+
self._config_check_policy_registry[quantization_algorithm],
|
164
|
+
)
|
165
|
+
|
166
|
+
def get_supported_ops(self, alg_key: str) -> list[qtyping.TFLOperationName]:
|
167
|
+
"""Returns the list of supported ops for the given algorithm.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
alg_key: Algorithm key.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
The list of supported JAX operations.
|
174
|
+
|
175
|
+
Raises:
|
176
|
+
ValueError if the alg_key is not registered.
|
177
|
+
"""
|
178
|
+
if alg_key not in self._algorithm_registry:
|
179
|
+
raise ValueError(f"Unregistered algorithm: {alg_key}")
|
180
|
+
|
181
|
+
return list(self._algorithm_registry[alg_key].quantized_ops.keys())
|
182
|
+
|
183
|
+
def get_quantization_func(
|
184
|
+
self,
|
185
|
+
algorithm_key: str,
|
186
|
+
tfl_op_name: qtyping.TFLOperationName,
|
187
|
+
quantize_mode: qtyping.QuantizeMode,
|
188
|
+
) -> Callable[..., Any]:
|
189
|
+
"""Gets the quantization function.
|
190
|
+
|
191
|
+
Args:
|
192
|
+
algorithm_key: Target quantization algorithm key (e.g.,
|
193
|
+
AlgorithmName.MIN_MAX_UNIFORM_QUANT).
|
194
|
+
tfl_op_name: TFLite op name.
|
195
|
+
quantize_mode: Quantization mode to be used.
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
A quantized operation (function) corresponds to the requested algorithm
|
199
|
+
for the TFL op.
|
200
|
+
"""
|
201
|
+
if not self.is_op_registered(algorithm_key, tfl_op_name):
|
202
|
+
raise ValueError(
|
203
|
+
f"Unsupported operation {tfl_op_name} for Algorithm: {algorithm_key}."
|
204
|
+
f" Supported ops for algorithm {algorithm_key}:"
|
205
|
+
f" {self.get_supported_ops(algorithm_key)}"
|
206
|
+
)
|
207
|
+
|
208
|
+
quantized_algorithm_info = self._algorithm_registry[algorithm_key]
|
209
|
+
quantized_op_info = quantized_algorithm_info.quantized_ops
|
210
|
+
quantized_func = self._get_target_func(
|
211
|
+
quantized_op_info, tfl_op_name, quantize_mode
|
212
|
+
)
|
213
|
+
if quantized_func is None:
|
214
|
+
raise ValueError(
|
215
|
+
"Cannot retrieve appropriate quantization function for"
|
216
|
+
f" {tfl_op_name} for algorithm {algorithm_key} under quantization"
|
217
|
+
f" mode {quantize_mode}. Check if the op is registed in"
|
218
|
+
" algorithm_manager."
|
219
|
+
)
|
220
|
+
|
221
|
+
return quantized_func
|
222
|
+
|
223
|
+
def get_init_qsv_func(
|
224
|
+
self,
|
225
|
+
algorithm_key: str,
|
226
|
+
tfl_op_name: qtyping.TFLOperationName,
|
227
|
+
) -> functools.partial:
|
228
|
+
"""Gets the initial Quantization Statistics Variable function for a given op.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
algorithm_key: Quantization algorithm to search.
|
232
|
+
tfl_op_name: Target TFL operation.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
A function for qsv initialization.
|
236
|
+
"""
|
237
|
+
|
238
|
+
if not self.is_op_registered(algorithm_key, tfl_op_name):
|
239
|
+
raise ValueError(
|
240
|
+
f"Unsupported operation {tfl_op_name} for Algorithm: {algorithm_key}."
|
241
|
+
f" Supported ops for algorithm {algorithm_key}:"
|
242
|
+
f" {self.get_supported_ops(algorithm_key)}"
|
243
|
+
)
|
244
|
+
quantized_algorithm_info = self._algorithm_registry[algorithm_key]
|
245
|
+
quantized_op_info = quantized_algorithm_info.quantized_ops
|
246
|
+
|
247
|
+
return quantized_op_info[tfl_op_name].init_qsv_func
|
248
|
+
|
249
|
+
def _get_target_func(
|
250
|
+
self,
|
251
|
+
quantized_op_info,
|
252
|
+
tfl_op_name: qtyping.TFLOperationName,
|
253
|
+
quantize_mode: qtyping.QuantizeMode,
|
254
|
+
):
|
255
|
+
"""Gets the function corresponding to the given JAX quantization phase and op."""
|
256
|
+
if quantize_mode == qtyping.QuantizeMode.CALIBRATE:
|
257
|
+
return quantized_op_info[tfl_op_name].calibration_func
|
258
|
+
elif quantize_mode == qtyping.QuantizeMode.MATERIALIZE:
|
259
|
+
return quantized_op_info[tfl_op_name].materialize_func
|
260
|
+
return None
|
261
|
+
|
262
|
+
# TODO: b/53780772 - Merge this function with
|
263
|
+
# register_op_quant_config_validation_func after full transition to new check
|
264
|
+
# mechanism.
|
265
|
+
def register_config_check_policy(
|
266
|
+
self,
|
267
|
+
algorithm_key: str,
|
268
|
+
config_check_policy: Optional[qtyping.ConfigCheckPolicyDict],
|
269
|
+
):
|
270
|
+
"""Registers a policy to check the op quantization config."""
|
271
|
+
self._config_check_policy_registry[algorithm_key] = config_check_policy
|
@@ -0,0 +1,210 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Tests for algorithm_manager_api."""
|
17
|
+
|
18
|
+
from absl.testing import parameterized
|
19
|
+
from tensorflow.python.platform import googletest
|
20
|
+
from ai_edge_quantizer import algorithm_manager_api
|
21
|
+
from ai_edge_quantizer import qtyping
|
22
|
+
|
23
|
+
_TFLOpName = qtyping.TFLOperationName
|
24
|
+
|
25
|
+
|
26
|
+
# Sample functions for test cases.
|
27
|
+
def _sample_init_qsvs(*_, **__):
|
28
|
+
return 1.0, dict()
|
29
|
+
|
30
|
+
|
31
|
+
def _sample_calibration_func(*_, **__):
|
32
|
+
return 2.0, dict()
|
33
|
+
|
34
|
+
|
35
|
+
def _sample_materialize_func(*_, **__):
|
36
|
+
return 3.0, dict()
|
37
|
+
|
38
|
+
|
39
|
+
def _sample_check_op_config_func(_, op_config):
|
40
|
+
if op_config.weight_tensor_config.num_bits == 17:
|
41
|
+
raise ValueError("Unsupported number of bits.")
|
42
|
+
|
43
|
+
|
44
|
+
class AlgorithmManagerApiTest(parameterized.TestCase):
|
45
|
+
|
46
|
+
def setUp(self):
|
47
|
+
super().setUp()
|
48
|
+
self._alg_manager = algorithm_manager_api.AlgorithmManagerApi()
|
49
|
+
|
50
|
+
def test_register_op_quant_config_validation_func_succeeds(self):
|
51
|
+
self.assertEmpty(self._alg_manager._config_check_registry)
|
52
|
+
test_algorithm_name = "test_algorithm"
|
53
|
+
self._alg_manager.register_op_quant_config_validation_func(
|
54
|
+
test_algorithm_name, _sample_check_op_config_func
|
55
|
+
)
|
56
|
+
self.assertIn(test_algorithm_name, self._alg_manager._config_check_registry)
|
57
|
+
check_func = self._alg_manager._config_check_registry[test_algorithm_name]
|
58
|
+
self.assertEqual(check_func, _sample_check_op_config_func)
|
59
|
+
|
60
|
+
def test_register_quantized_op(self):
|
61
|
+
self._alg_manager.register_quantized_op(
|
62
|
+
algorithm_key="ptq",
|
63
|
+
tfl_op_name=_TFLOpName.FULLY_CONNECTED,
|
64
|
+
init_qsv_func=_sample_init_qsvs,
|
65
|
+
calibration_func=_sample_calibration_func,
|
66
|
+
materialize_func=_sample_materialize_func,
|
67
|
+
)
|
68
|
+
self._alg_manager.register_quantized_op(
|
69
|
+
algorithm_key="gptq",
|
70
|
+
tfl_op_name=_TFLOpName.CONV_2D,
|
71
|
+
init_qsv_func=_sample_init_qsvs,
|
72
|
+
calibration_func=_sample_calibration_func,
|
73
|
+
materialize_func=_sample_materialize_func,
|
74
|
+
)
|
75
|
+
self.assertTrue(self._alg_manager.is_algorithm_registered("ptq"))
|
76
|
+
self.assertTrue(self._alg_manager.is_algorithm_registered("gptq"))
|
77
|
+
self.assertTrue(
|
78
|
+
self._alg_manager.is_op_registered("ptq", _TFLOpName.FULLY_CONNECTED)
|
79
|
+
)
|
80
|
+
self.assertTrue(
|
81
|
+
self._alg_manager.is_op_registered("gptq", _TFLOpName.CONV_2D)
|
82
|
+
)
|
83
|
+
self.assertFalse(
|
84
|
+
self._alg_manager.is_op_registered("gptq", _TFLOpName.DEPTHWISE_CONV_2D)
|
85
|
+
)
|
86
|
+
|
87
|
+
def test_get_supported_ops(self):
|
88
|
+
algorithm_key = "ptq"
|
89
|
+
self._alg_manager.register_quantized_op(
|
90
|
+
algorithm_key=algorithm_key,
|
91
|
+
tfl_op_name=_TFLOpName.FULLY_CONNECTED,
|
92
|
+
init_qsv_func=_sample_init_qsvs,
|
93
|
+
calibration_func=_sample_calibration_func,
|
94
|
+
materialize_func=_sample_materialize_func,
|
95
|
+
)
|
96
|
+
self._alg_manager.register_quantized_op(
|
97
|
+
algorithm_key=algorithm_key,
|
98
|
+
tfl_op_name=_TFLOpName.CONV_2D,
|
99
|
+
init_qsv_func=_sample_init_qsvs,
|
100
|
+
calibration_func=_sample_calibration_func,
|
101
|
+
materialize_func=_sample_materialize_func,
|
102
|
+
)
|
103
|
+
supported_ops = self._alg_manager.get_supported_ops(algorithm_key)
|
104
|
+
self.assertIn(_TFLOpName.CONV_2D, supported_ops)
|
105
|
+
self.assertIn(_TFLOpName.FULLY_CONNECTED, supported_ops)
|
106
|
+
self.assertNotIn(_TFLOpName.DEPTHWISE_CONV_2D, supported_ops)
|
107
|
+
|
108
|
+
def test_get_quantization_func(self):
|
109
|
+
algorithm_key = "ptq"
|
110
|
+
tfl_op = _TFLOpName.FULLY_CONNECTED
|
111
|
+
self._alg_manager.register_quantized_op(
|
112
|
+
algorithm_key=algorithm_key,
|
113
|
+
tfl_op_name=tfl_op,
|
114
|
+
init_qsv_func=_sample_init_qsvs,
|
115
|
+
calibration_func=_sample_calibration_func,
|
116
|
+
materialize_func=_sample_materialize_func,
|
117
|
+
)
|
118
|
+
materialize_func = self._alg_manager.get_quantization_func(
|
119
|
+
algorithm_key,
|
120
|
+
tfl_op,
|
121
|
+
qtyping.QuantizeMode.MATERIALIZE,
|
122
|
+
)
|
123
|
+
self.assertEqual(_sample_materialize_func()[0], materialize_func()[0])
|
124
|
+
calibration_func = self._alg_manager.get_quantization_func(
|
125
|
+
algorithm_key,
|
126
|
+
tfl_op,
|
127
|
+
qtyping.QuantizeMode.CALIBRATE,
|
128
|
+
)
|
129
|
+
self.assertEqual(_sample_calibration_func()[0], calibration_func()[0])
|
130
|
+
|
131
|
+
# Query for unsupported operation.
|
132
|
+
error_message = "Unsupported operation"
|
133
|
+
with self.assertRaisesWithPredicateMatch(
|
134
|
+
ValueError, lambda err: error_message in str(err)
|
135
|
+
):
|
136
|
+
self._alg_manager.get_quantization_func(
|
137
|
+
algorithm_key,
|
138
|
+
_TFLOpName.BATCH_MATMUL,
|
139
|
+
qtyping.QuantizeMode.MATERIALIZE,
|
140
|
+
)
|
141
|
+
|
142
|
+
# Query for unregisted algorithm.
|
143
|
+
error_message = "Unregistered algorithm"
|
144
|
+
with self.assertRaisesWithPredicateMatch(
|
145
|
+
ValueError, lambda err: error_message in str(err)
|
146
|
+
):
|
147
|
+
self._alg_manager.get_quantization_func(
|
148
|
+
"gptq",
|
149
|
+
tfl_op,
|
150
|
+
qtyping.QuantizeMode.MATERIALIZE,
|
151
|
+
)
|
152
|
+
|
153
|
+
def test_get_init_qsv_func(self):
|
154
|
+
algorithm_key = "ptq"
|
155
|
+
tfl_op = _TFLOpName.FULLY_CONNECTED
|
156
|
+
self._alg_manager.register_quantized_op(
|
157
|
+
algorithm_key=algorithm_key,
|
158
|
+
tfl_op_name=tfl_op,
|
159
|
+
init_qsv_func=_sample_init_qsvs,
|
160
|
+
calibration_func=_sample_calibration_func,
|
161
|
+
materialize_func=_sample_materialize_func,
|
162
|
+
)
|
163
|
+
init_qsv_func = self._alg_manager.get_init_qsv_func(algorithm_key, tfl_op)
|
164
|
+
self.assertEqual(_sample_init_qsvs()[0], init_qsv_func()[0])
|
165
|
+
|
166
|
+
# Query for unsupported operation.
|
167
|
+
error_message = "Unsupported operation"
|
168
|
+
with self.assertRaisesWithPredicateMatch(
|
169
|
+
ValueError, lambda err: error_message in str(err)
|
170
|
+
):
|
171
|
+
self._alg_manager.get_init_qsv_func(
|
172
|
+
algorithm_key,
|
173
|
+
_TFLOpName.BATCH_MATMUL,
|
174
|
+
)
|
175
|
+
|
176
|
+
# Query for unregisted algorithm.
|
177
|
+
error_message = "Unregistered algorithm"
|
178
|
+
with self.assertRaisesWithPredicateMatch(
|
179
|
+
ValueError, lambda err: error_message in str(err)
|
180
|
+
):
|
181
|
+
self._alg_manager.get_init_qsv_func(
|
182
|
+
"gptq",
|
183
|
+
tfl_op,
|
184
|
+
)
|
185
|
+
|
186
|
+
def test_register_config_check_policy_succeeds(self):
|
187
|
+
self.assertEmpty(self._alg_manager._config_check_policy_registry)
|
188
|
+
test_algorithm_name = "test_algorithm"
|
189
|
+
test_config_check_policy = qtyping.ConfigCheckPolicyDict({
|
190
|
+
_TFLOpName.FULLY_CONNECTED: {
|
191
|
+
qtyping.OpQuantizationConfig(
|
192
|
+
weight_tensor_config=qtyping.TensorQuantizationConfig(
|
193
|
+
num_bits=1
|
194
|
+
)
|
195
|
+
)
|
196
|
+
}
|
197
|
+
})
|
198
|
+
self._alg_manager.register_config_check_policy(
|
199
|
+
test_algorithm_name, test_config_check_policy
|
200
|
+
)
|
201
|
+
self.assertIn(
|
202
|
+
test_algorithm_name, self._alg_manager._config_check_policy_registry
|
203
|
+
)
|
204
|
+
self.assertIsNotNone(
|
205
|
+
self._alg_manager._config_check_policy_registry[test_algorithm_name]
|
206
|
+
)
|
207
|
+
|
208
|
+
|
209
|
+
if __name__ == "__main__":
|
210
|
+
googletest.main()
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|