ai-edge-quantizer-nightly 0.0.1.dev20250211__py3-none-any.whl → 0.0.1.dev20250212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,7 +19,6 @@ from absl.testing import parameterized
19
19
  import numpy as np
20
20
 
21
21
  from tensorflow.python.platform import googletest
22
- from ai_edge_quantizer import default_policy
23
22
  from ai_edge_quantizer import qtyping
24
23
  from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
25
24
  from ai_edge_quantizer.utils import test_utils
@@ -158,27 +157,6 @@ class NaiveMinMaxQuantizeTest(parameterized.TestCase):
158
157
  self.assertNotIn("arith.constant1", op_qsvs)
159
158
  self.assertNotIn("arith.constant2", op_qsvs)
160
159
 
161
- def test_check_op_quantization_config_with_negative_min_weight_elements_raises_error(
162
- self,
163
- ):
164
- op_quant_config = qtyping.OpQuantizationConfig(
165
- weight_tensor_config=_TensorQuantConfig(
166
- num_bits=8,
167
- granularity=qtyping.QuantGranularity.CHANNELWISE,
168
- ),
169
- compute_precision=qtyping.ComputePrecision.INTEGER, # DRQ.
170
- min_weight_elements=-1,
171
- )
172
- with self.assertRaisesWithPredicateMatch(
173
- ValueError,
174
- lambda err: "min_weight_elements must be non-negative" in str(err),
175
- ):
176
- naive_min_max_quantize.check_op_quantization_config(
177
- _TFLOpName.FULLY_CONNECTED,
178
- op_quant_config,
179
- default_policy.DEFAULT_CONFIG_CHECK_POLICY,
180
- )
181
-
182
160
 
183
161
  if __name__ == "__main__":
