ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ from collections.abc import MutableMapping
20
20
  import copy
21
21
  import dataclasses
22
22
  import enum
23
- from typing import Any, Optional, Union, Callable
23
+ from typing import Any, Callable, Optional, Union
24
24
 
25
25
  import numpy as np
26
26
  from typing_extensions import TypeAlias
@@ -59,7 +59,31 @@ class TFLOperationName(str, enum.Enum):
59
59
  LOGISTIC = 'LOGISTIC'
60
60
  SLICE = 'SLICE'
61
61
  SUM = 'SUM'
62
+ SELECT = 'SELECT'
62
63
  SELECT_V2 = 'SELECT_V2'
64
+ DYNAMIC_UPDATE_SLICE = 'DYNAMIC_UPDATE_SLICE'
65
+ STABLEHLO_COMPOSITE = 'STABLEHLO_COMPOSITE'
66
+ PAD = 'PAD'
67
+ SQUARED_DIFFERENCE = 'SQUARED_DIFFERENCE'
68
+ MAX_POOL_2D = 'MAX_POOL_2D'
69
+ RESIZE_BILINEAR = 'RESIZE_BILINEAR'
70
+ RESIZE_NEAREST_NEIGHBOR = 'RESIZE_NEAREST_NEIGHBOR'
71
+ GATHER_ND = 'GATHER_ND'
72
+ PACK = 'PACK'
73
+ UNPACK = 'UNPACK'
74
+ DIV = 'DIV'
75
+ BROADCAST_TO = 'BROADCAST_TO'
76
+ SQRT = 'SQRT'
77
+ GATHER = 'GATHER'
78
+ HARD_SWISH = 'HARD_SWISH'
79
+ MAXIMUM = 'MAXIMUM'
80
+ PADV2 = 'PADV2'
81
+ REDUCE_MIN = 'REDUCE_MIN'
82
+ EQUAL = 'EQUAL'
83
+ NOT_EQUAL = 'NOT_EQUAL'
84
+ MIRROR_PAD = 'MIRROR_PAD'
85
+ SPACE_TO_DEPTH = 'SPACE_TO_DEPTH'
86
+ RELU = 'RELU'
63
87
 
64
88
 
65
89
  class QuantizeMode(enum.Enum):
@@ -90,7 +114,11 @@ class TensorDataType(str, enum.Enum):
90
114
  class QuantGranularity(str, enum.Enum):
91
115
  TENSORWISE = 'TENSORWISE'
92
116
  CHANNELWISE = 'CHANNELWISE'
93
- BLOCKWISE = 'BLOCKWISE'
117
+ # Blockwise quantization with various block sizes.
118
+ BLOCKWISE_32 = 'BLOCKWISE_32'
119
+ BLOCKWISE_64 = 'BLOCKWISE_64'
120
+ BLOCKWISE_128 = 'BLOCKWISE_128'
121
+ BLOCKWISE_256 = 'BLOCKWISE_256'
94
122
 
95
123
 
96
124
  class QuantTransformation(enum.Enum):
@@ -104,9 +132,18 @@ class QuantTransformation(enum.Enum):
104
132
  ADD_DEQUANTIZE = 2
105
133
  # Quantize the float tensor: float_tensor -> quantized_tensor.
106
134
  QUANTIZE_TENSOR = 3
107
- # Create pattern for emulated subchannel quantization, only support fully
108
- # connected op.
135
+ # (Deprecated) Create pattern for emulated subchannel quantization,
136
+ # only support fully connected op.
109
137
  EMULATED_SUBCHANNEL = 4
138
+ # Duplicate the buffer.
139
+ DUPLICATE_BUFFER = 5
140
+ # Duplicate the tensor.
141
+ DUPLICATE_TENSOR = 6
142
+ # Insert the aeq.hadamard_rotation op.
143
+ INSERT_HADAMARD_ROTATION = 7
144
+ # Insert decomposed Hadamard rotation ops. This expresses the Hadamard
145
+ # rotation as matrix multiplication with Hadamard matrices.
146
+ INSERT_DECOMPOSED_HADAMARD_ROTATION = 8
110
147
 
