ai-edge-quantizer-nightly 0.4.0.dev20251008__py3-none-any.whl → 0.5.0.dev20251121__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. ai_edge_quantizer/algorithm_manager.py +5 -0
  2. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +49 -25
  3. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +1 -1
  4. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +1 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +5 -3
  6. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +1 -1
  7. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +6 -11
  8. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +18 -14
  9. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +9 -5
  10. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +1 -2
  11. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +40 -13
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +5 -2
  13. ai_edge_quantizer/algorithms/utils/common_utils.py +46 -33
  14. ai_edge_quantizer/calibrator.py +1 -50
  15. ai_edge_quantizer/calibrator_test.py +2 -67
  16. ai_edge_quantizer/default_policy.py +9 -18
  17. ai_edge_quantizer/qtyping.py +25 -3
  18. ai_edge_quantizer/quantizer.py +25 -2
  19. ai_edge_quantizer/quantizer_test.py +56 -6
  20. ai_edge_quantizer/recipe_manager_test.py +0 -6
  21. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +8 -0
  22. ai_edge_quantizer/utils/constrained_ops_utils_test.py +1 -1
  23. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
  24. ai_edge_quantizer/utils/validation_utils.py +80 -5
  25. ai_edge_quantizer/utils/validation_utils_test.py +56 -0
  26. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info}/METADATA +11 -2
  27. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info}/RECORD +30 -30
  28. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info}/WHEEL +1 -1
  29. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info/licenses}/LICENSE +0 -0
  30. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info}/top_level.txt +0 -0
@@ -98,9 +98,7 @@ class Calibrator:
98
98
  qsv_update_func: The function to update the QSVs.
