ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -1135,16 +1135,11 @@ class ParamsGeneratorAlreadyQuantizedModelTest(googletest.TestCase):
1135
1135
  )
1136
1136
  _ = params_generator.ParamsGenerator(test_model_path)
1137
1137
 
1138
- def test_check_is_float_model_raises_error_when_model_is_quantized(self):
1138
+ def test_check_is_quantized_model_succeeds_when_model_is_quantized(self):
1139
1139
  test_model_path = os.path.join(
1140
1140
  TEST_DATA_PREFIX_PATH, 'tests/models/mnist_quantized.tflite'
1141
1141
  )
1142
- with self.assertRaisesRegex(
1143
- ValueError,
1144
- 'The input model for quantization parameters generation is not a float'
1145
- ' model.',
1146
- ):
1147
- _ = params_generator.ParamsGenerator(test_model_path)
1142
+ _ = params_generator.ParamsGenerator(test_model_path)
1148
1143
 
1149
1144
 
1150
1145
  if __name__ == '__main__':
@@ -20,7 +20,7 @@ from collections.abc import MutableMapping
20
20
  import copy
21
21
  import dataclasses
22
22
  import enum
23
- from typing import Any, Optional, Union, Callable
23
+ from typing import Any, Callable, Optional, Union
24
24
 
25
25
  import numpy as np
26
26
  from typing_extensions import TypeAlias
@@ -59,9 +59,31 @@ class TFLOperationName(str, enum.Enum):
59
59
  LOGISTIC = 'LOGISTIC'
60
60
  SLICE = 'SLICE'
61
61
  SUM = 'SUM'
62
+ SELECT = 'SELECT'
62
63
  SELECT_V2 = 'SELECT_V2'
63
64
  DYNAMIC_UPDATE_SLICE = 'DYNAMIC_UPDATE_SLICE'
64
65
  STABLEHLO_COMPOSITE = 'STABLEHLO_COMPOSITE'
66
+ PAD = 'PAD'
67
+ SQUARED_DIFFERENCE = 'SQUARED_DIFFERENCE'
68
+ MAX_POOL_2D = 'MAX_POOL_2D'
69
+ RESIZE_BILINEAR = 'RESIZE_BILINEAR'
70
+ RESIZE_NEAREST_NEIGHBOR = 'RESIZE_NEAREST_NEIGHBOR'
71
+ GATHER_ND = 'GATHER_ND'
72
+ PACK = 'PACK'
73
+ UNPACK = 'UNPACK'
74
+ DIV = 'DIV'
75
+ BROADCAST_TO = 'BROADCAST_TO'
76
+ SQRT = 'SQRT'
77
+ GATHER = 'GATHER'
78
+ HARD_SWISH = 'HARD_SWISH'
79
+ MAXIMUM = 'MAXIMUM'
80
+ PADV2 = 'PADV2'
81
+ REDUCE_MIN = 'REDUCE_MIN'
82
+ EQUAL = 'EQUAL'
83
+ NOT_EQUAL = 'NOT_EQUAL'
84
+ MIRROR_PAD = 'MIRROR_PAD'
85
+ SPACE_TO_DEPTH = 'SPACE_TO_DEPTH'
86
+ RELU = 'RELU'
65
87
 
66
88
 
67
89
  class QuantizeMode(enum.Enum):
@@ -92,7 +114,11 @@ class TensorDataType(str, enum.Enum):
92
114
  class QuantGranularity(str, enum.Enum):
93
115
  TENSORWISE = 'TENSORWISE'
94
116
  CHANNELWISE = 'CHANNELWISE'
95
- BLOCKWISE = 'BLOCKWISE'
117
+ # Blockwise quantization with various block sizes.
118
+ BLOCKWISE_32 = 'BLOCKWISE_32'
119
+ BLOCKWISE_64 = 'BLOCKWISE_64'
120
+ BLOCKWISE_128 = 'BLOCKWISE_128'
121
+ BLOCKWISE_256 = 'BLOCKWISE_256'
96
122
 
97
123
 
