ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,11 @@
15
15
 
16
16
  """Tests for tensor_utils."""
17
17
 
18
+ import dataclasses
19
+
18
20
  from absl.testing import parameterized
19
21
  import numpy as np
22
+
20
23
  from tensorflow.python.platform import googletest
21
24
  from ai_edge_quantizer import qtyping
22
25
  from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
@@ -123,6 +126,14 @@ class TensorUtilsTest(parameterized.TestCase):
123
126
  False,
124
127
  [-24, 10, 19, 127],
125
128
  ),
129
+ (
130
+ [-16.0, 1.3, 2.4, 16.0],
131
+ [0.12598425],
132
+ [0],
133
+ 8,
134
+ True,
135
+ [-127, 10, 19, 127], # int8 symmetric is narrow range, -127 to 127
136
+ ),
126
137
  (
127
138
  [-3.0, 1.3, 2.4, 16.0],
128
139
  [1.2666667],
@@ -137,7 +148,7 @@ class TensorUtilsTest(parameterized.TestCase):
137
148
  [-6],
138
149
  4,
139
150
  True,
140
- [-7, -5, -4, 7],
151
+ [-8, -5, -4, 7], # int4 symmetric is not narrow range, -8 to 7
141
152
  ),
142
153
  )
143
154
  def test_uniform_quantize(
@@ -160,7 +171,9 @@ class TensorUtilsTest(parameterized.TestCase):
160
171
  def test_uniform_quantize_wrong_shape(self):
161
172
  tensor = [-3.0, 1.3, 2.4, 16.0]
162
173
 
163
- error_message = "scale and zero_point must have the same shape."
174
+ error_message = (
175
+ "Ranks of scales (3) and zps (2) must be the same as the tensor rank"
176
+ )
164
177
  with self.assertRaisesWithPredicateMatch(
165
178
  ValueError, lambda err: error_message in str(err)
166
179
  ):
@@ -190,6 +203,28 @@ class TensorUtilsTest(parameterized.TestCase):
190
203
  ),
191
204
  )
192
205
 
206
+ def test_uniform_quantize_quant_dim_not_divisible_by_block_size_raise(self):
207
+ tensor = np.random.rand(34, 2)
208
+ error_message = (
209
+ "Quantized dimension 34 in tensor shape (34, 2) is not divisible by"
210
+ " block size 32."
211
+ )
212
+ with self.assertRaisesWithPredicateMatch(
213
+ ValueError, lambda err: error_message in str(err)
214
+ ):
215
+ uniform_quantize_tensor.uniform_quantize(
216
+ np.array(tensor),
217
+ qtyping.UniformQuantParams(
218
+ quantized_dimension=0,
219
+ block_size=32,
220
+ num_bits=4,
221
+ scale=np.array([1.2666667]),
222
+ zero_point=np.array([-6]),
223
+ symmetric=True,
224
+ ),
225
+ is_blockwise_quant=True,
226
+ )
227
+
193
228
  @parameterized.parameters(
194
229
  (
195
230
  8,
@@ -233,7 +268,9 @@ class TensorUtilsTest(parameterized.TestCase):
233
268
  def test_uniform_dequantize_wrong_shape(self):
234
269
  tensor = [-3.0, 1.3, 2.4, 16.0]
235
270
 
236
- error_message = "scale and zero_point must have the same shape."
271
+ error_message = (
272
+ "Ranks of scales (3) and zps (2) must be the same as the tensor rank"
273
+ )
237
274
  with self.assertRaisesWithPredicateMatch(
238
275
  ValueError, lambda err: error_message in str(err)
239
276
  ):
@@ -263,8 +300,35 @@ class TensorUtilsTest(parameterized.TestCase):
263
300
  ),
264
301
  )
265
302
 