99
99
  """
100
100
  op_codes = self._flatbuffer_model.operatorCodes
101
- if not self._model_qsvs:
102
- self._initialize_model_qsvs(model_recipe_manager)
103
- else:
101
+ if self._model_qsvs:
104
102
  logging.warning(
105
103
  "Calibrator contains non-empty model qsvs, and the current"
106
104
  " calibration process will start on top of this state (i.e., update"
@@ -263,50 +261,3 @@ class Calibrator:
263
261
  output_tensor = subgraph_tensors[output_tensor_idx]
264
262
  scope += tfl_flatbuffer_utils.get_tensor_name(output_tensor)
265
263
  return scope
266
-
267
- # TODO: b/354224138 - Remove code duplication between calibrate and
268
- # _initialize_model_qsvs.
269
- def _initialize_model_qsvs(
270
- self, model_recipe_manager: recipe_manager.RecipeManager
271
- ) -> None:
272
- """Initialize the model qsvs.
273
-
274
- Args:
275
- model_recipe_manager: A RecipeManager object that contains the
276
- quantization recipe.
277
- """
278
- op_codes = self._flatbuffer_model.operatorCodes
279
- for subgraph in self._flatbuffer_model.subgraphs:
280
- graph_info = qtyping.GraphInfo(
281
- subgraph.tensors, self._flatbuffer_model.buffers
282
- )
283
- for subgraph_op_id, op in enumerate(subgraph.operators):
284
- op_code = op_codes[op.opcodeIndex].builtinCode
285
- if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
286
- continue
287
- op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
288
- # Step1: query the quantization_recipe to get op quantization
289
- # settings.
290
- op_scope = self._get_op_scope(op, subgraph.tensors)
291
- algorithm_name, op_quant_config = (
292
- model_recipe_manager.get_quantization_configs(op_key, op_scope)
293
- )
294
- if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
295
- continue
296
- # Step2: query algorithm_manager to get/call the related qsv init
297
- # function.
298
- qsv_init_func = algorithm_manager.get_init_qsv_func(
299
- algorithm_name, op_key
300
- )
301
- op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
302
- # Ignore the input tensors where any dimension of the shape is 0.
303
- inputs_to_ignore = [
304
- opr_idx
305
- for opr_idx, tensor_idx in enumerate(op.inputs)
306
- if not np.all(graph_info.subgraph_tensors[tensor_idx].shape)
307
- ]
308
- op_qsvs = qsv_init_func(op_info, graph_info, inputs_to_ignore)
309
- # Step3: initialize tensor qsvs.
310
- for tensor_name, qsv in op_qsvs.items():
311
- if tensor_name not in self._model_qsvs:
312
- self._model_qsvs[tensor_name] = qsv
@@ -103,58 +103,6 @@ class CalibratorTest(googletest.TestCase):
103
103
  model_tensor_qsvs = self._calibrator.get_model_qsvs()
104
104
  self.assertEmpty(model_tensor_qsvs)
105
105
 
106
- def test_calibrator_initialize_qsv(self):
107
- _add_default_int8xint8_integer_recipe(self._recipe_manager)
108
- # Overwrite the single op to fc
109
- self._recipe_manager.add_quantization_config(
110
- regex=".*Stateful.*",
111
- operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
112
- algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
113
- op_config=qtyping.OpQuantizationConfig(
114
- weight_tensor_config=_TENSOR_QUANT_CONFIG(
115
- num_bits=4,
116
- granularity=qtyping.QuantGranularity.CHANNELWISE,
117
- ),
118
- compute_precision=_ComputePrecision.INTEGER,
119
- ),
120
- )
121
- self._calibrator._initialize_model_qsvs(self._recipe_manager)
122
- model_tensor_qsvs = self._calibrator.get_model_qsvs()
123
-
124
- self.assertLen(model_tensor_qsvs, 4)
125
- self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
126
- input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
127
- self.assertEmpty(input_qsv)
128
-
129
- self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
130
- weight_tensor_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
131
- mins_maxs_shape = (16, 1)
132
- self.assertTupleEqual(weight_tensor_qsv["min"].shape, mins_maxs_shape)
133
- self.assertAlmostEqual(weight_tensor_qsv["min"][0][0], -0.40436327)
134
- self.assertTupleEqual(weight_tensor_qsv["max"].shape, mins_maxs_shape)
135
- self.assertAlmostEqual(weight_tensor_qsv["max"][0][0], 0.46138108)
136
-
137
- self.assertIn(
138
- "sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
139
- ) # bias
140
- bias_tensor_qsv = model_tensor_qsvs[
141
- "sequential/dense/BiasAdd/ReadVariableOp"
142
- ]
143
- mins_maxs_shape = (16,)
144
- self.assertTupleEqual(bias_tensor_qsv["min"].shape, mins_maxs_shape)
145
- self.assertAlmostEqual(bias_tensor_qsv["min"][0], -0.26978338)
146
- self.assertTupleEqual(bias_tensor_qsv["max"].shape, mins_maxs_shape)
147
- # Here bias min/max will be the same as each element is a scalar
148
- # Bias will be quantized with input_scale * weight_scale.
149
- self.assertSequenceEqual(
150
- list(bias_tensor_qsv["max"].flatten()),
151
- list(bias_tensor_qsv["min"].flatten()),
152
- )
153
-
154
- self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
155
- output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
156
- self.assertEmpty(output_qsv)
157
-
158
106
  def test_calibrate_single_fc_success(self):
159
107
  _add_default_int8xint8_integer_recipe(self._recipe_manager)
160
108
  self._calibrator.calibrate(
@@ -162,7 +110,7 @@ class CalibratorTest(googletest.TestCase):
162
110
  )
163
111
  model_tensor_qsvs = self._calibrator.get_model_qsvs()
164
112
 
165
- self.assertLen(model_tensor_qsvs, 4)
113
+ self.assertLen(model_tensor_qsvs, 2)
166
114
  self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
167
115
  input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
168
116
  self.assertSequenceAlmostEqual(
@@ -171,19 +119,6 @@ class CalibratorTest(googletest.TestCase):
171
119
  self.assertSequenceAlmostEqual(
172
120
  input_qsv["max"].flatten(), [TEST_MAX_VAL], delta=1e-5
173
121
  )
174
-
175
- self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
176
- weight_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
177
- self.assertSequenceAlmostEqual(weight_qsv["min"].flatten(), [-0.49114203])
178
- self.assertSequenceAlmostEqual(weight_qsv["max"].flatten(), [0.4903704])
179
-
180
- self.assertIn(
181
- "sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
182
- ) # bias
183
- bias_qsv = model_tensor_qsvs["sequential/dense/BiasAdd/ReadVariableOp"]
184
- self.assertSequenceAlmostEqual(bias_qsv["min"].flatten(), [-0.38401994])
185
- self.assertSequenceAlmostEqual(bias_qsv["max"].flatten(), [0.31727126])
186
-
187
122
  self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
188
123
  output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
189
124
  # Relu, only check the min
@@ -302,7 +237,7 @@ class CalibratorToyGemma2Test(googletest.TestCase):
302
237
  self._toy_gemma2_calibration_dataset,
303
238
  model_recipe_manager=recipe_mngr,
304
239
  )
305
- self.assertLen(calib.get_model_qsvs(), 290)
240
+ self.assertLen(calib.get_model_qsvs(), 202)
306
241
 
307
242
 
308
243
  if __name__ == "__main__":
@@ -61,9 +61,8 @@ DEFAULT_JSON_POLICY = """
61
61
  "weight_tensor_config": {
62
62
  "num_bits": 4,
63
63
  "symmetric": [true],
64
- "granularity": ["BLOCKWISE"],
65
- "dtype": "INT",
66
- "block_size": [32, 64, 96, 128, 256]
64
+ "granularity": ["BLOCKWISE_32", "BLOCKWISE_64", "BLOCKWISE_128", "BLOCKWISE_256"],
65
+ "dtype": "INT"
67
66
  },