111
148
 
112
149
  @dataclasses.dataclass(frozen=True)
@@ -122,8 +159,35 @@ class UniformQuantParams:
122
159
  quantized_data: The quantized data.
123
160
  block_size: The block size for blockwise quantization, block_size=0 meaning
124
161
  no blockwise quantization.
162
+ hadamard: The Hadamard rotation parameters, if set.
125
163
  """
126
164
 
165
+ class HadamardRotationParams:
166
+ """Parameters for the Hadamard rotation.
167
+
168
+ Attributes:
169
+ random_binary_vector: The random binary vector for the Hadamard rotation.
170
+ TODO(b/415392354): Randomization is an experimental feature that's
171
+ currently not implemented yet hence this is always 1. We will add
172
+ support or remove in the future.
173
+ hadamard_size: The size of the Hadamard matrix.
174
+ """
175
+
176
+ random_binary_vector: np.ndarray
177
+ hadamard_size: int
178
+
179
+ def __init__(self, random_binary_vector: np.ndarray, hadamard_size: int):
180
+ self.random_binary_vector = random_binary_vector
181
+ self.hadamard_size = hadamard_size
182
+
183
+ def __eq__(self, other):
184
+ if other.__class__ is not self.__class__:
185
+ return NotImplemented
186
+ return (
187
+ np.array_equal(self.random_binary_vector, other.random_binary_vector)
188
+ and self.hadamard_size == other.hadamard_size
189
+ )
190
+
127
191
  num_bits: int
128
192
  quantized_dimension: Optional[int]
129
193
  scale: np.ndarray
@@ -131,6 +195,7 @@ class UniformQuantParams:
131
195
  symmetric: bool = True
132
196
  quantized_data: Optional[np.ndarray] = None
133
197
  block_size: int = 0
198
+ hadamard: Optional[HadamardRotationParams] = None
134
199
 
135
200
  @classmethod
136
201
  def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams':
@@ -161,6 +226,7 @@ class UniformQuantParams:
161
226
  scale=quant_params['scales'],
162
227
  zero_point=quant_params['zero_points'],
163
228
  symmetric=symmetric,
229
+ block_size=quant_params['block_size'],
164
230
  )
165
231
 
166
232
  def __eq__(self, other):
@@ -174,6 +240,7 @@ class UniformQuantParams:
174
240
  and self.symmetric == other.symmetric
175
241
  and _compare_array_or_none(self.quantized_data, other.quantized_data)
176
242
  and self.block_size == other.block_size
243
+ and self.hadamard == other.hadamard
177
244
  )
178
245
 
179
246
 
@@ -249,14 +316,13 @@ class TensorQuantizationConfig:
249
316
  granularity: Whether to perform per-tensor, per-channel or per-block
250
317
  quantization.
251
318
  dtype: The data type of the tensor.
252
- block_size: The block size for blockwise quantization, ignored otherwise.
319
+ algorithm_key: The algorithm key to use for quantization.
253
320
  """
254
321
 
255
322
  num_bits: int
256
323
  symmetric: bool = True
257
324
  granularity: QuantGranularity = QuantGranularity.TENSORWISE
258
325
  dtype: TensorDataType = TensorDataType.INT
259
- block_size: int = 0
260
326
 
261
327
  def to_dict(self) -> dict[str, Any]:
262
328
  """Converts ActivationQuantizationConfig to dict."""
@@ -274,9 +340,28 @@ class TensorQuantizationConfig:
274
340
  def from_dict(cls, params: dict[str, Any]) -> 'TensorQuantizationConfig':
275
341
  """Converts a given dict to TensorQuantizationConfig."""
276
342
  params_copy = copy.deepcopy(params)