98
124
  class QuantTransformation(enum.Enum):
@@ -106,13 +132,18 @@ class QuantTransformation(enum.Enum):
106
132
  ADD_DEQUANTIZE = 2
107
133
  # Quantize the float tensor: float_tensor -> quantized_tensor.
108
134
  QUANTIZE_TENSOR = 3
109
- # Create pattern for emulated subchannel quantization, only support fully
110
- # connected op.
135
+ # (Deprecated) Create pattern for emulated subchannel quantization,
136
+ # only support fully connected op.
111
137
  EMULATED_SUBCHANNEL = 4
112
138
  # Duplicate the buffer.
113
139
  DUPLICATE_BUFFER = 5
114
140
  # Duplicate the tensor.
115
141
  DUPLICATE_TENSOR = 6
142
+ # Insert the aeq.hadamard_rotation op.
143
+ INSERT_HADAMARD_ROTATION = 7
144
+ # Insert decomposed Hadamard rotation ops. This expresses the Hadamard
145
+ # rotation as matrix multiplication with Hadamard matrices.
146
+ INSERT_DECOMPOSED_HADAMARD_ROTATION = 8
116
147
 
117
148
 
118
149
  @dataclasses.dataclass(frozen=True)
@@ -128,8 +159,35 @@ class UniformQuantParams:
128
159
  quantized_data: The quantized data.
129
160
  block_size: The block size for blockwise quantization, block_size=0 meaning
130
161
  no blockwise quantization.
162
+ hadamard: The Hadamard rotation parameters, if set.
131
163
  """
132
164
 
165
+ class HadamardRotationParams:
166
+ """Parameters for the Hadamard rotation.
167
+
168
+ Attributes:
169
+ random_binary_vector: The random binary vector for the Hadamard rotation.
170
+ TODO(b/415392354): Randomization is an experimental feature that's
171
+ currently not implemented yet hence this is always 1. We will add
172
+ support or remove in the future.
173
+ hadamard_size: The size of the Hadamard matrix.
174
+ """
175
+
176
+ random_binary_vector: np.ndarray
177
+ hadamard_size: int
178
+
179
+ def __init__(self, random_binary_vector: np.ndarray, hadamard_size: int):
180
+ self.random_binary_vector = random_binary_vector
181
+ self.hadamard_size = hadamard_size
182
+
183
+ def __eq__(self, other):
184
+ if other.__class__ is not self.__class__:
185
+ return NotImplemented
186
+ return (
187
+ np.array_equal(self.random_binary_vector, other.random_binary_vector)
188
+ and self.hadamard_size == other.hadamard_size
189
+ )
190
+
133
191
  num_bits: int
134
192
  quantized_dimension: Optional[int]
135
193
  scale: np.ndarray
@@ -137,6 +195,7 @@ class UniformQuantParams:
137
195
  symmetric: bool = True
138
196
  quantized_data: Optional[np.ndarray] = None
139
197
  block_size: int = 0
198
+ hadamard: Optional[HadamardRotationParams] = None
140
199
 
141
200
  @classmethod
142
201
  def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams':
@@ -167,6 +226,7 @@ class UniformQuantParams:
167
226
  scale=quant_params['scales'],
168
227
  zero_point=quant_params['zero_points'],
169
228
  symmetric=symmetric,
229
+ block_size=quant_params['block_size'],
170
230
  )
171
231
 
172
232
  def __eq__(self, other):
@@ -180,6 +240,7 @@ class UniformQuantParams:
180
240
  and self.symmetric == other.symmetric
181
241
  and _compare_array_or_none(self.quantized_data, other.quantized_data)
182
242
  and self.block_size == other.block_size
243
+ and self.hadamard == other.hadamard
183
244
  )
184
245
 
185
246
 
@@ -255,14 +316,13 @@ class TensorQuantizationConfig:
255
316
  granularity: Whether to perform per-tensor, per-channel or per-block
256
317
  quantization.
257
318
  dtype: The data type of the tensor.
258
- block_size: The block size for blockwise quantization, ignored otherwise.
319
+ algorithm_key: The algorithm key to use for quantization.
259
320
  """