68
67
  "explicit_dequantize": false,
69
68
  "compute_precision": "INTEGER"
@@ -178,7 +177,6 @@ DEFAULT_JSON_POLICY = """
178
177
  "INPUT",
179
178
  "OUTPUT",
180
179
  "SLICE",
181
- "EMBEDDING_LOOKUP",
182
180
  "SUM",
183
181
  "SELECT",
184
182
  "SELECT_V2",
@@ -226,7 +224,6 @@ DEFAULT_JSON_POLICY = """
226
224
  "INPUT",
227
225
  "OUTPUT",
228
226
  "SLICE",
229
- "EMBEDDING_LOOKUP",
230
227
  "SUM",
231
228
  "SELECT",
232
229
  "SELECT_V2",
@@ -250,10 +247,11 @@ DEFAULT_JSON_POLICY = """
250
247
  "REDUCE_MIN",
251
248
  "EQUAL",
252
249
  "NOT_EQUAL",
253
- "MIRROR_PAD"
250
+ "MIRROR_PAD",
251
+ "SPACE_TO_DEPTH"
254
252
  ],
255
- "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
256
- "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
253
+ "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
254
+ "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
257
255
  "dynamic_wi8_afp32": [
258
256
  "BATCH_MATMUL",
259
257
  "CONV_2D",
@@ -321,16 +319,9 @@ def _unroll_json_config(
321
319
  "granularity": granularity,
322
320
  "dtype": json_config["weight_tensor_config"]["dtype"],
323
321
  }
324
- if "block_size" in json_config["weight_tensor_config"]:
325
- for block_size in json_config["weight_tensor_config"]["block_size"]:
326
- tensor_config["block_size"] = block_size
327
- weight_configs.append(
328
- qtyping.TensorQuantizationConfig.from_dict(tensor_config)
329
- )
330
- else:
331
- weight_configs.append(
332
- qtyping.TensorQuantizationConfig.from_dict(tensor_config)
333
- )
322
+ weight_configs.append(
323
+ qtyping.TensorQuantizationConfig.from_dict(tensor_config)
324
+ )
334
325
 
335
326
  if activation_configs:
336
327
  for activation_config in activation_configs:
@@ -81,6 +81,7 @@ class TFLOperationName(str, enum.Enum):
81
81
  EQUAL = 'EQUAL'
82
82
  NOT_EQUAL = 'NOT_EQUAL'
83
83
  MIRROR_PAD = 'MIRROR_PAD'
84
+ SPACE_TO_DEPTH = 'SPACE_TO_DEPTH'
84
85
 
85
86
 
86
87
  class QuantizeMode(enum.Enum):
@@ -111,7 +112,11 @@ class TensorDataType(str, enum.Enum):
111
112
  class QuantGranularity(str, enum.Enum):
112
113
  TENSORWISE = 'TENSORWISE'
113
114
  CHANNELWISE = 'CHANNELWISE'
114
- BLOCKWISE = 'BLOCKWISE'
115
+ # Blockwise quantization with various block sizes.
116
+ BLOCKWISE_32 = 'BLOCKWISE_32'
117
+ BLOCKWISE_64 = 'BLOCKWISE_64'
118
+ BLOCKWISE_128 = 'BLOCKWISE_128'
119
+ BLOCKWISE_256 = 'BLOCKWISE_256'
115
120
 
116
121
 
117
122
  class QuantTransformation(enum.Enum):
@@ -309,7 +314,6 @@ class TensorQuantizationConfig:
309
314
  granularity: Whether to perform per-tensor, per-channel or per-block
310
315
  quantization.
311
316
  dtype: The data type of the tensor.
312
- block_size: The block size for blockwise quantization, ignored otherwise.
313
317
  algorithm_key: The algorithm key to use for quantization.
314
318
  """
