ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,512 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import collections
17
+ from absl.testing import parameterized
18
+ from tensorflow.python.platform import googletest
19
+ from ai_edge_quantizer import default_policy
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.algorithms.utils import min_max_quantize_utils
22
+
23
+ _ComputePrecision = qtyping.ComputePrecision
24
+ _QuantTransformation = qtyping.QuantTransformation
25
+ _TFLOpName = qtyping.TFLOperationName
26
+ _OpQuantConfig = qtyping.OpQuantizationConfig
27
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
28
+ _DEFAULT_CONFIG_CHECK_POLICY = default_policy.DEFAULT_CONFIG_CHECK_POLICY
29
+
30
+
31
+ # TODO: b/335008966 - increase test coverage.
32
+ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
33
+
34
+ @parameterized.product(
35
+ test_case=[
36
+ # Tuple holds computation precision, whether to use SRQ and whether
37
+ # to use explicit dequantize.
38
+ (_ComputePrecision.FLOAT, False, True), # WEIGHT_ONLY.
39
+ (_ComputePrecision.INTEGER, False, False), # DRQ.
40
+ (_ComputePrecision.INTEGER, True, False), # SRQ.
41
+ ],
42
+ is_inbounding_tensor=[True, False],
43
+ is_constant=[True, False],
44
+ )
45
+ def test_get_tensor_transformations(
46
+ self, test_case, is_inbounding_tensor, is_constant
47
+ ):
48
+ compute_precision, is_srq, explicit_dequantize = test_case
49
+ weight_tensor_config = _TensorQuantConfig(num_bits=8)
50
+ op_quant_config = qtyping.OpQuantizationConfig(
51
+ activation_tensor_config=weight_tensor_config if is_srq else None,
52
+ compute_precision=compute_precision,
53
+ explicit_dequantize=explicit_dequantize,
54
+ )
55
+ transformations = min_max_quantize_utils.get_tensor_transformations(
56
+ op_quant_config, is_inbounding_tensor, is_constant
57
+ )
58
+ # Check if WEIGHT_ONLY.
59
+ if (
60
+ compute_precision == _ComputePrecision.FLOAT
61
+ and op_quant_config.explicit_dequantize
62
+ ):
63
+ if is_inbounding_tensor and is_constant:
64
+ self.assertSequenceEqual(
65
+ transformations,
66
+ [
67
+ _QuantTransformation.ADD_DEQUANTIZE,
68
+ ],
69
+ )
70
+ else:
71
+ self.assertSequenceEqual(
72
+ transformations,
73
+ [_QuantTransformation.NO_QUANTIZE],
74
+ )
75
+
76
+ # Check if DRQ.
77
+ if compute_precision == _ComputePrecision.INTEGER and not is_srq:
78
+ if is_inbounding_tensor and is_constant:
79
+ self.assertSequenceEqual(
80
+ transformations, [_QuantTransformation.QUANTIZE_TENSOR]
81
+ )
82
+ else:
83
+ self.assertSequenceEqual(
84
+ transformations,
85
+ [_QuantTransformation.NO_QUANTIZE],
86
+ )
87
+
88
+ # Check if SRQ.
89
+ if compute_precision == _ComputePrecision.INTEGER and is_srq:
90
+ if is_inbounding_tensor:
91
+ if is_constant:
92
+ self.assertSequenceEqual(
93
+ transformations, [_QuantTransformation.QUANTIZE_TENSOR]
94
+ )
95
+ else:
96
+ self.assertSequenceEqual(
97
+ transformations, [_QuantTransformation.ADD_QUANTIZE]
98
+ )
99
+ else:
100
+ self.assertSequenceEqual(
101
+ transformations, [_QuantTransformation.ADD_DEQUANTIZE]
102
+ )
103
+
104
+ @parameterized.parameters((_TFLOpName.FULLY_CONNECTED), (_TFLOpName.CONV_2D))
105
+ def test_check_weight_only_config_succeeds(self, op_name):
106
+ self.assertIn(op_name, _DEFAULT_CONFIG_CHECK_POLICY.keys())
107
+
108
+ @parameterized.parameters((_TFLOpName.RESHAPE), (_TFLOpName.AVERAGE_POOL_2D))
109
+ def test_check_weight_only_config_raises_when_invalid_config(self, op_name):
110
+ op_quant_config = _OpQuantConfig(
111
+ weight_tensor_config=_TensorQuantConfig(
112
+ num_bits=8,
113
+ ),
114
+ compute_precision=_ComputePrecision.FLOAT,
115
+ )
116
+ error_message = (
117
+ f"Quantization config for op: {op_name} with config:"
118
+ f" {op_quant_config} was not found in the policy."
119
+ )
120
+ with self.assertRaisesWithPredicateMatch(
121
+ ValueError, lambda err: error_message in str(err)
122
+ ):
123
+ min_max_quantize_utils.check_if_valid_op_config(
124
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
125
+ )
126
+
127
+ @parameterized.product(
128
+ op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
129
+ weight_num_bits=(4, 8),
130
+ granularity=(
131
+ qtyping.QuantGranularity.TENSORWISE,
132
+ qtyping.QuantGranularity.CHANNELWISE,
133
+ ),
134
+ )
135
+ def test_check_drq_config_succeeds(
136
+ self, op_name, weight_num_bits, granularity
137
+ ):
138
+ # TODO: b/353365054 - Remove this check after int4 DRQ is supported for
139
+ # conv2d.
140
+ if op_name == _TFLOpName.CONV_2D and weight_num_bits == 4:
141
+ return
142
+ op_quant_config = _OpQuantConfig(
143
+ weight_tensor_config=_TensorQuantConfig(
144
+ num_bits=weight_num_bits,
145
+ granularity=granularity,
146
+ ),
147
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
148
+ )
149
+ min_max_quantize_utils.check_if_valid_op_config(
150
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
151
+ )
152
+
153
+ @parameterized.parameters((_TFLOpName.RESHAPE), (_TFLOpName.AVERAGE_POOL_2D))
154
+ def test_check_drq_config_unsupported_op_raise_error(self, op_name):
155
+ op_quant_config = _OpQuantConfig(
156
+ weight_tensor_config=_TensorQuantConfig(
157
+ num_bits=8,
158
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
159
+ ),
160
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
161
+ )
162
+ error_message = (
163
+ f"Quantization config for op: {op_name} with config:"
164
+ f" {op_quant_config} was not found in the policy."
165
+ )
166
+ with self.assertRaisesWithPredicateMatch(
167
+ ValueError, lambda err: error_message in str(err)
168
+ ):
169
+ min_max_quantize_utils.check_if_valid_op_config(
170
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
171
+ )
172
+
173
+ @parameterized.parameters((_TFLOpName.FULLY_CONNECTED), (_TFLOpName.CONV_2D))
174
+ def test_check_drq_config_wrong_bits_raise_error(self, op_name):
175
+ op_quant_config = _OpQuantConfig(
176
+ weight_tensor_config=_TensorQuantConfig(
177
+ num_bits=2,
178
+ granularity=qtyping.QuantGranularity.TENSORWISE,
179
+ ),
180
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
181
+ )
182
+ error_message = (
183
+ f"Quantization config for op: {op_name} with config:"
184
+ f" {op_quant_config} was not found in the policy."
185
+ )
186
+ with self.assertRaisesWithPredicateMatch(
187
+ ValueError, lambda err: error_message in str(err)
188
+ ):
189
+ min_max_quantize_utils.check_if_valid_op_config(
190
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
191
+ )
192
+
193
+ @parameterized.parameters((_TFLOpName.FULLY_CONNECTED), (_TFLOpName.CONV_2D))
194
+ def test_check_drq_config_asymmetric_weights_raise_error(self, op_name):
195
+ op_quant_config = _OpQuantConfig(
196
+ weight_tensor_config=_TensorQuantConfig(
197
+ num_bits=8,
198
+ symmetric=False,
199
+ granularity=qtyping.QuantGranularity.TENSORWISE,
200
+ ),
201
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
202
+ )
203
+ error_message = (
204
+ f"Quantization config for op: {op_name} with config:"
205
+ f" {op_quant_config} was not found in the policy."
206
+ )
207
+ with self.assertRaisesWithPredicateMatch(
208
+ ValueError, lambda err: error_message in str(err)
209
+ ):
210
+ min_max_quantize_utils.check_if_valid_op_config(
211
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
212
+ )
213
+
214
+ def test_check_drq_config_with_non_default_min_weight_elements_succeeds(self):
215
+ op_quant_config = _OpQuantConfig(
216
+ weight_tensor_config=_TensorQuantConfig(
217
+ num_bits=8,
218
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
219
+ ),
220
+ compute_precision=_ComputePrecision.INTEGER, # DRQ.
221
+ min_weight_elements=100,
222
+ )
223
+ min_max_quantize_utils.check_if_valid_op_config(
224
+ _TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
225
+ )
226
+
227
+ @parameterized.product(
228
+ op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
229
+ act_num_bits=(8, 16),
230
+ weight_num_bits=(4, 8),
231
+ granularity=(
232
+ qtyping.QuantGranularity.TENSORWISE,
233
+ qtyping.QuantGranularity.CHANNELWISE,
234
+ ),
235
+ symmetric_act=(True, False),
236
+ )
237
+ def test_check_srq_config_succeeds(
238
+ self,
239
+ op_name,
240
+ act_num_bits,
241
+ weight_num_bits,
242
+ granularity,
243
+ symmetric_act,
244
+ ):
245
+ # Asym int16 activation is not supported.
246
+ if not symmetric_act and act_num_bits == 16:
247
+ return
248
+ op_quant_config = _OpQuantConfig(
249
+ activation_tensor_config=_TensorQuantConfig(
250
+ num_bits=act_num_bits, symmetric=symmetric_act
251
+ ),
252
+ weight_tensor_config=_TensorQuantConfig(
253
+ num_bits=weight_num_bits,
254
+ granularity=granularity,
255
+ ),
256
+ compute_precision=_ComputePrecision.INTEGER, # SRQ.
257
+ )
258
+ min_max_quantize_utils.check_if_valid_op_config(
259
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
260
+ )
261
+
262
+ def test_check_srq_config_unsupported_op_raise_error(self):
263
+ op_quant_config = _OpQuantConfig(
264
+ activation_tensor_config=_TensorQuantConfig(num_bits=8, symmetric=True),
265
+ weight_tensor_config=_TensorQuantConfig(
266
+ num_bits=8,
267
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
268
+ ),
269
+ compute_precision=_ComputePrecision.INTEGER, # SRQ.
270
+ )
271
+ error_message = (
272
+ f"Unsupported op for {op_quant_config.compute_precision}:"
273
+ f" {_TFLOpName.CUSTOM_OP}"
274
+ )
275
+ with self.assertRaisesWithPredicateMatch(
276
+ ValueError, lambda err: error_message in str(err)
277
+ ):
278
+ min_max_quantize_utils.check_if_valid_op_config(
279
+ _TFLOpName.CUSTOM_OP, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
280
+ )
281
+
282
+ def test_check_srq_config_wrong_act_bits_config_raise_error(self):
283
+ op_quant_config = _OpQuantConfig(
284
+ activation_tensor_config=_TensorQuantConfig(
285
+ num_bits=14, symmetric=True
286
+ ),
287
+ weight_tensor_config=_TensorQuantConfig(
288
+ num_bits=8,
289
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
290
+ ),
291
+ compute_precision=_ComputePrecision.INTEGER, # SRQ.
292
+ )
293
+ error_message = (
294
+ f"Quantization config for op: {_TFLOpName.FULLY_CONNECTED} with config:"
295
+ f" {op_quant_config} was not found in the policy."
296
+ )
297
+ with self.assertRaisesWithPredicateMatch(
298
+ ValueError, lambda err: error_message in str(err)
299
+ ):
300
+ min_max_quantize_utils.check_if_valid_op_config(
301
+ _TFLOpName.FULLY_CONNECTED,
302
+ op_quant_config,
303
+ _DEFAULT_CONFIG_CHECK_POLICY,
304
+ )
305
+
306
+ def test_check_srq_config_asym_int16_act_raise_error(self):
307
+ op_quant_config = _OpQuantConfig(
308
+ activation_tensor_config=_TensorQuantConfig(
309
+ num_bits=16, symmetric=False
310
+ ),
311
+ weight_tensor_config=_TensorQuantConfig(
312
+ num_bits=8,
313
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
314
+ ),
315
+ compute_precision=_ComputePrecision.INTEGER, # SRQ.
316
+ )
317
+ error_message = (
318
+ f"Quantization config for op: {_TFLOpName.FULLY_CONNECTED} with config:"
319
+ f" {op_quant_config} was not found in the policy."
320
+ )
321
+ with self.assertRaisesWithPredicateMatch(
322
+ ValueError, lambda err: error_message in str(err)
323
+ ):
324
+ min_max_quantize_utils.check_if_valid_op_config(
325
+ _TFLOpName.FULLY_CONNECTED,
326
+ op_quant_config,
327
+ _DEFAULT_CONFIG_CHECK_POLICY,
328
+ )
329
+
330
+ def test_check_srq_config_wrong_weight_bits_raise_error(self):
331
+ op_quant_config = _OpQuantConfig(
332
+ activation_tensor_config=_TensorQuantConfig(
333
+ num_bits=16, symmetric=True
334
+ ),
335
+ weight_tensor_config=_TensorQuantConfig(
336
+ num_bits=2,
337
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
338
+ ),
339
+ compute_precision=_ComputePrecision.INTEGER, # SRQ.
340
+ )
341
+ error_message = (
342
+ f"Quantization config for op: {_TFLOpName.FULLY_CONNECTED} with config:"
343
+ f" {op_quant_config} was not found in the policy."
344
+ )
345
+ with self.assertRaisesWithPredicateMatch(
346
+ ValueError, lambda err: error_message in str(err)
347
+ ):
348
+ min_max_quantize_utils.check_if_valid_op_config(
349
+ _TFLOpName.FULLY_CONNECTED,
350
+ op_quant_config,
351
+ _DEFAULT_CONFIG_CHECK_POLICY,
352
+ )
353
+
354
+ def test_check_srq_config_asym_weight_raise_error(self):
355
+ op_quant_config = _OpQuantConfig(
356
+ activation_tensor_config=_TensorQuantConfig(num_bits=8, symmetric=True),
357
+ weight_tensor_config=_TensorQuantConfig(
358
+ num_bits=8,
359
+ symmetric=False,
360
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
361
+ ),
362
+ compute_precision=_ComputePrecision.INTEGER, # SRQ.
363
+ )
364
+ error_message = (
365
+ f"Quantization config for op: {_TFLOpName.FULLY_CONNECTED} with config:"
366
+ f" {op_quant_config} was not found in the policy."
367
+ )
368
+ with self.assertRaisesWithPredicateMatch(
369
+ ValueError, lambda err: error_message in str(err)
370
+ ):
371
+ min_max_quantize_utils.check_if_valid_op_config(
372
+ _TFLOpName.FULLY_CONNECTED,
373
+ op_quant_config,
374
+ _DEFAULT_CONFIG_CHECK_POLICY,
375
+ )
376
+
377
+ @parameterized.product(
378
+ op_name=[
379
+ _TFLOpName.FULLY_CONNECTED,
380
+ _TFLOpName.CONV_2D,
381
+ ],
382
+ activation_tensor_config=[
383
+ None,
384
+ _TensorQuantConfig(num_bits=8, symmetric=False),
385
+ _TensorQuantConfig(num_bits=16, symmetric=True),
386
+ ],
387
+ compute_precision=[
388
+ _ComputePrecision.FLOAT,
389
+ _ComputePrecision.INTEGER,
390
+ ],
391
+ )
392
+ def test_check_supported_int4_config_succeeds(
393
+ self, op_name, activation_tensor_config, compute_precision
394
+ ):
395
+ # Exclude invalid SRQ config.
396
+ if (
397
+ activation_tensor_config is not None
398
+ and compute_precision != _ComputePrecision.INTEGER
399
+ ) or (
400
+ activation_tensor_config is None
401
+ and compute_precision == _ComputePrecision.FLOAT
402
+ ):
403
+ return
404
+ # TODO: b/353365054 - Remove this check after int4 DRQ is supported for
405
+ # conv2d.
406
+ if (
407
+ # Check if DRQ and CONV_2D.
408
+ compute_precision == _ComputePrecision.INTEGER
409
+ and activation_tensor_config is None
410
+ and op_name == _TFLOpName.CONV_2D
411
+ ):
412
+ return
413
+ op_quant_config = _OpQuantConfig(
414
+ activation_tensor_config=activation_tensor_config,
415
+ weight_tensor_config=_TensorQuantConfig(
416
+ num_bits=4,
417
+ symmetric=True,
418
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
419
+ ),
420
+ compute_precision=compute_precision,
421
+ )
422
+ # Raise error if the config is not supported.
423
+ # Check if DRQ.
424
+ if (
425
+ compute_precision == _ComputePrecision.INTEGER
426
+ and op_quant_config.activation_tensor_config is None
427
+ ):
428
+ min_max_quantize_utils.check_if_valid_op_config(
429
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
430
+ )
431
+ # Check if WEIGHT_ONLY.
432
+ elif (
433
+ compute_precision == _ComputePrecision.FLOAT
434
+ and op_quant_config.explicit_dequantize
435
+ ):
436
+ self.assertIn(op_name, _DEFAULT_CONFIG_CHECK_POLICY.keys())
437
+ # Check if SRQ.
438
+ if (
439
+ compute_precision == _ComputePrecision.INTEGER
440
+ and op_quant_config.activation_tensor_config is not None
441
+ ):
442
+ min_max_quantize_utils.check_if_valid_op_config(
443
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
444
+ )
445
+
446
+ @parameterized.product(
447
+ op_name=[_TFLOpName.BATCH_MATMUL],
448
+ activation_tensor_config=[
449
+ None,
450
+ _TensorQuantConfig(num_bits=8, symmetric=False),
451
+ _TensorQuantConfig(num_bits=16, symmetric=True),
452
+ ],
453
+ test_case=[
454
+ # Tuple holds compute precision and whether to use drq.
455
+ (_ComputePrecision.INTEGER, True),
456
+ (_ComputePrecision.INTEGER, False),
457
+ ],
458
+ )
459
+ def test_check_unsupported_int4_config_raise_error(
460
+ self, op_name, activation_tensor_config, test_case
461
+ ):
462
+ compute_precision, is_drq = test_case
463
+ # Exclude invalid SRQ config.
464
+ if (activation_tensor_config is not None and is_drq) or (
465
+ activation_tensor_config is None and not is_drq
466
+ ):
467
+ return
468
+ op_quant_config = _OpQuantConfig(
469
+ activation_tensor_config=activation_tensor_config,
470
+ weight_tensor_config=_TensorQuantConfig(
471
+ num_bits=4,
472
+ symmetric=True,
473
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
474
+ ),
475
+ compute_precision=compute_precision,
476
+ )
477
+
478
+ with self.assertRaises(ValueError):
479
+ if is_drq:
480
+ min_max_quantize_utils.check_if_valid_op_config(
481
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
482
+ )
483
+ elif not is_drq:
484
+ min_max_quantize_utils.check_if_valid_op_config(
485
+ op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
486
+ )
487
+
488
+ def test_materialize_op_with_output_activation_constraint_fails_for_multiple_output_op(
489
+ self,
490
+ ):
491
+ # Create a mock op with multiple outputs.
492
+ MockOp = collections.namedtuple("MockOp", ["outputs"])
493
+ mock_op_info = qtyping.OpInfo(
494
+ op=MockOp(outputs=[1, 2]),
495
+ op_name=_TFLOpName.SOFTMAX,
496
+ subgraph_op_index=0,
497
+ op_quant_config=_OpQuantConfig(),
498
+ )
499
+
500
+ with self.assertRaisesRegex(
501
+ ValueError, "only supports ops with a single output tensor"
502
+ ):
503
+ min_max_quantize_utils.materialize_op_with_output_activation_constraint(
504
+ op_info=mock_op_info,
505
+ graph_info=qtyping.GraphInfo([], []),
506
+ tensor_name_to_qsv={},
507
+ output_activation_constraints={},
508
+ )
509
+
510
+
511
+ if __name__ == "__main__":
512
+ googletest.main()