ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,7 @@ _DRQ_OR_WEIGHT_ONLY_OPS = frozenset([
41
41
 
42
42
  _SUPPORTED_SUBCHANNEL_OPS = frozenset([
43
43
  _TFLOpName.FULLY_CONNECTED,
44
+ _TFLOpName.EMBEDDING_LOOKUP,
44
45
  ])
45
46
 
46
47
 
@@ -50,8 +51,9 @@ def check_subchannel_config(
50
51
  """Checks the op quantization config for subchannel quantization."""
51
52
  if (
52
53
  op_quant_config.weight_tensor_config is not None
53
- and op_quant_config.weight_tensor_config.granularity
54
- == qtyping.QuantGranularity.BLOCKWISE
54
+ and uniform_quantize_tensor.is_blockwise(
55
+ op_quant_config.weight_tensor_config.granularity
56
+ )
55
57
  ):
56
58
  if op_name not in _SUPPORTED_SUBCHANNEL_OPS:
57
59
  raise ValueError(f"Unsupported op for blockwise quantization: {op_name}.")
@@ -65,10 +67,6 @@ def check_subchannel_config(
65
67
  "Blockwise quantization does not support for asymmetric weight"
66
68
  " quantization."
67
69
  )
68
- if op_quant_config.weight_tensor_config.block_size <= 0:
69
- raise ValueError(
70
- "Blockwise quantization must have a non-zero block size."
71
- )
72
70
 
73
71
 
74
72
  def check_if_valid_op_config(
@@ -86,7 +84,6 @@ def check_if_valid_op_config(
86
84
  Raises:
87
85
  ValueError: If the op quantization config is not valid.
88
86
  """
89
-
90
87
  check_passed = False
91
88
  error_msg = ""
92
89
  # Check if find op_config in policy config_check_policy.
@@ -260,6 +257,60 @@ def _get_single_tensor_params(
260
257
  )
261
258
 
262
259
 
260
+ def _materialize_tensors_with_quantized_data_update(
261
+ op_tensor_params: list[qtyping.TensorTransformationParams],
262
+ tensors: Sequence[Any],
263
+ quant_params: Optional[qtyping.UniformQuantParams],
264
+ is_inbounding_tensor: bool,
265
+ op_info: qtyping.OpInfo,
266
+ graph_info: qtyping.GraphInfo,
267
+ tensor_name_to_qsv: dict[str, Any],
268
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
269
+ ) -> None:
270
+ """Materialize a list of tensors with `quantized_data` updated when needed.
271
+
272
+ Args:
273
+ op_tensor_params: Tensor transformation parameters for the op. Will be
274
+ modified to include new tensor parameters.
275
+ tensors: Tensors to be materialized.
276
+ quant_params: The quantization parameters to be used for materialization.
277
+ is_inbounding_tensor: Whether the tensor is an inbounding tensor for the op.
278
+ op_info: Aggregated information about the op (e.g., quantization config).
279
+ graph_info: Graph information needed to perform quantization for the op.
280
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
281
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
282
+ tensor.
283
+ """
284
+ if quant_params is not None and quant_params.quantized_data is not None:
285
+ quant_params = dataclasses.replace(quant_params, quantized_data=None)
286
+
287
+ for tensor in tensors:
288
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
289
+ tensor, graph_info.buffers
290
+ )
291
+ if quant_params is None or tensor_data is None:
292
+ tensor_quant_params = quant_params
293
+ else:
294
+ # Constant tensors require updating `quantized_data`.
295
+ quantized_data = uniform_quantize_tensor.uniform_quantize(
296
+ tensor_data, quant_params
297
+ )
298
+ tensor_quant_params = dataclasses.replace(
299
+ quant_params,
300
+ quantized_data=quantized_data,
301
+ )
302
+ _materialize_op_tensors(
303
+ op_tensor_params,
304
+ [tensor],
305
+ is_inbounding_tensor=is_inbounding_tensor,
306
+ op_info=op_info,
307
+ graph_info=graph_info,
308
+ tensor_name_to_qsv=tensor_name_to_qsv,
309
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
310
+ quant_params=tensor_quant_params,
311
+ )
312
+
313
+
263
314
  def _materialize_standard_op_with_same_as_input_scale(
264
315
  input_tensors: Sequence[Any],
265
316
  output_tensors: Sequence[Any],
@@ -295,23 +346,48 @@ def _materialize_standard_op_with_same_as_input_scale(
295
346
  )
296
347
  op_tensor_params.append(input_tensor_params)
297
348
  # Use input quantization params for all output tensors.
298
- _materialize_op_tensors(
349
+ input_quant_params = input_tensor_params.consumers[0].parameters
350
+ if not isinstance(input_quant_params, qtyping.UniformQuantParams):
351
+ raise ValueError(
352
+ "_materialize_standard_op_with_same_as_input_scale only supports"
353
+ f" UniformQuantParams. For tensor {input_tensor_params.tensor_name},"
354
+ f" got {type(input_quant_params)}"
355
+ )
356
+ _materialize_tensors_with_quantized_data_update(
299
357
  op_tensor_params,
300
358
  output_tensors,
359
+ input_quant_params,
301
360
  is_inbounding_tensor=False,
302
361
  op_info=op_info,
303
362
  graph_info=graph_info,
304
363
  tensor_name_to_qsv=tensor_name_to_qsv,
305
364
  get_tensor_quant_params_fn=get_tensor_quant_params_fn,
306
- quant_params=input_tensor_params.consumers[0].parameters,
307
365
  )
366
+
308
367
  # Change output qsv to be the same as input qsv. This is safe since TFL
309
368
  # subgraph is acyclic.
310
- input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name]
311
- for output_tensor in output_tensors:
312
- tensor_name_to_qsv[tfl_flatbuffer_utils.get_tensor_name(output_tensor)] = (
313
- input_tensor_qsv
369
+ input_tensor_qsv = tensor_name_to_qsv.get(
370
+ input_tensor_params.tensor_name, None
371
+ )
372
+ if input_tensor_qsv is None:
373
+ input_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
374
+ input_tensors[0], graph_info.buffers
314
375
  )
376
+ # If the input tensor is a constant tensor without qsv, compute qsv from
377
+ # its quant params.
378
+ if input_tensor_data is None:
379
+ # If the only input to an op that needs to match input to
380
+ # output has no qsv and is not a constant tensor, then this is an error.
381
+ raise ValueError(
382
+ "Input tensor qsv is None for tensor"
383
+ f" {input_tensor_params.tensor_name}."
384
+ )
385
+ min_val, max_val = _get_min_max_from_quant_params(input_quant_params)
386
+ input_tensor_qsv = {"min": min_val, "max": max_val}
387
+ for output_tensor in output_tensors:
388
+ tensor_name_to_qsv[
389
+ tfl_flatbuffer_utils.get_tensor_name(output_tensor)
390
+ ] = input_tensor_qsv
315
391
 
316
392
  return op_tensor_params
317
393
 
@@ -351,19 +427,26 @@ def _materialize_standard_op_with_same_as_output_scale(
351
427
  )
352
428
  # Use output quantization params for all input tensors.
353
429
  if output_tensor_params.producer is None:
354
- quant_params = None
430
+ output_quant_params = None
355
431
  else:
356
- quant_params = output_tensor_params.producer.parameters
357
- _materialize_op_tensors(
432
+ output_quant_params = output_tensor_params.producer.parameters
433
+ if not isinstance(output_quant_params, qtyping.UniformQuantParams):
434
+ raise ValueError(
435
+ "_materialize_standard_op_with_same_as_output_scale only supports"
436
+ f" UniformQuantParams. For tensor {output_tensor_params.tensor_name},"
437
+ f" got {type(output_quant_params)}"
438
+ )
439
+ _materialize_tensors_with_quantized_data_update(
358
440
  op_tensor_params,
359
441
  input_tensors,
442
+ output_quant_params,
360
443
  is_inbounding_tensor=True,
361
444
  op_info=op_info,
362
445
  graph_info=graph_info,
363
446
  tensor_name_to_qsv=tensor_name_to_qsv,
364
447
  get_tensor_quant_params_fn=get_tensor_quant_params_fn,
365
- quant_params=quant_params,
366
448
  )
449
+
367
450
  op_tensor_params.append(output_tensor_params)
368
451
 
369
452
  return op_tensor_params
@@ -628,6 +711,26 @@ def _add_non_match_tensors_to_ignored_lists(
628
711
  return inputs_to_ignore, outputs_to_ignore
629
712
 
630
713
 
714
+ def _get_min_max_from_quant_params(
715
+ quant_params: qtyping.UniformQuantParams,
716
+ ) -> tuple[np.ndarray, np.ndarray]:
717
+ """Recalculate min/max from tensor quantization params."""
718
+ q_min, q_max = uniform_quantize_tensor.get_quantized_range(
719
+ _IntType(quant_params.num_bits, True)
720
+ )
721
+ float_min = uniform_quantize_tensor.uniform_dequantize(
722
+ np.array(q_min), quant_params
723
+ )
724
+ float_max = uniform_quantize_tensor.uniform_dequantize(
725
+ np.array(q_max), quant_params
726
+ )
727
+ # We use qmax values to compute scale for symmetric quantization (see
728
+ # uniform_quantize_tensor.tensor_zp_scale_from_min_max).
729
+ if quant_params.symmetric:
730
+ float_min = -float_max
731
+ return float_min, float_max
732
+
733
+
631
734
  def materialize_standard_op(
632
735
  op_info: qtyping.OpInfo,
633
736
  graph_info: qtyping.GraphInfo,
@@ -794,8 +897,6 @@ def materialize_op_with_output_activation_constraint(
794
897
  output_tensor_params.producer = op_tensor_params
795
898
  # Update the tensor_name_to_qsv map using the output activation constraints.
796
899
  min_val, max_val = _get_min_max_from_quant_params(
797
- activation_num_bits,
798
- activation_tensor_config.symmetric,
799
900
  fixed_quant_params,
800
901
  )
801
902
  tensor_name_to_qsv[output_tensor_params.tensor_name]["min"] = min_val
@@ -842,13 +943,6 @@ def get_tensor_transformations(
842
943
  transformations = [_QuantTransformation.QUANTIZE_TENSOR]
843
944
  else:
844
945
  transformations = [_QuantTransformation.NO_QUANTIZE]
845
- elif (
846
- op_quant_config.weight_tensor_config is not None
847
- and op_quant_config.weight_tensor_config.granularity
848
- == qtyping.QuantGranularity.BLOCKWISE
849
- and is_constant
850
- ):
851
- transformations = [_QuantTransformation.EMULATED_SUBCHANNEL]
852
946
  # Check if WEIGHT_ONLY.
853
947
  elif (
854
948
  op_quant_config.compute_precision == qtyping.ComputePrecision.FLOAT
@@ -906,23 +1000,36 @@ def get_tensor_transformation_params(
906
1000
  )
907
1001
 
908
1002
 
909
- def get_weight_quantized_dim(op_info: qtyping.OpInfo, tensor_data: np.ndarray):
1003
+ def get_weight_quantized_dim(
1004
+ op_info: qtyping.OpInfo,
1005
+ tensor_data: np.ndarray,
1006
+ granularity: qtyping.QuantGranularity,
1007
+ ):
910
1008
  """Get the quantized dimension for the weight tensor.
911
1009
 
912
1010
  Args:
913
1011
  op_info: Aggregated information about the op (e.g., quantization config).
914
1012
  tensor_data: The weight tensor data.
1013
+ granularity: The granularity of the weight tensor.
915
1014
 
916
1015
  Returns:
917
1016
  The quantized dimension for the weight tensor.
918
1017
  """
919
- if op_info.op_name == _TFLOpName.BATCH_MATMUL:
920
- quantized_dim = get_bmm_weight_quantized_dim(
921
- tensor_data, adj_y=op_info.op.builtinOptions.adjY
922
- )
923
- else:
924
- quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
925
- op_info.op_name, None
1018
+ quantized_dim = None
1019
+ if granularity == qtyping.QuantGranularity.CHANNELWISE:
1020
+ if op_info.op_name == _TFLOpName.BATCH_MATMUL:
1021
+ quantized_dim = get_bmm_weight_quantized_dim(
1022
+ tensor_data, adj_y=op_info.op.builtinOptions.adjY
1023
+ )
1024
+ else:
1025
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
1026
+ op_info.op_name, None
1027
+ )
1028
+ elif uniform_quantize_tensor.is_blockwise(granularity):
1029
+ quantized_dim = (
1030
+ tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
1031
+ op_info.op_name
1032
+ ]
926
1033
  )
927
1034
  return quantized_dim
928
1035
 
@@ -952,23 +1059,4 @@ def get_bmm_weight_quantized_dim(
952
1059
  return rank - 1
953
1060
 
954
1061
 
955
- def _get_min_max_from_quant_params(
956
- num_bits: int,
957
- symmetric: bool,
958
- tensor_params: qtyping.UniformQuantParams,
959
- ) -> tuple[float, float]:
960
- """Recalculate min/max from tensor quantization params."""
961
- q_min, q_max = uniform_quantize_tensor.get_quantized_range(
962
- _IntType(num_bits, True)
963
- )
964
- float_min = uniform_quantize_tensor.uniform_dequantize(
965
- np.array(q_min), tensor_params
966
- )
967
- float_max = uniform_quantize_tensor.uniform_dequantize(
968
- np.array(q_max), tensor_params
969
- )
970
- # We use qmax values to compute scale for symmetric quantization (see
971
- # uniform_quantize_tensor.tensor_zp_scale_from_min_max).
972
- if symmetric:
973
- float_min = -float_max
974
- return (float_min, float_max)
1062
+
@@ -23,6 +23,7 @@ from absl import logging
23
23
  import numpy as np
24
24
 
25
25
  from ai_edge_quantizer import algorithm_manager
26
+ from ai_edge_quantizer import default_policy as policy
26
27
  from ai_edge_quantizer import qtyping
27
28
  from ai_edge_quantizer import recipe_manager
28
29
  from ai_edge_quantizer.utils import calibration_utils
@@ -45,11 +46,6 @@ class Calibrator:
45
46
  ):
46
47
  self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
47
48
 
48
- if not tfl_flatbuffer_utils.is_float_model(self._flatbuffer_model):
49
- raise ValueError(
50
- "The input model for calibration is not a float model. Please check"
51
- " the model (e.g., if it is already quantized)."
52
- )
53
49
  self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
54
50
  float_tflite, use_xnnpack=True, num_threads=num_threads
55
51
  )
@@ -97,9 +93,7 @@ class Calibrator:
97
93
  qsv_update_func: The function to update the QSVs.
98
94
  """
99
95
  op_codes = self._flatbuffer_model.operatorCodes
100
- if not self._model_qsvs:
101
- self._initialize_model_qsvs(model_recipe_manager)
102
- else:
96
+ if self._model_qsvs:
103
97
  logging.warning(
104
98
  "Calibrator contains non-empty model qsvs, and the current"
105
99
  " calibration process will start on top of this state (i.e., update"
@@ -124,50 +118,67 @@ class Calibrator:
124
118
  )
125
119
  if cache_output:
126
120
  self._cached_output.append(signature_output)
127
- self._tensor_content_map.update(
128
- tfl_interpreter_utils.get_tensor_name_to_content_map(
129
- self._tfl_interpreter, subgraph_idx
130
- )
131
- )
132
121
 
133
122
  # Step2: go through each op in subgraph to update quantization
134
123
  # statistic values.
135
- subgraph = self._flatbuffer_model.subgraphs[subgraph_idx]
136
- graph_info = qtyping.GraphInfo(
137
- subgraph.tensors, self._flatbuffer_model.buffers
138
- )
139
- # Add input/output operators to the subgraph.
140
- subgraph.operators += (
141
- tfl_flatbuffer_utils.get_subgraph_input_output_operators(subgraph)
142
- )
143
- for op in subgraph.operators:
144
- if isinstance(op, qtyping.IOOperator):
145
- op_key = op.op_key
146
- else:
147
- op_code = op_codes[op.opcodeIndex].builtinCode
148
- if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
149
- continue
150
- op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
151
- # Step2.1: query the quantization_recipe to get op quantization
152
- # settings.
153
- op_scope = self._get_op_scope(op, subgraph.tensors)
154
- algorithm_name, _ = model_recipe_manager.get_quantization_configs(
155
- op_key, op_scope
156
- )
157
- if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
158
- continue
159
- # Step2.2: query algorithm_manager to get/call the related calibration
160
- # function.
161
- calibrate_func = algorithm_manager.get_quantization_func(
162
- algorithm_name, op_key, qtyping.QuantizeMode.CALIBRATE
124
+ subgraphs_inds = [subgraph_idx]
125
+ while subgraphs_inds:
126
+ subgraph_ind = subgraphs_inds.pop()
127
+ self._tensor_content_map.update(
128
+ tfl_interpreter_utils.get_tensor_name_to_content_map(
129
+ self._tfl_interpreter, subgraph_ind
130
+ )
163
131
  )
164
- op_qsvs = calibrate_func(op, graph_info, self._tensor_content_map)
165
- # Step3: Update tensor qsvs with the new values. Ignore the tensor
166
- # names that are already updated in this round of calibration.
167
- op_updated_tensor_name = self._update_qsvs(
168
- op_qsvs, updated_tensor_names, qsv_update_func
132
+ subgraph = self._flatbuffer_model.subgraphs[subgraph_ind]
133
+ graph_info = qtyping.GraphInfo(
134
+ subgraph.tensors, self._flatbuffer_model.buffers
169
135
  )
170
- updated_tensor_names.update(op_updated_tensor_name)
136
+ # Add input/output operators if they are not in the subgraph.
137
+ if not any(
138
+ isinstance(op, qtyping.IOOperator) for op in subgraph.operators
139
+ ):
140
+ subgraph.operators += (
141
+ tfl_flatbuffer_utils.get_subgraph_input_output_operators(
142
+ subgraph
143
+ )
144
+ )
145
+ for op in subgraph.operators:
146
+ if isinstance(op, qtyping.IOOperator):
147
+ op_key = op.op_key
148
+ else:
149
+ op_code = op_codes[op.opcodeIndex].builtinCode
150
+ if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
151
+ continue
152
+ op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
153
+ # Step2.1: query the quantization_recipe to get op quantization
154
+ # settings.
155
+ op_scope = self._get_op_scope(op, subgraph.tensors)
156
+ algorithm_name, _ = model_recipe_manager.get_quantization_configs(
157
+ op_key, op_scope
158
+ )
159
+ if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
160
+ continue
161
+ if policy.is_non_quantizable_composite_op(op):
162
+ continue
163
+
164
+ # Step2.2: query algorithm_manager to get/call the related
165
+ # calibration function.
166
+ calibrate_func = algorithm_manager.get_quantization_func(
167
+ algorithm_name, op_key, qtyping.QuantizeMode.CALIBRATE
168
+ )
169
+ op_qsvs = calibrate_func(op, graph_info, self._tensor_content_map)
170
+ # Step3: Update tensor qsvs with the new values. Ignore the tensor
171
+ # names that are already updated in this round of calibration.
172
+ op_updated_tensor_name = self._update_qsvs(
173
+ op_qsvs, updated_tensor_names, qsv_update_func
174
+ )
175
+ updated_tensor_names.update(op_updated_tensor_name)
176
+
177
+ # Step4: Invoke any subgraphs invoked as a side effect of the op.
178
+ subgraphs_inds.extend(
179
+ tfl_flatbuffer_utils.get_op_side_effect_subgraphs(op)
180
+ )
181
+
171
182
  # Reset interpreter after one round of calibration.
172
183
  self._tfl_interpreter.reset_all_variables()
173
184
 
@@ -245,50 +256,3 @@ class Calibrator:
245
256
  output_tensor = subgraph_tensors[output_tensor_idx]
246
257
  scope += tfl_flatbuffer_utils.get_tensor_name(output_tensor)
247
258
  return scope
248
-
249
- # TODO: b/354224138 - Remove code duplication between calibrate and
250
- # _initialize_model_qsvs.
251
- def _initialize_model_qsvs(
252
- self, model_recipe_manager: recipe_manager.RecipeManager
253
- ) -> None:
254
- """Initialize the model qsvs.
255
-
256
- Args:
257
- model_recipe_manager: A RecipeManager object that contains the
258
- quantization recipe.
259
- """
260
- op_codes = self._flatbuffer_model.operatorCodes
261
- for subgraph in self._flatbuffer_model.subgraphs:
262
- graph_info = qtyping.GraphInfo(
263
- subgraph.tensors, self._flatbuffer_model.buffers
264
- )
265
- for subgraph_op_id, op in enumerate(subgraph.operators):
266
- op_code = op_codes[op.opcodeIndex].builtinCode
267
- if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
268
- continue
269
- op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
270
- # Step1: query the quantization_recipe to get op quantization
271
- # settings.
272
- op_scope = self._get_op_scope(op, subgraph.tensors)
273
- algorithm_name, op_quant_config = (
274
- model_recipe_manager.get_quantization_configs(op_key, op_scope)
275
- )
276
- if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
277
- continue
278
- # Step2: query algorithm_manager to get/call the related qsv init
279
- # function.
280
- qsv_init_func = algorithm_manager.get_init_qsv_func(
281
- algorithm_name, op_key
282
- )
283
- op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
284
- # Ignore the input tensors where any dimension of the shape is 0.
285
- inputs_to_ignore = [
286
- opr_idx
287
- for opr_idx, tensor_idx in enumerate(op.inputs)
288
- if not np.all(graph_info.subgraph_tensors[tensor_idx].shape)
289
- ]
290
- op_qsvs = qsv_init_func(op_info, graph_info, inputs_to_ignore)
291
- # Step3: initialize tensor qsvs.
292
- for tensor_name, qsv in op_qsvs.items():
293
- if tensor_name not in self._model_qsvs:
294
- self._model_qsvs[tensor_name] = qsv
@@ -103,58 +103,6 @@ class CalibratorTest(googletest.TestCase):
103
103
  model_tensor_qsvs = self._calibrator.get_model_qsvs()
104
104
  self.assertEmpty(model_tensor_qsvs)
105
105
 
106
- def test_calibrator_initialize_qsv(self):
107
- _add_default_int8xint8_integer_recipe(self._recipe_manager)
108
- # Overwrite the single op to fc
109
- self._recipe_manager.add_quantization_config(
110
- regex=".*Stateful.*",
111
- operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
112
- algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
113
- op_config=qtyping.OpQuantizationConfig(
114
- weight_tensor_config=_TENSOR_QUANT_CONFIG(
115
- num_bits=4,
116
- granularity=qtyping.QuantGranularity.CHANNELWISE,
117
- ),
118
- compute_precision=_ComputePrecision.INTEGER,
119
- ),
120
- )
121
- self._calibrator._initialize_model_qsvs(self._recipe_manager)
122
- model_tensor_qsvs = self._calibrator.get_model_qsvs()
123
-
124
- self.assertLen(model_tensor_qsvs, 4)
125
- self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
126
- input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
127
- self.assertEmpty(input_qsv)
128
-
129
- self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
130
- weight_tensor_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
131
- mins_maxs_shape = (16, 1)
132
- self.assertTupleEqual(weight_tensor_qsv["min"].shape, mins_maxs_shape)
133
- self.assertAlmostEqual(weight_tensor_qsv["min"][0][0], -0.40436327)
134
- self.assertTupleEqual(weight_tensor_qsv["max"].shape, mins_maxs_shape)
135
- self.assertAlmostEqual(weight_tensor_qsv["max"][0][0], 0.46138108)
136
-
137
- self.assertIn(
138
- "sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
139
- ) # bias
140
- bias_tensor_qsv = model_tensor_qsvs[
141
- "sequential/dense/BiasAdd/ReadVariableOp"
142
- ]
143
- mins_maxs_shape = (16,)
144
- self.assertTupleEqual(bias_tensor_qsv["min"].shape, mins_maxs_shape)
145
- self.assertAlmostEqual(bias_tensor_qsv["min"][0], -0.26978338)
146
- self.assertTupleEqual(bias_tensor_qsv["max"].shape, mins_maxs_shape)
147
- # Here bias min/max will be the same as each element is a scalar
148
- # Bias will be quantized with input_scale * weight_scale.
149
- self.assertSequenceEqual(
150
- list(bias_tensor_qsv["max"].flatten()),
151
- list(bias_tensor_qsv["min"].flatten()),
152
- )
153
-
154
- self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
155
- output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
156
- self.assertEmpty(output_qsv)
157
-
158
106
  def test_calibrate_single_fc_success(self):
159
107
  _add_default_int8xint8_integer_recipe(self._recipe_manager)
160
108
  self._calibrator.calibrate(
@@ -162,7 +110,7 @@ class CalibratorTest(googletest.TestCase):
162
110
  )
163
111
  model_tensor_qsvs = self._calibrator.get_model_qsvs()
164
112
 
165
- self.assertLen(model_tensor_qsvs, 4)
113
+ self.assertLen(model_tensor_qsvs, 2)
166
114
  self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
167
115
  input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
168
116
  self.assertSequenceAlmostEqual(
@@ -171,19 +119,6 @@ class CalibratorTest(googletest.TestCase):
171
119
  self.assertSequenceAlmostEqual(
172
120
  input_qsv["max"].flatten(), [TEST_MAX_VAL], delta=1e-5
173
121
  )
174
-
175
- self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
176
- weight_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
177
- self.assertSequenceAlmostEqual(weight_qsv["min"].flatten(), [-0.49114203])
178
- self.assertSequenceAlmostEqual(weight_qsv["max"].flatten(), [0.4903704])
179
-
180
- self.assertIn(
181
- "sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
182
- ) # bias
183
- bias_qsv = model_tensor_qsvs["sequential/dense/BiasAdd/ReadVariableOp"]
184
- self.assertSequenceAlmostEqual(bias_qsv["min"].flatten(), [-0.38401994])
185
- self.assertSequenceAlmostEqual(bias_qsv["max"].flatten(), [0.31727126])
186
-
187
122
  self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
188
123
  output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
189
124
  # Relu, only check the min
@@ -234,7 +169,7 @@ class CalibratorTest(googletest.TestCase):
234
169
  )
235
170
  test_calibrator = calibrator.Calibrator(test_model_path)
236
171
  _add_default_int8xint8_integer_recipe(self._recipe_manager)
237
- calib_data = test_utils.create_random_normal_input_data(
172
+ calib_data = tfl_interpreter_utils.create_random_normal_input_data(
238
173
  test_model_path, num_samples=4
239
174
  )
240
175
  test_calibrator.calibrate(calib_data, self._recipe_manager)
@@ -249,15 +184,11 @@ class CalibratorAlreadyQuantizedModelTest(googletest.TestCase):
249
184
  )
250
185
  _ = calibrator.Calibrator(test_model_path)
251
186
 
252
- def test_check_is_float_model_raises_error_when_model_is_quantized(self):
187
+ def test_check_is_quantized_model_succeeds_when_model_is_quantized(self):
253
188
  test_model_path = os.path.join(
254
189
  TEST_DATA_PREFIX_PATH, "tests/models/mnist_quantized.tflite"
255
190
  )
256
- with self.assertRaisesRegex(
257
- ValueError,
258
- "The input model for calibration is not a float model.",
259
- ):
260
- _ = calibrator.Calibrator(test_model_path)
191
+ _ = calibrator.Calibrator(test_model_path)
261
192
 
262
193
 
263
194
  class CalibratorToyGemma2Test(googletest.TestCase):
@@ -302,7 +233,7 @@ class CalibratorToyGemma2Test(googletest.TestCase):
302
233
  self._toy_gemma2_calibration_dataset,
303
234
  model_recipe_manager=recipe_mngr,
304
235
  )
305
- self.assertLen(calib.get_model_qsvs(), 282)
236
+ self.assertLen(calib.get_model_qsvs(), 202)
306
237
 
307
238
 
308
239
  if __name__ == "__main__":