315
319
 
@@ -317,7 +321,6 @@ class TensorQuantizationConfig:
317
321
  symmetric: bool = True
318
322
  granularity: QuantGranularity = QuantGranularity.TENSORWISE
319
323
  dtype: TensorDataType = TensorDataType.INT
320
- block_size: int = 0
321
324
 
322
325
  def to_dict(self) -> dict[str, Any]:
323
326
  """Converts ActivationQuantizationConfig to dict."""
@@ -335,9 +338,28 @@ class TensorQuantizationConfig:
335
338
  def from_dict(cls, params: dict[str, Any]) -> 'TensorQuantizationConfig':
336
339
  """Converts a given dict to TensorQuantizationConfig."""
337
340
  params_copy = copy.deepcopy(params)
341
+ # Process block_size config from legacy recipe.
342
+ params_copy = _process_block_size(params_copy)
338
343
  return cls(**params_copy)
339
344
 
340
345
 
346
+ def _process_block_size(params: dict[str, Any]) -> dict[str, Any]:
347
+ """Processes block size in the params."""
348
+ block_size = params.pop('block_size', 0)
349
+ if block_size > 0:
350
+ if block_size == 32:
351
+ params['granularity'] = QuantGranularity.BLOCKWISE_32
352
+ elif block_size == 64:
353
+ params['granularity'] = QuantGranularity.BLOCKWISE_64
354
+ elif block_size == 128:
355
+ params['granularity'] = QuantGranularity.BLOCKWISE_128
356
+ elif block_size == 256:
357
+ params['granularity'] = QuantGranularity.BLOCKWISE_256
358
+ else:
359
+ raise ValueError(f'Unsupported block size: {block_size}')
360
+ return params
361
+
362
+
341
363
  @dataclasses.dataclass(frozen=True)
342
364
  class OpQuantizationConfig:
343
365
  """Configuration class to control the quantization process behavior.
@@ -126,12 +126,16 @@ class Quantizer:
126
126
  float_model: TFLite model file path or bytearray.
127
127
  quantization_recipe: Quantization recipe .json filepath or in loaded json
128
128
  format.
129
+ previous_quantized_model: Optional previously quantized TFLite model file
130
+ path or bytearray. This is useful for validating a quantized model
131
+ without quantizing it again.
129
132
  """
130
133
 
131
134
  def __init__(
132
135
  self,
133
136
  float_model: Union[str, bytearray],
134
137
  quantization_recipe: Optional[Union[str, _QuantRecipe]] = None,
138
+ previous_quantized_model: Optional[Union[str, bytearray]] = None,
135
139
  ):
136
140
  """Initializes the quantizer.
137
141
 
@@ -139,6 +143,9 @@ class Quantizer:
139
143
  float_model: Path to the float tflite model.
140
144
  quantization_recipe: Quantization recipe in .json filepath or loaded json
141
145
  format.
146
+ previous_quantized_model: Path to an optional previously quantized tflite
147
+ model. This is useful for validating a quantized model without
148
+ quantizing it again.
142
149
  """
143
150
  # Use `float model` as bytes for memory efficiency.
144
151
  self.float_model: bytes = (
@@ -146,6 +153,14 @@ class Quantizer:
146
153
  if isinstance(float_model, str)
147
154
  else float_model
148
155
  )
156
+ if previous_quantized_model is not None:
157
+ self.previous_quantized_model: bytes = (
158
+ tfl_flatbuffer_utils.get_model_content(previous_quantized_model)
159
+ if isinstance(previous_quantized_model, str)
160
+ else previous_quantized_model
161
+ )
162
+ else:
163
+ self.previous_quantized_model = None
149
164
 