303
+ def test_uniform_dequantize_blockwise(self):
304
+ quantized_tensor = np.array([[-8, -5, -4, 7], [-4, 7, -8, -5]])
305
+ expected_output_tensor = np.array([
306
+ [-10.1333336, -6.3333335, -5.0666668, 8.8666669],
307
+ [-5.0666668, 8.8666669, -10.1333336, -6.3333335],
308
+ ])
309
+ quant_params = qtyping.UniformQuantParams(
310
+ # b/443830202:
311
+ quantized_dimension=0,
312
+ num_bits=4,
313
+ scale=np.array([[[1.2666667, 1.2666667], [1.2666667, 1.2666667]]]),
314
+ zero_point=np.array([[0]]),
315
+ symmetric=True,
316
+ block_size=2,
317
+ )
318
+
319
+ dequantized_tensor = uniform_quantize_tensor.uniform_dequantize(
320
+ np.array(quantized_tensor), quant_params
321
+ )
322
+
323
+ self.assertSequenceAlmostEqual(
324
+ expected_output_tensor.flatten(), dequantized_tensor.flatten(), places=4
325
+ )
326
+
266
327
  @parameterized.parameters(
267
- (8, 8, True, True), (8, 4, False, True), (16, 8, True, False)
328
+ (8, 8, True, True),
329
+ (8, 4, False, True),
330
+ (16, 8, True, False),
331
+ (16, 8, True, True),
268
332
  )
269
333
  def test_quantize_bias_tensor(
270
334
  self,
@@ -322,6 +386,26 @@ class TensorUtilsTest(parameterized.TestCase):
322
386
  self.assertSequenceAlmostEqual(
323
387
  list(dequantized_bias.flatten()), list(bias_tensor_data), places=5
324
388
  )
389
+
390
+ if activation_num_bits == 16:
391
+ # Check if it is safe to cast int64 bias to int32. We save the int32
392
+ # quantized bias as int64 if the input tensor is quantized to 16 bits.
393
+ # This is to assume the matmul is using int64 accumulator (safe from
394
+ # overflow). For accelerators with int32 accumulator, it is safe to cast
395
+ # int64 back to int32.
396
+ quantized_bias = bias_quant_config.quantized_data
397
+ self.assertIsNotNone(quantized_bias)
398
+ self.assertEqual(quantized_bias.dtype, np.int64)
399
+ self.assertSequenceEqual(
400
+ list(quantized_bias.flatten()),
401
+ list(quantized_bias.astype(np.int32).flatten()),
402
+ )
403
+
404
+ bias_quant_config = dataclasses.replace(
405
+ bias_quant_config,
406
+ num_bits=32,
407
+ )
408
+
325
409
  expected_quantized_data = uniform_quantize_tensor.uniform_quantize(
326
410
  bias_tensor_data, bias_quant_config
327
411
  )
@@ -330,13 +414,44 @@ class TensorUtilsTest(parameterized.TestCase):
330
414
  list(bias_quant_config.quantized_data.flatten()), # pytype: disable=attribute-error
331
415
  )
332
416
 
417
+ def test_quantize_bias_tensor_raises_error_for_large_quantization_error(self):
418
+ input_quant_config = qtyping.UniformQuantParams(
419
+ scale=np.array([0.1]),
420
+ zero_point=np.array([10]),
421
+ num_bits=8,
422
+ symmetric=False,
423
+ quantized_dimension=None,
424
+ )
425
+ weight_quant_config = qtyping.UniformQuantParams(
426
+ scale=np.array([0.1]),
427
+ zero_point=np.array([-1]),
428
+ num_bits=8,
429
+ symmetric=True,
430
+ quantized_dimension=None,
431
+ )
432
+ # This will result in quantized bias of 3e9, which is larger than int32 max.
433
+ bias_tensor_data = np.array([3e7])
434
+ with self.assertRaisesRegex(
435
+ ValueError,
436
+ "Quantization error is too large for bias tensor quantization.",
437
+ ):
438
+ uniform_quantize_tensor.symmetric_quantize_bias_tensor(
439
+ bias_tensor_data,
440
+ input_quant_config,
441
+ weight_quant_config,
442
+ )
443
+
333
444
  @parameterized.parameters((8, True), (16, False))
