ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -15,10 +15,12 @@
15
15
 
16
16
  """Generate model tensor level quantization config."""
17
17
 
18
+ from collections.abc import Sequence
18
19
  import copy
19
20
  from typing import Any, Optional, Union
20
21
 
21
22
  from ai_edge_quantizer import algorithm_manager
23
+ from ai_edge_quantizer import default_policy as policy
22
24
  from ai_edge_quantizer import qtyping
23
25
  from ai_edge_quantizer import recipe_manager
24
26
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
@@ -33,12 +35,12 @@ class ParamsGenerator:
33
35
  def __init__(self, float_tflite: Union[str, bytes]):
34
36
  self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
35
37
 
36
- if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
37
- raise ValueError(
38
- 'The input model for quantization parameters generation is not a'
39
- ' float model. Please check the model (e.g., if it is already'
40
- ' quantized).'
41
- )
38
+ # if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
39
+ # raise ValueError(
40
+ # 'The input model for quantization parameters generation is not a'
41
+ # ' float model. Please check the model (e.g., if it is already'
42
+ # ' quantized).'
43
+ # )
42
44
  self._check_tensor_names_are_unique()
43
45
  self.buffer_to_tensors: dict[int, list[Any]] = (
44
46
  tfl_flatbuffer_utils.buffer_to_tensors(self.flatbuffer_model)
@@ -73,8 +75,10 @@ class ParamsGenerator:
73
75
  if model_qsvs is None:
74
76
  model_qsvs = {}
75
77
 
78
+ skip_subgraphs = set()
76
79
  op_codes = self.flatbuffer_model.operatorCodes
77
- for subgraph in self.flatbuffer_model.subgraphs:
80
+ for sg_ind, subgraph in enumerate(self.flatbuffer_model.subgraphs):
81
+
78
82
  graph_info = qtyping.GraphInfo(
79
83
  subgraph.tensors, self.flatbuffer_model.buffers
80
84
  )
@@ -103,10 +107,22 @@ class ParamsGenerator:
103
107
  algorithm_name, op_quant_config = (
104
108
  model_recipe_manager.get_quantization_configs(op_key, op_scope)
105
109
  )
110
+
111
+ if sg_ind in skip_subgraphs or policy.is_non_quantizable_composite_op(
112
+ op
113
+ ):
114
+ algorithm_name = algorithm_manager.AlgorithmName.NO_QUANTIZE
115
+
106
116
  if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
117
+ side_effect_subgraphs = (
118
+ tfl_flatbuffer_utils.get_op_side_effect_subgraphs(op)
119
+ )
120
+ skip_subgraphs.update(side_effect_subgraphs)
121
+
107
122
  op_quant_results = self._get_params_for_no_quant_op(
108
123
  subgraph_op_id, op, subgraph.tensors
109
124
  )
125
+
110
126
  else:
111
127
  op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
112
128
  # Step2: query algorithm_manager to get/call the related function.
@@ -146,7 +162,7 @@ class ParamsGenerator:
146
162
  RuntimeError: If the tensors sharing the same buffer have different
147
163
  quantization settings.
148
164
  """
149
- self._check_buffer_sharing()
165
+ self._check_and_fix_buffer_sharing()
150
166
 
151
167
  def _update_model_quant_results(
152
168
  self,
@@ -252,57 +268,224 @@ class ParamsGenerator:
252
268
  tensor_params.append(output_tensor_params)
253
269
  return tensor_params
254
270
 
255
- def _check_buffer_sharing(self) -> None:
256
- """Check if tensors sharing the same buffer have the same quantization.
271
+ def _mark_tensors_requiring_buffer_duplication(
272
+ self, buffers_to_duplicate: Sequence[int]
273
+ ) -> None:
274
+ """Mark tensors that require buffer duplication.
275
+
276
+ Marking a tensor means adding a DUPLICATE_BUFFER transformation as the first
277
+ transformation to be applied for each consumer of the tensor. Need to do
278
+ that for each consumer to preserve a zero layer and not affect the
279
+ horizontal optimization later in the transformation instructions generator.
280
+
281
+ Marks all tensors within each of the provided buffers as requiring buffer
282
+ duplication, except for the last tensor. The order of tensors is assumed to
283
+ be the same during both the marking and transformation performer steps, as
284
+ determined by `self.buffer_to_tensors`. This allows the final tensor to
285
+ reuse the original buffer, as it is not marked for duplication.
286
+
287
+ Args:
288
+ buffers_to_duplicate: Indices of the buffers to duplicate.
289
+ """
290
+ for buffer_idx in buffers_to_duplicate:
291
+ for tensor in self.buffer_to_tensors[buffer_idx][:-1]:
292
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
293
+ for consumer_params in self.model_quant_results[tensor_name].consumers:
294
+ consumer_params.transformations.insert(
295
+ 0, _QuantTrans.DUPLICATE_BUFFER
296
+ )
297
+
298
+ def _mark_tensors_requiring_tensor_duplication(
299
+ self, tensor_names_to_duplicate
300
+ ) -> None:
301
+ """Mark tensors that require tensor duplication.
302
+
303
+ Marking a tensor means adding a DUPLICATE_TENSOR transformation as the first
304
+ transformation to be applied for each consumer of the tensor. Need to do
305
+ that for each consumer to preserve a zero layer and not affect the
306
+ horizontal optimization later in the transformation instructions generator.
307
+
308
+ Args:
309
+ tensor_names_to_duplicate: Names of tensors to duplicate.
310
+ """
311
+ for tensor_name in tensor_names_to_duplicate:
312
+ for consumer_params in self.model_quant_results[tensor_name].consumers:
313
+ consumer_params.transformations.insert(0, _QuantTrans.DUPLICATE_TENSOR)
314
+
315
+ def _check_buffer_sharing_for_tensor(self, tensor: Any) -> bool:
316
+ """Check buffer sharing for the tensor against itself.
317
+
318
+ Args:
319
+ tensor: The tensor to check.
320
+
321
+ Returns:
322
+ Whether the tensor has compatible quantization parameters.
323
+
324
+ Raises:
325
+ RuntimeError: If the tensor has incompatible quantization parameters
326
+ and the buffer is not constant.
327
+ """
328
+ tensor_params = self.model_quant_results.get(
329
+ tfl_flatbuffer_utils.get_tensor_name(tensor), None
330
+ )
331
+ if tensor_params is None:
332
+ return True
333
+
334
+ if _are_tensor_consumer_params_compatible(tensor_params):
335
+ return True
336
+ elif _is_constant_tensor(tensor, self.flatbuffer_model.buffers):
337
+ return False
338
+ else:
339
+ error_msg = (
340
+ f'The tensor {tensor.name} consumers do not have the same'
341
+ ' quantization parameters. Please modify your quantization recipe to'
342
+ ' make sure the two tensors have the same quantization settings.'
343
+ )
344
+ raise RuntimeError(error_msg)
345
+
346
+ def _check_buffer_sharing_for_self_compatible_tensors(
347
+ self, tensor1: Any, tensor2: Any
348
+ ) -> bool:
349
+ """Check a pair of self compatible tensors have the same quantization params.
350
+
351
+ Self compatible means that all tensor's consumers have the same quantization
352
+ parameters.
353
+
354
+ Args:
355
+ tensor1: The first tensor to check.
356
+ tensor2: The second tensor to check.
357
+
358
+ Returns:
359
+ Whether the tensors have compatible quantization parameters.
360
+
361
+ Raises:
362
+ RuntimeError: If the tensors have incompatible quantization parameters
363
+ and the buffer is not constant.
364
+ """
365
+ tensor1_params = self.model_quant_results.get(
366
+ tfl_flatbuffer_utils.get_tensor_name(tensor1), None
367
+ )
368
+ tensor2_params = self.model_quant_results.get(
369
+ tfl_flatbuffer_utils.get_tensor_name(tensor2), None
370
+ )
371
+
372
+ if tensor1_params is None or tensor2_params is None:
373
+ return True
374
+
375
+ if _are_self_compatible_tensors_compatible_to_each_other(
376
+ tensor1_params, tensor2_params
377
+ ):
378
+ return True
379
+ elif _is_constant_tensor(tensor1, self.flatbuffer_model.buffers):
380
+ return False
381
+ else:
382
+ error_msg = (
383
+ f'The tensors {tensor1.name} and {tensor2.name} do not have'
384
+ ' the same quantization parameters even though they share the'
385
+ ' same buffer. Please modify your quantization recipe to make'
386
+ ' sure the two tensors have the same quantization settings.'
387
+ )
388
+ raise RuntimeError(error_msg)
389
+
390
+ def _check_and_fix_buffer_sharing(self) -> None:
391
+ """Check and fix tensor/buffer sharing issues when possible.
392
+
393
+ This function checks if tensors sharing the same buffer have the same
394
+ quantization settings. If not, when it's possible, it will fix it by marking
395
+ such tensors or buffers to be duplicated. Otherwise, it will raise an error.
396
+
397
+ Possible cases that can be fixed by duplication:
398
+ 1. A constant tensor recieves different quantization parameters from its
399
+ consumers. In this case, the tensor is marked for duplication.
400
+ 2. Two or more tensors share the same constant buffer and have different
401
+ quantization parameters. In this case, the buffer is marked for
402
+ duplication.
257
403
 
258
404
  Raises:
259
405
  RuntimeError: If the tensors sharing the same buffer have different
260
- quantization settings.
406
+ quantization settings and it can't be resolved by duplicating the
407
+ buffer/tensor.
261
408
  """
262
- for tensors in self.buffer_to_tensors.values():
263
- if len(tensors) <= 1:
409
+ buffers_to_duplicate = []
410
+ tensor_names_to_duplicate = []
411
+ for buffer_idx, tensors in self.buffer_to_tensors.items():
412
+ # TODO: b/458797890 - Investigate if skipping buffer_idx == 0 is a
413
+ # correct fix, or if it just covers up a deeper issue. This is only
414
+ # required when statically quantizing models that have already been
415
+ # quantized dynamically.
416
+ if not tensors or buffer_idx == 0:
264
417
  continue
265
- first_tensor = tensors[0]
266
- first_tensor_params = self.model_quant_results[
267
- tfl_flatbuffer_utils.get_tensor_name(first_tensor)
268
- ]
269
- for tensor in tensors[1:]:
270
- tensor_params = self.model_quant_results[
271
- tfl_flatbuffer_utils.get_tensor_name(tensor)
272
- ]
273
- if not _compatible_tensor_transformation_params(
274
- first_tensor_params, tensor_params
275
- ):
276
- error_msg = (
277
- f'The tensors {first_tensor.name} and {tensor.name} do not have'
278
- ' the same quantization parameters even though they share the'
279
- ' same buffer. Please modify your quantization recipe to make'
280
- ' sure the two tensors have the same quantization settings.'
418
+ # Check if any of the tensors needs to be duplicated.
419
+ for tensor in tensors:
420
+ if not self._check_buffer_sharing_for_tensor(tensor):
421
+ tensor_names_to_duplicate.append(
422
+ tfl_flatbuffer_utils.get_tensor_name(tensor)
281
423
  )
282
- raise RuntimeError(error_msg)
424
+ # Check if the buffer needs to be duplicated.
425
+ tensor_1 = tensors[0]
426
+ tensor_name_1 = tfl_flatbuffer_utils.get_tensor_name(tensor_1)
427
+ if tensor_name_1 in tensor_names_to_duplicate:
428
+ buffers_to_duplicate.append(buffer_idx)
429
+ continue
430
+ for tensor_2 in tensors[1:]:
431
+ tensor_name_2 = tfl_flatbuffer_utils.get_tensor_name(tensor_2)
432
+ if (
433
+ tensor_name_2 in tensor_names_to_duplicate
434
+ or not self._check_buffer_sharing_for_self_compatible_tensors(
435
+ tensor_1, tensor_2
436
+ )
437
+ ):
438
+ buffers_to_duplicate.append(buffer_idx)
439
+ break
440
+
441
+ # Fix the buffer sharing issues.
442
+ self._mark_tensors_requiring_buffer_duplication(buffers_to_duplicate)
443
+ self._mark_tensors_requiring_tensor_duplication(tensor_names_to_duplicate)
444
+
445
+
446
+ def _are_tensor_consumer_params_compatible(
447
+ params: qtyping.TensorTransformationParams,
448
+ ) -> bool:
449
+ """Check if all tensor's consumers have the same quantization parameters."""
450
+ if params.consumers is None or len(params.consumers) < 2:
451
+ return True
452
+ consumer_1 = params.consumers[0]
453
+ for consumer in params.consumers[1:]:
454
+ if not _compatible_tensor_params(consumer, consumer_1):
455
+ return False
456
+ return True
283
457
 
284
458
 
285
- def _compatible_tensor_transformation_params(
459
+ def _are_self_compatible_tensors_compatible_to_each_other(
286
460
  params1: qtyping.TensorTransformationParams,
287
461
  params2: qtyping.TensorTransformationParams,
288
462
  ) -> bool:
289
- """Check if two tensor transformation params are compatible."""
463
+ """Check if two self compatible tensors are compatible to each other.
464
+
465
+ Self compatible means that all tensor's consumers have the same quantization
466
+ parameters.
467
+
468
+ Args:
469
+ params1: The first tensor transformation params.
470
+ params2: The second tensor transformation params.
471
+
472
+ Returns:
473
+ Whether the two tensors are compatible to each other.
474
+ """
475
+ # Check the producer.
290
476
  if params1.producer is None or params2.producer is None:
291
477
  if params1.producer != params2.producer:
292
478
  return False
293
479
  elif not _compatible_tensor_params(params1.producer, params2.producer):
294
480
  return False
481
+
482
+ # Check the consumers.
295
483
  if params1.consumers is None or params2.consumers is None:
296
484
  if params1.consumers != params2.consumers:
297
485
  return False
298
486
  else:
299
- # Check all consumers within each params are compatible.
300
- for params1_consumer in params1.consumers:
301
- if not _compatible_tensor_params(params1_consumer, params1.consumers[0]):
302
- return False
303
- for params2_consumer in params2.consumers:
304
- if not _compatible_tensor_params(params2_consumer, params2.consumers[0]):
305
- return False
487
+ # Since all consumer params within each tensor are the same, it's enough to
488
+ # check only the first consumers.
306
489
  if not _compatible_tensor_params(
307
490
  params1.consumers[0], params2.consumers[0]
308
491
  ):
@@ -330,6 +513,8 @@ def _compatible_tensor_params(
330
513
  float_source_transformations = [
331
514
  _QuantTrans.ADD_QUANTIZE,
332
515
  _QuantTrans.NO_QUANTIZE,
516
+ _QuantTrans.INSERT_HADAMARD_ROTATION,
517
+ _QuantTrans.INSERT_DECOMPOSED_HADAMARD_ROTATION,
333
518
  ]
334
519
  quantized_source_transformations = [
335
520
  _QuantTrans.QUANTIZE_TENSOR,
@@ -337,14 +522,6 @@ def _compatible_tensor_params(
337
522
  ]
338
523
  if _same_tensor_params_except_id(params1, params2):
339
524
  return True
340
- if (
341
- params1.transformations[0] != _QuantTrans.NO_QUANTIZE
342
- and params2.transformations[0] != _QuantTrans.NO_QUANTIZE
343
- ):
344
- # NO_QUANTIZE has no parameters. So only if both params aren't NO_QUANTIZE
345
- # do we expect the parameters to be the same.
346
- if params1.parameters != params2.parameters:
347
- return False
348
525
  # We only need to check the first transformation because transformations are
349
526
  # applied in order, and as long as the one that's immediately after the tensor
350
527
  # is the same, it's compatible.
@@ -356,6 +533,12 @@ def _compatible_tensor_params(
356
533
  if (
357
534
  params1.transformations[0] in quantized_source_transformations
358
535
  and params2.transformations[0] in quantized_source_transformations
536
+ and params1.parameters == params2.parameters
359
537
  ):
360
538
  return True
361
539
  return False
540
+
541
+
542
+ def _is_constant_tensor(tensor: Any, buffers: Sequence[Any]) -> bool:
543
+ """Check if the tensor is a constant tensor."""
544
+ return buffers[tensor.buffer].data is not None