150
165
  self._recipe_manager: recipe_manager.RecipeManager = (
151
166
  recipe_manager.RecipeManager()
@@ -153,6 +168,7 @@ class Quantizer:
153
168
  if quantization_recipe is not None:
154
169
  self.load_quantization_recipe(quantization_recipe)
155
170
  self._result: QuantizationResult = QuantizationResult([{}], None)
171
+ self._quantize_called = False
156
172
 
157
173
  def load_quantization_recipe(self, recipe: Union[str, _QuantRecipe]) -> None:
158
174
  """Loads a quantization recipe.
@@ -399,7 +415,7 @@ class Quantizer:
399
415
  Raises:
400
416
  RuntimeError: If quantization recipe is empty.
401
417
  """
402
-
418
+ self._quantize_called = True
403
419
  if calibration_result is not None:
404
420
  self._ensure_model_qsv_sufficient(calibration_result)
405
421
 
@@ -445,9 +461,16 @@ class Quantizer:
445
461
  test_data = tfl_interpreter_utils.create_random_normal_input_data(
446
462
  self.float_model, num_samples=1
447
463
  )
464
+ if self._quantize_called:
465
+ quantized_model = self._result.quantized_model
466
+ else:
467
+ quantized_model = self.previous_quantized_model
468
+
469
+ if quantized_model is None:
470
+ raise ValueError('No quantized model available to validate.')
448
471
  return model_validator.compare_model(
449
472
  self.float_model,
450
- self._result.quantized_model,
473
+ quantized_model,
451
474
  test_data,
452
475
  error_metrics,
453
476
  validation_utils.get_validation_func(error_metrics),
@@ -212,7 +212,7 @@ class QuantizerTest(parameterized.TestCase):
212
212
  # Calibrate with empty state.
213
213
  calib_data = _get_calibration_data()
214
214
  calibration_result = self._quantizer.calibrate(calib_data)
215
- self.assertLen(calibration_result, 13)
215
+ self.assertLen(calibration_result, 7)
216
216
 
217
217
  @parameterized.parameters(
218
218
  'recipes/default_a8w8_recipe.json',
@@ -227,7 +227,7 @@ class QuantizerTest(parameterized.TestCase):
227
227
  updated_calibration_result = self._quantizer.calibrate(
228
228
  calib_data, previous_calibration_result=calibration_result
229
229
  )
230
- self.assertLen(updated_calibration_result, 13)
230
+ self.assertLen(updated_calibration_result, 7)
231
231
  self.assertNotEqual(
232
232
  calibration_result['StatefulPartitionedCall:0'],
233
233
  updated_calibration_result['StatefulPartitionedCall:0'],
@@ -309,6 +309,44 @@ class QuantizerTest(parameterized.TestCase):
309
309
  saved_recipe = json.load(json_file)
310
310
  self.assertEqual(saved_recipe, self._test_recipe)
311
311
 
312
+ def test_saved_legacy_recipe_lacks_block_size(self):
313
+ model_name = 'test_model'
314
+ legacy_recipe_path = os.path.join(
315
+ TEST_DATA_PREFIX_PATH,
316
+ 'recipes/dynamic_legacy_wi8_afp32_recipe.json',
317
+ )
318
+ self._quantizer.load_quantization_recipe(legacy_recipe_path)
319
+ result = self._quantizer.quantize()
320
+ result.save(self._tmp_save_path, model_name)
321
+ saved_recipe_path = os.path.join(
322
+ self._tmp_save_path, model_name + '_recipe.json'
323
+ )
324
+ with open(saved_recipe_path) as json_file:
325
+ saved_recipe = json.load(json_file)
326
+ with open(legacy_recipe_path) as json_file:
327
+ legacy_recipe = json.load(json_file)
328
+
329
+ self.assertNotEqual(saved_recipe, legacy_recipe)
330
+
331
+ # Verify that the default test recipe contains 'block_size'.
332
+ has_block_size = False
333
+ for config in legacy_recipe:
334
+ op_config = config.get('op_config')
335
+ if op_config:
336
+ weight_config = op_config.get('weight_tensor_config')
337
+ if weight_config and 'block_size' in weight_config:
338
+ has_block_size = True
339
+ break
340
+ self.assertTrue(has_block_size)
341
+
342
+ # Verify that the saved recipe does not have 'block_size'.
343
+ for config in saved_recipe:
344
+ op_config = config.get('op_config')
345
+ if op_config:
346
+ weight_config = op_config.get('weight_tensor_config')
347
+ if weight_config:
348
+ self.assertNotIn('block_size', weight_config)
349
+
312
350
  def test_save_no_quantize_raise_error(self):
313
351
  error_message = 'No quantized model to save.'
314
352
  with self.assertRaisesWithPredicateMatch(
@@ -337,6 +375,21 @@ class QuantizerTest(parameterized.TestCase):
337
375
  'sequential/dense_1/MatMul', validation_result.intermediate_tensors
338
376
  )
339
377
 
378
+ def test_validate_with_quantized_model_arg_succeeds(self):
379
+ self._quantizer.quantize()
380
+ quantized_model = self._quantizer._result.quantized_model
381
+ self.assertIsNotNone(quantized_model)
382
+
383
+ new_quantizer = quantizer.Quantizer(
384
+ self._test_model_path, previous_quantized_model=quantized_model
385
+ )
386
+ validation_result = new_quantizer.validate()
387
+ validation_result = validation_result.get_signature_comparison_result()
388
+ self.assertIsNotNone(validation_result)
389
+ self.assertIn(
390
+ 'sequential/dense_1/MatMul', validation_result.intermediate_tensors
391
+ )
392
+
340
393
  def test_load_custom_policies_succeeds(self):
341
394
 
342
395
  test_op_config = qtyping.OpQuantizationConfig(
@@ -520,14 +573,12 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
520
573
  'symmetric': False,
521
574
  'granularity': 'TENSORWISE',
522
575
  'dtype': 'INT',
523
- 'block_size': 0,
524
576
  },
525
577
  'weight_tensor_config': {
526
578
  'num_bits': 8,
527
579
  'symmetric': True,
528
580
  'granularity': 'CHANNELWISE',
529
581
  'dtype': 'INT',
530
- 'block_size': 0,
531
582
  },
532
583
  'compute_precision': 'INTEGER',
533
584
  'explicit_dequantize': False,
@@ -548,8 +599,7 @@ class QuantizerMultiSignatureModelTest(parameterized.TestCase):
548
599
 
549
600
  # Quantize and expect an error about missing signature in calibration data.
550
601
  error_message = (
551
- 'Missing QSVs (min/max) for tensor multiply_x:0 in Signature'
552
- " 'multiply'."
602
+ 'MUL(index: 0) not found in tensor_name_to_qsv'
553
603
  )
554
604
  with self.assertRaisesWithPredicateMatch(
555
605
  ValueError, lambda err: error_message in str(err)
@@ -569,14 +569,12 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
569
569
  'symmetric': False,
570
570
  'granularity': _QuantGranularity.TENSORWISE,
571
571
  'dtype': 'INT',
572
- 'block_size': 0,
573
572
  },
574
573
  'weight_tensor_config': {
575
574
  'num_bits': 8,
576
575
  'symmetric': True,
577
576
  'granularity': _QuantGranularity.TENSORWISE,
578
577
  'dtype': 'INT',
579
- 'block_size': 0,
580
578
  },
581
579
  # WEIGHT_ONLY.
582
580
  'compute_precision': _ComputePrecision.INTEGER,
@@ -595,7 +593,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
595
593
  'num_bits': 8,
596
594
  'symmetric': True,
597
595
  'granularity': _QuantGranularity.TENSORWISE,
598
- 'block_size': 0,
599
596
  },
600
597
  # WEIGHT_ONLY.
601
598
  'compute_precision': _ComputePrecision.FLOAT,
@@ -614,7 +611,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
614
611
  'num_bits': 4,
615
612
  'symmetric': True,
616
613
  'granularity': _QuantGranularity.TENSORWISE,
617
- 'block_size': 0,
618
614
  },
619
615
  # WEIGHT_ONLY.
620
616
  'compute_precision': _ComputePrecision.FLOAT,
@@ -633,7 +629,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
633
629
  'num_bits': 6,
634
630
  'symmetric': True,
635
631
  'granularity': _QuantGranularity.TENSORWISE,
636
- 'block_size': 0,
637
632
  },
638
633
  # WEIGHT_ONLY.
639
634
  'compute_precision': _ComputePrecision.FLOAT,
@@ -652,7 +647,6 @@ class ConfiguratorTest(parameterized.TestCase, googletest.TestCase):
652
647
  'num_bits': 3,
653
648
  'symmetric': True,
654
649
  'granularity': _QuantGranularity.TENSORWISE,
655
- 'block_size': 0,
656
650
  },