184
162
  googletest.main()
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Utils for min/max based quantization."""
16
+ """Common utils for uniform quantization algorithms."""
17
17
 
18
18
  from collections.abc import Sequence
19
19
  import dataclasses
@@ -29,16 +29,8 @@ _TFLOpName = qtyping.TFLOperationName
29
29
  _QuantTransformation = qtyping.QuantTransformation
30
30
  _IntType = uniform_quantize_tensor.IntType
31
31
 
32
- _SUPPORTED_WEIGHT_ONLY_OPS = frozenset([
33
- _TFLOpName.FULLY_CONNECTED,
34
- _TFLOpName.CONV_2D,
35
- _TFLOpName.BATCH_MATMUL,
36
- _TFLOpName.EMBEDDING_LOOKUP,
37
- _TFLOpName.DEPTHWISE_CONV_2D,
38
- _TFLOpName.CONV_2D_TRANSPOSE,
39
- ])
40
32
 
41
- _SUPPORTED_DRQ_OPS = frozenset([
33
+ _DRQ_OR_WEIGHT_ONLY_OPS = frozenset([
42
34
  _TFLOpName.FULLY_CONNECTED,
43
35
  _TFLOpName.CONV_2D,
44
36
  _TFLOpName.BATCH_MATMUL,
@@ -46,6 +38,7 @@ _SUPPORTED_DRQ_OPS = frozenset([
46
38
  _TFLOpName.DEPTHWISE_CONV_2D,
47
39
  _TFLOpName.CONV_2D_TRANSPOSE,
48
40
  ])
41
+
49
42
  _SUPPORTED_SUBCHANNEL_OPS = frozenset([
50
43
  _TFLOpName.FULLY_CONNECTED,
51
44
  ])
@@ -139,73 +132,13 @@ class OpQuantConstraint(enum.Enum):
139
132
  SAME_AS_OUTPUT_SCALE = 2
140
133
 
141
134
 
142
- def init_tensor_min_max(
143
- tensor: Any,
144
- graph_info: qtyping.GraphInfo,
145
- op_info: qtyping.OpInfo,
146
- ):
147
- """Initialize the min/max for a tensor."""
148
- tensor_data = tfl_flatbuffer_utils.get_tensor_data(tensor, graph_info.buffers)
149
- # Initial values for non-constant tensors.
150
- if tensor_data is None:
151
- return {}
152
- # Real min/max for constant tensors.
153
- else:
154
- quantized_dim = None
155
- if (
156
- op_info.op_quant_config.weight_tensor_config is not None
157
- and op_info.op_quant_config.weight_tensor_config.granularity
158
- == qtyping.QuantGranularity.BLOCKWISE
159
- ):
160
- # TODO(b/346612503): emulate subchannel only supports fully connected,
161
- # will skip special handling. Once we have a spec, we can change this.
162
- block_size = op_info.op_quant_config.weight_tensor_config.block_size
163
- # assuming tensor is 2D, which is correct for FULLY_CONNECTED
164
- transposed_tensor_data = np.transpose(tensor_data, (1, 0))
165
- if transposed_tensor_data.shape[0] % block_size:
166
- raise ValueError(
167
- f"Block size {block_size} does not divide channel dimension"
168
- f" {transposed_tensor_data.shape[0]}."
169
- )
170
- reshaped_tensor_data = np.reshape(
171
- transposed_tensor_data,
172
- (
173
- 1,
174
- int(transposed_tensor_data.shape[0] / block_size),
175
- block_size,
176
- transposed_tensor_data.shape[1],
177
- ),
178
- )
179
- return {
180
- "min": np.min(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
181
- "max": np.max(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
182
- }
183
- if (
184
- op_info.op_quant_config.weight_tensor_config is not None
185
- and op_info.op_quant_config.weight_tensor_config.granularity
186
- == qtyping.QuantGranularity.CHANNELWISE
187
- ):
188
- if op_info.op_name == _TFLOpName.BATCH_MATMUL:
189
- quantized_dim = _get_bmm_weight_quantized_dim(
190
- tensor_data, adj_y=op_info.op.builtinOptions.adjY
191
- )
192
- else:
193
- quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
194
- op_info.op_name, None
195
- )
196
- reduce_dims = _get_reduce_dims(quantized_dim, tensor.shape)
197
- return {
198
- "min": np.min(tensor_data, axis=reduce_dims, keepdims=True),
199
- "max": np.max(tensor_data, axis=reduce_dims, keepdims=True),
200
- }
201
-
202
-
203
135
  def _get_tensor_transformation_params_wrapper(
204
136
  tensor: Any,
205
137
  is_inbounding_tensor: bool,
206
138
  op_info: qtyping.OpInfo,
207
139
  graph_info: qtyping.GraphInfo,
208
140
  tensor_name_to_qsv: dict[str, Any],
141
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
209
142
  quant_params=None,
210
143
  ) -> qtyping.TensorTransformationParams:
211
144
  """Util to get tensor transformation params.