260
321
 
261
322
  num_bits: int
262
323
  symmetric: bool = True
263
324
  granularity: QuantGranularity = QuantGranularity.TENSORWISE
264
325
  dtype: TensorDataType = TensorDataType.INT
265
- block_size: int = 0
266
326
 
267
327
  def to_dict(self) -> dict[str, Any]:
268
328
  """Converts ActivationQuantizationConfig to dict."""
@@ -280,9 +340,28 @@ class TensorQuantizationConfig:
280
340
  def from_dict(cls, params: dict[str, Any]) -> 'TensorQuantizationConfig':
281
341
  """Converts a given dict to TensorQuantizationConfig."""
282
342
  params_copy = copy.deepcopy(params)
343
+ # Process block_size config from legacy recipe.
344
+ params_copy = _process_block_size(params_copy)
283
345
  return cls(**params_copy)
284
346
 
285
347
 
348
+ def _process_block_size(params: dict[str, Any]) -> dict[str, Any]:
349
+ """Processes block size in the params."""
350
+ block_size = params.pop('block_size', 0)
351
+ if block_size > 0:
352
+ if block_size == 32:
353
+ params['granularity'] = QuantGranularity.BLOCKWISE_32
354
+ elif block_size == 64:
355
+ params['granularity'] = QuantGranularity.BLOCKWISE_64
356
+ elif block_size == 128:
357
+ params['granularity'] = QuantGranularity.BLOCKWISE_128
358
+ elif block_size == 256:
359
+ params['granularity'] = QuantGranularity.BLOCKWISE_256
360
+ else:
361
+ raise ValueError(f'Unsupported block size: {block_size}')
362
+ return params
363
+
364
+
286
365
  @dataclasses.dataclass(frozen=True)
287
366
  class OpQuantizationConfig:
288
367
  """Configuration class to control the quantization process behavior.
@@ -492,6 +571,7 @@ class IOOperator:
492
571
  outputs: list[int]
493
572
  op_key: TFLOperationName
494
573
 
574
+
495
575
  # The function signature for `get_tensor_quant_params_fn`.
496
576
  GetTensorQuantParamsFuncSignature = Callable[
497
577
  [
@@ -18,8 +18,10 @@
18
18
  from collections.abc import Iterable
19
19
  import dataclasses
20
20
  import json
21
+ import logging
21
22
  import os
22
23
  from typing import Any, Optional, Union
24
+
23
25
  from ai_edge_quantizer import algorithm_manager
24
26
  from ai_edge_quantizer import calibrator
25
27
  from ai_edge_quantizer import default_policy
@@ -31,7 +33,8 @@ from ai_edge_quantizer import recipe_manager
31
33
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
32
34
  from ai_edge_quantizer.utils import tfl_interpreter_utils
33
35
  from ai_edge_quantizer.utils import validation_utils
34
- from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
36
+ import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
37
+
35
38
 
36
39
  # Expose algorithm names to users.
37
40
  AlgorithmName = algorithm_manager.AlgorithmName
@@ -57,50 +60,62 @@ class QuantizationResult:
57
60
  recipe: _QuantRecipe
58
61
  quantized_model: Optional[bytearray]
59
62
 
60
- def save(self, save_folder: str, model_name: str) -> None:
63
+ def save(
64
+ self, save_folder: str, model_name: str, overwrite: bool = False
65
+ ) -> None:
61
66
  """Saves the quantized model and the quantization recipe.
62
67
 
63
68
  Args:
64
69
  save_folder: Path to the folder to save the quantized model and the
65
70
  quantization recipe.
66
71
  model_name: Name of the model.
72
+ overwrite: Whether to overwrite the model if it already exists.
67
73
 
68
74
  Raises:
69
75
  RuntimeError: If no quantized model is available.