657
651
  # WEIGHT_ONLY.
658
652
  'compute_precision': _ComputePrecision.FLOAT,
@@ -220,6 +220,14 @@ def insert_decomposed_hadamard_rotation(
220
220
  fc_op.opcodeIndex = fc_op_code_idx
221
221
  fc_op.inputs = [prerotate_reshape_output_tensor_id, hadamard_matrix_tensor_id]
222
222
  fc_op.outputs = [fc_output_tensor_id]
223
+ fc_options = schema_py_generated.FullyConnectedOptionsT()
224
+ fc_options.fusedActivationFunction = (
225
+ schema_py_generated.ActivationFunctionType.NONE
226
+ )
227
+ fc_op.builtinOptionsType = (
228
+ schema_py_generated.BuiltinOptions.FullyConnectedOptions
229
+ )
230
+ fc_op.builtinOptions = fc_options
223
231
 
224
232
  # Insert x' = tfl.reshape(x', x.shape)
225
233
  post_reshape_op_code_idx = transformation_utils.add_op_code(
@@ -28,7 +28,7 @@ class ConstrainedOpsUtilsTest(parameterized.TestCase):
28
28
  dict(
29
29
  testcase_name="same_as_input_scale",
30
30
  constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
31
- expected_num_ops=16,
31
+ expected_num_ops=17,
32
32
  ),
33
33
  dict(
34
34
  testcase_name="same_as_output_scale",
@@ -75,6 +75,7 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
75
75
  _TFLOpName.EQUAL: schema.BuiltinOperator.EQUAL,
76
76
  _TFLOpName.NOT_EQUAL: schema.BuiltinOperator.NOT_EQUAL,
77
77
  _TFLOpName.MIRROR_PAD: schema.BuiltinOperator.MIRROR_PAD,
78
+ _TFLOpName.SPACE_TO_DEPTH: schema.BuiltinOperator.SPACE_TO_DEPTH,
78
79
  })
79
80
 
80
81
  TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
@@ -32,7 +32,7 @@ def get_validation_func(
32
32
  a validation function
33
33
 
34
34
  Raises:
35
- Value error if the function name is not supported
35
+ ValueError: if the function name is not supported
36
36
  """
37
37
  if func_name == "mse":
38
38
  return mean_squared_difference
@@ -40,6 +40,10 @@ def get_validation_func(
40
40
  return median_diff_ratio
41
41
  elif func_name == "cosine_similarity":
42
42
  return cosine_similarity
43
+ elif func_name == "kl_divergence":
44
+ return kl_divergence
45
+ elif func_name == "snr":
46
+ return signal_to_noise_ratio
43
47
  else:
44
48
  raise ValueError(f"Validation function {func_name} not supported")
45
49
 
@@ -60,7 +64,7 @@ def mean_squared_difference(
60
64
  a float value representing the MSD between data1 & 2
61
65
 
62
66
  Raises:
63
- Value error if the two inputs don't have the same number of elements
67
+ ValueError: if the two inputs don't have the same number of elements
64
68
  """
65
69
  data1, data2 = _preprocess_same_size_arrays(data1, data2)
66
70
  # special handling for tensor of size 0
@@ -89,7 +93,7 @@ def median_diff_ratio(
89
93
  a float value representing the median diff ratio between data1 & 2
90
94
 
91
95
  Raises:
92
- Value error if the two inputs don't have the same number of elements
96
+ ValueError: if the two inputs don't have the same number of elements
93
97
  """
94
98
  data1, data2 = _preprocess_same_size_arrays(data1, data2)
95
99
  # special handling for tensor of size 0
@@ -118,7 +122,7 @@ def cosine_similarity(
118
122
  a float value representing the cosine similarity between data1 & 2
119
123
 
120
124
  Raises:
121
- Value error if the two inputs don't have the same number of elements
125
+ ValueError: if the two inputs don't have the same number of elements
122
126
  """
123
127
  data1, data2 = _preprocess_same_size_arrays(data1, data2)
124
128
  # special handling for tensor of size 0
@@ -134,6 +138,77 @@ def cosine_similarity(
134
138
  return np.dot(data1, data2) / (norm_data1 * norm_data2)
135
139
 
136
140
 
141
+ def kl_divergence(
142
+ data1: np._typing.ArrayLike,
143
+ data2: np._typing.ArrayLike,
144
+ epsilon: float = 1e-9,
145
+ ) -> float:
146
+ """Calculates the KL divergence between data1 & data2.
147
+
148
+ KL(data2 || data1) = sum(data2 * log(data2 / data1)).
149
+ data2 is treated as the true distribution P, and data1 as the
150
+ approximated distribution Q.
151
+ Non-positive values in data1 and data2 are clipped to 0 before
152
+ KL divergence calculation. Epsilon is added to avoid log(0) and
153
+ division by zero.
154
+
155
+ Args:
156
+ data1: input data to be used for comparison (distribution Q)
157
+ data2: input data to be used for comparison (distribution P), data1 & 2 must
158
+ be of the same shape
159
+ epsilon: small value to avoid log(0) and division by zero.
160
+
161
+ Returns:
162
+ A float value representing the KL divergence between data1 & 2.
163
+
164
+ Raises:
165
+ ValueError: if the two inputs don't have the same number of elements.
166
+ """
167
+ data1, data2 = _preprocess_same_size_arrays(data1, data2)
168
+ # special handling for tensor of size 0
169
+ if data1.size == 0:
170
+ return float(0)
171
+
172
+ p = np.maximum(0, data2)
173
+ q = np.maximum(0, data1)
174
+
175
+ return float(np.sum(p * np.log((p + epsilon) / (q + epsilon))))
176
+
177
+
178
+ def signal_to_noise_ratio(
179
+ noisy_signal: np._typing.ArrayLike,
180
+ signal: np._typing.ArrayLike,
181
+ epsilon: float = 1e-9,
182
+ ) -> float:
183
+ """Calculates the signal to noise ratio between noisy_signal & signal.
184
+
185
+ SNR = P_signal / P_noise, where signal is treated as the clean signal and
186
+ noisy_signal-signal is treated as the noise samples.
187
+ P_signal = mean(signal^2)
188
+ P_noise = mean((noisy_signal-signal)^2) = mse(noisy_signal, signal)
189
+
190
+ Args:
191
+ noisy_signal: Input data to be used for comparison (e.g. noisy signal).
192
+ signal: Input data to be used for comparison (e.g. clean signal),
193
+ noisy_signal & signal must be of the same shape.
194
+ epsilon: Small value to avoid division by zero.
195
+
196
+ Returns:
197
+ A float value representing the SNR between noisy_signal & signal.
198
+
199
+ Raises:
200
+ ValueError: If the two inputs don't have the same number of elements.
201
+ """
202
+ noisy_signal, signal = _preprocess_same_size_arrays(noisy_signal, signal)
203
+ if signal.size == 0:
204
+ return float(0)
205
+
206
+ mse = mean_squared_difference(noisy_signal, signal)
207
+ signal_power = float(np.square(signal).mean())
208
+ snr = signal_power / (mse + epsilon)
209
+ return snr
210
+
211
+
137
212
  def _preprocess_same_size_arrays(
138
213
  data1: np._typing.ArrayLike, data2: np._typing.ArrayLike
139
214
  ) -> Tuple[np.ndarray, np.ndarray]:
@@ -148,7 +223,7 @@ def _preprocess_same_size_arrays(
148
223
  a tuple of the preprocessed data1 & 2
149
224
 
150
225
  Raises:
151
- Value error if the two inputs don't have the same number of elements
226
+ ValueError: if the two inputs don't have the same number of elements
152
227
  """
153
228
  data1 = np.array(data1, dtype=np.float32).flatten()
154
229
  data2 = np.array(data2, dtype=np.float32).flatten()