334
445
  def test_tensor_zp_scale_from_min_max(self, num_bits, symmetric):
335
446
  min_val = np.min(self._test_data, keepdims=True)
336
447
  max_val = np.max(self._test_data, keepdims=True)
337
448
 
338
449
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
339
- min_val, max_val, num_bits, symmetric
450
+ min_val,
451
+ max_val,
452
+ num_bits,
453
+ symmetric,
454
+ qtyping.QuantGranularity.TENSORWISE,
340
455
  )
341
456
  self.assertEqual(zp.shape, scale.shape)
342
457
  max_q = 2**num_bits / 2 - 1
@@ -364,7 +479,12 @@ class TensorUtilsTest(parameterized.TestCase):
364
479
  max_val = np.array([[5.0]])
365
480
  clipping_values = np.array([4.0])
366
481
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
367
- min_val, max_val, num_bits, symmetric, clipping_values
482
+ min_val,
483
+ max_val,
484
+ num_bits,
485
+ symmetric,
486
+ qtyping.QuantGranularity.TENSORWISE,
487
+ clipping_values,
368
488
  )
369
489
  expected_scale = clipping_values / quantized_bound
370
490
 
@@ -41,6 +41,7 @@ _DRQ_OR_WEIGHT_ONLY_OPS = frozenset([
41
41
 
42
42
  _SUPPORTED_SUBCHANNEL_OPS = frozenset([
43
43
  _TFLOpName.FULLY_CONNECTED,
44
+ _TFLOpName.EMBEDDING_LOOKUP,
44
45
  ])
45
46
 
46
47
 
@@ -50,8 +51,9 @@ def check_subchannel_config(
50
51
  """Checks the op quantization config for subchannel quantization."""
51
52
  if (
52
53
  op_quant_config.weight_tensor_config is not None
53
- and op_quant_config.weight_tensor_config.granularity
54
- == qtyping.QuantGranularity.BLOCKWISE
54
+ and uniform_quantize_tensor.is_blockwise(
55
+ op_quant_config.weight_tensor_config.granularity
56
+ )
55
57
  ):
56
58
  if op_name not in _SUPPORTED_SUBCHANNEL_OPS:
57
59
  raise ValueError(f"Unsupported op for blockwise quantization: {op_name}.")
@@ -65,10 +67,6 @@ def check_subchannel_config(
65
67
  "Blockwise quantization does not support for asymmetric weight"
66
68
  " quantization."
67
69
  )
68
- if op_quant_config.weight_tensor_config.block_size <= 0:
69
- raise ValueError(
70
- "Blockwise quantization must have a non-zero block size."
71
- )
72
70
 
73
71
 
74
72
  def check_if_valid_op_config(
@@ -259,6 +257,60 @@ def _get_single_tensor_params(
259
257
  )
260
258
 
261
259
 
260
+ def _materialize_tensors_with_quantized_data_update(
261
+ op_tensor_params: list[qtyping.TensorTransformationParams],
262
+ tensors: Sequence[Any],
263
+ quant_params: Optional[qtyping.UniformQuantParams],
264
+ is_inbounding_tensor: bool,
265
+ op_info: qtyping.OpInfo,
266
+ graph_info: qtyping.GraphInfo,
267
+ tensor_name_to_qsv: dict[str, Any],
268
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
269
+ ) -> None:
270
+ """Materialize a list of tensors with `quantized_data` updated when needed.
271
+
272
+ Args:
273
+ op_tensor_params: Tensor transformation parameters for the op. Will be
274
+ modified to include new tensor parameters.
275
+ tensors: Tensors to be materialized.
276
+ quant_params: The quantization parameters to be used for materialization.
277
+ is_inbounding_tensor: Whether the tensor is an inbounding tensor for the op.
278
+ op_info: Aggregated information about the op (e.g., quantization config).
279
+ graph_info: Graph information needed to perform quantization for the op.
280
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
281
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
282
+ tensor.
283
+ """
284
+ if quant_params is not None and quant_params.quantized_data is not None:
285
+ quant_params = dataclasses.replace(quant_params, quantized_data=None)
286
+
287
+ for tensor in tensors:
288
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
289
+ tensor, graph_info.buffers
290
+ )
291
+ if quant_params is None or tensor_data is None:
292
+ tensor_quant_params = quant_params
293
+ else:
294
+ # Constant tensors require updating `quantized_data`.
295
+ quantized_data = uniform_quantize_tensor.uniform_quantize(
296
+ tensor_data, quant_params
297
+ )
298
+ tensor_quant_params = dataclasses.replace(
299
+ quant_params,
300
+ quantized_data=quantized_data,
301
+ )
302
+ _materialize_op_tensors(
303
+ op_tensor_params,
304
+ [tensor],
305
+ is_inbounding_tensor=is_inbounding_tensor,
306
+ op_info=op_info,
307
+ graph_info=graph_info,
308
+ tensor_name_to_qsv=tensor_name_to_qsv,
309
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
310
+ quant_params=tensor_quant_params,
311
+ )
312
+
313
+
262
314
  def _materialize_standard_op_with_same_as_input_scale(
263
315
  input_tensors: Sequence[Any],
264
316
  output_tensors: Sequence[Any],
@@ -294,23 +346,48 @@ def _materialize_standard_op_with_same_as_input_scale(
294
346
  )
295
347
  op_tensor_params.append(input_tensor_params)
296
348
  # Use input quantization params for all output tensors.
297
- _materialize_op_tensors(
349
+ input_quant_params = input_tensor_params.consumers[0].parameters
350
+ if not isinstance(input_quant_params, qtyping.UniformQuantParams):
351
+ raise ValueError(
352
+ "_materialize_standard_op_with_same_as_input_scale only supports"
353
+ f" UniformQuantParams. For tensor {input_tensor_params.tensor_name},"
354
+ f" got {type(input_quant_params)}"
355
+ )
356
+ _materialize_tensors_with_quantized_data_update(
298
357
  op_tensor_params,
299
358
  output_tensors,
359
+ input_quant_params,
300
360
  is_inbounding_tensor=False,
301
361
  op_info=op_info,
302
362
  graph_info=graph_info,
303
363
  tensor_name_to_qsv=tensor_name_to_qsv,
304
364
  get_tensor_quant_params_fn=get_tensor_quant_params_fn,
305
- quant_params=input_tensor_params.consumers[0].parameters,
306
365
  )
366
+
307
367
  # Change output qsv to be the same as input qsv. This is safe since TFL
308
368
  # subgraph is acyclic.
309
- input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name]
310
- for output_tensor in output_tensors:
311
- tensor_name_to_qsv[tfl_flatbuffer_utils.get_tensor_name(output_tensor)] = (
312
- input_tensor_qsv
369
+ input_tensor_qsv = tensor_name_to_qsv.get(
370
+ input_tensor_params.tensor_name, None
371
+ )
372
+ if input_tensor_qsv is None:
373
+ input_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
374
+ input_tensors[0], graph_info.buffers
313
375
  )
376
+ # If the input tensor is a constant tensor without qsv, compute qsv from
377
+ # its quant params.
378
+ if input_tensor_data is None:
379
+ # If the only input to an op that needs to match input to
380
+ # output has no qsv and is not a constant tensor, then this is an error.
381
+ raise ValueError(
382
+ "Input tensor qsv is None for tensor"
383
+ f" {input_tensor_params.tensor_name}."
384
+ )
385
+ min_val, max_val = _get_min_max_from_quant_params(input_quant_params)
386
+ input_tensor_qsv = {"min": min_val, "max": max_val}
387
+ for output_tensor in output_tensors:
388
+ tensor_name_to_qsv[
389
+ tfl_flatbuffer_utils.get_tensor_name(output_tensor)
390
+ ] = input_tensor_qsv
314
391
 
315
392
  return op_tensor_params
316
393
 
@@ -350,19 +427,26 @@ def _materialize_standard_op_with_same_as_output_scale(
350
427
  )
351
428
  # Use output quantization params for all input tensors.
352
429
  if output_tensor_params.producer is None:
353
- quant_params = None
430
+ output_quant_params = None
354
431
  else:
355
- quant_params = output_tensor_params.producer.parameters
356
- _materialize_op_tensors(
432
+ output_quant_params = output_tensor_params.producer.parameters
433
+ if not isinstance(output_quant_params, qtyping.UniformQuantParams):
434
+ raise ValueError(
435
+ "_materialize_standard_op_with_same_as_output_scale only supports"
436
+ f" UniformQuantParams. For tensor {output_tensor_params.tensor_name},"
437
+ f" got {type(output_quant_params)}"
438
+ )
439
+ _materialize_tensors_with_quantized_data_update(
357
440
  op_tensor_params,
358
441
  input_tensors,
442
+ output_quant_params,
359
443
  is_inbounding_tensor=True,
360
444
  op_info=op_info,
361
445
  graph_info=graph_info,
362
446
  tensor_name_to_qsv=tensor_name_to_qsv,
363
447
  get_tensor_quant_params_fn=get_tensor_quant_params_fn,
364
- quant_params=quant_params,
365
448
  )
449
+
366
450
  op_tensor_params.append(output_tensor_params)
367
451
 
368
452
  return op_tensor_params
@@ -627,6 +711,26 @@ def _add_non_match_tensors_to_ignored_lists(
627
711
  return inputs_to_ignore, outputs_to_ignore
628
712
 
629
713
 
714
+ def _get_min_max_from_quant_params(
715
+ quant_params: qtyping.UniformQuantParams,
716
+ ) -> tuple[np.ndarray, np.ndarray]:
717
+ """Recalculate min/max from tensor quantization params."""
718
+ q_min, q_max = uniform_quantize_tensor.get_quantized_range(
719
+ _IntType(quant_params.num_bits, True)
720
+ )
721
+ float_min = uniform_quantize_tensor.uniform_dequantize(
722
+ np.array(q_min), quant_params
723
+ )
724
+ float_max = uniform_quantize_tensor.uniform_dequantize(
725
+ np.array(q_max), quant_params
726
+ )
727
+ # We use qmax values to compute scale for symmetric quantization (see
728
+ # uniform_quantize_tensor.tensor_zp_scale_from_min_max).
729
+ if quant_params.symmetric:
730
+ float_min = -float_max
731
+ return float_min, float_max
732
+
733
+
630
734
  def materialize_standard_op(
631
735
  op_info: qtyping.OpInfo,
632
736
  graph_info: qtyping.GraphInfo,
@@ -793,8 +897,6 @@ def materialize_op_with_output_activation_constraint(
793
897
  output_tensor_params.producer = op_tensor_params
794
898
  # Update the tensor_name_to_qsv map using the output activation constraints.
795
899
  min_val, max_val = _get_min_max_from_quant_params(
796
- activation_num_bits,
797
- activation_tensor_config.symmetric,
798
900
  fixed_quant_params,
799
901
  )
800
902
  tensor_name_to_qsv[output_tensor_params.tensor_name]["min"] = min_val
@@ -841,13 +943,6 @@ def get_tensor_transformations(
841
943
  transformations = [_QuantTransformation.QUANTIZE_TENSOR]
842
944
  else:
843
945
  transformations = [_QuantTransformation.NO_QUANTIZE]
844
- elif (
845
- op_quant_config.weight_tensor_config is not None
846
- and op_quant_config.weight_tensor_config.granularity
847
- == qtyping.QuantGranularity.BLOCKWISE
848
- and is_constant
849
- ):
850
- transformations = [_QuantTransformation.EMULATED_SUBCHANNEL]
851
946
  # Check if WEIGHT_ONLY.
852
947
  elif (
853
948
  op_quant_config.compute_precision == qtyping.ComputePrecision.FLOAT
@@ -905,23 +1000,36 @@ def get_tensor_transformation_params(
905
1000
  )
906
1001
 
907
1002
 
908
- def get_weight_quantized_dim(op_info: qtyping.OpInfo, tensor_data: np.ndarray):
1003
+ def get_weight_quantized_dim(
1004
+ op_info: qtyping.OpInfo,
1005
+ tensor_data: np.ndarray,
1006
+ granularity: qtyping.QuantGranularity,
1007
+ ):
909
1008
  """Get the quantized dimension for the weight tensor.
910
1009
 
911
1010
  Args:
912
1011
  op_info: Aggregated information about the op (e.g., quantization config).
913
1012
  tensor_data: The weight tensor data.
1013
+ granularity: The granularity of the weight tensor.
914
1014
 
915
1015
  Returns:
916
1016
  The quantized dimension for the weight tensor.
917
1017
  """
918
- if op_info.op_name == _TFLOpName.BATCH_MATMUL:
919
- quantized_dim = get_bmm_weight_quantized_dim(
920
- tensor_data, adj_y=op_info.op.builtinOptions.adjY
921
- )
922
- else:
923
- quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
924
- op_info.op_name, None
1018
+ quantized_dim = None
1019
+ if granularity == qtyping.QuantGranularity.CHANNELWISE:
1020
+ if op_info.op_name == _TFLOpName.BATCH_MATMUL:
1021
+ quantized_dim = get_bmm_weight_quantized_dim(
1022
+ tensor_data, adj_y=op_info.op.builtinOptions.adjY
1023
+ )
1024
+ else:
1025
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
1026
+ op_info.op_name, None
1027
+ )
1028
+ elif uniform_quantize_tensor.is_blockwise(granularity):
1029
+ quantized_dim = (
1030
+ tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
1031
+ op_info.op_name
1032
+ ]
925
1033
  )
926
1034
  return quantized_dim
927
1035
 
@@ -951,23 +1059,4 @@ def get_bmm_weight_quantized_dim(
951
1059
  return rank - 1
952
1060
 
953
1061
 
954
- def _get_min_max_from_quant_params(
955
- num_bits: int,
956
- symmetric: bool,
957
- tensor_params: qtyping.UniformQuantParams,
958
- ) -> tuple[float, float]:
959
- """Recalculate min/max from tensor quantization params."""
960
- q_min, q_max = uniform_quantize_tensor.get_quantized_range(
961
- _IntType(num_bits, True)
962
- )
963
- float_min = uniform_quantize_tensor.uniform_dequantize(
964
- np.array(q_min), tensor_params
965
- )
966
- float_max = uniform_quantize_tensor.uniform_dequantize(
967
- np.array(q_max), tensor_params
968
- )
969
- # We use qmax values to compute scale for symmetric quantization (see
970
- # uniform_quantize_tensor.tensor_zp_scale_from_min_max).
971
- if symmetric:
972
- float_min = -float_max
973
- return (float_min, float_max)
1062
+
@@ -46,11 +46,6 @@ class Calibrator:
46
46
  ):
47
47
  self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
48
48
 
49
- if not tfl_flatbuffer_utils.is_float_model(self._flatbuffer_model):
50
- raise ValueError(
51
- "The input model for calibration is not a float model. Please check"
52
- " the model (e.g., if it is already quantized)."
53
- )
54
49
  self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
55
50
  float_tflite, use_xnnpack=True, num_threads=num_threads
56
51
  )
@@ -98,9 +93,7 @@ class Calibrator:
98
93
  qsv_update_func: The function to update the QSVs.
99
94
  """
100
95
  op_codes = self._flatbuffer_model.operatorCodes
101
- if not self._model_qsvs:
102
- self._initialize_model_qsvs(model_recipe_manager)
103
- else:
96
+ if self._model_qsvs:
104
97
  logging.warning(
105
98
  "Calibrator contains non-empty model qsvs, and the current"
106
99
  " calibration process will start on top of this state (i.e., update"
@@ -140,10 +133,15 @@ class Calibrator:
140
133
  graph_info = qtyping.GraphInfo(
141
134
  subgraph.tensors, self._flatbuffer_model.buffers
142
135
  )
143
- # Add input/output operators to the subgraph.
144
- subgraph.operators += (
145
- tfl_flatbuffer_utils.get_subgraph_input_output_operators(subgraph)
146
- )
136
+ # Add input/output operators if they are not in the subgraph.
137
+ if not any(
138
+ isinstance(op, qtyping.IOOperator) for op in subgraph.operators
139
+ ):
140
+ subgraph.operators += (
141
+ tfl_flatbuffer_utils.get_subgraph_input_output_operators(
142
+ subgraph
143
+ )
144
+ )
147
145
  for op in subgraph.operators:
148
146
  if isinstance(op, qtyping.IOOperator):
149
147
  op_key = op.op_key
@@ -160,7 +158,7 @@ class Calibrator:
160
158
  )
161
159
  if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
162
160
  continue
163
- if policy.is_conditionally_unquantized(op):
161
+ if policy.is_non_quantizable_composite_op(op):
164
162
  continue
165
163
 
166
164
  # Step2.2: query algorithm_manager to get/call the related
@@ -258,50 +256,3 @@ class Calibrator:
258
256
  output_tensor = subgraph_tensors[output_tensor_idx]
259
257
  scope += tfl_flatbuffer_utils.get_tensor_name(output_tensor)
260
258
  return scope
261
-
262
- # TODO: b/354224138 - Remove code duplication between calibrate and
263
- # _initialize_model_qsvs.
264
- def _initialize_model_qsvs(
265
- self, model_recipe_manager: recipe_manager.RecipeManager
266
- ) -> None:
267
- """Initialize the model qsvs.
268
-
269
- Args:
270
- model_recipe_manager: A RecipeManager object that contains the
271
- quantization recipe.
272
- """
273
- op_codes = self._flatbuffer_model.operatorCodes
274
- for subgraph in self._flatbuffer_model.subgraphs:
275
- graph_info = qtyping.GraphInfo(
276
- subgraph.tensors, self._flatbuffer_model.buffers
277
- )
278
- for subgraph_op_id, op in enumerate(subgraph.operators):
279
- op_code = op_codes[op.opcodeIndex].builtinCode
280
- if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
281
- continue
282
- op_key = tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME[op_code]
283
- # Step1: query the quantization_recipe to get op quantization
284
- # settings.
285
- op_scope = self._get_op_scope(op, subgraph.tensors)
286
- algorithm_name, op_quant_config = (
287
- model_recipe_manager.get_quantization_configs(op_key, op_scope)
288
- )
289
- if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
290
- continue
291
- # Step2: query algorithm_manager to get/call the related qsv init
292
- # function.
293
- qsv_init_func = algorithm_manager.get_init_qsv_func(
294
- algorithm_name, op_key
295
- )
296
- op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
297
- # Ignore the input tensors where any dimension of the shape is 0.
298
- inputs_to_ignore = [
299
- opr_idx
300
- for opr_idx, tensor_idx in enumerate(op.inputs)
301
- if not np.all(graph_info.subgraph_tensors[tensor_idx].shape)
302
- ]
303
- op_qsvs = qsv_init_func(op_info, graph_info, inputs_to_ignore)
304
- # Step3: initialize tensor qsvs.
305
- for tensor_name, qsv in op_qsvs.items():
306
- if tensor_name not in self._model_qsvs:
307
- self._model_qsvs[tensor_name] = qsv