ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,512 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import collections
|
17
|
+
from absl.testing import parameterized
|
18
|
+
from tensorflow.python.platform import googletest
|
19
|
+
from ai_edge_quantizer import default_policy
|
20
|
+
from ai_edge_quantizer import qtyping
|
21
|
+
from ai_edge_quantizer.algorithms.utils import min_max_quantize_utils
|
22
|
+
|
23
|
+
_ComputePrecision = qtyping.ComputePrecision
|
24
|
+
_QuantTransformation = qtyping.QuantTransformation
|
25
|
+
_TFLOpName = qtyping.TFLOperationName
|
26
|
+
_OpQuantConfig = qtyping.OpQuantizationConfig
|
27
|
+
_TensorQuantConfig = qtyping.TensorQuantizationConfig
|
28
|
+
_DEFAULT_CONFIG_CHECK_POLICY = default_policy.DEFAULT_CONFIG_CHECK_POLICY
|
29
|
+
|
30
|
+
|
31
|
+
# TODO: b/335008966 - increase test coverage.
|
32
|
+
class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
33
|
+
|
34
|
+
@parameterized.product(
|
35
|
+
test_case=[
|
36
|
+
# Tuple holds computation precision, whether to use SRQ and whether
|
37
|
+
# to use explicit dequantize.
|
38
|
+
(_ComputePrecision.FLOAT, False, True), # WEIGHT_ONLY.
|
39
|
+
(_ComputePrecision.INTEGER, False, False), # DRQ.
|
40
|
+
(_ComputePrecision.INTEGER, True, False), # SRQ.
|
41
|
+
],
|
42
|
+
is_inbounding_tensor=[True, False],
|
43
|
+
is_constant=[True, False],
|
44
|
+
)
|
45
|
+
def test_get_tensor_transformations(
|
46
|
+
self, test_case, is_inbounding_tensor, is_constant
|
47
|
+
):
|
48
|
+
compute_precision, is_srq, explicit_dequantize = test_case
|
49
|
+
weight_tensor_config = _TensorQuantConfig(num_bits=8)
|
50
|
+
op_quant_config = qtyping.OpQuantizationConfig(
|
51
|
+
activation_tensor_config=weight_tensor_config if is_srq else None,
|
52
|
+
compute_precision=compute_precision,
|
53
|
+
explicit_dequantize=explicit_dequantize,
|
54
|
+
)
|
55
|
+
transformations = min_max_quantize_utils.get_tensor_transformations(
|
56
|
+
op_quant_config, is_inbounding_tensor, is_constant
|
57
|
+
)
|
58
|
+
# Check if WEIGHT_ONLY.
|
59
|
+
if (
|
60
|
+
compute_precision == _ComputePrecision.FLOAT
|
61
|
+
and op_quant_config.explicit_dequantize
|
62
|
+
):
|
63
|
+
if is_inbounding_tensor and is_constant:
|
64
|
+
self.assertSequenceEqual(
|
65
|
+
transformations,
|
66
|
+
[
|
67
|
+
_QuantTransformation.ADD_DEQUANTIZE,
|
68
|
+
],
|
69
|
+
)
|
70
|
+
else:
|
71
|
+
self.assertSequenceEqual(
|
72
|
+
transformations,
|
73
|
+
[_QuantTransformation.NO_QUANTIZE],
|
74
|
+
)
|
75
|
+
|
76
|
+
# Check if DRQ.
|
77
|
+
if compute_precision == _ComputePrecision.INTEGER and not is_srq:
|
78
|
+
if is_inbounding_tensor and is_constant:
|
79
|
+
self.assertSequenceEqual(
|
80
|
+
transformations, [_QuantTransformation.QUANTIZE_TENSOR]
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
self.assertSequenceEqual(
|
84
|
+
transformations,
|
85
|
+
[_QuantTransformation.NO_QUANTIZE],
|
86
|
+
)
|
87
|
+
|
88
|
+
# Check if SRQ.
|
89
|
+
if compute_precision == _ComputePrecision.INTEGER and is_srq:
|
90
|
+
if is_inbounding_tensor:
|
91
|
+
if is_constant:
|
92
|
+
self.assertSequenceEqual(
|
93
|
+
transformations, [_QuantTransformation.QUANTIZE_TENSOR]
|
94
|
+
)
|
95
|
+
else:
|
96
|
+
self.assertSequenceEqual(
|
97
|
+
transformations, [_QuantTransformation.ADD_QUANTIZE]
|
98
|
+
)
|
99
|
+
else:
|
100
|
+
self.assertSequenceEqual(
|
101
|
+
transformations, [_QuantTransformation.ADD_DEQUANTIZE]
|
102
|
+
)
|
103
|
+
|
104
|
+
@parameterized.parameters((_TFLOpName.FULLY_CONNECTED), (_TFLOpName.CONV_2D))
|
105
|
+
def test_check_weight_only_config_succeeds(self, op_name):
|
106
|
+
self.assertIn(op_name, _DEFAULT_CONFIG_CHECK_POLICY.keys())
|
107
|
+
|
108
|
+
@parameterized.parameters((_TFLOpName.RESHAPE), (_TFLOpName.AVERAGE_POOL_2D))
|
109
|
+
def test_check_weight_only_config_raises_when_invalid_config(self, op_name):
|
110
|
+
op_quant_config = _OpQuantConfig(
|
111
|
+
weight_tensor_config=_TensorQuantConfig(
|
112
|
+
num_bits=8,
|
113
|
+
),
|
114
|
+
compute_precision=_ComputePrecision.FLOAT,
|
115
|
+
)
|
116
|
+
error_message = (
|
117
|
+
f"Quantization config for op: {op_name} with config:"
|
118
|
+
f" {op_quant_config} was not found in the policy."
|
119
|
+
)
|
120
|
+
with self.assertRaisesWithPredicateMatch(
|
121
|
+
ValueError, lambda err: error_message in str(err)
|
122
|
+
):
|
123
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
124
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
125
|
+
)
|
126
|
+
|
127
|
+
@parameterized.product(
|
128
|
+
op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
|
129
|
+
weight_num_bits=(4, 8),
|
130
|
+
granularity=(
|
131
|
+
qtyping.QuantGranularity.TENSORWISE,
|
132
|
+
qtyping.QuantGranularity.CHANNELWISE,
|
133
|
+
),
|
134
|
+
)
|
135
|
+
def test_check_drq_config_succeeds(
|
136
|
+
self, op_name, weight_num_bits, granularity
|
137
|
+
):
|
138
|
+
# TODO: b/353365054 - Remove this check after int4 DRQ is supported for
|
139
|
+
# conv2d.
|
140
|
+
if op_name == _TFLOpName.CONV_2D and weight_num_bits == 4:
|
141
|
+
return
|
142
|
+
op_quant_config = _OpQuantConfig(
|
143
|
+
weight_tensor_config=_TensorQuantConfig(
|
144
|
+
num_bits=weight_num_bits,
|
145
|
+
granularity=granularity,
|
146
|
+
),
|
147
|
+
compute_precision=_ComputePrecision.INTEGER, # DRQ.
|
148
|
+
)
|
149
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
150
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
151
|
+
)
|
152
|
+
|
153
|
+
@parameterized.parameters((_TFLOpName.RESHAPE), (_TFLOpName.AVERAGE_POOL_2D))
|
154
|
+
def test_check_drq_config_unsupported_op_raise_error(self, op_name):
|
155
|
+
op_quant_config = _OpQuantConfig(
|
156
|
+
weight_tensor_config=_TensorQuantConfig(
|
157
|
+
num_bits=8,
|
158
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
159
|
+
),
|
160
|
+
compute_precision=_ComputePrecision.INTEGER, # DRQ.
|
161
|
+
)
|
162
|
+
error_message = (
|
163
|
+
f"Quantization config for op: {op_name} with config:"
|
164
|
+
f" {op_quant_config} was not found in the policy."
|
165
|
+
)
|
166
|
+
with self.assertRaisesWithPredicateMatch(
|
167
|
+
ValueError, lambda err: error_message in str(err)
|
168
|
+
):
|
169
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
170
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
171
|
+
)
|
172
|
+
|
173
|
+
@parameterized.parameters((_TFLOpName.FULLY_CONNECTED), (_TFLOpName.CONV_2D))
|
174
|
+
def test_check_drq_config_wrong_bits_raise_error(self, op_name):
|
175
|
+
op_quant_config = _OpQuantConfig(
|
176
|
+
weight_tensor_config=_TensorQuantConfig(
|
177
|
+
num_bits=2,
|
178
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
179
|
+
),
|
180
|
+
compute_precision=_ComputePrecision.INTEGER, # DRQ.
|
181
|
+
)
|
182
|
+
error_message = (
|
183
|
+
f"Quantization config for op: {op_name} with config:"
|
184
|
+
f" {op_quant_config} was not found in the policy."
|
185
|
+
)
|
186
|
+
with self.assertRaisesWithPredicateMatch(
|
187
|
+
ValueError, lambda err: error_message in str(err)
|
188
|
+
):
|
189
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
190
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
191
|
+
)
|
192
|
+
|
193
|
+
@parameterized.parameters((_TFLOpName.FULLY_CONNECTED), (_TFLOpName.CONV_2D))
|
194
|
+
def test_check_drq_config_asymmetric_weights_raise_error(self, op_name):
|
195
|
+
op_quant_config = _OpQuantConfig(
|
196
|
+
weight_tensor_config=_TensorQuantConfig(
|
197
|
+
num_bits=8,
|
198
|
+
symmetric=False,
|
199
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
200
|
+
),
|
201
|
+
compute_precision=_ComputePrecision.INTEGER, # DRQ.
|
202
|
+
)
|
203
|
+
error_message = (
|
204
|
+
f"Quantization config for op: {op_name} with config:"
|
205
|
+
f" {op_quant_config} was not found in the policy."
|
206
|
+
)
|
207
|
+
with self.assertRaisesWithPredicateMatch(
|
208
|
+
ValueError, lambda err: error_message in str(err)
|
209
|
+
):
|
210
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
211
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
212
|
+
)
|
213
|
+
|
214
|
+
def test_check_drq_config_with_non_default_min_weight_elements_succeeds(self):
|
215
|
+
op_quant_config = _OpQuantConfig(
|
216
|
+
weight_tensor_config=_TensorQuantConfig(
|
217
|
+
num_bits=8,
|
218
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
219
|
+
),
|
220
|
+
compute_precision=_ComputePrecision.INTEGER, # DRQ.
|
221
|
+
min_weight_elements=100,
|
222
|
+
)
|
223
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
224
|
+
_TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
225
|
+
)
|
226
|
+
|
227
|
+
@parameterized.product(
|
228
|
+
op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
|
229
|
+
act_num_bits=(8, 16),
|
230
|
+
weight_num_bits=(4, 8),
|
231
|
+
granularity=(
|
232
|
+
qtyping.QuantGranularity.TENSORWISE,
|
233
|
+
qtyping.QuantGranularity.CHANNELWISE,
|
234
|
+
),
|
235
|
+
symmetric_act=(True, False),
|
236
|
+
)
|
237
|
+
def test_check_srq_config_succeeds(
|
238
|
+
self,
|
239
|
+
op_name,
|
240
|
+
act_num_bits,
|
241
|
+
weight_num_bits,
|
242
|
+
granularity,
|
243
|
+
symmetric_act,
|
244
|
+
):
|
245
|
+
# Asym int16 activation is not supported.
|
246
|
+
if not symmetric_act and act_num_bits == 16:
|
247
|
+
return
|
248
|
+
op_quant_config = _OpQuantConfig(
|
249
|
+
activation_tensor_config=_TensorQuantConfig(
|
250
|
+
num_bits=act_num_bits, symmetric=symmetric_act
|
251
|
+
),
|
252
|
+
weight_tensor_config=_TensorQuantConfig(
|
253
|
+
num_bits=weight_num_bits,
|
254
|
+
granularity=granularity,
|
255
|
+
),
|
256
|
+
compute_precision=_ComputePrecision.INTEGER, # SRQ.
|
257
|
+
)
|
258
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
259
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
260
|
+
)
|
261
|
+
|
262
|
+
def test_check_srq_config_unsupported_op_raise_error(self):
|
263
|
+
op_quant_config = _OpQuantConfig(
|
264
|
+
activation_tensor_config=_TensorQuantConfig(num_bits=8, symmetric=True),
|
265
|
+
weight_tensor_config=_TensorQuantConfig(
|
266
|
+
num_bits=8,
|
267
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
268
|
+
),
|
269
|
+
compute_precision=_ComputePrecision.INTEGER, # SRQ.
|
270
|
+
)
|
271
|
+
error_message = (
|
272
|
+
f"Unsupported op for {op_quant_config.compute_precision}:"
|
273
|
+
f" {_TFLOpName.CUSTOM_OP}"
|
274
|
+
)
|
275
|
+
with self.assertRaisesWithPredicateMatch(
|
276
|
+
ValueError, lambda err: error_message in str(err)
|
277
|
+
):
|
278
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
279
|
+
_TFLOpName.CUSTOM_OP, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
280
|
+
)
|
281
|
+
|
282
|
+
def test_check_srq_config_wrong_act_bits_config_raise_error(self):
|
283
|
+
op_quant_config = _OpQuantConfig(
|
284
|
+
activation_tensor_config=_TensorQuantConfig(
|
285
|
+
num_bits=14, symmetric=True
|
286
|
+
),
|
287
|
+
weight_tensor_config=_TensorQuantConfig(
|
288
|
+
num_bits=8,
|
289
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
290
|
+
),
|
291
|
+
compute_precision=_ComputePrecision.INTEGER, # SRQ.
|
292
|
+
)
|
293
|
+
error_message = (
|
294
|
+
f"Quantization config for op: {_TFLOpName.FULLY_CONNECTED} with config:"
|
295
|
+
f" {op_quant_config} was not found in the policy."
|
296
|
+
)
|
297
|
+
with self.assertRaisesWithPredicateMatch(
|
298
|
+
ValueError, lambda err: error_message in str(err)
|
299
|
+
):
|
300
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
301
|
+
_TFLOpName.FULLY_CONNECTED,
|
302
|
+
op_quant_config,
|
303
|
+
_DEFAULT_CONFIG_CHECK_POLICY,
|
304
|
+
)
|
305
|
+
|
306
|
+
def test_check_srq_config_asym_int16_act_raise_error(self):
|
307
|
+
op_quant_config = _OpQuantConfig(
|
308
|
+
activation_tensor_config=_TensorQuantConfig(
|
309
|
+
num_bits=16, symmetric=False
|
310
|
+
),
|
311
|
+
weight_tensor_config=_TensorQuantConfig(
|
312
|
+
num_bits=8,
|
313
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
314
|
+
),
|
315
|
+
compute_precision=_ComputePrecision.INTEGER, # SRQ.
|
316
|
+
)
|
317
|
+
error_message = (
|
318
|
+
f"Quantization config for op: {_TFLOpName.FULLY_CONNECTED} with config:"
|
319
|
+
f" {op_quant_config} was not found in the policy."
|
320
|
+
)
|
321
|
+
with self.assertRaisesWithPredicateMatch(
|
322
|
+
ValueError, lambda err: error_message in str(err)
|
323
|
+
):
|
324
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
325
|
+
_TFLOpName.FULLY_CONNECTED,
|
326
|
+
op_quant_config,
|
327
|
+
_DEFAULT_CONFIG_CHECK_POLICY,
|
328
|
+
)
|
329
|
+
|
330
|
+
def test_check_srq_config_wrong_weight_bits_raise_error(self):
|
331
|
+
op_quant_config = _OpQuantConfig(
|
332
|
+
activation_tensor_config=_TensorQuantConfig(
|
333
|
+
num_bits=16, symmetric=True
|
334
|
+
),
|
335
|
+
weight_tensor_config=_TensorQuantConfig(
|
336
|
+
num_bits=2,
|
337
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
338
|
+
),
|
339
|
+
compute_precision=_ComputePrecision.INTEGER, # SRQ.
|
340
|
+
)
|
341
|
+
error_message = (
|
342
|
+
f"Quantization config for op: {_TFLOpName.FULLY_CONNECTED} with config:"
|
343
|
+
f" {op_quant_config} was not found in the policy."
|
344
|
+
)
|
345
|
+
with self.assertRaisesWithPredicateMatch(
|
346
|
+
ValueError, lambda err: error_message in str(err)
|
347
|
+
):
|
348
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
349
|
+
_TFLOpName.FULLY_CONNECTED,
|
350
|
+
op_quant_config,
|
351
|
+
_DEFAULT_CONFIG_CHECK_POLICY,
|
352
|
+
)
|
353
|
+
|
354
|
+
def test_check_srq_config_asym_weight_raise_error(self):
|
355
|
+
op_quant_config = _OpQuantConfig(
|
356
|
+
activation_tensor_config=_TensorQuantConfig(num_bits=8, symmetric=True),
|
357
|
+
weight_tensor_config=_TensorQuantConfig(
|
358
|
+
num_bits=8,
|
359
|
+
symmetric=False,
|
360
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
361
|
+
),
|
362
|
+
compute_precision=_ComputePrecision.INTEGER, # SRQ.
|
363
|
+
)
|
364
|
+
error_message = (
|
365
|
+
f"Quantization config for op: {_TFLOpName.FULLY_CONNECTED} with config:"
|
366
|
+
f" {op_quant_config} was not found in the policy."
|
367
|
+
)
|
368
|
+
with self.assertRaisesWithPredicateMatch(
|
369
|
+
ValueError, lambda err: error_message in str(err)
|
370
|
+
):
|
371
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
372
|
+
_TFLOpName.FULLY_CONNECTED,
|
373
|
+
op_quant_config,
|
374
|
+
_DEFAULT_CONFIG_CHECK_POLICY,
|
375
|
+
)
|
376
|
+
|
377
|
+
@parameterized.product(
|
378
|
+
op_name=[
|
379
|
+
_TFLOpName.FULLY_CONNECTED,
|
380
|
+
_TFLOpName.CONV_2D,
|
381
|
+
],
|
382
|
+
activation_tensor_config=[
|
383
|
+
None,
|
384
|
+
_TensorQuantConfig(num_bits=8, symmetric=False),
|
385
|
+
_TensorQuantConfig(num_bits=16, symmetric=True),
|
386
|
+
],
|
387
|
+
compute_precision=[
|
388
|
+
_ComputePrecision.FLOAT,
|
389
|
+
_ComputePrecision.INTEGER,
|
390
|
+
],
|
391
|
+
)
|
392
|
+
def test_check_supported_int4_config_succeeds(
|
393
|
+
self, op_name, activation_tensor_config, compute_precision
|
394
|
+
):
|
395
|
+
# Exclude invalid SRQ config.
|
396
|
+
if (
|
397
|
+
activation_tensor_config is not None
|
398
|
+
and compute_precision != _ComputePrecision.INTEGER
|
399
|
+
) or (
|
400
|
+
activation_tensor_config is None
|
401
|
+
and compute_precision == _ComputePrecision.FLOAT
|
402
|
+
):
|
403
|
+
return
|
404
|
+
# TODO: b/353365054 - Remove this check after int4 DRQ is supported for
|
405
|
+
# conv2d.
|
406
|
+
if (
|
407
|
+
# Check if DRQ and CONV_2D.
|
408
|
+
compute_precision == _ComputePrecision.INTEGER
|
409
|
+
and activation_tensor_config is None
|
410
|
+
and op_name == _TFLOpName.CONV_2D
|
411
|
+
):
|
412
|
+
return
|
413
|
+
op_quant_config = _OpQuantConfig(
|
414
|
+
activation_tensor_config=activation_tensor_config,
|
415
|
+
weight_tensor_config=_TensorQuantConfig(
|
416
|
+
num_bits=4,
|
417
|
+
symmetric=True,
|
418
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
419
|
+
),
|
420
|
+
compute_precision=compute_precision,
|
421
|
+
)
|
422
|
+
# Raise error if the config is not supported.
|
423
|
+
# Check if DRQ.
|
424
|
+
if (
|
425
|
+
compute_precision == _ComputePrecision.INTEGER
|
426
|
+
and op_quant_config.activation_tensor_config is None
|
427
|
+
):
|
428
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
429
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
430
|
+
)
|
431
|
+
# Check if WEIGHT_ONLY.
|
432
|
+
elif (
|
433
|
+
compute_precision == _ComputePrecision.FLOAT
|
434
|
+
and op_quant_config.explicit_dequantize
|
435
|
+
):
|
436
|
+
self.assertIn(op_name, _DEFAULT_CONFIG_CHECK_POLICY.keys())
|
437
|
+
# Check if SRQ.
|
438
|
+
if (
|
439
|
+
compute_precision == _ComputePrecision.INTEGER
|
440
|
+
and op_quant_config.activation_tensor_config is not None
|
441
|
+
):
|
442
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
443
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
444
|
+
)
|
445
|
+
|
446
|
+
@parameterized.product(
|
447
|
+
op_name=[_TFLOpName.BATCH_MATMUL],
|
448
|
+
activation_tensor_config=[
|
449
|
+
None,
|
450
|
+
_TensorQuantConfig(num_bits=8, symmetric=False),
|
451
|
+
_TensorQuantConfig(num_bits=16, symmetric=True),
|
452
|
+
],
|
453
|
+
test_case=[
|
454
|
+
# Tuple holds compute precision and whether to use drq.
|
455
|
+
(_ComputePrecision.INTEGER, True),
|
456
|
+
(_ComputePrecision.INTEGER, False),
|
457
|
+
],
|
458
|
+
)
|
459
|
+
def test_check_unsupported_int4_config_raise_error(
|
460
|
+
self, op_name, activation_tensor_config, test_case
|
461
|
+
):
|
462
|
+
compute_precision, is_drq = test_case
|
463
|
+
# Exclude invalid SRQ config.
|
464
|
+
if (activation_tensor_config is not None and is_drq) or (
|
465
|
+
activation_tensor_config is None and not is_drq
|
466
|
+
):
|
467
|
+
return
|
468
|
+
op_quant_config = _OpQuantConfig(
|
469
|
+
activation_tensor_config=activation_tensor_config,
|
470
|
+
weight_tensor_config=_TensorQuantConfig(
|
471
|
+
num_bits=4,
|
472
|
+
symmetric=True,
|
473
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
474
|
+
),
|
475
|
+
compute_precision=compute_precision,
|
476
|
+
)
|
477
|
+
|
478
|
+
with self.assertRaises(ValueError):
|
479
|
+
if is_drq:
|
480
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
481
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
482
|
+
)
|
483
|
+
elif not is_drq:
|
484
|
+
min_max_quantize_utils.check_if_valid_op_config(
|
485
|
+
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
486
|
+
)
|
487
|
+
|
488
|
+
def test_materialize_op_with_output_activation_constraint_fails_for_multiple_output_op(
|
489
|
+
self,
|
490
|
+
):
|
491
|
+
# Create a mock op with multiple outputs.
|
492
|
+
MockOp = collections.namedtuple("MockOp", ["outputs"])
|
493
|
+
mock_op_info = qtyping.OpInfo(
|
494
|
+
op=MockOp(outputs=[1, 2]),
|
495
|
+
op_name=_TFLOpName.SOFTMAX,
|
496
|
+
subgraph_op_index=0,
|
497
|
+
op_quant_config=_OpQuantConfig(),
|
498
|
+
)
|
499
|
+
|
500
|
+
with self.assertRaisesRegex(
|
501
|
+
ValueError, "only supports ops with a single output tensor"
|
502
|
+
):
|
503
|
+
min_max_quantize_utils.materialize_op_with_output_activation_constraint(
|
504
|
+
op_info=mock_op_info,
|
505
|
+
graph_info=qtyping.GraphInfo([], []),
|
506
|
+
tensor_name_to_qsv={},
|
507
|
+
output_activation_constraints={},
|
508
|
+
)
|
509
|
+
|
510
|
+
|
511
|
+
if __name__ == "__main__":
|
512
|
+
googletest.main()
|