343
+ # Process block_size config from legacy recipe.
344
+ params_copy = _process_block_size(params_copy)
277
345
  return cls(**params_copy)
278
346
 
279
347
 
348
+ def _process_block_size(params: dict[str, Any]) -> dict[str, Any]:
349
+ """Processes block size in the params."""
350
+ block_size = params.pop('block_size', 0)
351
+ if block_size > 0:
352
+ if block_size == 32:
353
+ params['granularity'] = QuantGranularity.BLOCKWISE_32
354
+ elif block_size == 64:
355
+ params['granularity'] = QuantGranularity.BLOCKWISE_64
356
+ elif block_size == 128:
357
+ params['granularity'] = QuantGranularity.BLOCKWISE_128
358
+ elif block_size == 256:
359
+ params['granularity'] = QuantGranularity.BLOCKWISE_256
360
+ else:
361
+ raise ValueError(f'Unsupported block size: {block_size}')
362
+ return params
363
+
364
+
280
365
  @dataclasses.dataclass(frozen=True)
281
366
  class OpQuantizationConfig:
282
367
  """Configuration class to control the quantization process behavior.
@@ -486,6 +571,7 @@ class IOOperator:
486
571
  outputs: list[int]
487
572
  op_key: TFLOperationName
488
573
 
574
+
489
575
  # The function signature for `get_tensor_quant_params_fn`.
490
576
  GetTensorQuantParamsFuncSignature = Callable[
491
577
  [
@@ -18,8 +18,10 @@
18
18
  from collections.abc import Iterable
19
19
  import dataclasses
20
20
  import json
21
+ import logging
21
22
  import os
22
23
  from typing import Any, Optional, Union
24
+
23
25
  from ai_edge_quantizer import algorithm_manager
24
26
  from ai_edge_quantizer import calibrator
25
27
  from ai_edge_quantizer import default_policy
@@ -28,11 +30,11 @@ from ai_edge_quantizer import model_validator
28
30
  from ai_edge_quantizer import params_generator
29
31
  from ai_edge_quantizer import qtyping
30
32
  from ai_edge_quantizer import recipe_manager
31
- from ai_edge_quantizer.utils import test_utils
32
33
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
33
34
  from ai_edge_quantizer.utils import tfl_interpreter_utils
34
35
  from ai_edge_quantizer.utils import validation_utils
35
- from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
36
+ import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
37
+
36
38
 
37
39
  # Expose algorithm names to users.
38
40
  AlgorithmName = algorithm_manager.AlgorithmName
@@ -58,50 +60,62 @@ class QuantizationResult:
58
60
  recipe: _QuantRecipe
59
61
  quantized_model: Optional[bytearray]
60
62
 
61
- def save(self, save_folder: str, model_name: str) -> None:
63
+ def save(
64
+ self, save_folder: str, model_name: str, overwrite: bool = False
65
+ ) -> None:
62
66
  """Saves the quantized model and the quantization recipe.
63
67
 
64
68
  Args:
65
69
  save_folder: Path to the folder to save the quantized model and the
66
70
  quantization recipe.
67
71
  model_name: Name of the model.
72
+ overwrite: Whether to overwrite the model if it already exists.
68
73
 
69
74
  Raises:
70
75
  RuntimeError: If no quantized model is available.