70
- FileExistsError: If the model already exists in the folder.
71
76
  """
72
- if self.quantized_model is None:
73
- raise RuntimeError(
74
- 'No quantized model to save. Make sure .quantize() is called.'
75
- )
77
+ if not os.path.exists(save_folder):
78
+ os.makedirs(save_folder)
79
+
76
80
  model_save_path = os.path.join(save_folder, f'{model_name}.tflite')
77
- if gfile.Exists(model_save_path):
78
- raise FileExistsError(
79
- f'The model {model_save_path} already exists in the folder.'
80
- )
81
- with gfile.GFile(model_save_path, 'wb') as output_file_handle:
82
- output_file_handle.write(self.quantized_model)
81
+ self.export_model(model_save_path, overwrite)
83
82
 
84
- recipe = json.dumps(self.recipe)
85
83
  recipe_save_path = os.path.join(save_folder, model_name + '_recipe.json')
86
- with gfile.GFile(recipe_save_path, 'w') as output_file_handle:
84
+ recipe = json.dumps(self.recipe)
85
+ with open(recipe_save_path, 'w') as output_file_handle:
87
86
  output_file_handle.write(recipe)
88
87
 
89
- def export_model(self, filepath: str) -> None:
88
+ def export_model(self, filepath: str, overwrite: bool = False) -> None:
90
89
  """Exports the quantized model to a .tflite flatbuffer.
91
90
 
92
91
  Args:
93
92
  filepath: Path (including file name) that the exported model should be
94
93
  serialized to.
94
+ overwrite: Whether to overwrite the model if it already exists.
95
95
 
96
96
  Raises:
97
97
  RuntimeError: If no quantized model is available.
98
+ ValueError: If the model already exists in the folder and overwrite is
99
+ False.
98
100
  """
99
101
  if self.quantized_model is None:
100
102
  raise RuntimeError(
101
103
  'No quantized model to save. Make sure .quantize() is called.'
102
104
  )
103
- with gfile.GFile(filepath, 'wb') as output_file_handle:
105
+ if os.path.exists(filepath):
106
+ if overwrite:
107
+ logging.warning(
108
+ 'The model %s already exists in the folder. Overwriting the model'
109
+ ' since overwrite=True.',
110
+ filepath,
111
+ )
112
+ else:
113
+ raise ValueError(
114
+ f'The model {filepath} already exists in the folder. Please'
115
+ ' consider change the model name or specify overwrite=True to'
116
+ ' overwrite the model if needed.'
117
+ )
118
+ with open(filepath, 'wb') as output_file_handle:
104
119
  output_file_handle.write(self.quantized_model)
105
120
 
106
121
 
@@ -111,12 +126,16 @@ class Quantizer:
111
126
  float_model: TFLite model file path or bytearray.
112
127
  quantization_recipe: Quantization recipe .json filepath or in loaded json
113
128
  format.
129
+ previous_quantized_model: Optional previously quantized TFLite model file
130
+ path or bytearray. This is useful for validating a quantized model
131
+ without quantizing it again.
114
132
  """
115
133
 
116
134
  def __init__(
117
135
  self,
118
136
  float_model: Union[str, bytearray],
119
137
  quantization_recipe: Optional[Union[str, _QuantRecipe]] = None,
138
+ previous_quantized_model: Optional[Union[str, bytearray]] = None,
120
139
  ):
121
140
  """Initializes the quantizer.
122
141
 
@@ -124,6 +143,9 @@ class Quantizer:
124
143
  float_model: Path to the float tflite model.
125
144
  quantization_recipe: Quantization recipe in .json filepath or loaded json
126
145
  format.
146
+ previous_quantized_model: Path to an optional previously quantized tflite
147
+ model. This is useful for validating a quantized model without
148
+ quantizing it again.
127
149
  """
128
150
  # Use `float model` as bytes for memory efficiency.
129
151
  self.float_model: bytes = (
@@ -131,6 +153,14 @@ class Quantizer:
131
153
  if isinstance(float_model, str)
132
154
  else float_model
133
155
  )
156
+ if previous_quantized_model is not None:
157
+ self.previous_quantized_model: bytes = (
158
+ tfl_flatbuffer_utils.get_model_content(previous_quantized_model)
159
+ if isinstance(previous_quantized_model, str)
160
+ else previous_quantized_model
161
+ )
162
+ else:
163
+ self.previous_quantized_model = None
134
164
 
