ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -82,7 +82,6 @@ class RecipeManager:
82
82
  str, list[OpQuantizationRecipe]
83
83
  ] = collections.OrderedDict()
84
84
 
85
- # TODO: b/335254997 - Check if an op quantization config is supported.
86
85
  def add_quantization_config(
87
86
  self,
88
87
  regex: str,
@@ -109,6 +108,11 @@ class RecipeManager:
109
108
  configuration will be used.
110
109
  algorithm_key: Algorithm key to be applied.
111
110
  """
111
+ try:
112
+ algorithm_manager.AlgorithmName(algorithm_key)
113
+ except ValueError as e:
114
+ raise ValueError(f'Unsupported algorithm key: {algorithm_key}.') from e
115
+
112
116
  if op_config is None:
113
117
  op_config = _OpQuantizationConfig()
114
118
 
@@ -243,3 +247,156 @@ class RecipeManager:
243
247
  ):
244
248
  return True
245
249
  return False
250
+
251
+ def add_dynamic_config(
252
+ self,
253
+ regex: str,
254
+ operation_name: _TFLOpName,
255
+ num_bits: int,
256
+ granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
257
+ algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
258
+ ):
259
+ """Adds a dynamic quantization configuration to the recipe.
260
+
261
+ During dynamic quantization, activations are not processed by AEQ and
262
+ remain in float format. The runtime kernel is expected to quantize these
263
+ activations on-the-fly, as indicated by compute_precision=Integer and
264
+ explicit_dequantize=False.
265
+
266
+ The model quality may suffer due to the on-the-fly quantization. If quality
267
+ is a concern, consider using weight-only
268
+ quantization.
269
+
270
+ Args:
271
+ regex: Regular expression for layer name matching.
272
+ operation_name: Target TFLite operation.
273
+ num_bits: Number of bits for quantization.
274
+ granularity: Granularity of quantization.
275
+ algorithm_key: Algorithm key to be applied.
276
+ """
277
+ weight_config = qtyping.TensorQuantizationConfig(
278
+ num_bits=num_bits,
279
+ symmetric=True, # LiteRT kernels only support symmetric quantized
280
+ # weights.
281
+ granularity=granularity,
282
+ )
283
+ self.add_quantization_config(
284
+ regex,
285
+ operation_name,
286
+ op_config=_OpQuantizationConfig(
287
+ weight_tensor_config=weight_config,
288
+ compute_precision=qtyping.ComputePrecision.INTEGER,
289
+ explicit_dequantize=False,
290
+ ),
291
+ algorithm_key=algorithm_key,
292
+ )
293
+
294
+ def add_weight_only_config(
295
+ self,
296
+ regex: str,
297
+ operation_name: _TFLOpName,
298
+ num_bits: int,
299
+ granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
300
+ algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
301
+ ):
302
+ """Adds a weight only quantization configuration to the recipe.
303
+
304
+ In weight-only quantization, weights are quantized, but the actual operation
305
+ (op) computation remains in float. The quantized weight is explicitly
306
+ dequantized before being fed into the op. This is achieved by inserting a
307
+ dequantize op between the quantized weight and the consuming op. To enable
308
+ this, both compute_precision will be set to Float and explicit_dequantize to
309
+ True.
310
+
311
+ Weight-only quantization is useful for reducing model size but may
312
+ not decrease latency due to float computation. However, quantized model
313
+ generally has better quality than other quantization options (e.g., dynamic
314
+ range quantization) due to no loss of precision on activations. If latency
315
+ is a concern, consider using dynamic quantization.
316
+
317
+ Args:
318
+ regex: Regular expression for layer name matching.
319
+ operation_name: Target TFLite operation.
320
+ num_bits: Number of bits for quantization.
321
+ granularity: Granularity of quantization.
322
+ algorithm_key: Algorithm key to be applied.
323
+ """
324
+ # Default to integer quantization but allow float quantization for
325
+ # FLOAT_CASTING algorithm. This is to support weight-only quantization with
326
+ # fp16 weights.
327
+ weight_dtype = qtyping.TensorDataType.INT
328
+ if algorithm_key == AlgorithmName.FLOAT_CASTING:
329
+ weight_dtype = qtyping.TensorDataType.FLOAT
330
+
331
+ weight_config = qtyping.TensorQuantizationConfig(
332
+ num_bits=num_bits,
333
+ symmetric=True, # TFL kernels only support symmetric quantized weights.
334
+ granularity=granularity,
335
+ dtype=weight_dtype,
336
+ )
337
+ self.add_quantization_config(
338
+ regex,
339
+ operation_name,
340
+ op_config=_OpQuantizationConfig(
341
+ weight_tensor_config=weight_config,
342
+ compute_precision=qtyping.ComputePrecision.FLOAT,
343
+ explicit_dequantize=True,
344
+ ),
345
+ algorithm_key=algorithm_key,
346
+ )
347
+
348
+ def add_static_config(
349
+ self,
350
+ regex: str,
351
+ operation_name: _TFLOpName,
352
+ activation_num_bits: int,
353
+ weight_num_bits: int,
354
+ weight_granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
355
+ algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
356
+ ):
357
+ """Adds a static range quantization configuration to the recipe.
358
+
359
+ In static quantization, both weights and activations are quantized. This
360
+ requires a calibration step to determine the quantization parameters (e.g.,
361
+ min/max ranges) for activations. The quantized model uses integer arithmetic
362
+ for computations, which can lead to significant latency reductions.
363
+
364
+ However, calibration is needed to determine the quantization parameters for
365
+ activations, which requires sample data and may lead to quality loss. If
366
+ there is no hardware requirement for full integer quantization, consider
367
+ using dynamic quantization for simplicity.
368
+
369
+ Args:
370
+ regex: Regular expression for layer name matching.
371
+ operation_name: Target TFLite operation.
372
+ activation_num_bits: Number of bits for activation quantization.
373
+ weight_num_bits: Number of bits for weight quantization.
374
+ weight_granularity: Granularity of weight quantization.
375
+ algorithm_key: Algorithm key to be applied.
376
+ """
377
+ if activation_num_bits not in [16, 8]:
378
+ raise ValueError(
379
+ 'Activation quantization is only supported for 16 or 8 bits.'
380
+ )
381
+ # INT16 is symmetric and INT8 is asymmetric due to LiteRT kernel
382
+ # limitations.
383
+ activation_symmetric = activation_num_bits == 16
384
+ activation_config = qtyping.TensorQuantizationConfig(
385
+ num_bits=activation_num_bits, symmetric=activation_symmetric
386
+ )
387
+ weight_config = qtyping.TensorQuantizationConfig(
388
+ num_bits=weight_num_bits,
389
+ symmetric=True, # TFL kernels only support symmetric quantized weights.
390
+ granularity=weight_granularity,
391
+ )
392
+ self.add_quantization_config(
393
+ regex,
394
+ operation_name,
395
+ op_config=_OpQuantizationConfig(
396
+ activation_tensor_config=activation_config,
397
+ weight_tensor_config=weight_config,
398
+ compute_precision=qtyping.ComputePrecision.INTEGER,
399
+ explicit_dequantize=False,
400
+ ),
401
+ algorithm_key=algorithm_key,
402
+ )
@@ -29,19 +29,6 @@ _AlgorithmName = recipe_manager.AlgorithmName
29
29
  _QuantGranularity = qtyping.QuantGranularity
30
30
 
31
31
 
32
- # Sample functions for test cases.
33
- def _sample_init_qsvs(*_, **__):
34
- return 1.0, dict()
35
-
36
-
37
- def _sample_calibration_func(*_, **__):
38
- return 2.0, dict()
39
-
40
-
41
- def _sample_materialize_func(*_, **__):
42
- return 3.0, dict()
43
-
44
-
45
32
  def _sample_check_op_config_func(op_name, op_config, _):
46
33
  if (
47
34
  op_config.weight_tensor_config is not None
@@ -67,6 +54,16 @@ def _add_default_int8xint8_integer_recipe(recipe_manager_object):
67
54
 
68
55
  # register some currently unsupported ops for testing purposes
69
56
  def _register_testing_op(algorithm_key, tfl_op):
57
+ # Sample functions for test cases.
58
+ def _sample_init_qsvs(*_, **__):
59
+ return {'name': dict()}
60
+
61
+ def _sample_calibration_func(*_, **__):
62
+ return {'name2': dict()}
63
+
64
+ def _sample_materialize_func(*_, **__):
65
+ return []
66
+
70
67
  algorithm_manager.register_op_quant_config_validation_func(
71
68
  algorithm_key, _sample_check_op_config_func
72
69
  )
@@ -244,19 +241,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
244
241
  compute_precision=_ComputePrecision.INTEGER, # DRQ.
245
242
  ),
246
243
  )
247
- # Add unregistered algorithm
248
- with self.assertRaisesWithPredicateMatch(
249
- ValueError, lambda err: error_message in str(err)
250
- ):
251
- self._recipe_manager.add_quantization_config(
252
- regex='.*/Dense/.*',
253
- operation_name=_TFLOpName.FULLY_CONNECTED,
254
- algorithm_key='AWQ',
255
- op_config=qtyping.OpQuantizationConfig(
256
- weight_tensor_config=_TensorQuantConfig(num_bits=8),
257
- compute_precision=_ComputePrecision.INTEGER, # DRQ.
258
- ),
259
- )
260
244
 
261
245
  def test_add_unsupported_num_bits_raise_error(self):
262
246
  test_op_name = _TFLOpName.FULLY_CONNECTED
@@ -296,6 +280,142 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
296
280
  # DRQ check.
297
281
  self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
298
282
 
283
+ def test_add_unsupported_algorithm_key_raise_error(self):
284
+ error_message = 'Unsupported algorithm key'
285
+ with self.assertRaisesWithPredicateMatch(
286
+ ValueError, lambda err: error_message in str(err)
287
+ ):
288
+ self._recipe_manager.add_quantization_config(
289
+ regex='.*/Dense/.*',
290
+ operation_name=_TFLOpName.FULLY_CONNECTED,
291
+ algorithm_key='decomposed_hadamard',
292
+ op_config=qtyping.OpQuantizationConfig(
293
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
294
+ ),
295
+ )
296
+ with self.assertRaisesWithPredicateMatch(
297
+ ValueError, lambda err: error_message in str(err)
298
+ ):
299
+ self._recipe_manager.add_quantization_config(
300
+ regex='.*/Dense/.*',
301
+ operation_name=_TFLOpName.ALL_SUPPORTED,
302
+ algorithm_key='decomposed_hadamard',
303
+ op_config=qtyping.OpQuantizationConfig(
304
+ weight_tensor_config=_TensorQuantConfig(num_bits=8),
305
+ ),
306
+ )
307
+
308
+ def test_add_dynamic_config(self):
309
+ self._recipe_manager.add_dynamic_config(
310
+ regex='.*/Dense/.*',
311
+ operation_name=_TFLOpName.FULLY_CONNECTED,
312
+ num_bits=8,
313
+ )
314
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
315
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
316
+ )
317
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
318
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
319
+ self.assertFalse(op_config.explicit_dequantize)
320
+ self.assertIsNone(op_config.activation_tensor_config)
321
+ weight_tensor_config = op_config.weight_tensor_config
322
+ self.assertIsNotNone(weight_tensor_config)
323
+ self.assertEqual(weight_tensor_config.num_bits, 8)
324
+ self.assertTrue(weight_tensor_config.symmetric)
325
+ self.assertEqual(
326
+ weight_tensor_config.granularity,
327
+ _QuantGranularity.CHANNELWISE,
328
+ )
329
+
330
+ @parameterized.parameters(4, 8)
331
+ def test_add_weight_only_config_int(self, num_bits):
332
+ self._recipe_manager.add_weight_only_config(
333
+ regex='.*/Dense/.*',
334
+ operation_name=_TFLOpName.FULLY_CONNECTED,
335
+ num_bits=num_bits,
336
+ )
337
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
338
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
339
+ )
340
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
341
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.FLOAT)
342
+ self.assertTrue(op_config.explicit_dequantize)
343
+ self.assertIsNone(op_config.activation_tensor_config)
344
+ weight_tensor_config = op_config.weight_tensor_config
345
+ self.assertIsNotNone(weight_tensor_config)
346
+ self.assertEqual(weight_tensor_config.num_bits, num_bits)
347
+ self.assertTrue(weight_tensor_config.symmetric)
348
+ self.assertEqual(
349
+ weight_tensor_config.granularity,
350
+ _QuantGranularity.CHANNELWISE,
351
+ )
352
+ self.assertEqual(weight_tensor_config.dtype, _TensorDataType.INT)
353
+
354
+ def test_add_weight_only_config_fp16(self):
355
+ self._recipe_manager.add_weight_only_config(
356
+ regex='.*/Dense2/.*',
357
+ operation_name=_TFLOpName.FULLY_CONNECTED,
358
+ num_bits=16,
359
+ algorithm_key=_AlgorithmName.FLOAT_CASTING,
360
+ )
361
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
362
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense2/op'
363
+ )
364
+ self.assertEqual(alg_key, _AlgorithmName.FLOAT_CASTING)
365
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.FLOAT)
366
+ self.assertTrue(op_config.explicit_dequantize)
367
+ self.assertIsNone(op_config.activation_tensor_config)
368
+ weight_tensor_config = op_config.weight_tensor_config
369
+ self.assertIsNotNone(weight_tensor_config)
370
+ self.assertEqual(weight_tensor_config.num_bits, 16)
371
+ self.assertTrue(weight_tensor_config.symmetric)
372
+ self.assertEqual(
373
+ weight_tensor_config.granularity,
374
+ _QuantGranularity.CHANNELWISE,
375
+ )
376
+ self.assertEqual(weight_tensor_config.dtype, _TensorDataType.FLOAT)
377
+
378
+ def test_add_weight_only_config_fp8_raise_error(self):
379
+ error_message = (
380
+ 'float casting quantization config requires number of bits to be set'
381
+ ' as 16'
382
+ )
383
+ with self.assertRaisesWithPredicateMatch(
384
+ ValueError, lambda err: error_message in str(err)
385
+ ):
386
+ self._recipe_manager.add_weight_only_config(
387
+ regex='.*/Dense2/.*',
388
+ operation_name=_TFLOpName.FULLY_CONNECTED,
389
+ num_bits=8,
390
+ algorithm_key=_AlgorithmName.FLOAT_CASTING,
391
+ )
392
+
393
+ def test_add_static_config(self):
394
+ self._recipe_manager.add_static_config(
395
+ regex='.*/Dense/.*',
396
+ operation_name=_TFLOpName.FULLY_CONNECTED,
397
+ activation_num_bits=8,
398
+ weight_num_bits=4,
399
+ )
400
+ alg_key, op_config = self._recipe_manager.get_quantization_configs(
401
+ _TFLOpName.FULLY_CONNECTED, 'model/Dense/op'
402
+ )
403
+ self.assertEqual(alg_key, _AlgorithmName.MIN_MAX_UNIFORM_QUANT)
404
+ self.assertEqual(op_config.compute_precision, _ComputePrecision.INTEGER)
405
+ self.assertFalse(op_config.explicit_dequantize)
406
+ activation_tensor_config = op_config.activation_tensor_config
407
+ self.assertIsNotNone(activation_tensor_config)
408
+ self.assertEqual(activation_tensor_config.num_bits, 8)
409
+ self.assertFalse(activation_tensor_config.symmetric)
410
+ weight_tensor_config = op_config.weight_tensor_config
411
+ self.assertIsNotNone(weight_tensor_config)
412
+ self.assertEqual(weight_tensor_config.num_bits, 4)
413
+ self.assertTrue(weight_tensor_config.symmetric)
414
+ self.assertEqual(
415
+ weight_tensor_config.granularity,
416
+ _QuantGranularity.CHANNELWISE,
417
+ )
418
+
299
419
  def test_set_full_integer_quantization_config(self):
300
420
  _add_default_int8xint8_integer_recipe(self._recipe_manager)
301
421
  # Full integer setting is global
@@ -461,14 +581,12 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
461
581
  'symmetric': False,
462
582
  'granularity': _QuantGranularity.TENSORWISE,
463
583
  'dtype': 'INT',
464
- 'block_size': 0,
465
584
  },
466
585
  'weight_tensor_config': {
467
586
  'num_bits': 8,
468
587
  'symmetric': True,
469
588
  'granularity': _QuantGranularity.TENSORWISE,
470
589
  'dtype': 'INT',
471
- 'block_size': 0,
472
590
  },
473
591
  # WEIGHT_ONLY.
474
592
  'compute_precision': _ComputePrecision.INTEGER,
@@ -487,7 +605,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
487
605
  'num_bits': 8,
488
606
  'symmetric': True,
489
607
  'granularity': _QuantGranularity.TENSORWISE,
490
- 'block_size': 0,
491
608
  },
492
609
  # WEIGHT_ONLY.
493
610
  'compute_precision': _ComputePrecision.FLOAT,
@@ -506,7 +623,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
506
623
  'num_bits': 4,
507
624
  'symmetric': True,
508
625
  'granularity': _QuantGranularity.TENSORWISE,
509
- 'block_size': 0,
510
626
  },
511
627
  # WEIGHT_ONLY.
512
628
  'compute_precision': _ComputePrecision.FLOAT,
@@ -525,7 +641,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
525
641
  'num_bits': 6,
526
642
  'symmetric': True,
527
643
  'granularity': _QuantGranularity.TENSORWISE,
528
- 'block_size': 0,
529
644
  },
530
645
  # WEIGHT_ONLY.
531
646
  'compute_precision': _ComputePrecision.FLOAT,
@@ -544,7 +659,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
544
659
  'num_bits': 3,
545
660
  'symmetric': True,
546
661
  'granularity': _QuantGranularity.TENSORWISE,
547
- 'block_size': 0,
548
662
  },
549
663
  # WEIGHT_ONLY.
550
664
  'compute_precision': _ComputePrecision.FLOAT,
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import os
17
+ import unittest # pylint: disable=unused-import, required for OSS.
17
18
 
18
19
  from absl.testing import parameterized
19
20
 
@@ -21,6 +22,7 @@ from tensorflow.python.platform import googletest
21
22
  from ai_edge_quantizer import quantizer
22
23
  from ai_edge_quantizer import recipe
23
24
  from ai_edge_quantizer.utils import test_utils
25
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
24
26
 
25
27
 
26
28
  _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
@@ -30,21 +32,67 @@ class RecipeTest(parameterized.TestCase):
30
32
 
31
33
  def setUp(self):
32
34
  super().setUp()
33
- self._test_model_path = os.path.join(
35
+ # Weights has < 1024 elements so legacy recipe will not quantize it.
36
+ self._small_model_path = os.path.join(
34
37
  _TEST_DATA_PREFIX_PATH,
35
38
  'tests/models/single_conv2d_transpose_bias.tflite',
36
39
  )
40
+ self._test_model_path = os.path.join(
41
+ _TEST_DATA_PREFIX_PATH,
42
+ 'tests/models/conv_fc_mnist.tflite',
43
+ )
37
44
 
38
- def _quantize_with_recipe_func(self, recipe_func):
39
- qt = quantizer.Quantizer(self._test_model_path)
45
+ def _quantize_with_recipe_func(self, recipe_func, test_model_path):
46
+ qt = quantizer.Quantizer(test_model_path)
40
47
  qt.load_quantization_recipe(recipe_func())
41
48
  self.assertIsNone(qt._result.quantized_model)
42
- quant_result = qt.quantize()
43
- self.assertIsNotNone(quant_result.quantized_model)
44
- return quant_result
45
-
49
+ if qt.need_calibration:
50
+ calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
51
+ qt.float_model,
52
+ num_samples=1,
53
+ )
54
+ calibration_result = qt.calibrate(calibration_data)
55
+ quantization_result = qt.quantize(calibration_result)
56
+ else:
57
+ quantization_result = qt.quantize()
58
+ self.assertIsNotNone(quantization_result.quantized_model)
59
+ return quantization_result
60
+
61
+ @unittest.skip('skipping due to b/438971945')
46
62
  def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self):
47
- quant_result = self._quantize_with_recipe_func(recipe.dynamic_wi8_afp32)
63
+ quant_result = self._quantize_with_recipe_func(
64
+ recipe.dynamic_wi8_afp32, self._test_model_path
65
+ )
66
+ self.assertLess(
67
+ len(quant_result.quantized_model),
68
+ os.path.getsize(self._test_model_path),
69
+ )
70
+
71
+ @unittest.skip('skipping due to b/438971945')
72
+ def test_quantization_from_dynamic_wi4_afp32_func_succeeds(self):
73
+ quant_result = self._quantize_with_recipe_func(
74
+ recipe.dynamic_wi4_afp32, self._test_model_path
75
+ )
76
+ self.assertLess(
77
+ len(quant_result.quantized_model),
78
+ os.path.getsize(self._test_model_path),
79
+ )
80
+
81
+ @unittest.skip('skipping due to b/438971945')
82
+ def test_quantization_from_weight_only_wi8_afp32_func_succeeds(self):
83
+ quant_result = self._quantize_with_recipe_func(
84
+ recipe.weight_only_wi8_afp32, self._test_model_path
85
+ )
86
+ self.assertLess(
87
+ len(quant_result.quantized_model),
88
+ os.path.getsize(self._test_model_path),
89
+ )
90
+
91
+ @unittest.skip('skipping due to b/438971945')
92
+ def test_quantization_from_weight_only_wi4_afp32_func_succeeds(self):
93
+ quant_result = self._quantize_with_recipe_func(
94
+ recipe.weight_only_wi4_afp32, self._test_model_path
95
+ )
48
96
  self.assertLess(
49
97
  len(quant_result.quantized_model),
50
98
  os.path.getsize(self._test_model_path),
@@ -52,11 +100,12 @@ class RecipeTest(parameterized.TestCase):
52
100
 
53
101
  def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self):
54
102
  quant_result = self._quantize_with_recipe_func(
55
- recipe.dynamic_legacy_wi8_afp32
103
+ recipe.dynamic_legacy_wi8_afp32,
104
+ self._small_model_path,
56
105
  )
57
106
  self.assertLen(
58
107
  quant_result.quantized_model,
59
- os.path.getsize(self._test_model_path),
108
+ os.path.getsize(self._small_model_path),
60
109
  )
61
110
 
62
111
  @parameterized.named_parameters(
@@ -65,28 +114,55 @@ class RecipeTest(parameterized.TestCase):
65
114
  recipe_json_path='recipes/dynamic_wi8_afp32_recipe.json',
66
115
  recipe_func=recipe.dynamic_wi8_afp32,
67
116
  ),
117
+ dict(
118
+ testcase_name='weight_only_wi8_afp32',
119
+ recipe_json_path='recipes/default_af32w8float_recipe.json',
120
+ recipe_func=recipe.weight_only_wi8_afp32,
121
+ ),
122
+ dict(
123
+ testcase_name='weight_only_wi4_afp32',
124
+ recipe_json_path='recipes/default_af32w4float_recipe.json',
125
+ recipe_func=recipe.weight_only_wi4_afp32,
126
+ ),
68
127
  dict(
69
128
  testcase_name='dynamic_legacy_wi8_afp32',
70
129
  recipe_json_path='recipes/dynamic_legacy_wi8_afp32_recipe.json',
71
130
  recipe_func=recipe.dynamic_legacy_wi8_afp32,
72
131
  ),
132
+ dict(
133
+ testcase_name='a8w8',
134
+ recipe_json_path='recipes/default_a8w8_recipe.json',
135
+ recipe_func=recipe.static_wi8_ai8,
136
+ ),
137
+ dict(
138
+ testcase_name='a16w8',
139
+ recipe_json_path='recipes/default_a16w8_recipe.json',
140
+ recipe_func=recipe.static_wi8_ai16,
141
+ ),
73
142
  )
143
+ @unittest.skip('skipping due to b/438971945')
74
144
  def test_recipe_func_and_json_matches(self, recipe_json_path, recipe_func):
75
145
  # Quantize with recipe from function in recipe module.
76
- quant_result_from_func = self._quantize_with_recipe_func(recipe_func)
146
+ quant_result_from_func = self._quantize_with_recipe_func(
147
+ recipe_func, self._test_model_path
148
+ )
77
149
 
78
150
  # Quantize with recipe from json file.
79
151
  qt_json = quantizer.Quantizer(self._test_model_path)
80
152
  json_recipe_path = os.path.join(_TEST_DATA_PREFIX_PATH, recipe_json_path)
81
153
  qt_json.load_quantization_recipe(json_recipe_path)
82
- quant_result_from_json = qt_json.quantize()
154
+ if qt_json.need_calibration:
155
+ calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
156
+ qt_json.float_model,
157
+ num_samples=1,
158
+ )
159
+ calibration_result = qt_json.calibrate(calibration_data)
160
+ quant_result_from_json = qt_json.quantize(calibration_result)
161
+ else:
162
+ quant_result_from_json = qt_json.quantize()
83
163
  self.assertIsNotNone(quant_result_from_json.quantized_model)
84
164
 
85
- # Check if the recipes and quantized models match.
86
- self.assertEqual(
87
- quant_result_from_func.recipe,
88
- quant_result_from_json.recipe,
89
- )
165
+ # Check if the quantized models match.
90
166
  self.assertEqual(
91
167
  len(quant_result_from_func.quantized_model),
92
168
  len(quant_result_from_json.quantized_model),