71
- FileExistsError: If the model already exists in the folder.
72
76
  """
73
- if self.quantized_model is None:
74
- raise RuntimeError(
75
- 'No quantized model to save. Make sure .quantize() is called.'
76
- )
77
+ if not os.path.exists(save_folder):
78
+ os.makedirs(save_folder)
79
+
77
80
  model_save_path = os.path.join(save_folder, f'{model_name}.tflite')
78
- if gfile.Exists(model_save_path):
79
- raise FileExistsError(
80
- f'The model {model_save_path} already exists in the folder.'
81
- )
82
- with gfile.GFile(model_save_path, 'wb') as output_file_handle:
83
- output_file_handle.write(self.quantized_model)
81
+ self.export_model(model_save_path, overwrite)
84
82
 
85
- recipe = json.dumps(self.recipe)
86
83
  recipe_save_path = os.path.join(save_folder, model_name + '_recipe.json')
87
- with gfile.GFile(recipe_save_path, 'w') as output_file_handle:
84
+ recipe = json.dumps(self.recipe)
85
+ with open(recipe_save_path, 'w') as output_file_handle:
88
86
  output_file_handle.write(recipe)
89
87
 
90
- def export_model(self, filepath: str) -> None:
88
+ def export_model(self, filepath: str, overwrite: bool = False) -> None:
91
89
  """Exports the quantized model to a .tflite flatbuffer.
92
90
 
93
91
  Args:
94
92
  filepath: Path (including file name) that the exported model should be
95
93
  serialized to.
94
+ overwrite: Whether to overwrite the model if it already exists.
96
95
 
97
96
  Raises:
98
97
  RuntimeError: If no quantized model is available.
98
+ ValueError: If the model already exists in the folder and overwrite is
99
+ False.
99
100
  """
100
101
  if self.quantized_model is None:
101
102
  raise RuntimeError(
102
103
  'No quantized model to save. Make sure .quantize() is called.'
103
104
  )
104
- with gfile.GFile(filepath, 'wb') as output_file_handle:
105
+ if os.path.exists(filepath):
106
+ if overwrite:
107
+ logging.warning(
108
+ 'The model %s already exists in the folder. Overwriting the model'
109
+ ' since overwrite=True.',
110
+ filepath,
111
+ )
112
+ else:
113
+ raise ValueError(
114
+ f'The model {filepath} already exists in the folder. Please'
115
+ ' consider change the model name or specify overwrite=True to'
116
+ ' overwrite the model if needed.'
117
+ )
118
+ with open(filepath, 'wb') as output_file_handle:
105
119
  output_file_handle.write(self.quantized_model)
106
120
 
107
121
 
@@ -112,12 +126,16 @@ class Quantizer:
112
126
  float_model: TFLite model file path or bytearray.
113
127
  quantization_recipe: Quantization recipe .json filepath or in loaded json
114
128
  format.
129
+ previous_quantized_model: Optional previously quantized TFLite model file
130
+ path or bytearray. This is useful for validating a quantized model
131
+ without quantizing it again.
115
132
  """
116
133
 
117
134
  def __init__(
118
135
  self,
119
136
  float_model: Union[str, bytearray],
120
137
  quantization_recipe: Optional[Union[str, _QuantRecipe]] = None,
138
+ previous_quantized_model: Optional[Union[str, bytearray]] = None,
121
139
  ):
122
140
  """Initializes the quantizer.
123
141
 
@@ -125,6 +143,9 @@ class Quantizer:
125
143
  float_model: Path to the float tflite model.
126
144
  quantization_recipe: Quantization recipe in .json filepath or loaded json
127
145
  format.
146
+ previous_quantized_model: Path to an optional previously quantized tflite
147
+ model. This is useful for validating a quantized model without
148
+ quantizing it again.
128
149
  """
129
150
  # Use `float model` as bytes for memory efficiency.
130
151
  self.float_model: bytes = (
@@ -132,6 +153,14 @@ class Quantizer:
132
153
  if isinstance(float_model, str)
133
154
  else float_model
134
155
  )
156
+ if previous_quantized_model is not None:
157
+ self.previous_quantized_model: bytes = (
158
+ tfl_flatbuffer_utils.get_model_content(previous_quantized_model)
159
+ if isinstance(previous_quantized_model, str)
160
+ else previous_quantized_model
161
+ )
162
+ else:
163
+ self.previous_quantized_model = None
135
164
 