@@ -216,6 +149,8 @@ def _get_tensor_transformation_params_wrapper(
216
149
  op_info: Aggregated information about the op (e.g., quantization config).
217
150
  graph_info: Graph information needed to perform quantization for the op.
218
151
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
152
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
153
+ tensor.
219
154
  quant_params: Quantization parameters for the tensor.
220
155
 
221
156
  Returns:
@@ -229,37 +164,15 @@ def _get_tensor_transformation_params_wrapper(
229
164
  tensor_quant_config = op_info.op_quant_config.activation_tensor_config
230
165
  is_constant = tensor_data is not None
231
166
  # Use weight configuration if it is supported.
232
- if is_constant and op_info.op_name in frozenset.union(
233
- _SUPPORTED_WEIGHT_ONLY_OPS, _SUPPORTED_DRQ_OPS
234
- ):
167
+ if is_constant and op_info.op_name in _DRQ_OR_WEIGHT_ONLY_OPS:
235
168
  tensor_quant_config = op_info.op_quant_config.weight_tensor_config
236
169
  # Get quant params.
237
170
  if quant_params is None and tensor_quant_config is not None:
238
- if tensor_name not in tensor_name_to_qsv:
239
- if is_constant:
240
- # We need min/max to calculate quantization parameters, which
241
- # should be collected during the calibration process. However,
242
- # weight-only and DRQ do not require calibration, thus it is
243
- # possible that this information is missing here. In that case we
244
- # collect min/max on the spot.
245
- tensor_min_max = init_tensor_min_max(
246
- tensor,
247
- graph_info,
248
- op_info,
249
- )
250
- else:
251
- raise ValueError(
252
- f"Tensor {tensor_name} not found in tensor_name_to_qsv. Check"
253
- " if the correct calibration results are passed into the"
254
- " ParamsGenerator."
255
- )
256
- else:
257
- tensor_min_max = tensor_name_to_qsv[tensor_name]
258
- quant_params = _get_tensor_quant_params(
171
+ quant_params = get_tensor_quant_params_fn(
259
172
  op_info,
260
- tensor_min_max,
261
173
  tensor_quant_config,
262
- tensor_content=tensor_data,
174
+ tensor_data,
175
+ tensor_name_to_qsv.get(tensor_name),
263
176
  )
264
177
  return get_tensor_transformation_params(
265
178
  tensor_name,
@@ -277,6 +190,7 @@ def _materialize_op_tensors(
277
190
  op_info: qtyping.OpInfo,
278
191
  graph_info: qtyping.GraphInfo,
279
192
  tensor_name_to_qsv: dict[str, Any],
193
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
280
194
  quant_params=None,
281
195
  ) -> None:
282
196
  """Util to materialize op tensors. Appends the results to op_tensor_params.
@@ -289,6 +203,8 @@ def _materialize_op_tensors(
289
203
  op_info: Aggregated information about the op (e.g., quantization config).
290
204
  graph_info: Graph information needed to perform quantization for the op.
291
205
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
206
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
207
+ tensor.
292
208
  quant_params: Quantization parameters for the tensor.
293
209
  """
294
210
  for tensor in op_tensors:
@@ -298,7 +214,8 @@ def _materialize_op_tensors(
298
214
  op_info,
299
215
  graph_info,
300
216
  tensor_name_to_qsv,
301
- quant_params,
217
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
218
+ quant_params=quant_params,
302
219
  )
303
220
  op_tensor_params.append(tensor_params)
304
221
 
@@ -309,6 +226,7 @@ def _get_single_tensor_params(
309
226
  op_info: qtyping.OpInfo,
310
227
  graph_info: qtyping.GraphInfo,
311
228
  tensor_name_to_qsv: dict[str, Any],
229
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
312
230
  ) -> qtyping.TensorTransformationParams:
313
231
  """Util to get single tensor params.
314
232
 
@@ -318,6 +236,8 @@ def _get_single_tensor_params(
318
236
  op_info: Aggregated information about the op (e.g., quantization config).
319
237
  graph_info: Graph information needed to perform quantization for the op.
320
238
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
239
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
240
+ tensor.
321
241
 
322
242
  Returns:
323
243
  Transformation parameters for the tensor.
@@ -336,6 +256,7 @@ def _get_single_tensor_params(
336
256
  op_info,
337
257
  graph_info,
338
258
  tensor_name_to_qsv,
259
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
339
260
  )
340
261
 
341
262
 
@@ -345,6 +266,7 @@ def _materialize_standard_op_with_same_as_input_scale(
345
266
  op_info: qtyping.OpInfo,
346
267
  graph_info: qtyping.GraphInfo,
347
268
  tensor_name_to_qsv: dict[str, Any],
269
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
348
270
  ) -> list[qtyping.TensorTransformationParams]:
349
271
  """Materialize tensors in an op with same as input scale constraint.
350
272
 
@@ -354,6 +276,8 @@ def _materialize_standard_op_with_same_as_input_scale(
354
276
  op_info: Aggregated information about the op (e.g., quantization config).
355
277
  graph_info: Graph information needed to perform quantization for the op.
356
278
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
279
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
280
+ tensor.
357
281
 
358
282
  Returns:
359
283
  Quantization configuration for the tensors associated with the op (e.g.,
@@ -367,6 +291,7 @@ def _materialize_standard_op_with_same_as_input_scale(
367
291
  op_info=op_info,
368
292
  graph_info=graph_info,
369
293
  tensor_name_to_qsv=tensor_name_to_qsv,
294
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
370
295
  )
371
296
  op_tensor_params.append(input_tensor_params)
372
297
  # Use input quantization params for all output tensors.
@@ -377,6 +302,7 @@ def _materialize_standard_op_with_same_as_input_scale(
377
302
  op_info=op_info,
378
303
  graph_info=graph_info,
379
304
  tensor_name_to_qsv=tensor_name_to_qsv,
305
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
380
306
  quant_params=input_tensor_params.consumers[0].parameters,
381
307
  )
382
308
  # Change output qsv to be the same as input qsv. This is safe since TFL
@@ -396,6 +322,7 @@ def _materialize_standard_op_with_same_as_output_scale(
396
322
  op_info: qtyping.OpInfo,
397
323
  graph_info: qtyping.GraphInfo,
398
324
  tensor_name_to_qsv: dict[str, Any],
325
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
399
326
  ) -> list[qtyping.TensorTransformationParams]:
400
327
  """Materialize tensors in an op with same as output scale constraint.
401
328
 
@@ -405,6 +332,8 @@ def _materialize_standard_op_with_same_as_output_scale(
405
332
  op_info: Aggregated information about the op (e.g., quantization config).
406
333
  graph_info: Graph information needed to perform quantization for the op.
407
334
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
335
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
336
+ tensor.
408
337
 
409
338
  Returns:
410
339
  Quantization configuration for the tensors associated with the op (e.g.,
@@ -418,6 +347,7 @@ def _materialize_standard_op_with_same_as_output_scale(
418
347
  op_info=op_info,
419
348
  graph_info=graph_info,
420
349
  tensor_name_to_qsv=tensor_name_to_qsv,
350
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
421
351
  )
422
352
  # Use output quantization params for all input tensors.
423
353
  if output_tensor_params.producer is None:
@@ -431,6 +361,7 @@ def _materialize_standard_op_with_same_as_output_scale(
431
361
  op_info=op_info,
432
362
  graph_info=graph_info,
433
363
  tensor_name_to_qsv=tensor_name_to_qsv,
364
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
434
365
  quant_params=quant_params,
435
366
  )
436
367
  op_tensor_params.append(output_tensor_params)
@@ -444,6 +375,7 @@ def _materialize_standard_op_no_constraint(
444
375
  op_info: qtyping.OpInfo,
445
376
  graph_info: qtyping.GraphInfo,
446
377
  tensor_name_to_qsv: dict[str, Any],
378
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
447
379
  ) -> list[qtyping.TensorTransformationParams]:
448
380
  """Materialize tensors in an op with no constraint.
449
381
 
@@ -453,6 +385,8 @@ def _materialize_standard_op_no_constraint(
453
385
  op_info: Aggregated information about the op (e.g., quantization config).
454
386
  graph_info: Graph information needed to perform quantization for the op.
455
387
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
388
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
389
+ tensor.
456
390
 
457
391
  Returns:
458
392
  Quantization configuration for the tensors associated with the op (e.g.,
@@ -466,6 +400,7 @@ def _materialize_standard_op_no_constraint(
466
400
  op_info=op_info,
467
401
  graph_info=graph_info,
468
402
  tensor_name_to_qsv=tensor_name_to_qsv,
403
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
469
404
  )
470
405
  _materialize_op_tensors(
471
406
  op_tensor_params,
@@ -474,6 +409,7 @@ def _materialize_standard_op_no_constraint(
474
409
  op_info=op_info,
475
410
  graph_info=graph_info,
476
411
  tensor_name_to_qsv=tensor_name_to_qsv,
412
+ get_tensor_quant_params_fn=get_tensor_quant_params_fn,
477
413
  )
478
414
 
479
415
  return op_tensor_params
@@ -696,6 +632,7 @@ def materialize_standard_op(
696
632
  op_info: qtyping.OpInfo,
697
633
  graph_info: qtyping.GraphInfo,
698
634
  tensor_name_to_qsv: dict[str, Any],
635
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
699
636
  constraint: OpQuantConstraint = OpQuantConstraint.NO_CONSTRAIN,
700
637
  inputs_to_ignore: Optional[Sequence[int]] = None,
701
638
  outputs_to_ignore: Optional[Sequence[int]] = None,
@@ -709,6 +646,8 @@ def materialize_standard_op(
709
646
  op_info: Aggregated information about the op (e.g., quantization config).
710
647
  graph_info: Graph information needed to perform quantization for the op.
711
648
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
649
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
650
+ tensor.
712
651
  constraint: The constraint for materializing the op.
713
652
  inputs_to_ignore: Input tensor indices to ignore.
714
653
  outputs_to_ignore: Output tensor indices to ignore.
@@ -747,15 +686,30 @@ def materialize_standard_op(
747
686
  tensor_params = [] # Every tensor is ignored.
748
687
  elif constraint == OpQuantConstraint.SAME_AS_INPUT_SCALE:
749
688
  tensor_params = _materialize_standard_op_with_same_as_input_scale(
750
- input_tensors, output_tensors, op_info, graph_info, tensor_name_to_qsv
689
+ input_tensors,
690
+ output_tensors,
691
+ op_info,
692
+ graph_info,
693
+ tensor_name_to_qsv,
694
+ get_tensor_quant_params_fn,
751
695
  )
752
696
  elif constraint == OpQuantConstraint.SAME_AS_OUTPUT_SCALE:
753
697
  tensor_params = _materialize_standard_op_with_same_as_output_scale(
754
- input_tensors, output_tensors, op_info, graph_info, tensor_name_to_qsv
698
+ input_tensors,
699
+ output_tensors,
700
+ op_info,
701
+ graph_info,
702
+ tensor_name_to_qsv,
703
+ get_tensor_quant_params_fn,
755
704
  )
756
705
  else:
757
706
  tensor_params = _materialize_standard_op_no_constraint(
758
- input_tensors, output_tensors, op_info, graph_info, tensor_name_to_qsv
707
+ input_tensors,
708
+ output_tensors,
709
+ op_info,
710
+ graph_info,
711
+ tensor_name_to_qsv,
712
+ get_tensor_quant_params_fn,
759
713
  )
760
714
 
761
715
  # Materialize ignored tensors.
@@ -781,6 +735,7 @@ def materialize_op_with_output_activation_constraint(
781
735
  graph_info: qtyping.GraphInfo,
782
736
  tensor_name_to_qsv: dict[str, Any],
783
737
  output_activation_constraints: dict[int, qtyping.UniformQuantParams],
738
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
784
739
  ) -> list[qtyping.TensorTransformationParams]:
785
740
  """Materialize tensors in an op with output activation constraint.
786
741
 
@@ -795,6 +750,8 @@ def materialize_op_with_output_activation_constraint(
795
750
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
796
751
  output_activation_constraints: A map of output activation num bits to
797
752
  quantization parameters.
753
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
754
+ tensor.
798
755
 
799
756
  Returns:
800
757
  Quantization configuration for the tensors associated with the op (e.g.,
@@ -812,9 +769,7 @@ def materialize_op_with_output_activation_constraint(
812
769
  )
813
770
 
814
771
  tensor_params = materialize_standard_op(
815
- op_info,
816
- graph_info,
817
- tensor_name_to_qsv,
772
+ op_info, graph_info, tensor_name_to_qsv, get_tensor_quant_params_fn
818
773
  )
819
774
  output_tensor_params = tensor_params[-1]
820
775
 
@@ -951,76 +906,7 @@ def get_tensor_transformation_params(
951
906
  )
952
907
 
953
908
 
954
- def _get_tensor_quant_params(
955
- op_info: qtyping.OpInfo,
956
- tensor_min_max: dict[str, Any],
957
- tensor_quant_config: qtyping.TensorQuantizationConfig,
958
- tensor_content: Optional[np.ndarray] = None,
959
- ) -> qtyping.UniformQuantParams:
960
- """Get the quantization parameters for a tensor.
961
-
962
- Args:
963
- op_info: aggregated information about the op (e.g., quantization config).
964
- tensor_min_max: the min/max of the tensor.
965
- tensor_quant_config: the quantization config for the tensor.
966
- tensor_content: the content of the tensor.
967
-
968
- Returns:
969
- The quantization parameters for the tensor.
970
- """
971
- if "min" not in tensor_min_max or "max" not in tensor_min_max:
972
- raise ValueError(
973
- "min and max must be provided to produce tensor quantization"
974
- " parameters. Check if the correct calibration results are passed into"
975
- " the ParamsGenerator."
976
- )
977
- zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
978
- tensor_min_max["min"],
979
- tensor_min_max["max"],
980
- tensor_quant_config.num_bits,
981
- tensor_quant_config.symmetric,
982
- )
983
- quantized_dim = None
984
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
985
- if op_info.op_name == _TFLOpName.BATCH_MATMUL:
986
- quantized_dim = _get_bmm_weight_quantized_dim(
987
- tensor_content, adj_y=op_info.op.builtinOptions.adjY
988
- )
989
- else:
990
- quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM[
991
- op_info.op_name
992
- ]
993
- quant_params = qtyping.UniformQuantParams(
994
- scale=scale,
995
- zero_point=zp,
996
- num_bits=tensor_quant_config.num_bits,
997
- symmetric=tensor_quant_config.symmetric,
998
- quantized_dimension=quantized_dim,
999
- )
1000
- if tensor_content is None:
1001
- return quant_params
1002
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
1003
- quantized_vars = (
1004
- uniform_quantize_tensor.uniform_quantize_for_emulated_subchannel(
1005
- tensor_content, quant_params, tensor_quant_config.block_size
1006
- )
1007
- )
1008
- else:
1009
- quantized_vars = uniform_quantize_tensor.uniform_quantize(
1010
- tensor_content, quant_params
1011
- )
1012
- # Update with quantized values.
1013
- return qtyping.UniformQuantParams(
1014
- scale=scale,
1015
- zero_point=zp,
1016
- num_bits=tensor_quant_config.num_bits,
1017
- symmetric=tensor_quant_config.symmetric,
1018
- quantized_dimension=quantized_dim,
1019
- quantized_data=quantized_vars,
1020
- )
1021
-
1022
-
1023
- def _get_reduce_dims(
909
+ def get_reduce_dims(
1024
910
  quantized_dim: Optional[int],
1025
911
  tensor_shape: list[int],
1026
912
  ) -> Optional[tuple[int, ...]]:
@@ -1034,7 +920,7 @@ def _get_reduce_dims(
1034
920
  return tuple(reduce_dims)
1035
921
 
1036
922
 
1037
- def _get_bmm_weight_quantized_dim(
923
+ def get_bmm_weight_quantized_dim(
1038
924
  weight_tensor_data: np.ndarray, adj_y: bool
1039
925
  ) -> int:
1040
926
  """Get the quantized dimension for batch matmul."""
@@ -18,7 +18,7 @@ from absl.testing import parameterized
18
18
  from tensorflow.python.platform import googletest
19
19
  from ai_edge_quantizer import default_policy
20
20
  from ai_edge_quantizer import qtyping
21
- from ai_edge_quantizer.algorithms.utils import min_max_quantize_utils
21
+ from ai_edge_quantizer.algorithms.utils import common_utils
22
22
 
23
23
  _ComputePrecision = qtyping.ComputePrecision
24
24
  _QuantTransformation = qtyping.QuantTransformation
@@ -52,7 +52,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
52
52
  compute_precision=compute_precision,
53
53
  explicit_dequantize=explicit_dequantize,
54
54
  )
55
- transformations = min_max_quantize_utils.get_tensor_transformations(
55
+ transformations = common_utils.get_tensor_transformations(
56
56
  op_quant_config, is_inbounding_tensor, is_constant
57
57
  )
58
58
  # Check if WEIGHT_ONLY.
@@ -120,7 +120,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
120
120
  with self.assertRaisesWithPredicateMatch(
121
121
  ValueError, lambda err: error_message in str(err)
122
122
  ):
123
- min_max_quantize_utils.check_if_valid_op_config(
123
+ common_utils.check_if_valid_op_config(
124
124
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
125
125
  )
126
126
 
@@ -146,7 +146,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
146
146
  ),
147
147
  compute_precision=_ComputePrecision.INTEGER, # DRQ.
148
148
  )
149
- min_max_quantize_utils.check_if_valid_op_config(
149
+ common_utils.check_if_valid_op_config(
150
150
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
151
151
  )
152
152
 
@@ -166,7 +166,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
166
166
  with self.assertRaisesWithPredicateMatch(
167
167
  ValueError, lambda err: error_message in str(err)
168
168
  ):
169
- min_max_quantize_utils.check_if_valid_op_config(
169
+ common_utils.check_if_valid_op_config(
170
170
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
171
171
  )
172
172
 
@@ -186,7 +186,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
186
186
  with self.assertRaisesWithPredicateMatch(
187
187
  ValueError, lambda err: error_message in str(err)
188
188
  ):
189
- min_max_quantize_utils.check_if_valid_op_config(
189
+ common_utils.check_if_valid_op_config(
190
190
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
191
191
  )
192
192
 
@@ -207,7 +207,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
207
207
  with self.assertRaisesWithPredicateMatch(
208
208
  ValueError, lambda err: error_message in str(err)
209
209
  ):
210
- min_max_quantize_utils.check_if_valid_op_config(
210
+ common_utils.check_if_valid_op_config(
211
211
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
212
212
  )
213
213
 
@@ -220,7 +220,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
220
220
  compute_precision=_ComputePrecision.INTEGER, # DRQ.
221
221
  min_weight_elements=100,
222
222
  )
223
- min_max_quantize_utils.check_if_valid_op_config(
223
+ common_utils.check_if_valid_op_config(
224
224
  _TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
225
225
  )
226
226
 
@@ -255,7 +255,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
255
255
  ),
256
256
  compute_precision=_ComputePrecision.INTEGER, # SRQ.
257
257
  )
258
- min_max_quantize_utils.check_if_valid_op_config(
258
+ common_utils.check_if_valid_op_config(
259
259
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
260
260
  )
261
261
 
@@ -275,7 +275,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
275
275
  with self.assertRaisesWithPredicateMatch(
276
276
  ValueError, lambda err: error_message in str(err)
277
277
  ):
278
- min_max_quantize_utils.check_if_valid_op_config(
278
+ common_utils.check_if_valid_op_config(
279
279
  _TFLOpName.CUSTOM_OP, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
280
280
  )
281
281
 
@@ -297,7 +297,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
297
297
  with self.assertRaisesWithPredicateMatch(
298
298
  ValueError, lambda err: error_message in str(err)
299
299
  ):
300
- min_max_quantize_utils.check_if_valid_op_config(
300
+ common_utils.check_if_valid_op_config(
301
301
  _TFLOpName.FULLY_CONNECTED,
302
302
  op_quant_config,
303
303
  _DEFAULT_CONFIG_CHECK_POLICY,
@@ -321,7 +321,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
321
321
  with self.assertRaisesWithPredicateMatch(
322
322
  ValueError, lambda err: error_message in str(err)
323
323
  ):
324
- min_max_quantize_utils.check_if_valid_op_config(
324
+ common_utils.check_if_valid_op_config(
325
325
  _TFLOpName.FULLY_CONNECTED,
326
326
  op_quant_config,
327
327
  _DEFAULT_CONFIG_CHECK_POLICY,
@@ -345,7 +345,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
345
345
  with self.assertRaisesWithPredicateMatch(
346
346
  ValueError, lambda err: error_message in str(err)
347
347
  ):
348
- min_max_quantize_utils.check_if_valid_op_config(
348
+ common_utils.check_if_valid_op_config(
349
349
  _TFLOpName.FULLY_CONNECTED,
350
350
  op_quant_config,
351
351
  _DEFAULT_CONFIG_CHECK_POLICY,
@@ -368,7 +368,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
368
368
  with self.assertRaisesWithPredicateMatch(
369
369
  ValueError, lambda err: error_message in str(err)
370
370
  ):
371
- min_max_quantize_utils.check_if_valid_op_config(
371
+ common_utils.check_if_valid_op_config(
372
372
  _TFLOpName.FULLY_CONNECTED,
373
373
  op_quant_config,
374
374
  _DEFAULT_CONFIG_CHECK_POLICY,
@@ -425,7 +425,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
425
425
  compute_precision == _ComputePrecision.INTEGER
426
426
  and op_quant_config.activation_tensor_config is None
427
427
  ):
428
- min_max_quantize_utils.check_if_valid_op_config(
428
+ common_utils.check_if_valid_op_config(
429
429
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
430
430
  )
431
431
  # Check if WEIGHT_ONLY.
@@ -439,7 +439,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
439
439
  compute_precision == _ComputePrecision.INTEGER
440
440
  and op_quant_config.activation_tensor_config is not None
441
441
  ):
442
- min_max_quantize_utils.check_if_valid_op_config(
442
+ common_utils.check_if_valid_op_config(
443
443
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
444
444
  )
445
445
 
@@ -477,11 +477,11 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
477
477
 
478
478
  with self.assertRaises(ValueError):
479
479
  if is_drq:
480
- min_max_quantize_utils.check_if_valid_op_config(
480
+ common_utils.check_if_valid_op_config(
481
481
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
482
482
  )
483
483
  elif not is_drq:
484
- min_max_quantize_utils.check_if_valid_op_config(
484
+ common_utils.check_if_valid_op_config(
485
485
  op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
486
486
  )
487
487
 
@@ -500,11 +500,12 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
500
500
  with self.assertRaisesRegex(
501
501
  ValueError, "only supports ops with a single output tensor"
502
502
  ):
503
- min_max_quantize_utils.materialize_op_with_output_activation_constraint(
503
+ common_utils.materialize_op_with_output_activation_constraint(
504
504
  op_info=mock_op_info,
505
505
  graph_info=qtyping.GraphInfo([], []),
506
506
  tensor_name_to_qsv={},
507
507
  output_activation_constraints={},
508
+ get_tensor_quant_params_fn=lambda *args: [],
508
509
  )
509
510
 
510
511
 
@@ -20,7 +20,7 @@ from collections.abc import MutableMapping
20
20
  import copy
21
21
  import dataclasses
22
22
  import enum
23
- from typing import Any, Optional, Union
23
+ from typing import Any, Optional, Union, Callable
24
24
 
25
25
  import numpy as np
26
26
  from typing_extensions import TypeAlias
@@ -485,3 +485,14 @@ class IOOperator:
485
485
  inputs: list[int]
486
486
  outputs: list[int]
487
487
  op_key: TFLOperationName
488
+
489
+ # The function signature for `get_tensor_quant_params_fn`.
490
+ GetTensorQuantParamsFuncSignature = Callable[
491
+ [
492
+ OpInfo, # op_info
493
+ TensorQuantizationConfig, # tensor_quant_config
494
+ Optional[np.ndarray], # tensor_data
495
+ Optional[dict[str, Any]], # tensor qsv
496
+ ],
497
+ UniformQuantParams,
498
+ ]