ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,19 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Init file for the AI Edge quantizer package."""
17
+
18
+ # pylint: disable=unused-import
19
+ from .quantizer import Quantizer
@@ -0,0 +1,167 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Quantizer Algorithm Manager Interface."""
17
+
18
+ import enum
19
+ from ai_edge_quantizer import algorithm_manager_api
20
+ from ai_edge_quantizer import default_policy
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
23
+ from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
24
+
25
+ _TFLOpName = qtyping.TFLOperationName
26
+
27
+ _alg_manager_instance = algorithm_manager_api.AlgorithmManagerApi()
28
+
29
+ # Expose instance functions.
30
+ get_quantization_func = _alg_manager_instance.get_quantization_func
31
+ get_supported_ops = _alg_manager_instance.get_supported_ops
32
+ get_init_qsv_func = _alg_manager_instance.get_init_qsv_func
33
+ register_op_quant_config_validation_func = (
34
+ _alg_manager_instance.register_op_quant_config_validation_func
35
+ )
36
+ register_config_check_policy_func = (
37
+ _alg_manager_instance.register_config_check_policy
38
+ )
39
+ register_quantized_op = _alg_manager_instance.register_quantized_op
40
+ is_op_registered = _alg_manager_instance.is_op_registered
41
+ is_algorithm_registered = _alg_manager_instance.is_algorithm_registered
42
+ check_op_quantization_config = (
43
+ _alg_manager_instance.check_op_quantization_config
44
+ )
45
+
46
+
47
+ # Quantization algorithms.
48
+ class AlgorithmName(str, enum.Enum):
49
+ NO_QUANTIZE = "no_quantize"
50
+ MIN_MAX_UNIFORM_QUANT = naive_min_max_quantize.ALGORITHM_KEY
51
+ FLOAT_CASTING = float_casting.ALGORITHM_KEY
52
+
53
+
54
+ # Register MIN_MAX_UNIFORM_QUANT algorithm.
55
+ register_op_quant_config_validation_func(
56
+ AlgorithmName.MIN_MAX_UNIFORM_QUANT,
57
+ naive_min_max_quantize.check_op_quantization_config,
58
+ )
59
+
60
+ # Register a config check policy for MIN_MAX_UNIFORM_QUANT algorithm.
61
+ register_config_check_policy_func(
62
+ AlgorithmName.MIN_MAX_UNIFORM_QUANT,
63
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
64
+ )
65
+
66
+
67
+ for op_name, materialize_func in zip(
68
+ (
69
+ _TFLOpName.INPUT,
70
+ _TFLOpName.OUTPUT,
71
+ _TFLOpName.FULLY_CONNECTED,
72
+ _TFLOpName.BATCH_MATMUL,
73
+ _TFLOpName.CONV_2D,
74
+ _TFLOpName.DEPTHWISE_CONV_2D,
75
+ _TFLOpName.CONV_2D_TRANSPOSE,
76
+ _TFLOpName.RESHAPE,
77
+ _TFLOpName.AVERAGE_POOL_2D,
78
+ _TFLOpName.EMBEDDING_LOOKUP,
79
+ _TFLOpName.SOFTMAX,
80
+ _TFLOpName.TANH,
81
+ _TFLOpName.TRANSPOSE,
82
+ _TFLOpName.GELU,
83
+ _TFLOpName.ADD,
84
+ _TFLOpName.SUB,
85
+ _TFLOpName.MUL,
86
+ _TFLOpName.MEAN,
87
+ _TFLOpName.RSQRT,
88
+ _TFLOpName.CONCATENATION,
89
+ _TFLOpName.STRIDED_SLICE,
90
+ _TFLOpName.SPLIT,
91
+ _TFLOpName.LOGISTIC, # Sigmoid
92
+ _TFLOpName.SLICE,
93
+ _TFLOpName.SUM,
94
+ _TFLOpName.SELECT_V2,
95
+ ),
96
+ (
97
+ naive_min_max_quantize.materialize_input,
98
+ naive_min_max_quantize.materialize_output,
99
+ naive_min_max_quantize.materialize_fc_conv,
100
+ naive_min_max_quantize.materialize_batch_matmul,
101
+ naive_min_max_quantize.materialize_fc_conv,
102
+ naive_min_max_quantize.materialize_fc_conv,
103
+ naive_min_max_quantize.materialize_conv2d_transpose,
104
+ naive_min_max_quantize.materialize_reshape,
105
+ naive_min_max_quantize.materialize_average_pool_2d,
106
+ naive_min_max_quantize.materialize_embedding_lookup,
107
+ naive_min_max_quantize.materialize_softmax_and_logistic,
108
+ naive_min_max_quantize.materialize_tanh,
109
+ naive_min_max_quantize.materialize_transpose,
110
+ naive_min_max_quantize.materialize_gelu,
111
+ naive_min_max_quantize.materialize_add,
112
+ naive_min_max_quantize.materialize_sub,
113
+ naive_min_max_quantize.materialize_mul,
114
+ naive_min_max_quantize.materialize_mean,
115
+ naive_min_max_quantize.materialize_rsqrt,
116
+ naive_min_max_quantize.materialize_concatenation,
117
+ naive_min_max_quantize.materialize_strided_slice,
118
+ naive_min_max_quantize.materialize_split,
119
+ naive_min_max_quantize.materialize_softmax_and_logistic,
120
+ naive_min_max_quantize.materialize_slice,
121
+ naive_min_max_quantize.materialize_sum,
122
+ naive_min_max_quantize.materialize_select_v2,
123
+ ),
124
+ ):
125
+ register_quantized_op(
126
+ AlgorithmName.MIN_MAX_UNIFORM_QUANT,
127
+ op_name,
128
+ naive_min_max_quantize.init_qsvs,
129
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
130
+ materialize_func=materialize_func,
131
+ )
132
+
133
+ # Register FLOAT_CASTING algorithm.
134
+ register_op_quant_config_validation_func(
135
+ AlgorithmName.FLOAT_CASTING,
136
+ float_casting.check_op_quantization_config,
137
+ )
138
+
139
+ # Register a config check policy for FLOAT_CASTING algorithm.
140
+ # TODO: b/353780772 - Replace an empty policy for FLOAT_CASTING algorithm.
141
+ register_config_check_policy_func(
142
+ AlgorithmName.FLOAT_CASTING, qtyping.ConfigCheckPolicyDict()
143
+ )
144
+
145
+ for op_name, materialize_func in zip(
146
+ (
147
+ _TFLOpName.FULLY_CONNECTED,
148
+ _TFLOpName.CONV_2D,
149
+ _TFLOpName.DEPTHWISE_CONV_2D,
150
+ _TFLOpName.CONV_2D_TRANSPOSE,
151
+ _TFLOpName.EMBEDDING_LOOKUP,
152
+ ),
153
+ (
154
+ float_casting.materialize_fc_conv,
155
+ float_casting.materialize_fc_conv,
156
+ float_casting.materialize_fc_conv,
157
+ float_casting.materialize_conv2d_transpose,
158
+ float_casting.materialize_embedding_lookup,
159
+ ),
160
+ ):
161
+ register_quantized_op(
162
+ AlgorithmName.FLOAT_CASTING,
163
+ op_name,
164
+ float_casting.init_qsvs,
165
+ calibration_func=float_casting.calibrate,
166
+ materialize_func=materialize_func,
167
+ )
@@ -0,0 +1,271 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """The Python API for Algorithm Manager of Quantizer."""
17
+
18
+ from collections.abc import Callable
19
+ import dataclasses
20
+ import functools
21
+ from typing import Any, Optional
22
+ from ai_edge_quantizer import qtyping
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class QuantizedOperationInfo:
27
+ """Stores all quantization functions for a given op."""
28
+
29
+ tfl_op_key: qtyping.TFLOperationName
30
+ init_qsv_func: Callable[..., Any]
31
+ calibration_func: Callable[..., Any]
32
+ materialize_func: Callable[..., Any]
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class QuantizationAlgorithmInfo:
37
+ quantization_algorithm: str
38
+ quantized_ops: dict[qtyping.TFLOperationName, QuantizedOperationInfo]
39
+
40
+
41
+ class AlgorithmManagerApi:
42
+ """Quantizer API client to manage quantization configs and functions."""
43
+
44
+ def __init__(self):
45
+ self._algorithm_registry = dict()
46
+ # Check if an op quantization config is supported for a given algorithm.
47
+ self._config_check_registry = dict()
48
+ # Policy to check if an op quantization config is supported for a given
49
+ # algorithm.
50
+ self._config_check_policy_registry = dict()
51
+
52
+ def register_op_quant_config_validation_func(
53
+ self,
54
+ algorithm_key: str,
55
+ config_check_func: Callable[..., Any],
56
+ ):
57
+ """Register functions to check if an op quantization config is supported."""
58
+ self._config_check_registry[algorithm_key] = config_check_func
59
+
60
+ def register_quantized_op(
61
+ self,
62
+ algorithm_key: str,
63
+ tfl_op_name: qtyping.TFLOperationName,
64
+ init_qsv_func: Callable[..., Any],
65
+ calibration_func: Callable[..., Any],
66
+ materialize_func: Callable[..., Any],
67
+ ):
68
+ """Register functions to support a quantization operation.
69
+
70
+ This function registers the relevant information to support the quantized
71
+ version of given tfl_operation, for the algorithm specified by
72
+ quantization_algorithm.
73
+
74
+ Args:
75
+ algorithm_key: Quantization algorithm keyword for which the quantized
76
+ operation is for.
77
+ tfl_op_name: TFLite op name.
78
+ init_qsv_func: QSV init function to be called.
79
+ calibration_func: Quantized operation to be called during calibration.
80
+ materialize_func: Quantized operation to be called during materialization.
81
+ """
82
+ quantized_algorithm_info = self._algorithm_registry.setdefault(
83
+ algorithm_key, QuantizationAlgorithmInfo(algorithm_key, dict())
84
+ )
85
+
86
+ quantized_algorithm_info.quantized_ops[tfl_op_name] = (
87
+ QuantizedOperationInfo(
88
+ tfl_op_name,
89
+ init_qsv_func,
90
+ calibration_func,
91
+ materialize_func,
92
+ )
93
+ )
94
+
95
+ def is_op_registered(
96
+ self,
97
+ quantization_algorithm: str,
98
+ tfl_op_name: qtyping.TFLOperationName,
99
+ ) -> bool:
100
+ """Check if the given key for quantization is valid.
101
+
102
+ Args:
103
+ quantization_algorithm: Target quantization algorithm.
104
+ tfl_op_name: TFL operation name.
105
+
106
+ Returns:
107
+ True if the given op is registered for the given algorithm, false
108
+ otherwise.
109
+ """
110
+ if not self.is_algorithm_registered(quantization_algorithm):
111
+ return False
112
+
113
+ return (
114
+ tfl_op_name
115
+ in self._algorithm_registry[quantization_algorithm].quantized_ops
116
+ )
117
+
118
+ def is_algorithm_registered(self, quantization_algorithm: str) -> bool:
119
+ """Check if the given algorithm is registered.
120
+
121
+ Args:
122
+ quantization_algorithm: Target quantization algorithm.
123
+
124
+ Returns:
125
+ True if the given algorithm is registered, false otherwise.
126
+ """
127
+ return quantization_algorithm in self._algorithm_registry
128
+
129
+ def check_op_quantization_config(
130
+ self,
131
+ quantization_algorithm: str,
132
+ tfl_op_name: qtyping.TFLOperationName,
133
+ op_quantization_config: qtyping.OpQuantizationConfig,
134
+ ) -> None:
135
+ """Checks if the given op quantization config is valid.
136
+
137
+ Args:
138
+ quantization_algorithm: Target quantization algorithm.
139
+ tfl_op_name: TFL operation name.
140
+ op_quantization_config: Op quantization config to be checked.
141
+
142
+ Raises:
143
+ ValueError if the given op is not registered for the given algorithm, or
144
+ the given algorithm is not registered.
145
+ """
146
+ if op_quantization_config.skip_checks:
147
+ return
148
+ if not self.is_op_registered(quantization_algorithm, tfl_op_name):
149
+ raise ValueError(
150
+ f"Unsupported operation {tfl_op_name} for Algorithm:"
151
+ f" {quantization_algorithm}."
152
+ )
153
+ if quantization_algorithm not in self._config_check_registry:
154
+ raise ValueError(
155
+ f"Config checking function for algorithm {quantization_algorithm} is"
156
+ " not registered. Please use"
157
+ " `register_op_quant_config_validation_func` to register the"
158
+ " validation function."
159
+ )
160
+ self._config_check_registry[quantization_algorithm](
161
+ tfl_op_name,
162
+ op_quantization_config,
163
+ self._config_check_policy_registry[quantization_algorithm],
164
+ )
165
+
166
+ def get_supported_ops(self, alg_key: str) -> list[qtyping.TFLOperationName]:
167
+ """Returns the list of supported ops for the given algorithm.
168
+
169
+ Args:
170
+ alg_key: Algorithm key.
171
+
172
+ Returns:
173
+ The list of supported JAX operations.
174
+
175
+ Raises:
176
+ ValueError if the alg_key is not registered.
177
+ """
178
+ if alg_key not in self._algorithm_registry:
179
+ raise ValueError(f"Unregistered algorithm: {alg_key}")
180
+
181
+ return list(self._algorithm_registry[alg_key].quantized_ops.keys())
182
+
183
+ def get_quantization_func(
184
+ self,
185
+ algorithm_key: str,
186
+ tfl_op_name: qtyping.TFLOperationName,
187
+ quantize_mode: qtyping.QuantizeMode,
188
+ ) -> Callable[..., Any]:
189
+ """Gets the quantization function.
190
+
191
+ Args:
192
+ algorithm_key: Target quantization algorithm key (e.g.,
193
+ AlgorithmName.MIN_MAX_UNIFORM_QUANT).
194
+ tfl_op_name: TFLite op name.
195
+ quantize_mode: Quantization mode to be used.
196
+
197
+ Returns:
198
+ A quantized operation (function) corresponds to the requested algorithm
199
+ for the TFL op.
200
+ """
201
+ if not self.is_op_registered(algorithm_key, tfl_op_name):
202
+ raise ValueError(
203
+ f"Unsupported operation {tfl_op_name} for Algorithm: {algorithm_key}."
204
+ f" Supported ops for algorithm {algorithm_key}:"
205
+ f" {self.get_supported_ops(algorithm_key)}"
206
+ )
207
+
208
+ quantized_algorithm_info = self._algorithm_registry[algorithm_key]
209
+ quantized_op_info = quantized_algorithm_info.quantized_ops
210
+ quantized_func = self._get_target_func(
211
+ quantized_op_info, tfl_op_name, quantize_mode
212
+ )
213
+ if quantized_func is None:
214
+ raise ValueError(
215
+ "Cannot retrieve appropriate quantization function for"
216
+ f" {tfl_op_name} for algorithm {algorithm_key} under quantization"
217
+ f" mode {quantize_mode}. Check if the op is registed in"
218
+ " algorithm_manager."
219
+ )
220
+
221
+ return quantized_func
222
+
223
+ def get_init_qsv_func(
224
+ self,
225
+ algorithm_key: str,
226
+ tfl_op_name: qtyping.TFLOperationName,
227
+ ) -> functools.partial:
228
+ """Gets the initial Quantization Statistics Variable function for a given op.
229
+
230
+ Args:
231
+ algorithm_key: Quantization algorithm to search.
232
+ tfl_op_name: Target TFL operation.
233
+
234
+ Returns:
235
+ A function for qsv initialization.
236
+ """
237
+
238
+ if not self.is_op_registered(algorithm_key, tfl_op_name):
239
+ raise ValueError(
240
+ f"Unsupported operation {tfl_op_name} for Algorithm: {algorithm_key}."
241
+ f" Supported ops for algorithm {algorithm_key}:"
242
+ f" {self.get_supported_ops(algorithm_key)}"
243
+ )
244
+ quantized_algorithm_info = self._algorithm_registry[algorithm_key]
245
+ quantized_op_info = quantized_algorithm_info.quantized_ops
246
+
247
+ return quantized_op_info[tfl_op_name].init_qsv_func
248
+
249
+ def _get_target_func(
250
+ self,
251
+ quantized_op_info,
252
+ tfl_op_name: qtyping.TFLOperationName,
253
+ quantize_mode: qtyping.QuantizeMode,
254
+ ):
255
+ """Gets the function corresponding to the given JAX quantization phase and op."""
256
+ if quantize_mode == qtyping.QuantizeMode.CALIBRATE:
257
+ return quantized_op_info[tfl_op_name].calibration_func
258
+ elif quantize_mode == qtyping.QuantizeMode.MATERIALIZE:
259
+ return quantized_op_info[tfl_op_name].materialize_func
260
+ return None
261
+
262
+ # TODO: b/53780772 - Merge this function with
263
+ # register_op_quant_config_validation_func after full transition to new check
264
+ # mechanism.
265
+ def register_config_check_policy(
266
+ self,
267
+ algorithm_key: str,
268
+ config_check_policy: Optional[qtyping.ConfigCheckPolicyDict],
269
+ ):
270
+ """Registers a policy to check the op quantization config."""
271
+ self._config_check_policy_registry[algorithm_key] = config_check_policy
@@ -0,0 +1,210 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for algorithm_manager_api."""
17
+
18
+ from absl.testing import parameterized
19
+ from tensorflow.python.platform import googletest
20
+ from ai_edge_quantizer import algorithm_manager_api
21
+ from ai_edge_quantizer import qtyping
22
+
23
+ _TFLOpName = qtyping.TFLOperationName
24
+
25
+
26
+ # Sample functions for test cases.
27
+ def _sample_init_qsvs(*_, **__):
28
+ return 1.0, dict()
29
+
30
+
31
+ def _sample_calibration_func(*_, **__):
32
+ return 2.0, dict()
33
+
34
+
35
+ def _sample_materialize_func(*_, **__):
36
+ return 3.0, dict()
37
+
38
+
39
+ def _sample_check_op_config_func(_, op_config):
40
+ if op_config.weight_tensor_config.num_bits == 17:
41
+ raise ValueError("Unsupported number of bits.")
42
+
43
+
44
+ class AlgorithmManagerApiTest(parameterized.TestCase):
45
+
46
+ def setUp(self):
47
+ super().setUp()
48
+ self._alg_manager = algorithm_manager_api.AlgorithmManagerApi()
49
+
50
+ def test_register_op_quant_config_validation_func_succeeds(self):
51
+ self.assertEmpty(self._alg_manager._config_check_registry)
52
+ test_algorithm_name = "test_algorithm"
53
+ self._alg_manager.register_op_quant_config_validation_func(
54
+ test_algorithm_name, _sample_check_op_config_func
55
+ )
56
+ self.assertIn(test_algorithm_name, self._alg_manager._config_check_registry)
57
+ check_func = self._alg_manager._config_check_registry[test_algorithm_name]
58
+ self.assertEqual(check_func, _sample_check_op_config_func)
59
+
60
+ def test_register_quantized_op(self):
61
+ self._alg_manager.register_quantized_op(
62
+ algorithm_key="ptq",
63
+ tfl_op_name=_TFLOpName.FULLY_CONNECTED,
64
+ init_qsv_func=_sample_init_qsvs,
65
+ calibration_func=_sample_calibration_func,
66
+ materialize_func=_sample_materialize_func,
67
+ )
68
+ self._alg_manager.register_quantized_op(
69
+ algorithm_key="gptq",
70
+ tfl_op_name=_TFLOpName.CONV_2D,
71
+ init_qsv_func=_sample_init_qsvs,
72
+ calibration_func=_sample_calibration_func,
73
+ materialize_func=_sample_materialize_func,
74
+ )
75
+ self.assertTrue(self._alg_manager.is_algorithm_registered("ptq"))
76
+ self.assertTrue(self._alg_manager.is_algorithm_registered("gptq"))
77
+ self.assertTrue(
78
+ self._alg_manager.is_op_registered("ptq", _TFLOpName.FULLY_CONNECTED)
79
+ )
80
+ self.assertTrue(
81
+ self._alg_manager.is_op_registered("gptq", _TFLOpName.CONV_2D)
82
+ )
83
+ self.assertFalse(
84
+ self._alg_manager.is_op_registered("gptq", _TFLOpName.DEPTHWISE_CONV_2D)
85
+ )
86
+
87
+ def test_get_supported_ops(self):
88
+ algorithm_key = "ptq"
89
+ self._alg_manager.register_quantized_op(
90
+ algorithm_key=algorithm_key,
91
+ tfl_op_name=_TFLOpName.FULLY_CONNECTED,
92
+ init_qsv_func=_sample_init_qsvs,
93
+ calibration_func=_sample_calibration_func,
94
+ materialize_func=_sample_materialize_func,
95
+ )
96
+ self._alg_manager.register_quantized_op(
97
+ algorithm_key=algorithm_key,
98
+ tfl_op_name=_TFLOpName.CONV_2D,
99
+ init_qsv_func=_sample_init_qsvs,
100
+ calibration_func=_sample_calibration_func,
101
+ materialize_func=_sample_materialize_func,
102
+ )
103
+ supported_ops = self._alg_manager.get_supported_ops(algorithm_key)
104
+ self.assertIn(_TFLOpName.CONV_2D, supported_ops)
105
+ self.assertIn(_TFLOpName.FULLY_CONNECTED, supported_ops)
106
+ self.assertNotIn(_TFLOpName.DEPTHWISE_CONV_2D, supported_ops)
107
+
108
+ def test_get_quantization_func(self):
109
+ algorithm_key = "ptq"
110
+ tfl_op = _TFLOpName.FULLY_CONNECTED
111
+ self._alg_manager.register_quantized_op(
112
+ algorithm_key=algorithm_key,
113
+ tfl_op_name=tfl_op,
114
+ init_qsv_func=_sample_init_qsvs,
115
+ calibration_func=_sample_calibration_func,
116
+ materialize_func=_sample_materialize_func,
117
+ )
118
+ materialize_func = self._alg_manager.get_quantization_func(
119
+ algorithm_key,
120
+ tfl_op,
121
+ qtyping.QuantizeMode.MATERIALIZE,
122
+ )
123
+ self.assertEqual(_sample_materialize_func()[0], materialize_func()[0])
124
+ calibration_func = self._alg_manager.get_quantization_func(
125
+ algorithm_key,
126
+ tfl_op,
127
+ qtyping.QuantizeMode.CALIBRATE,
128
+ )
129
+ self.assertEqual(_sample_calibration_func()[0], calibration_func()[0])
130
+
131
+ # Query for unsupported operation.
132
+ error_message = "Unsupported operation"
133
+ with self.assertRaisesWithPredicateMatch(
134
+ ValueError, lambda err: error_message in str(err)
135
+ ):
136
+ self._alg_manager.get_quantization_func(
137
+ algorithm_key,
138
+ _TFLOpName.BATCH_MATMUL,
139
+ qtyping.QuantizeMode.MATERIALIZE,
140
+ )
141
+
142
+ # Query for unregisted algorithm.
143
+ error_message = "Unregistered algorithm"
144
+ with self.assertRaisesWithPredicateMatch(
145
+ ValueError, lambda err: error_message in str(err)
146
+ ):
147
+ self._alg_manager.get_quantization_func(
148
+ "gptq",
149
+ tfl_op,
150
+ qtyping.QuantizeMode.MATERIALIZE,
151
+ )
152
+
153
+ def test_get_init_qsv_func(self):
154
+ algorithm_key = "ptq"
155
+ tfl_op = _TFLOpName.FULLY_CONNECTED
156
+ self._alg_manager.register_quantized_op(
157
+ algorithm_key=algorithm_key,
158
+ tfl_op_name=tfl_op,
159
+ init_qsv_func=_sample_init_qsvs,
160
+ calibration_func=_sample_calibration_func,
161
+ materialize_func=_sample_materialize_func,
162
+ )
163
+ init_qsv_func = self._alg_manager.get_init_qsv_func(algorithm_key, tfl_op)
164
+ self.assertEqual(_sample_init_qsvs()[0], init_qsv_func()[0])
165
+
166
+ # Query for unsupported operation.
167
+ error_message = "Unsupported operation"
168
+ with self.assertRaisesWithPredicateMatch(
169
+ ValueError, lambda err: error_message in str(err)
170
+ ):
171
+ self._alg_manager.get_init_qsv_func(
172
+ algorithm_key,
173
+ _TFLOpName.BATCH_MATMUL,
174
+ )
175
+
176
+ # Query for unregisted algorithm.
177
+ error_message = "Unregistered algorithm"
178
+ with self.assertRaisesWithPredicateMatch(
179
+ ValueError, lambda err: error_message in str(err)
180
+ ):
181
+ self._alg_manager.get_init_qsv_func(
182
+ "gptq",
183
+ tfl_op,
184
+ )
185
+
186
+ def test_register_config_check_policy_succeeds(self):
187
+ self.assertEmpty(self._alg_manager._config_check_policy_registry)
188
+ test_algorithm_name = "test_algorithm"
189
+ test_config_check_policy = qtyping.ConfigCheckPolicyDict({
190
+ _TFLOpName.FULLY_CONNECTED: {
191
+ qtyping.OpQuantizationConfig(
192
+ weight_tensor_config=qtyping.TensorQuantizationConfig(
193
+ num_bits=1
194
+ )
195
+ )
196
+ }
197
+ })
198
+ self._alg_manager.register_config_check_policy(
199
+ test_algorithm_name, test_config_check_policy
200
+ )
201
+ self.assertIn(
202
+ test_algorithm_name, self._alg_manager._config_check_policy_registry
203
+ )
204
+ self.assertIsNotNone(
205
+ self._alg_manager._config_check_policy_registry[test_algorithm_name]
206
+ )
207
+
208
+
209
+ if __name__ == "__main__":
210
+ googletest.main()
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+