ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -82,7 +82,6 @@ class RecipeManager:
|
|
|
82
82
|
str, list[OpQuantizationRecipe]
|
|
83
83
|
] = collections.OrderedDict()
|
|
84
84
|
|
|
85
|
-
# TODO: b/335254997 - Check if an op quantization config is supported.
|
|
86
85
|
def add_quantization_config(
|
|
87
86
|
self,
|
|
88
87
|
regex: str,
|
|
@@ -109,6 +108,11 @@ class RecipeManager:
|
|
|
109
108
|
configuration will be used.
|
|
110
109
|
algorithm_key: Algorithm key to be applied.
|
|
111
110
|
"""
|
|
111
|
+
try:
|
|
112
|
+
algorithm_manager.AlgorithmName(algorithm_key)
|
|
113
|
+
except ValueError as e:
|
|
114
|
+
raise ValueError(f'Unsupported algorithm key: {algorithm_key}.') from e
|
|
115
|
+
|
|
112
116
|
if op_config is None:
|
|
113
117
|
op_config = _OpQuantizationConfig()
|
|
114
118
|
|
|
@@ -243,3 +247,156 @@ class RecipeManager:
|
|
|
243
247
|
):
|
|
244
248
|
return True
|
|
245
249
|
return False
|
|
250
|
+
|
|
251
|
+
def add_dynamic_config(
|
|
252
|
+
self,
|
|
253
|
+
regex: str,
|
|
254
|
+
operation_name: _TFLOpName,
|
|
255
|
+
num_bits: int,
|
|
256
|
+
granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
|
|
257
|
+
algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
258
|
+
):
|
|
259
|
+
"""Adds a dynamic quantization configuration to the recipe.
|
|
260
|
+
|
|
261
|
+
During dynamic quantization, activations are not processed by AEQ and
|
|
262
|
+
remain in float format. The runtime kernel is expected to quantize these
|
|
263
|
+
activations on-the-fly, as indicated by compute_precision=Integer and
|
|
264
|
+
explicit_dequantize=False.
|
|
265
|
+
|
|
266
|
+
The model quality may suffer due to the on-the-fly quantization. If quality
|
|
267
|
+
is a concern, consider using weight-only
|
|
268
|
+
quantization.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
regex: Regular expression for layer name matching.
|
|
272
|
+
operation_name: Target TFLite operation.
|
|
273
|
+
num_bits: Number of bits for quantization.
|
|
274
|
+
granularity: Granularity of quantization.
|
|
275
|
+
algorithm_key: Algorithm key to be applied.
|
|
276
|
+
"""
|
|
277
|
+
weight_config = qtyping.TensorQuantizationConfig(
|
|
278
|
+
num_bits=num_bits,
|
|
279
|
+
symmetric=True, # LiteRT kernels only support symmetric quantized
|
|
280
|
+
# weights.
|
|
281
|
+
granularity=granularity,
|
|
282
|
+
)
|
|
283
|
+
self.add_quantization_config(
|
|
284
|
+
regex,
|
|
285
|
+
operation_name,
|
|
286
|
+
op_config=_OpQuantizationConfig(
|
|
287
|
+
weight_tensor_config=weight_config,
|
|
288
|
+
compute_precision=qtyping.ComputePrecision.INTEGER,
|
|
289
|
+
explicit_dequantize=False,
|
|
290
|
+
),
|
|
291
|
+
algorithm_key=algorithm_key,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def add_weight_only_config(
|
|
295
|
+
self,
|
|
296
|
+
regex: str,
|
|
297
|
+
operation_name: _TFLOpName,
|
|
298
|
+
num_bits: int,
|
|
299
|
+
granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
|
|
300
|
+
algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
301
|
+
):
|
|
302
|
+
"""Adds a weight only quantization configuration to the recipe.
|
|
303
|
+
|
|
304
|
+
In weight-only quantization, weights are quantized, but the actual operation
|
|
305
|
+
(op) computation remains in float. The quantized weight is explicitly
|
|
306
|
+
dequantized before being fed into the op. This is achieved by inserting a
|
|
307
|
+
dequantize op between the quantized weight and the consuming op. To enable
|
|
308
|
+
this, both compute_precision will be set to Float and explicit_dequantize to
|
|
309
|
+
True.
|
|
310
|
+
|
|
311
|
+
Weight-only quantization is useful for reducing model size but may
|
|
312
|
+
not decrease latency due to float computation. However, quantized model
|
|
313
|
+
generally has better quality than other quantization options (e.g., dynamic
|
|
314
|
+
range quantization) due to no loss of precision on activations. If latency
|
|
315
|
+
is a concern, consider using dynamic quantization.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
regex: Regular expression for layer name matching.
|
|
319
|
+
operation_name: Target TFLite operation.
|
|
320
|
+
num_bits: Number of bits for quantization.
|
|
321
|
+
granularity: Granularity of quantization.
|
|
322
|
+
algorithm_key: Algorithm key to be applied.
|
|
323
|
+
"""
|
|
324
|
+
# Default to integer quantization but allow float quantization for
|
|
325
|
+
# FLOAT_CASTING algorithm. This is to support weight-only quantization with
|
|
326
|
+
# fp16 weights.
|
|
327
|
+
weight_dtype = qtyping.TensorDataType.INT
|
|
328
|
+
if algorithm_key == AlgorithmName.FLOAT_CASTING:
|
|
329
|
+
weight_dtype = qtyping.TensorDataType.FLOAT
|
|
330
|
+
|
|
331
|
+
weight_config = qtyping.TensorQuantizationConfig(
|
|
332
|
+
num_bits=num_bits,
|
|
333
|
+
symmetric=True, # TFL kernels only support symmetric quantized weights.
|
|
334
|
+
granularity=granularity,
|
|
335
|
+
dtype=weight_dtype,
|
|
336
|
+
)
|
|
337
|
+
self.add_quantization_config(
|
|
338
|
+
regex,
|
|
339
|
+
operation_name,
|
|
340
|
+
op_config=_OpQuantizationConfig(
|
|
341
|
+
weight_tensor_config=weight_config,
|
|
342
|
+
compute_precision=qtyping.ComputePrecision.FLOAT,
|
|
343
|
+
explicit_dequantize=True,
|
|
344
|
+
),
|
|
345
|
+
algorithm_key=algorithm_key,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
def add_static_config(
|
|
349
|
+
self,
|
|
350
|
+
regex: str,
|
|
351
|
+
operation_name: _TFLOpName,
|
|
352
|
+
activation_num_bits: int,
|
|
353
|
+
weight_num_bits: int,
|
|
354
|
+
weight_granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
|
|
355
|
+
algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
|
|
356
|
+
):
|
|
357
|
+
"""Adds a static range quantization configuration to the recipe.
|
|
358
|
+
|
|
359
|
+
In static quantization, both weights and activations are quantized. This
|
|
360
|
+
requires a calibration step to determine the quantization parameters (e.g.,
|
|
361
|
+
min/max ranges) for activations. The quantized model uses integer arithmetic
|
|
362
|
+
for computations, which can lead to significant latency reductions.
|
|
363
|
+
|
|
364
|
+
However, calibration is needed to determine the quantization parameters for
|
|
365
|
+
activations, which requires sample data and may lead to quality loss. If
|
|
366
|
+
there is no hardware requirement for full integer quantization, consider
|
|
367
|
+
using dynamic quantization for simplicity.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
regex: Regular expression for layer name matching.
|
|
371
|
+
operation_name: Target TFLite operation.
|
|
372
|
+
activation_num_bits: Number of bits for activation quantization.
|
|
373
|
+
weight_num_bits: Number of bits for weight quantization.
|
|
374
|
+
weight_granularity: Granularity of weight quantization.
|
|
375
|
+
algorithm_key: Algorithm key to be applied.
|
|
376
|
+
"""
|
|
377
|
+
if activation_num_bits not in [16, 8]:
|
|
378
|
+
raise ValueError(
|
|
379
|
+
'Activation quantization is only supported for 16 or 8 bits.'
|
|
380
|
+
)
|
|
381
|
+
# INT16 is symmetric and INT8 is asymmetric due to LiteRT kernel
|
|
382
|
+
# limitations.
|
|
383
|
+
activation_symmetric = activation_num_bits == 16
|
|
384
|
+
activation_config = qtyping.TensorQuantizationConfig(
|
|
385
|
+
num_bits=activation_num_bits, symmetric=activation_symmetric
|
|
386
|
+
)
|
|
387
|
+
weight_config = qtyping.TensorQuantizationConfig(
|
|
388
|
+
num_bits=weight_num_bits,
|
|
389
|
+
symmetric=True, # TFL kernels only support symmetric quantized weights.
|
|
390
|
+
granularity=weight_granularity,
|
|
391
|
+
)
|
|
392
|
+
self.add_quantization_config(
|
|
393
|
+
regex,
|
|
394
|
+
operation_name,
|
|
395
|
+
op_config=_OpQuantizationConfig(
|
|
396
|
+
activation_tensor_config=activation_config,
|
|
397
|
+
weight_tensor_config=weight_config,
|
|
398
|
+
compute_precision=qtyping.ComputePrecision.INTEGER,
|
|
399
|
+
explicit_dequantize=False,
|
|
400
|
+
),
|
|
401
|
+
algorithm_key=algorithm_key,
|
|
402
|
+
)
|
|
@@ -29,19 +29,6 @@ _AlgorithmName = recipe_manager.AlgorithmName
|
|
|
29
29
|
_QuantGranularity = qtyping.QuantGranularity
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
# Sample functions for test cases.
|
|
33
|
-
def _sample_init_qsvs(*_, **__):
|
|
34
|
-
return 1.0, dict()
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def _sample_calibration_func(*_, **__):
|
|
38
|
-
return 2.0, dict()
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def _sample_materialize_func(*_, **__):
|
|
42
|
-
return 3.0, dict()
|
|
43
|
-
|
|
44
|
-
|
|
45
32
|
def _sample_check_op_config_func(op_name, op_config, _):
|
|
46
33
|
if (
|
|
47
34
|
op_config.weight_tensor_config is not None
|
|
@@ -67,6 +54,16 @@ def _add_default_int8xint8_integer_recipe(recipe_manager_object):
|
|
|
67
54
|
|
|
68
55
|
# register some currently unsupported ops for testing purposes
|
|
69
56
|
def _register_testing_op(algorithm_key, tfl_op):
|
|
57
|
+
# Sample functions for test cases.
|
|
58
|
+
def _sample_init_qsvs(*_, **__):
|
|
59
|
+
return {'name': dict()}
|
|
60
|
+
|
|
61
|
+
def _sample_calibration_func(*_, **__):
|
|
62
|
+
return {'name2': dict()}
|
|
63
|
+
|
|
64
|
+
def _sample_materialize_func(*_, **__):
|
|
65
|
+
return []
|
|
66
|
+
|
|
70
67
|
algorithm_manager.register_op_quant_config_validation_func(
|
|
71
68
|
algorithm_key, _sample_check_op_config_func
|
|
72
69
|
)
|
|
@@ -244,19 +241,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
|
|
|
244
241
|
compute_precision=_ComputePrecision.INTEGER, # DRQ.
|
|
245
242
|
),
|
|
246
243
|
)
|
|
247
|
-
# Add unregistered algorithm
|
|
248
|
-
with self.assertRaisesWithPredicateMatch(
|
|
249
|
-
ValueError, lambda err: error_message in str(err)
|
|
250
|
-
):
|
|
251
|
-
self._recipe_manager.add_quantization_config(
|
|
252
|
-
regex='.*/Dense/.*',
|
|
253
|
-
operation_name=_TFLOpName.FULLY_CONNECTED,
|
|
254
|
-
algorithm_key='AWQ',
|
|
255
|
-
op_config=qtyping.OpQuantizationConfig(
|
|
256
|
-
weight_tensor_config=_TensorQuantConfig(num_bits=8),
|
|
257
|
-
compute_precision=_ComputePrecision.INTEGER, # DRQ.
|
|
258
|
-
),
|
|
259
|
-
)
|
|
260
244
|
|
|
261
245
|
def test_add_unsupported_num_bits_raise_error(self):
|
|
262
246
|
test_op_name = _TFLOpName.FULLY_CONNECTED
|
|
@@ -296,6 +280,142 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
|
|
|
296
280
|
# DRQ check.
|
|
297
281
|
self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
|
|
298
282
|
|
|
283
|
+
def test_add_unsupported_algorithm_key_raise_error(self):
|
|
284
|
+
error_message = 'Unsupported algorithm key'
|
|
285
|
+
with self.assertRaisesWithPredicateMatch(
|
|
286
|
+
ValueError, lambda err: error_message in str(err)
|
|
287
|
+
):
|
|
288
|
+
self._recipe_manager.add_quantization_config(
|
|
289
|
+
regex='.*/Dense/.*',
|
|
290
|
+
operation_name=_TFLOpName.FULLY_CONNECTED,
|
|
291
|
+
algorithm_key='decomposed_hadamard',
|
|
292
|
+
op_config=qtyping.OpQuantizationConfig(
|
|
293
|
+
weight_tensor_config=_TensorQuantConfig(num_bits=8),
|
|
294
|
+
),
|
|
295
|
+
)
|
|
296
|
+
with self.assertRaisesWithPredicateMatch(
|
|
297
|
+
ValueError, lambda err: error_message in str(err)
|
|
298
|
+
):
|
|
299
|
+
self._recipe_manager.add_quantization_config(
|
|
300
|
+
regex='.*/Dense/.*',
|
|
301
|
+
operation_name=_TFLOpName.ALL_SUPPORTED,
|
|
302
|
+
algorithm_key='decomposed_hadamard',
|
|
303
|
+
op_config=qtyping.OpQuantizationConfig(
|
|
304
|
+
weight_tensor_config=_TensorQuantConfig(num_bits=8),
|
|
305
|
+
),
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
def test_add_dynamic_config(self):
|
|
309
|
+
self._recipe_manager.add_dynamic_config(
|
|
310
|
+
regex='.*/Dense/.*',
|
|
311
|
+
operation_name=_TFLOpName.FULLY_CONNECTED,
|
|
312
|
+
num_bits=8,
|
|
313
|
+
)
|
|
314
|
+
alg_key, op_config = self._recipe_manager.get_quantization_configs(
|
|
315
|
+
_TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
|
|
316
|
+
)
|
|
317
|
+
self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
|
|
318
|
+
self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
|
|
319
|
+
self.assertFalse(op_config.explicit_dequantize)
|
|
320
|
+
self.assertIsNone(op_config.activation_tensor_config)
|
|
321
|
+
weight_tensor_config = op_config.weight_tensor_config
|
|
322
|
+
self.assertIsNotNone(weight_tensor_config)
|
|
323
|
+
self.assertEqual(weight_tensor_config.num_bits, 8)
|
|
324
|
+
self.assertTrue(weight_tensor_config.symmetric)
|
|
325
|
+
self.assertEqual(
|
|
326
|
+
weight_tensor_config.granularity,
|
|
327
|
+
_QuantGranularity.CHANNELWISE,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
@parameterized.parameters(4, 8)
|
|
331
|
+
def test_add_weight_only_config_int(self, num_bits):
|
|
332
|
+
self._recipe_manager.add_weight_only_config(
|
|
333
|
+
regex='.*/Dense/.*',
|
|
334
|
+
operation_name=_TFLOpName.FULLY_CONNECTED,
|
|
335
|
+
num_bits=num_bits,
|
|
336
|
+
)
|
|
337
|
+
alg_key, op_config = self._recipe_manager.get_quantization_configs(
|
|
338
|
+
_TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
|
|
339
|
+
)
|
|
340
|
+
self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
|
|
341
|
+
self.assertEqual(op_config.compute_precision, _ComputePrecision.FLOAT)
|
|
342
|
+
self.assertTrue(op_config.explicit_dequantize)
|
|
343
|
+
self.assertIsNone(op_config.activation_tensor_config)
|
|
344
|
+
weight_tensor_config = op_config.weight_tensor_config
|
|
345
|
+
self.assertIsNotNone(weight_tensor_config)
|
|
346
|
+
self.assertEqual(weight_tensor_config.num_bits, num_bits)
|
|
347
|
+
self.assertTrue(weight_tensor_config.symmetric)
|
|
348
|
+
self.assertEqual(
|
|
349
|
+
weight_tensor_config.granularity,
|
|
350
|
+
_QuantGranularity.CHANNELWISE,
|
|
351
|
+
)
|
|
352
|
+
self.assertEqual(weight_tensor_config.dtype, _TensorDataType.INT)
|
|
353
|
+
|
|
354
|
+
def test_add_weight_only_config_fp16(self):
|
|
355
|
+
self._recipe_manager.add_weight_only_config(
|
|
356
|
+
regex='.*/Dense2/.*',
|
|
357
|
+
operation_name=_TFLOpName.FULLY_CONNECTED,
|
|
358
|
+
num_bits=16,
|
|
359
|
+
algorithm_key=_AlgorithmName.FLOAT_CASTING,
|
|
360
|
+
)
|
|
361
|
+
alg_key, op_config = self._recipe_manager.get_quantization_configs(
|
|
362
|
+
_TFLOpName.FULLY_CONNECTED, 'model/Dense2/op'
|
|
363
|
+
)
|
|
364
|
+
self.assertEqual(alg_key, _AlgorithmName.FLOAT_CASTING)
|
|
365
|
+
self.assertEqual(op_config.compute_precision, _ComputePrecision.FLOAT)
|
|
366
|
+
self.assertTrue(op_config.explicit_dequantize)
|
|
367
|
+
self.assertIsNone(op_config.activation_tensor_config)
|
|
368
|
+
weight_tensor_config = op_config.weight_tensor_config
|
|
369
|
+
self.assertIsNotNone(weight_tensor_config)
|
|
370
|
+
self.assertEqual(weight_tensor_config.num_bits, 16)
|
|
371
|
+
self.assertTrue(weight_tensor_config.symmetric)
|
|
372
|
+
self.assertEqual(
|
|
373
|
+
weight_tensor_config.granularity,
|
|
374
|
+
_QuantGranularity.CHANNELWISE,
|
|
375
|
+
)
|
|
376
|
+
self.assertEqual(weight_tensor_config.dtype, _TensorDataType.FLOAT)
|
|
377
|
+
|
|
378
|
+
def test_add_weight_only_config_fp8_raise_error(self):
|
|
379
|
+
error_message = (
|
|
380
|
+
'float casting quantization config requires number of bits to be set'
|
|
381
|
+
' as 16'
|
|
382
|
+
)
|
|
383
|
+
with self.assertRaisesWithPredicateMatch(
|
|
384
|
+
ValueError, lambda err: error_message in str(err)
|
|
385
|
+
):
|
|
386
|
+
self._recipe_manager.add_weight_only_config(
|
|
387
|
+
regex='.*/Dense2/.*',
|
|
388
|
+
operation_name=_TFLOpName.FULLY_CONNECTED,
|
|
389
|
+
num_bits=8,
|
|
390
|
+
algorithm_key=_AlgorithmName.FLOAT_CASTING,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
def test_add_static_config(self):
|
|
394
|
+
self._recipe_manager.add_static_config(
|
|
395
|
+
regex='.*/Dense/.*',
|
|
396
|
+
operation_name=_TFLOpName.FULLY_CONNECTED,
|
|
397
|
+
activation_num_bits=8,
|
|
398
|
+
weight_num_bits=4,
|
|
399
|
+
)
|
|
400
|
+
alg_key, op_config = self._recipe_manager.get_quantization_configs(
|
|
401
|
+
_TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
|
|
402
|
+
)
|
|
403
|
+
self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
|
|
404
|
+
self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
|
|
405
|
+
self.assertFalse(op_config.explicit_dequantize)
|
|
406
|
+
activation_tensor_config = op_config.activation_tensor_config
|
|
407
|
+
self.assertIsNotNone(activation_tensor_config)
|
|
408
|
+
self.assertEqual(activation_tensor_config.num_bits, 8)
|
|
409
|
+
self.assertFalse(activation_tensor_config.symmetric)
|
|
410
|
+
weight_tensor_config = op_config.weight_tensor_config
|
|
411
|
+
self.assertIsNotNone(weight_tensor_config)
|
|
412
|
+
self.assertEqual(weight_tensor_config.num_bits, 4)
|
|
413
|
+
self.assertTrue(weight_tensor_config.symmetric)
|
|
414
|
+
self.assertEqual(
|
|
415
|
+
weight_tensor_config.granularity,
|
|
416
|
+
_QuantGranularity.CHANNELWISE,
|
|
417
|
+
)
|
|
418
|
+
|
|
299
419
|
def test_set_full_integer_quantization_config(self):
|
|
300
420
|
_add_default_int8xint8_integer_recipe(self._recipe_manager)
|
|
301
421
|
# Full integer setting is global
|
|
@@ -461,14 +581,12 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
|
|
|
461
581
|
'symmetric': False,
|
|
462
582
|
'granularity': _QuantGranularity.TENSORWISE,
|
|
463
583
|
'dtype': 'INT',
|
|
464
|
-
'block_size': 0,
|
|
465
584
|
},
|
|
466
585
|
'weight_tensor_config': {
|
|
467
586
|
'num_bits': 8,
|
|
468
587
|
'symmetric': True,
|
|
469
588
|
'granularity': _QuantGranularity.TENSORWISE,
|
|
470
589
|
'dtype': 'INT',
|
|
471
|
-
'block_size': 0,
|
|
472
590
|
},
|
|
473
591
|
# WEIGHT_ONLY.
|
|
474
592
|
'compute_precision': _ComputePrecision.INTEGER,
|
|
@@ -487,7 +605,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
|
|
|
487
605
|
'num_bits': 8,
|
|
488
606
|
'symmetric': True,
|
|
489
607
|
'granularity': _QuantGranularity.TENSORWISE,
|
|
490
|
-
'block_size': 0,
|
|
491
608
|
},
|
|
492
609
|
# WEIGHT_ONLY.
|
|
493
610
|
'compute_precision': _ComputePrecision.FLOAT,
|
|
@@ -506,7 +623,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
|
|
|
506
623
|
'num_bits': 4,
|
|
507
624
|
'symmetric': True,
|
|
508
625
|
'granularity': _QuantGranularity.TENSORWISE,
|
|
509
|
-
'block_size': 0,
|
|
510
626
|
},
|
|
511
627
|
# WEIGHT_ONLY.
|
|
512
628
|
'compute_precision': _ComputePrecision.FLOAT,
|
|
@@ -525,7 +641,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
|
|
|
525
641
|
'num_bits': 6,
|
|
526
642
|
'symmetric': True,
|
|
527
643
|
'granularity': _QuantGranularity.TENSORWISE,
|
|
528
|
-
'block_size': 0,
|
|
529
644
|
},
|
|
530
645
|
# WEIGHT_ONLY.
|
|
531
646
|
'compute_precision': _ComputePrecision.FLOAT,
|
|
@@ -544,7 +659,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
|
|
|
544
659
|
'num_bits': 3,
|
|
545
660
|
'symmetric': True,
|
|
546
661
|
'granularity': _QuantGranularity.TENSORWISE,
|
|
547
|
-
'block_size': 0,
|
|
548
662
|
},
|
|
549
663
|
# WEIGHT_ONLY.
|
|
550
664
|
'compute_precision': _ComputePrecision.FLOAT,
|
ai_edge_quantizer/recipe_test.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
import unittest # pylint: disable=unused-import, required for OSS.
|
|
17
18
|
|
|
18
19
|
from absl.testing import parameterized
|
|
19
20
|
|
|
@@ -21,6 +22,7 @@ from tensorflow.python.platform import googletest
|
|
|
21
22
|
from ai_edge_quantizer import quantizer
|
|
22
23
|
from ai_edge_quantizer import recipe
|
|
23
24
|
from ai_edge_quantizer.utils import test_utils
|
|
25
|
+
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
|
|
@@ -30,21 +32,67 @@ class RecipeTest(parameterized.TestCase):
|
|
|
30
32
|
|
|
31
33
|
def setUp(self):
|
|
32
34
|
super().setUp()
|
|
33
|
-
|
|
35
|
+
# Weights has < 1024 elements so legacy recipe will not quantize it.
|
|
36
|
+
self._small_model_path = os.path.join(
|
|
34
37
|
_TEST_DATA_PREFIX_PATH,
|
|
35
38
|
'tests/models/single_conv2d_transpose_bias.tflite',
|
|
36
39
|
)
|
|
40
|
+
self._test_model_path = os.path.join(
|
|
41
|
+
_TEST_DATA_PREFIX_PATH,
|
|
42
|
+
'tests/models/conv_fc_mnist.tflite',
|
|
43
|
+
)
|
|
37
44
|
|
|
38
|
-
def _quantize_with_recipe_func(self, recipe_func):
|
|
39
|
-
qt = quantizer.Quantizer(
|
|
45
|
+
def _quantize_with_recipe_func(self, recipe_func, test_model_path):
|
|
46
|
+
qt = quantizer.Quantizer(test_model_path)
|
|
40
47
|
qt.load_quantization_recipe(recipe_func())
|
|
41
48
|
self.assertIsNone(qt._result.quantized_model)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
49
|
+
if qt.need_calibration:
|
|
50
|
+
calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
|
|
51
|
+
qt.float_model,
|
|
52
|
+
num_samples=1,
|
|
53
|
+
)
|
|
54
|
+
calibration_result = qt.calibrate(calibration_data)
|
|
55
|
+
quantization_result = qt.quantize(calibration_result)
|
|
56
|
+
else:
|
|
57
|
+
quantization_result = qt.quantize()
|
|
58
|
+
self.assertIsNotNone(quantization_result.quantized_model)
|
|
59
|
+
return quantization_result
|
|
60
|
+
|
|
61
|
+
@unittest.skip('skipping due to b/438971945')
|
|
46
62
|
def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self):
|
|
47
|
-
quant_result = self._quantize_with_recipe_func(
|
|
63
|
+
quant_result = self._quantize_with_recipe_func(
|
|
64
|
+
recipe.dynamic_wi8_afp32, self._test_model_path
|
|
65
|
+
)
|
|
66
|
+
self.assertLess(
|
|
67
|
+
len(quant_result.quantized_model),
|
|
68
|
+
os.path.getsize(self._test_model_path),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@unittest.skip('skipping due to b/438971945')
|
|
72
|
+
def test_quantization_from_dynamic_wi4_afp32_func_succeeds(self):
|
|
73
|
+
quant_result = self._quantize_with_recipe_func(
|
|
74
|
+
recipe.dynamic_wi4_afp32, self._test_model_path
|
|
75
|
+
)
|
|
76
|
+
self.assertLess(
|
|
77
|
+
len(quant_result.quantized_model),
|
|
78
|
+
os.path.getsize(self._test_model_path),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
@unittest.skip('skipping due to b/438971945')
|
|
82
|
+
def test_quantization_from_weight_only_wi8_afp32_func_succeeds(self):
|
|
83
|
+
quant_result = self._quantize_with_recipe_func(
|
|
84
|
+
recipe.weight_only_wi8_afp32, self._test_model_path
|
|
85
|
+
)
|
|
86
|
+
self.assertLess(
|
|
87
|
+
len(quant_result.quantized_model),
|
|
88
|
+
os.path.getsize(self._test_model_path),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@unittest.skip('skipping due to b/438971945')
|
|
92
|
+
def test_quantization_from_weight_only_wi4_afp32_func_succeeds(self):
|
|
93
|
+
quant_result = self._quantize_with_recipe_func(
|
|
94
|
+
recipe.weight_only_wi4_afp32, self._test_model_path
|
|
95
|
+
)
|
|
48
96
|
self.assertLess(
|
|
49
97
|
len(quant_result.quantized_model),
|
|
50
98
|
os.path.getsize(self._test_model_path),
|
|
@@ -52,11 +100,12 @@ class RecipeTest(parameterized.TestCase):
|
|
|
52
100
|
|
|
53
101
|
def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self):
|
|
54
102
|
quant_result = self._quantize_with_recipe_func(
|
|
55
|
-
recipe.dynamic_legacy_wi8_afp32
|
|
103
|
+
recipe.dynamic_legacy_wi8_afp32,
|
|
104
|
+
self._small_model_path,
|
|
56
105
|
)
|
|
57
106
|
self.assertLen(
|
|
58
107
|
quant_result.quantized_model,
|
|
59
|
-
os.path.getsize(self.
|
|
108
|
+
os.path.getsize(self._small_model_path),
|
|
60
109
|
)
|
|
61
110
|
|
|
62
111
|
@parameterized.named_parameters(
|
|
@@ -65,28 +114,55 @@ class RecipeTest(parameterized.TestCase):
|
|
|
65
114
|
recipe_json_path='recipes/dynamic_wi8_afp32_recipe.json',
|
|
66
115
|
recipe_func=recipe.dynamic_wi8_afp32,
|
|
67
116
|
),
|
|
117
|
+
dict(
|
|
118
|
+
testcase_name='weight_only_wi8_afp32',
|
|
119
|
+
recipe_json_path='recipes/default_af32w8float_recipe.json',
|
|
120
|
+
recipe_func=recipe.weight_only_wi8_afp32,
|
|
121
|
+
),
|
|
122
|
+
dict(
|
|
123
|
+
testcase_name='weight_only_wi4_afp32',
|
|
124
|
+
recipe_json_path='recipes/default_af32w4float_recipe.json',
|
|
125
|
+
recipe_func=recipe.weight_only_wi4_afp32,
|
|
126
|
+
),
|
|
68
127
|
dict(
|
|
69
128
|
testcase_name='dynamic_legacy_wi8_afp32',
|
|
70
129
|
recipe_json_path='recipes/dynamic_legacy_wi8_afp32_recipe.json',
|
|
71
130
|
recipe_func=recipe.dynamic_legacy_wi8_afp32,
|
|
72
131
|
),
|
|
132
|
+
dict(
|
|
133
|
+
testcase_name='a8w8',
|
|
134
|
+
recipe_json_path='recipes/default_a8w8_recipe.json',
|
|
135
|
+
recipe_func=recipe.static_wi8_ai8,
|
|
136
|
+
),
|
|
137
|
+
dict(
|
|
138
|
+
testcase_name='a16w8',
|
|
139
|
+
recipe_json_path='recipes/default_a16w8_recipe.json',
|
|
140
|
+
recipe_func=recipe.static_wi8_ai16,
|
|
141
|
+
),
|
|
73
142
|
)
|
|
143
|
+
@unittest.skip('skipping due to b/438971945')
|
|
74
144
|
def test_recipe_func_and_json_matches(self, recipe_json_path, recipe_func):
|
|
75
145
|
# Quantize with recipe from function in recipe module.
|
|
76
|
-
quant_result_from_func = self._quantize_with_recipe_func(
|
|
146
|
+
quant_result_from_func = self._quantize_with_recipe_func(
|
|
147
|
+
recipe_func, self._test_model_path
|
|
148
|
+
)
|
|
77
149
|
|
|
78
150
|
# Quantize with recipe from json file.
|
|
79
151
|
qt_json = quantizer.Quantizer(self._test_model_path)
|
|
80
152
|
json_recipe_path = os.path.join(_TEST_DATA_PREFIX_PATH, recipe_json_path)
|
|
81
153
|
qt_json.load_quantization_recipe(json_recipe_path)
|
|
82
|
-
|
|
154
|
+
if qt_json.need_calibration:
|
|
155
|
+
calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
|
|
156
|
+
qt_json.float_model,
|
|
157
|
+
num_samples=1,
|
|
158
|
+
)
|
|
159
|
+
calibration_result = qt_json.calibrate(calibration_data)
|
|
160
|
+
quant_result_from_json = qt_json.quantize(calibration_result)
|
|
161
|
+
else:
|
|
162
|
+
quant_result_from_json = qt_json.quantize()
|
|
83
163
|
self.assertIsNotNone(quant_result_from_json.quantized_model)
|
|
84
164
|
|
|
85
|
-
# Check if the
|
|
86
|
-
self.assertEqual(
|
|
87
|
-
quant_result_from_func.recipe,
|
|
88
|
-
quant_result_from_json.recipe,
|
|
89
|
-
)
|
|
165
|
+
# Check if the quantized models match.
|
|
90
166
|
self.assertEqual(
|
|
91
167
|
len(quant_result_from_func.quantized_model),
|
|
92
168
|
len(quant_result_from_json.quantized_model),
|