135
165
  self._recipe_manager: recipe_manager.RecipeManager = (
136
166
  recipe_manager.RecipeManager()
@@ -138,6 +168,7 @@ class Quantizer:
138
168
  if quantization_recipe is not None:
139
169
  self.load_quantization_recipe(quantization_recipe)
140
170
  self._result: QuantizationResult = QuantizationResult([{}], None)
171
+ self._quantize_called = False
141
172
 
142
173
  def load_quantization_recipe(self, recipe: Union[str, _QuantRecipe]) -> None:
143
174
  """Loads a quantization recipe.
@@ -148,7 +179,7 @@ class Quantizer:
148
179
  recipe: Quantization recipe in json format.
149
180
  """
150
181
  if isinstance(recipe, str):
151
- with gfile.Open(recipe) as json_file:
182
+ with open(recipe) as json_file:
152
183
  recipe = json.load(json_file)
153
184
  self._recipe_manager.load_quantization_recipe(recipe)
154
185
 
@@ -160,7 +191,7 @@ class Quantizer:
160
191
  Args:
161
192
  filename: Config policy filename.
162
193
  """
163
- with gfile.Open(filename, 'r') as f:
194
+ with open(filename, 'r') as f:
164
195
  policy = default_policy.update_default_config_policy(f.read())
165
196
 
166
197
  # Register the policy for MIN_MAX_UNIFORM_QUANT algorithm.
@@ -206,6 +237,109 @@ class Quantizer:
206
237
  regex, operation_name, op_config, algorithm_key
207
238
  )
208
239
 
240
+ def add_dynamic_config(
241
+ self,
242
+ regex: str,
243
+ operation_name: _TFLOpName,
244
+ num_bits: int,
245
+ granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
246
+ algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
247
+ ):
248
+ """Adds a dynamic quantization configuration to the recipe.
249
+
250
+ During dynamic quantization, activations are not processed by AEQ and
251
+ remain in float format. The runtime kernel is expected to quantize these
252
+ activations on-the-fly, as indicated by compute_precision=Integer and
253
+ explicit_dequantize=False.
254
+
255
+ The model quality may suffer due to the on-the-fly quantization. If quality
256
+ is a concern, consider using weight-only
257
+ quantization.
258
+
259
+ Args:
260
+ regex: Regular expression for layer name (op's output tensor name)
261
+ matching.
262
+ operation_name: Target TFLite operation.
263
+ num_bits: Number of bits for quantization.
264
+ granularity: Granularity of quantization.
265
+ algorithm_key: Algorithm key to be applied.
266
+ """
267
+ self._recipe_manager.add_dynamic_config(
268
+ regex, operation_name, num_bits, granularity, algorithm_key
269
+ )
270
+
271
+ def add_weight_only_config(
272
+ self,
273
+ regex: str,
274
+ operation_name: _TFLOpName,
275
+ num_bits: int,
276
+ granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
277
+ algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
278
+ ):
279
+ """Adds a weight only quantization configuration to the recipe.
280
+
281
+ In weight-only quantization, weights are quantized, but the actual operation
282
+ (op) computation remains in float. The quantized weight is explicitly
283
+ dequantized before being fed into the op. This is achieved by inserting a
284
+ dequantize op between the quantized weight and the consuming op. To enable
285
+ this, both compute_precision will be set to Float and explicit_dequantize to
286
+ True.
287
+
288
+ Weight-only quantization is useful for reducing model size but may
289
+ not decrease latency due to float computation. However, quantized model
290
+ generally has better quality than other quantization options (e.g., dynamic
291
+ range quantization) due to no loss of precision on activations. If latency
292
+ is a concern, consider using dynamic quantization.
293
+
294
+ Args:
295
+ regex: Regular expression for layer name matching.
296
+ operation_name: Target TFLite operation.
297
+ num_bits: Number of bits for quantization.
298
+ granularity: Granularity of quantization.
299
+ algorithm_key: Algorithm key to be applied.
300
+ """
301
+ self._recipe_manager.add_weight_only_config(
302
+ regex, operation_name, num_bits, granularity, algorithm_key
303
+ )
304
+
305
+ def add_static_config(
306
+ self,
307
+ regex: str,
308
+ operation_name: _TFLOpName,
309
+ activation_num_bits: int,
310
+ weight_num_bits: int,
311
+ weight_granularity: qtyping.QuantGranularity = qtyping.QuantGranularity.CHANNELWISE,
312
+ algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
313
+ ):
314
+ """Adds a static quantization configuration to the recipe.
315
+
316
+ In static quantization, both weights and activations are quantized. This
317
+ requires a calibration step to determine the quantization parameters (e.g.,
318
+ min/max ranges) for activations. The quantized model uses integer arithmetic
319
+ for computations, which can lead to significant latency reductions.
320
+
321
+ However, calibration is needed to determine the quantization parameters for
322
+ activations, which requires sample data and may lead to quality loss. If
323
+ there is no hardware requirement for full integer quantization, consider
324
+ using dynamic quantization for simplicity.
325
+
326
+ Args:
327
+ regex: Regular expression for layer name matching.
328
+ operation_name: Target TFLite operation.
329
+ activation_num_bits: Number of bits for activation quantization.
330
+ weight_num_bits: Number of bits for weight quantization.
331
+ weight_granularity: Granularity of weight quantization.
332
+ algorithm_key: Algorithm key to be applied.
333
+ """
334
+ self._recipe_manager.add_static_config(
335
+ regex,
336
+ operation_name,
337
+ activation_num_bits,
338
+ weight_num_bits,
339
+ weight_granularity,
340
+ algorithm_key,
341
+ )
342
+
209
343
  @property
210
344
  def need_calibration(self) -> bool:
211
345
  """Checks if the current recipe needs calibration."""
@@ -281,7 +415,7 @@ class Quantizer:
281
415
  Raises:
282
416
  RuntimeError: If quantization recipe is empty.
283
417
  """
284
-
418
+ self._quantize_called = True
285
419
  if calibration_result is not None:
286
420
  self._ensure_model_qsv_sufficient(calibration_result)
287
421
 
@@ -300,6 +434,7 @@ class Quantizer:
300
434
  error_metrics: str = 'mse',
301
435
  use_xnnpack: bool = True,
302
436
  num_threads: int = 16,
437
+ validate_output_tensors_only: bool = False,
303
438
  ) -> model_validator.ComparisonResult:
304
439
  """Numerical validation of the quantized model for a model signature.
305
440
 
@@ -318,6 +453,8 @@ class Quantizer:
318
453
  error_metrics: Error metrics to be used for comparison.
319
454
  use_xnnpack: Whether to use the xnnpack library for validation.
320
455
  num_threads: Number of threads to use for validation.
456
+ validate_output_tensors_only: If True, only compare output tensors.
457
+ Otherwise, compare all tensors.
321
458
 
322
459
  Returns:
323
460
  The comparison result.
@@ -327,14 +464,22 @@ class Quantizer:
327
464
  test_data = tfl_interpreter_utils.create_random_normal_input_data(
328
465
  self.float_model, num_samples=1
329
466
  )
467
+ if self._quantize_called:
468
+ quantized_model = self._result.quantized_model
469
+ else:
470
+ quantized_model = self.previous_quantized_model
471
+
472
+ if quantized_model is None:
473
+ raise ValueError('No quantized model available to validate.')
330
474
  return model_validator.compare_model(
331
475
  self.float_model,
332
- self._result.quantized_model,
476
+ quantized_model,
333
477
  test_data,
334
478
  error_metrics,
335
479
  validation_utils.get_validation_func(error_metrics),
336
480
  use_xnnpack=use_xnnpack,
337
481
  num_threads=num_threads,
482
+ validate_output_tensors_only=validate_output_tensors_only,
338
483
  )
339
484
 
340
485
  def _get_quantization_params(