136
165
  self._recipe_manager: recipe_manager.RecipeManager = (
137
166
  recipe_manager.RecipeManager()
@@ -139,6 +168,7 @@ class Quantizer:
139
168
  if quantization_recipe is not None:
140
169
  self.load_quantization_recipe(quantization_recipe)
141
170
  self._result: QuantizationResult = QuantizationResult([{}], None)
171
+ self._quantize_called = False
142
172
 
143
173
  def load_quantization_recipe(self, recipe: Union[str, _QuantRecipe]) -> None:
144
174
  """Loads a quantization recipe.
@@ -149,7 +179,7 @@ class Quantizer:
149
179
  recipe: Quantization recipe in json format.
150
180
  """
151
181
  if isinstance(recipe, str):
152
- with gfile.Open(recipe) as json_file:
182
+ with open(recipe) as json_file:
153
183
  recipe = json.load(json_file)
154
184
  self._recipe_manager.load_quantization_recipe(recipe)
155
185
 
@@ -161,7 +191,7 @@ class Quantizer:
161
191
  Args:
162
192
  filename: Config policy filename.
163
193
  """
164
- with gfile.Open(filename, 'r') as f:
194
+ with open(filename, 'r') as f:
165
195
  policy = default_policy.update_default_config_policy(f.read())
166
196
 
167
197
  # Register the policy for MIN_MAX_UNIFORM_QUANT algorithm.
@@ -207,6 +237,109 @@ class Quantizer:
207
237
  regex, operation_name, op_config, algorithm_key
208
238
  )
209
239
 
240
+ def add_dynamic_config(
241
+ self,
242
+ regex: str,
243
+ operation_name: _TFLOpName,
244
+ num_bits: int,
245
+ granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
246
+ algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
247
+ ):
248
+ """Adds a dynamic quantization configuration to the recipe.
249
+
250
+ During dynamic quantization, activations are not processed by AEQ and
251
+ remain in float format. The runtime kernel is expected to quantize these
252
+ activations on-the-fly, as indicated by compute_precision=Integer and
253
+ explicit_dequantize=False.
254
+
255
+ The model quality may suffer due to the on-the-fly quantization. If quality
256
+ is a concern, consider using weight-only
257
+ quantization.
258
+
259
+ Args:
260
+ regex: Regular expression for layer name (op's output tensor name)
261
+ matching.
262
+ operation_name: Target TFLite operation.
263
+ num_bits: Number of bits for quantization.
264
+ granularity: Granularity of quantization.
265
+ algorithm_key: Algorithm key to be applied.
266
+ """
267
+ self._recipe_manager.add_dynamic_config(
268
+ regex, operation_name, num_bits, granularity, algorithm_key
269
+ )
270
+
271
+ def add_weight_only_config(
272
+ self,
273
+ regex: str,
274
+ operation_name: _TFLOpName,
275
+ num_bits: int,
276
+ granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
277
+ algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
278
+ ):
279
+ """Adds a weight only quantization configuration to the recipe.
280
+
281
+ In weight-only quantization, weights are quantized, but the actual operation
282
+ (op) computation remains in float. The quantized weight is explicitly
283
+ dequantized before being fed into the op. This is achieved by inserting a
284
+ dequantize op between the quantized weight and the consuming op. To enable
285
+ this, both compute_precision will be set to Float and explicit_dequantize to
286
+ True.
287
+
288
+ Weight-only quantization is useful for reducing model size but may
289
+ not decrease latency due to float computation. However, quantized model
290
+ generally has better quality than other quantization options (e.g., dynamic
291
+ range quantization) due to no loss of precision on activations. If latency
292
+ is a concern, consider using dynamic quantization.
293
+
294
+ Args:
295
+ regex: Regular expression for layer name matching.
296
+ operation_name: Target TFLite operation.
297
+ num_bits: Number of bits for quantization.
298
+ granularity: Granularity of quantization.
299
+ algorithm_key: Algorithm key to be applied.
300
+ """
301
+ self._recipe_manager.add_weight_only_config(
302
+ regex, operation_name, num_bits, granularity, algorithm_key
303
+ )
304
+
305
+ def add_static_config(
306
+ self,
307
+ regex: str,
308
+ operation_name: _TFLOpName,
309
+ activation_num_bits: int,
310
+ weight_num_bits: int,
311
+ weight_granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
312
+ algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
313
+ ):
314
+ """Adds a static quantization configuration to the recipe.
315
+
316
+ In static quantization, both weights and activations are quantized. This
317
+ requires a calibration step to determine the quantization parameters (e.g.,
318
+ min/max ranges) for activations. The quantized model uses integer arithmetic
319
+ for computations, which can lead to significant latency reductions.
320
+
321
+ However, calibration is needed to determine the quantization parameters for
322
+ activations, which requires sample data and may lead to quality loss. If
323
+ there is no hardware requirement for full integer quantization, consider
324
+ using dynamic quantization for simplicity.
325
+
326
+ Args:
327
+ regex: Regular expression for layer name matching.
328
+ operation_name: Target TFLite operation.
329
+ activation_num_bits: Number of bits for activation quantization.
330
+ weight_num_bits: Number of bits for weight quantization.
331
+ weight_granularity: Granularity of weight quantization.
332
+ algorithm_key: Algorithm key to be applied.
333
+ """
334
+ self._recipe_manager.add_static_config(
335
+ regex,
336
+ operation_name,
337
+ activation_num_bits,
338
+ weight_num_bits,
339
+ weight_granularity,
340
+ algorithm_key,
341
+ )
342
+
210
343
  @property
211
344
  def need_calibration(self) -> bool:
212
345
  """Checks if the current recipe needs calibration."""
@@ -282,7 +415,7 @@ class Quantizer:
282
415
  Raises:
283
416
  RuntimeError: If quantization recipe is empty.
284
417
  """
285
-
418
+ self._quantize_called = True
286
419
  if calibration_result is not None:
287
420
  self._ensure_model_qsv_sufficient(calibration_result)
288
421
 
@@ -301,6 +434,7 @@ class Quantizer:
301
434
  error_metrics: str = 'mse',
302
435
  use_xnnpack: bool = True,
303
436
  num_threads: int = 16,
437
+ validate_output_tensors_only: bool = False,
304
438
  ) -> model_validator.ComparisonResult:
305
439
  """Numerical validation of the quantized model for a model signature.
306
440
 
@@ -319,23 +453,33 @@ class Quantizer:
319
453
  error_metrics: Error metrics to be used for comparison.
320
454
  use_xnnpack: Whether to use the xnnpack library for validation.
321
455
  num_threads: Number of threads to use for validation.
456
+ validate_output_tensors_only: If True, only compare output tensors.
457
+ Otherwise, compare all tensors.
322
458
 
323
459
  Returns:
324
460
  The comparison result.
325
461
  """
326
462
  if test_data is None:
327
463
  # Create test data for all signatures in the model.
328
- test_data = test_utils.create_random_normal_input_data(
464
+ test_data = tfl_interpreter_utils.create_random_normal_input_data(
329
465
  self.float_model, num_samples=1
330
466
  )
467
+ if self._quantize_called:
468
+ quantized_model = self._result.quantized_model
469
+ else:
470
+ quantized_model = self.previous_quantized_model
471
+
472
+ if quantized_model is None:
473
+ raise ValueError('No quantized model available to validate.')
331
474
  return model_validator.compare_model(
332
475
  self.float_model,
333
- self._result.quantized_model,
476
+ quantized_model,
334
477
  test_data,
335
478
  error_metrics,
336
479
  validation_utils.get_validation_func(error_metrics),
337
480
  use_xnnpack=use_xnnpack,
338
481
  num_threads=num_threads,
482
+ validate_output_tensors_only=validate_output_tensors_only,
339
483
  )
340
484
 
341
485
  def _get_quantization_params(