ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1067 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utils for min/max based quantization."""
17
+
18
+ from collections.abc import Sequence
19
+ import dataclasses
20
+ import enum
21
+ from typing import Any, Optional
22
+ import numpy as np
23
+ from ai_edge_quantizer import qtyping
24
+ from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
25
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
27
+
28
+ _TFLOpName = qtyping.TFLOperationName
29
+ _QuantTransformation = qtyping.QuantTransformation
30
+ _IntType = uniform_quantize_tensor.IntType
31
+
32
+ _SUPPORTED_WEIGHT_ONLY_OPS = frozenset([
33
+ _TFLOpName.FULLY_CONNECTED,
34
+ _TFLOpName.CONV_2D,
35
+ _TFLOpName.BATCH_MATMUL,
36
+ _TFLOpName.EMBEDDING_LOOKUP,
37
+ _TFLOpName.DEPTHWISE_CONV_2D,
38
+ _TFLOpName.CONV_2D_TRANSPOSE,
39
+ ])
40
+
41
+ _SUPPORTED_DRQ_OPS = frozenset([
42
+ _TFLOpName.FULLY_CONNECTED,
43
+ _TFLOpName.CONV_2D,
44
+ _TFLOpName.BATCH_MATMUL,
45
+ _TFLOpName.EMBEDDING_LOOKUP,
46
+ _TFLOpName.DEPTHWISE_CONV_2D,
47
+ _TFLOpName.CONV_2D_TRANSPOSE,
48
+ ])
49
+ _SUPPORTED_SUBCHANNEL_OPS = frozenset([
50
+ _TFLOpName.FULLY_CONNECTED,
51
+ ])
52
+
53
+
54
+ def check_subchannel_config(
55
+ op_name: _TFLOpName, op_quant_config: qtyping.OpQuantizationConfig
56
+ ):
57
+ """Checks the op quantization config for subchannel quantization."""
58
+ if (
59
+ op_quant_config.weight_tensor_config is not None
60
+ and op_quant_config.weight_tensor_config.granularity
61
+ == qtyping.QuantGranularity.BLOCKWISE
62
+ ):
63
+ if op_name not in _SUPPORTED_SUBCHANNEL_OPS:
64
+ raise ValueError(f"Unsupported op for blockwise quantization: {op_name}.")
65
+ if op_quant_config.activation_tensor_config is not None:
66
+ raise ValueError(
67
+ "Blockwise quantization does not support activation tensor"
68
+ " quantization."
69
+ )
70
+ if not op_quant_config.weight_tensor_config.symmetric:
71
+ raise ValueError(
72
+ "Blockwise quantization does not support for asymmetric weight"
73
+ " quantization."
74
+ )
75
+ if op_quant_config.weight_tensor_config.block_size <= 0:
76
+ raise ValueError(
77
+ "Blockwise quantization must have a non-zero block size."
78
+ )
79
+
80
+
81
+ def check_if_valid_op_config(
82
+ op_name: _TFLOpName,
83
+ op_quant_config: qtyping.OpQuantizationConfig,
84
+ config_check_policy: qtyping.ConfigCheckPolicyDict,
85
+ ) -> None:
86
+ """Check if the op quantization config is valid.
87
+
88
+ Args:
89
+ op_name: The name of the op.
90
+ op_quant_config: The quantization config for the op.
91
+ config_check_policy: The policy to check the op quantization config.
92
+
93
+ Raises:
94
+ ValueError: If the op quantization config is not valid.
95
+ """
96
+
97
+ check_passed = False
98
+ error_msg = ""
99
+ # Check if find op_config in policy config_check_policy.
100
+ if config_check_policy is None:
101
+ error_msg = "No policy was specified at all."
102
+ elif op_name not in config_check_policy.keys():
103
+ error_msg = (
104
+ f"No policy was specified for op: {op_name} with config:"
105
+ f" {op_quant_config}."
106
+ )
107
+ # The config_check_policy contains all possible valid configs, except for
108
+ # variations in the min_weight_elements field (it's set to 0 for all of them).
109
+ # min_weight_elements has to be ignored during policy check here because it
110
+ # can be any non-negative integer, which means we can't list all possible
111
+ # values in the policy.
112
+ elif (
113
+ dataclasses.replace(op_quant_config, min_weight_elements=0)
114
+ not in config_check_policy[op_name]
115
+ ):
116
+ error_msg = (
117
+ f"Quantization config for op: {op_name} with config:"
118
+ f" {op_quant_config} was not found in the policy."
119
+ )
120
+ else:
121
+ check_passed = True
122
+
123
+ if not check_passed:
124
+ raise ValueError(
125
+ f"Unsupported op for {op_quant_config.compute_precision}: {op_name}."
126
+ f" Error: {error_msg}"
127
+ )
128
+
129
+
130
+ class OpQuantConstraint(enum.Enum):
131
+ """Quantization constraint for an op."""
132
+
133
+ NO_CONSTRAIN = 0
134
+ # All tensors in the op have the same scale as the input tensor
135
+ # e.g., transpose/reshape/split.
136
+ SAME_AS_INPUT_SCALE = 1
137
+ # All tensors in the op have the same scale as the output tensor.
138
+ # e.g., concatenate
139
+ SAME_AS_OUTPUT_SCALE = 2
140
+
141
+
142
+ def init_tensor_min_max(
143
+ tensor: Any,
144
+ graph_info: qtyping.GraphInfo,
145
+ op_info: qtyping.OpInfo,
146
+ ):
147
+ """Initialize the min/max for a tensor."""
148
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(tensor, graph_info.buffers)
149
+ # Initial values for non-constant tensors.
150
+ if tensor_data is None:
151
+ return {}
152
+ # Real min/max for constant tensors.
153
+ else:
154
+ quantized_dim = None
155
+ if (
156
+ op_info.op_quant_config.weight_tensor_config is not None
157
+ and op_info.op_quant_config.weight_tensor_config.granularity
158
+ == qtyping.QuantGranularity.BLOCKWISE
159
+ ):
160
+ # TODO(b/346612503): emulate subchannel only supports fully connected,
161
+ # will skip special handling. Once we have a spec, we can change this.
162
+ block_size = op_info.op_quant_config.weight_tensor_config.block_size
163
+ # assuming tensor is 2D, which is correct for FULLY_CONNECTED
164
+ transposed_tensor_data = np.transpose(tensor_data, (1, 0))
165
+ if transposed_tensor_data.shape[0] % block_size:
166
+ raise ValueError(
167
+ f"Block size {block_size} does not divide channel dimension"
168
+ f" {transposed_tensor_data.shape[0]}."
169
+ )
170
+ reshaped_tensor_data = np.reshape(
171
+ transposed_tensor_data,
172
+ (
173
+ 1,
174
+ int(transposed_tensor_data.shape[0] / block_size),
175
+ block_size,
176
+ transposed_tensor_data.shape[1],
177
+ ),
178
+ )
179
+ return {
180
+ "min": np.min(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
181
+ "max": np.max(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
182
+ }
183
+ if (
184
+ op_info.op_quant_config.weight_tensor_config is not None
185
+ and op_info.op_quant_config.weight_tensor_config.granularity
186
+ == qtyping.QuantGranularity.CHANNELWISE
187
+ ):
188
+ if op_info.op_name == _TFLOpName.BATCH_MATMUL:
189
+ quantized_dim = _get_bmm_weight_quantized_dim(
190
+ tensor_data, adj_y=op_info.op.builtinOptions.adjY
191
+ )
192
+ else:
193
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
194
+ op_info.op_name, None
195
+ )
196
+ reduce_dims = _get_reduce_dims(quantized_dim, tensor.shape)
197
+ return {
198
+ "min": np.min(tensor_data, axis=reduce_dims, keepdims=True),
199
+ "max": np.max(tensor_data, axis=reduce_dims, keepdims=True),
200
+ }
201
+
202
+
203
+ def _get_tensor_transformation_params_wrapper(
204
+ tensor: Any,
205
+ is_inbounding_tensor: bool,
206
+ op_info: qtyping.OpInfo,
207
+ graph_info: qtyping.GraphInfo,
208
+ tensor_name_to_qsv: dict[str, Any],
209
+ quant_params=None,
210
+ ) -> qtyping.TensorTransformationParams:
211
+ """Util to get tensor transformation params.
212
+
213
+ Args:
214
+ tensor: Tensor to be quantized.
215
+ is_inbounding_tensor: Whether the tensor is an inbounding tensor for the op.
216
+ op_info: Aggregated information about the op (e.g., quantization config).
217
+ graph_info: Graph information needed to perform quantization for the op.
218
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
219
+ quant_params: Quantization parameters for the tensor.
220
+
221
+ Returns:
222
+ Transformation parameters for the tensor.
223
+
224
+ Raises:
225
+ ValueError: If the tensor is not found in tensor_name_to_qsv.
226
+ """
227
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
228
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(tensor, graph_info.buffers)
229
+ tensor_quant_config = op_info.op_quant_config.activation_tensor_config
230
+ is_constant = tensor_data is not None
231
+ # Use weight configuration if it is supported.
232
+ if is_constant and op_info.op_name in frozenset.union(
233
+ _SUPPORTED_WEIGHT_ONLY_OPS, _SUPPORTED_DRQ_OPS
234
+ ):
235
+ tensor_quant_config = op_info.op_quant_config.weight_tensor_config
236
+ # Get quant params.
237
+ if quant_params is None and tensor_quant_config is not None:
238
+ if tensor_name not in tensor_name_to_qsv:
239
+ if is_constant:
240
+ # We need min/max to calculate quantization parameters, which
241
+ # should be collected during the calibration process. However,
242
+ # weight-only and DRQ do not require calibration, thus it is
243
+ # possible that this information is missing here. In that case we
244
+ # collect min/max on the spot.
245
+ tensor_min_max = init_tensor_min_max(
246
+ tensor,
247
+ graph_info,
248
+ op_info,
249
+ )
250
+ else:
251
+ raise ValueError(
252
+ f"Tensor {tensor_name} not found in tensor_name_to_qsv. Check"
253
+ " if the correct calibration results are passed into the"
254
+ " ParamsGenerator."
255
+ )
256
+ else:
257
+ tensor_min_max = tensor_name_to_qsv[tensor_name]
258
+ quant_params = _get_tensor_quant_params(
259
+ op_info,
260
+ tensor_min_max,
261
+ tensor_quant_config,
262
+ tensor_content=tensor_data,
263
+ )
264
+ return get_tensor_transformation_params(
265
+ tensor_name,
266
+ op_info,
267
+ is_inbounding_tensor,
268
+ quant_params,
269
+ is_constant,
270
+ )
271
+
272
+
273
+ def _materialize_op_tensors(
274
+ op_tensor_params: list[qtyping.TensorTransformationParams],
275
+ op_tensors: Sequence[Any],
276
+ is_inbounding_tensor: bool,
277
+ op_info: qtyping.OpInfo,
278
+ graph_info: qtyping.GraphInfo,
279
+ tensor_name_to_qsv: dict[str, Any],
280
+ quant_params=None,
281
+ ) -> None:
282
+ """Util to materialize op tensors. Appends the results to op_tensor_params.
283
+
284
+ Args:
285
+ op_tensor_params: Tensor transformation parameters for the op. Will be
286
+ modified to include new tensor parameters.
287
+ op_tensors: Tensors associated with the op.
288
+ is_inbounding_tensor: Whether the tensor is an inbounding tensor for the op.
289
+ op_info: Aggregated information about the op (e.g., quantization config).
290
+ graph_info: Graph information needed to perform quantization for the op.
291
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
292
+ quant_params: Quantization parameters for the tensor.
293
+ """
294
+ for tensor in op_tensors:
295
+ tensor_params = _get_tensor_transformation_params_wrapper(
296
+ tensor,
297
+ is_inbounding_tensor,
298
+ op_info,
299
+ graph_info,
300
+ tensor_name_to_qsv,
301
+ quant_params,
302
+ )
303
+ op_tensor_params.append(tensor_params)
304
+
305
+
306
+ def _get_single_tensor_params(
307
+ tensors: Sequence[Any],
308
+ is_inbounding_tensor: bool,
309
+ op_info: qtyping.OpInfo,
310
+ graph_info: qtyping.GraphInfo,
311
+ tensor_name_to_qsv: dict[str, Any],
312
+ ) -> qtyping.TensorTransformationParams:
313
+ """Util to get single tensor params.
314
+
315
+ Args:
316
+ tensors: A list of a single tensor.
317
+ is_inbounding_tensor: Whether the tensor is an inbounding tensor for the op.
318
+ op_info: Aggregated information about the op (e.g., quantization config).
319
+ graph_info: Graph information needed to perform quantization for the op.
320
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
321
+
322
+ Returns:
323
+ Transformation parameters for the tensor.
324
+
325
+ Raises:
326
+ ValueError: If the tensor list is not of size 1.
327
+ """
328
+ if len(tensors) != 1:
329
+ raise ValueError(
330
+ "Trying to get a single tensor params with a list of multiple tensor"
331
+ f" with size {len(tensors)}."
332
+ )
333
+ return _get_tensor_transformation_params_wrapper(
334
+ tensors[0],
335
+ is_inbounding_tensor,
336
+ op_info,
337
+ graph_info,
338
+ tensor_name_to_qsv,
339
+ )
340
+
341
+
342
+ def _materialize_standard_op_with_same_as_input_scale(
343
+ input_tensors: Sequence[Any],
344
+ output_tensors: Sequence[Any],
345
+ op_info: qtyping.OpInfo,
346
+ graph_info: qtyping.GraphInfo,
347
+ tensor_name_to_qsv: dict[str, Any],
348
+ ) -> list[qtyping.TensorTransformationParams]:
349
+ """Materialize tensors in an op with same as input scale constraint.
350
+
351
+ Args:
352
+ input_tensors: Input tensors for the op.
353
+ output_tensors: Output tensors for the op.
354
+ op_info: Aggregated information about the op (e.g., quantization config).
355
+ graph_info: Graph information needed to perform quantization for the op.
356
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
357
+
358
+ Returns:
359
+ Quantization configuration for the tensors associated with the op (e.g.,
360
+ weights, bias).
361
+ """
362
+ op_tensor_params = []
363
+ # Must be a single input to avoid ambiguity.
364
+ input_tensor_params = _get_single_tensor_params(
365
+ input_tensors,
366
+ is_inbounding_tensor=True,
367
+ op_info=op_info,
368
+ graph_info=graph_info,
369
+ tensor_name_to_qsv=tensor_name_to_qsv,
370
+ )
371
+ op_tensor_params.append(input_tensor_params)
372
+ # Use input quantization params for all output tensors.
373
+ _materialize_op_tensors(
374
+ op_tensor_params,
375
+ output_tensors,
376
+ is_inbounding_tensor=False,
377
+ op_info=op_info,
378
+ graph_info=graph_info,
379
+ tensor_name_to_qsv=tensor_name_to_qsv,
380
+ quant_params=input_tensor_params.consumers[0].parameters,
381
+ )
382
+ # Change output qsv to be the same as input qsv. This is safe since TFL
383
+ # subgraph is acyclic.
384
+ input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name]
385
+ for output_tensor in output_tensors:
386
+ tensor_name_to_qsv[tfl_flatbuffer_utils.get_tensor_name(output_tensor)] = (
387
+ input_tensor_qsv
388
+ )
389
+
390
+ return op_tensor_params
391
+
392
+
393
+ def _materialize_standard_op_with_same_as_output_scale(
394
+ input_tensors: Sequence[Any],
395
+ output_tensors: Sequence[Any],
396
+ op_info: qtyping.OpInfo,
397
+ graph_info: qtyping.GraphInfo,
398
+ tensor_name_to_qsv: dict[str, Any],
399
+ ) -> list[qtyping.TensorTransformationParams]:
400
+ """Materialize tensors in an op with same as output scale constraint.
401
+
402
+ Args:
403
+ input_tensors: Input tensors for the op.
404
+ output_tensors: Output tensors for the op.
405
+ op_info: Aggregated information about the op (e.g., quantization config).
406
+ graph_info: Graph information needed to perform quantization for the op.
407
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
408
+
409
+ Returns:
410
+ Quantization configuration for the tensors associated with the op (e.g.,
411
+ weights, bias).
412
+ """
413
+ op_tensor_params = []
414
+ # Must be a single output to avoid ambiguity.
415
+ output_tensor_params = _get_single_tensor_params(
416
+ output_tensors,
417
+ is_inbounding_tensor=False,
418
+ op_info=op_info,
419
+ graph_info=graph_info,
420
+ tensor_name_to_qsv=tensor_name_to_qsv,
421
+ )
422
+ # Use output quantization params for all input tensors.
423
+ if output_tensor_params.producer is None:
424
+ quant_params = None
425
+ else:
426
+ quant_params = output_tensor_params.producer.parameters
427
+ _materialize_op_tensors(
428
+ op_tensor_params,
429
+ input_tensors,
430
+ is_inbounding_tensor=True,
431
+ op_info=op_info,
432
+ graph_info=graph_info,
433
+ tensor_name_to_qsv=tensor_name_to_qsv,
434
+ quant_params=quant_params,
435
+ )
436
+ op_tensor_params.append(output_tensor_params)
437
+
438
+ return op_tensor_params
439
+
440
+
441
+ def _materialize_standard_op_no_constraint(
442
+ input_tensors: Sequence[Any],
443
+ output_tensors: Sequence[Any],
444
+ op_info: qtyping.OpInfo,
445
+ graph_info: qtyping.GraphInfo,
446
+ tensor_name_to_qsv: dict[str, Any],
447
+ ) -> list[qtyping.TensorTransformationParams]:
448
+ """Materialize tensors in an op with no constraint.
449
+
450
+ Args:
451
+ input_tensors: Input tensors for the op.
452
+ output_tensors: Output tensors for the op.
453
+ op_info: Aggregated information about the op (e.g., quantization config).
454
+ graph_info: Graph information needed to perform quantization for the op.
455
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
456
+
457
+ Returns:
458
+ Quantization configuration for the tensors associated with the op (e.g.,
459
+ weights, bias).
460
+ """
461
+ op_tensor_params = []
462
+ _materialize_op_tensors(
463
+ op_tensor_params,
464
+ input_tensors,
465
+ is_inbounding_tensor=True,
466
+ op_info=op_info,
467
+ graph_info=graph_info,
468
+ tensor_name_to_qsv=tensor_name_to_qsv,
469
+ )
470
+ _materialize_op_tensors(
471
+ op_tensor_params,
472
+ output_tensors,
473
+ is_inbounding_tensor=False,
474
+ op_info=op_info,
475
+ graph_info=graph_info,
476
+ tensor_name_to_qsv=tensor_name_to_qsv,
477
+ )
478
+
479
+ return op_tensor_params
480
+
481
+
482
+ def _split_tensors_by_indices(
483
+ op_info: qtyping.OpInfo,
484
+ graph_info: qtyping.GraphInfo,
485
+ indices: Optional[Sequence[int]],
486
+ is_inbounding_tensor: bool,
487
+ ) -> tuple[list[Any], list[Any], list[int]]:
488
+ """Split tensors into two lists and return indices with -1 values removed.
489
+
490
+ This function splits the tensors into two lists based on the provided indices.
491
+ * The first list contains tensors with indices in the provided indices list.
492
+ * The second list contains all remaining tensors.
493
+
494
+ Additionally, the function filters out any tensors with the index -1
495
+ (indicating non-existing bias in FC and cov ops) and returns a new list of
496
+ indices excluding these values.
497
+
498
+ Args:
499
+ op_info: Aggregated information about the op (e.g., quantization config).
500
+ graph_info: Graph information needed to perform quantization for the op.
501
+ indices: Indices of tensors to use for split.
502
+ is_inbounding_tensor: Whether the tensor is an inbounding tensor for the op.
503
+
504
+ Returns:
505
+ A tuple containing:
506
+ * A list of tensors with indices in the provided list.
507
+ * A list of tensors with indices not in the provided list.
508
+ * A new list of indices with -1 values removed.
509
+ """
510
+ indices = indices or []
511
+ updated_indices = []
512
+ updated_index = 0
513
+ selected_tensors, others = [], []
514
+ tensors = op_info.op.inputs if is_inbounding_tensor else op_info.op.outputs
515
+ for i, tensor_index in enumerate(tensors):
516
+ # Ignore non-existing tensors.
517
+ if tensor_index == -1:
518
+ continue
519
+ if i in indices:
520
+ updated_indices.append(updated_index)
521
+ selected_tensors.append(graph_info.subgraph_tensors[tensor_index])
522
+ else:
523
+ others.append(graph_info.subgraph_tensors[tensor_index])
524
+ updated_index += 1
525
+
526
+ return selected_tensors, others, updated_indices
527
+
528
+
529
+ def _materialize_ignored_tensors(
530
+ tensors: Sequence[Any],
531
+ op_info: qtyping.OpInfo,
532
+ is_inbounding_tensor: bool,
533
+ ) -> list[qtyping.TensorTransformationParams]:
534
+ """Materialize ignored tensors.
535
+
536
+ Args:
537
+ tensors: Tensors to ignore.
538
+ op_info: Aggregated information about the op (e.g., quantization config).
539
+ is_inbounding_tensor: Whether the tensors are the inbounding tensors for the
540
+ op.
541
+
542
+ Returns:
543
+ A list of tensor transformation params for the ignored tensors.
544
+ """
545
+ op_ignored_tensor_params = []
546
+ for tensor in tensors:
547
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
548
+ no_quant_tensor_params = qtyping.OpToTensorParams(
549
+ subgraph_op_id=op_info.subgraph_op_index,
550
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
551
+ )
552
+ if is_inbounding_tensor:
553
+ tensor_params = qtyping.TensorTransformationParams(
554
+ tensor_name=tensor_name,
555
+ consumers=[no_quant_tensor_params],
556
+ )
557
+ else:
558
+ tensor_params = qtyping.TensorTransformationParams(
559
+ tensor_name=tensor_name,
560
+ producer=no_quant_tensor_params,
561
+ )
562
+ op_ignored_tensor_params.append(tensor_params)
563
+
564
+ return op_ignored_tensor_params
565
+
566
+
567
+ def _merge_materialized_tensors(
568
+ tensor_params: list[qtyping.TensorTransformationParams],
569
+ ignored_input_tensor_params: Sequence[qtyping.TensorTransformationParams],
570
+ ignored_output_tensor_params: Sequence[qtyping.TensorTransformationParams],
571
+ op_info: qtyping.OpInfo,
572
+ inputs_to_ignore: Sequence[int],
573
+ outputs_to_ignore: Sequence[int],
574
+ ) -> list[qtyping.TensorTransformationParams]:
575
+ """Merge materialized tensors.
576
+
577
+ Merge tensor transformation parameters for non-ignored and ignored tensors.
578
+ The result list will keep the original order of inputs and outputs tensors in
579
+ the op.
580
+
581
+ Args:
582
+ tensor_params: Tensor transformation params for non-ignored tensors in the
583
+ op.
584
+ ignored_input_tensor_params: Tensor transformation params for the ignored
585
+ input tensors.
586
+ ignored_output_tensor_params: Tensor transformation params for the ignored
587
+ output tensors.
588
+ op_info: Aggregated information about the op (e.g., quantization config).
589
+ inputs_to_ignore: Input tensor indices to ignore.
590
+ outputs_to_ignore: Output tensor indices to ignore.
591
+
592
+ Returns:
593
+ Full list of transformation params for the op.
594
+ """
595
+ if not inputs_to_ignore and not outputs_to_ignore:
596
+ return tensor_params
597
+
598
+ result_tensor_params = []
599
+ num_inputs = len([x for x in op_info.op.inputs if x != -1])
600
+ num_outputs = len([x for x in op_info.op.outputs if x != -1])
601
+
602
+ # Add input tensors.
603
+ if inputs_to_ignore:
604
+ input_idx, ignored_input_idx = 0, 0
605
+ for i in range(num_inputs):
606
+ if i in inputs_to_ignore:
607
+ result_tensor_params.append(
608
+ ignored_input_tensor_params[ignored_input_idx]
609
+ )
610
+ ignored_input_idx += 1
611
+ else:
612
+ result_tensor_params.append(tensor_params[input_idx])
613
+ input_idx += 1
614
+ else:
615
+ result_tensor_params.extend(tensor_params[:num_inputs])
616
+
617
+ # Add output tensors.
618
+ output_start_idx = num_inputs - len(inputs_to_ignore)
619
+ if outputs_to_ignore:
620
+ output_idx, ignored_output_idx = output_start_idx, 0
621
+ for i in range(num_outputs):
622
+ if i in outputs_to_ignore:
623
+ result_tensor_params.append(
624
+ ignored_output_tensor_params[ignored_output_idx]
625
+ )
626
+ ignored_output_idx += 1
627
+ else:
628
+ result_tensor_params.append(tensor_params[output_idx])
629
+ output_idx += 1
630
+ else:
631
+ result_tensor_params.extend(tensor_params[output_start_idx:])
632
+
633
+ return result_tensor_params
634
+
635
+
636
+ def _tensor_indices_with_dtype(
637
+ tensors: Sequence[int],
638
+ subgraph_tensors: Sequence[schema_py_generated.TensorT],
639
+ tensor_dtype_codes: Sequence[int],
640
+ ) -> list[int]:
641
+ """Get the indices of tensors with any of the given dtype.
642
+
643
+ Args:
644
+ tensors: A list of tensors (indices) from the subgraph.
645
+ subgraph_tensors: A list of tensors in the subgraph.
646
+ tensor_dtype_codes: A list of tensor dtype codes.
647
+
648
+ Returns:
649
+ A list of indices of tensors with the given dtype.
650
+ """
651
+ selected_indices = []
652
+ for i, tensor_index in enumerate(tensors):
653
+ tensor = subgraph_tensors[tensor_index]
654
+ if tensor.type in tensor_dtype_codes:
655
+ selected_indices.append(i)
656
+ return selected_indices
657
+
658
+
659
+ def _add_non_match_tensors_to_ignored_lists(
660
+ op: schema_py_generated.OperatorT,
661
+ subgraph_tensors: Sequence[schema_py_generated.TensorT],
662
+ dtypes_to_keep: Sequence[int],
663
+ inputs_to_ignore: Sequence[int],
664
+ outputs_to_ignore: Sequence[int],
665
+ ) -> tuple[list[int], list[int]]:
666
+ """Include tensors (indices) of data types other than the specified dtype in the ignored lists.
667
+
668
+ Args:
669
+ op: The op to be processed.
670
+ subgraph_tensors: A list of tensors in the subgraph.
671
+ dtypes_to_keep: A list of tensor dtype codes that need to be kept (not in
672
+ the ignored lists).
673
+ inputs_to_ignore: Input tensor indices to ignore.
674
+ outputs_to_ignore: Output tensor indices to ignore.
675
+
676
+ Returns:
677
+ A tuple of updated inputs_to_ignore and outputs_to_ignore.
678
+ """
679
+ input_indices = set(range(len(op.inputs)))
680
+ inputs_to_keep = set(
681
+ _tensor_indices_with_dtype(op.inputs, subgraph_tensors, dtypes_to_keep)
682
+ )
683
+ inputs_to_keep -= set(inputs_to_ignore) # remove already ignored tensors.
684
+ inputs_to_ignore = list(input_indices - inputs_to_keep)
685
+
686
+ output_indices = set(range(len(op.outputs)))
687
+ outputs_to_keep = set(
688
+ _tensor_indices_with_dtype(op.outputs, subgraph_tensors, dtypes_to_keep)
689
+ )
690
+ outputs_to_keep -= set(outputs_to_ignore) # remove already ignored tensors.
691
+ outputs_to_ignore = list(output_indices - outputs_to_keep)
692
+ return inputs_to_ignore, outputs_to_ignore
693
+
694
+
695
+ def materialize_standard_op(
696
+ op_info: qtyping.OpInfo,
697
+ graph_info: qtyping.GraphInfo,
698
+ tensor_name_to_qsv: dict[str, Any],
699
+ constraint: OpQuantConstraint = OpQuantConstraint.NO_CONSTRAIN,
700
+ inputs_to_ignore: Optional[Sequence[int]] = None,
701
+ outputs_to_ignore: Optional[Sequence[int]] = None,
702
+ ) -> list[qtyping.TensorTransformationParams]:
703
+ """Default materialization function for an op.
704
+
705
+ Use materialize_fc_conv as the entry point to materialize FULLY_CONNECTED,
706
+ CONV_2D, DEPTHWISE_CONV_2D as these ops may contain fused bias.
707
+
708
+ Args:
709
+ op_info: Aggregated information about the op (e.g., quantization config).
710
+ graph_info: Graph information needed to perform quantization for the op.
711
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
712
+ constraint: The constraint for materializing the op.
713
+ inputs_to_ignore: Input tensor indices to ignore.
714
+ outputs_to_ignore: Output tensor indices to ignore.
715
+
716
+ Returns:
717
+ Quantization configuration for the tensors associated with the op (e.g.,
718
+ weights, bias). The returned list has the structure:
719
+ [input_tensor_0_params, ..., input_tensor_n_params,
720
+ output_tensor_0_params, ..., output_tensor_m_params].
721
+ """
722
+ inputs_to_ignore = inputs_to_ignore or []
723
+ outputs_to_ignore = outputs_to_ignore or []
724
+ # Filter out non-fp32 tensors (e.g., int32 indices).
725
+ fp32_type_code = 0 # See schema_py_generated.py for type code.
726
+ inputs_to_ignore, outputs_to_ignore = _add_non_match_tensors_to_ignored_lists(
727
+ op_info.op,
728
+ graph_info.subgraph_tensors,
729
+ [fp32_type_code],
730
+ inputs_to_ignore,
731
+ outputs_to_ignore,
732
+ )
733
+
734
+ # Process op inputs and outputs.
735
+ ignored_input_tensors, input_tensors, inputs_to_ignore = (
736
+ _split_tensors_by_indices(
737
+ op_info, graph_info, inputs_to_ignore, is_inbounding_tensor=True
738
+ )
739
+ )
740
+ ignored_output_tensors, output_tensors, outputs_to_ignore = (
741
+ _split_tensors_by_indices(
742
+ op_info, graph_info, outputs_to_ignore, is_inbounding_tensor=False
743
+ )
744
+ )
745
+ # Materialize op tensors.
746
+ if not input_tensors and not output_tensors:
747
+ tensor_params = [] # Every tensor is ignored.
748
+ elif constraint == OpQuantConstraint.SAME_AS_INPUT_SCALE:
749
+ tensor_params = _materialize_standard_op_with_same_as_input_scale(
750
+ input_tensors, output_tensors, op_info, graph_info, tensor_name_to_qsv
751
+ )
752
+ elif constraint == OpQuantConstraint.SAME_AS_OUTPUT_SCALE:
753
+ tensor_params = _materialize_standard_op_with_same_as_output_scale(
754
+ input_tensors, output_tensors, op_info, graph_info, tensor_name_to_qsv
755
+ )
756
+ else:
757
+ tensor_params = _materialize_standard_op_no_constraint(
758
+ input_tensors, output_tensors, op_info, graph_info, tensor_name_to_qsv
759
+ )
760
+
761
+ # Materialize ignored tensors.
762
+ ignored_input_tensor_params = _materialize_ignored_tensors(
763
+ ignored_input_tensors, op_info, is_inbounding_tensor=True
764
+ )
765
+ ignored_output_tensor_params = _materialize_ignored_tensors(
766
+ ignored_output_tensors, op_info, is_inbounding_tensor=False
767
+ )
768
+ # Combine all tensor params keeping the original order.
769
+ return _merge_materialized_tensors(
770
+ tensor_params,
771
+ ignored_input_tensor_params,
772
+ ignored_output_tensor_params,
773
+ op_info,
774
+ inputs_to_ignore,
775
+ outputs_to_ignore,
776
+ )
777
+
778
+
779
+ def materialize_op_with_output_activation_constraint(
780
+ op_info: qtyping.OpInfo,
781
+ graph_info: qtyping.GraphInfo,
782
+ tensor_name_to_qsv: dict[str, Any],
783
+ output_activation_constraints: dict[int, qtyping.UniformQuantParams],
784
+ ) -> list[qtyping.TensorTransformationParams]:
785
+ """Materialize tensors in an op with output activation constraint.
786
+
787
+ The output activation constraints are used to explicitly set
788
+ quantization parameters for the output tensor when doing SRQ.
789
+
790
+ Function assumes that there is only one output tensor.
791
+
792
+ Args:
793
+ op_info: Aggregated information about the op (e.g., quantization config).
794
+ graph_info: Graph information needed to perform quantization for the op.
795
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
796
+ output_activation_constraints: A map of output activation num bits to
797
+ quantization parameters.
798
+
799
+ Returns:
800
+ Quantization configuration for the tensors associated with the op (e.g.,
801
+ weights, bias).
802
+
803
+ Raises:
804
+ ValueError: If the op has more than one output tensor, or if the output
805
+ activation constraints dictionary does not contain an entry for the
806
+ activation num bits specified in the op quantization config.
807
+ """
808
+ if len(op_info.op.outputs) != 1:
809
+ raise ValueError(
810
+ "Materialize op with output activation constraint only supports ops"
811
+ " with a single output tensor."
812
+ )
813
+
814
+ tensor_params = materialize_standard_op(
815
+ op_info,
816
+ graph_info,
817
+ tensor_name_to_qsv,
818
+ )
819
+ output_tensor_params = tensor_params[-1]
820
+
821
+ # Explicitly set quantization parameters for the output tensor when doing SRQ.
822
+ activation_tensor_config = op_info.op_quant_config.activation_tensor_config
823
+ if (
824
+ activation_tensor_config is not None
825
+ and output_tensor_params.producer is not None
826
+ ):
827
+ activation_num_bits = activation_tensor_config.num_bits
828
+ if activation_num_bits not in output_activation_constraints:
829
+ raise ValueError(
830
+ "Output activation constraints dictionary does not contain entity"
831
+ f" for activation num bits {activation_num_bits}."
832
+ )
833
+ fixed_quant_params = output_activation_constraints[activation_num_bits]
834
+ op_tensor_params = qtyping.OpToTensorParams(
835
+ subgraph_op_id=output_tensor_params.producer.subgraph_op_id,
836
+ transformations=output_tensor_params.producer.transformations,
837
+ parameters=fixed_quant_params,
838
+ )
839
+ output_tensor_params.producer = op_tensor_params
840
+ # Update the tensor_name_to_qsv map using the output activation constraints.
841
+ min_val, max_val = _get_min_max_from_quant_params(
842
+ activation_num_bits,
843
+ activation_tensor_config.symmetric,
844
+ fixed_quant_params,
845
+ )
846
+ tensor_name_to_qsv[output_tensor_params.tensor_name]["min"] = min_val
847
+ tensor_name_to_qsv[output_tensor_params.tensor_name]["max"] = max_val
848
+
849
+ return tensor_params
850
+
851
+
852
+ def get_tensor_transformations(
853
+ op_quant_config: qtyping.OpQuantizationConfig,
854
+ is_inbounding_tensor: bool,
855
+ is_constant: bool,
856
+ ):
857
+ """Get the transformations for the tensor.
858
+
859
+ Args:
860
+ op_quant_config: the quantization config for the op.
861
+ is_inbounding_tensor: whether the tensor is an inbounding tensor for the op.
862
+ is_constant: whether the tensor is a constant tensor.
863
+
864
+ Returns:
865
+ The transformations for the tensor.
866
+ """
867
+ transformations = []
868
+ # Check if SRQ.
869
+ if (
870
+ op_quant_config.compute_precision == qtyping.ComputePrecision.INTEGER
871
+ and op_quant_config.activation_tensor_config is not None
872
+ ):
873
+ if is_inbounding_tensor:
874
+ transformations = [_QuantTransformation.ADD_QUANTIZE]
875
+ if is_constant:
876
+ # Quantize the constant tensor directly to simplify downstream
877
+ # optimizations.
878
+ transformations = [_QuantTransformation.QUANTIZE_TENSOR]
879
+ else:
880
+ transformations = [_QuantTransformation.ADD_DEQUANTIZE]
881
+ # Check if DRQ.
882
+ elif (
883
+ op_quant_config.compute_precision == qtyping.ComputePrecision.INTEGER
884
+ and op_quant_config.activation_tensor_config is None
885
+ ):
886
+ if is_inbounding_tensor and is_constant:
887
+ transformations = [_QuantTransformation.QUANTIZE_TENSOR]
888
+ else:
889
+ transformations = [_QuantTransformation.NO_QUANTIZE]
890
+ elif (
891
+ op_quant_config.weight_tensor_config is not None
892
+ and op_quant_config.weight_tensor_config.granularity
893
+ == qtyping.QuantGranularity.BLOCKWISE
894
+ and is_constant
895
+ ):
896
+ transformations = [_QuantTransformation.EMULATED_SUBCHANNEL]
897
+ # Check if WEIGHT_ONLY.
898
+ elif (
899
+ op_quant_config.compute_precision == qtyping.ComputePrecision.FLOAT
900
+ and op_quant_config.explicit_dequantize
901
+ ):
902
+ if is_inbounding_tensor and is_constant:
903
+ # ADD_DEQUANTIZE is always accompanined with a quantization parameters.
904
+ # Thus [ADD_DEQUANTIZE] is equivalent to [QUANTIZE_TENSOR, ADD_DEQUANTIZE]
905
+ # downstream pattern: quantized_tensor -> dequantize op -> float_tensor.
906
+ transformations = [_QuantTransformation.ADD_DEQUANTIZE]
907
+ else:
908
+ transformations = [_QuantTransformation.NO_QUANTIZE]
909
+ else:
910
+ raise ValueError(
911
+ "Unsupported compute precision: %s" % op_quant_config.compute_precision
912
+ )
913
+ return transformations
914
+
915
+
916
+ def get_tensor_transformation_params(
917
+ tensor_name: str,
918
+ op_info: qtyping.OpInfo,
919
+ is_inbounding_tensor: bool,
920
+ quant_params: Optional[qtyping.UniformQuantParams] = None,
921
+ is_constant: bool = False,
922
+ ) -> qtyping.TensorTransformationParams:
923
+ """Transformation params for the op's tensor.
924
+
925
+ Args:
926
+ tensor_name: the name of the tensor.
927
+ op_info: aggregated information about the op (e.g., quantization config).
928
+ is_inbounding_tensor: whether the tensor is inbounding tensor to the op.
929
+ quant_params: the quantization parameters for the tensor.
930
+ is_constant: whether the tensor is a constant tensor.
931
+
932
+ Returns:
933
+ The transformation for the op's tensor.
934
+ """
935
+ transformations = get_tensor_transformations(
936
+ op_info.op_quant_config, is_inbounding_tensor, is_constant
937
+ )
938
+ op2tensor_params = qtyping.OpToTensorParams(
939
+ subgraph_op_id=op_info.subgraph_op_index,
940
+ parameters=quant_params,
941
+ transformations=transformations,
942
+ )
943
+ if is_inbounding_tensor:
944
+ return qtyping.TensorTransformationParams(
945
+ tensor_name=tensor_name,
946
+ consumers=[op2tensor_params],
947
+ )
948
+ return qtyping.TensorTransformationParams(
949
+ tensor_name=tensor_name,
950
+ producer=op2tensor_params,
951
+ )
952
+
953
+
954
+ def _get_tensor_quant_params(
955
+ op_info: qtyping.OpInfo,
956
+ tensor_min_max: dict[str, Any],
957
+ tensor_quant_config: qtyping.TensorQuantizationConfig,
958
+ tensor_content: Optional[np.ndarray] = None,
959
+ ) -> qtyping.UniformQuantParams:
960
+ """Get the quantization parameters for a tensor.
961
+
962
+ Args:
963
+ op_info: aggregated information about the op (e.g., quantization config).
964
+ tensor_min_max: the min/max of the tensor.
965
+ tensor_quant_config: the quantization config for the tensor.
966
+ tensor_content: the content of the tensor.
967
+
968
+ Returns:
969
+ The quantization parameters for the tensor.
970
+ """
971
+ if "min" not in tensor_min_max or "max" not in tensor_min_max:
972
+ raise ValueError(
973
+ "min and max must be provided to produce tensor quantization"
974
+ " parameters. Check if the correct calibration results are passed into"
975
+ " the ParamsGenerator."
976
+ )
977
+ zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
978
+ tensor_min_max["min"],
979
+ tensor_min_max["max"],
980
+ tensor_quant_config.num_bits,
981
+ tensor_quant_config.symmetric,
982
+ )
983
+ quantized_dim = None
984
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
985
+ if op_info.op_name == _TFLOpName.BATCH_MATMUL:
986
+ quantized_dim = _get_bmm_weight_quantized_dim(
987
+ tensor_content, adj_y=op_info.op.builtinOptions.adjY
988
+ )
989
+ else:
990
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM[
991
+ op_info.op_name
992
+ ]
993
+ quant_params = qtyping.UniformQuantParams(
994
+ scale=scale,
995
+ zero_point=zp,
996
+ num_bits=tensor_quant_config.num_bits,
997
+ symmetric=tensor_quant_config.symmetric,
998
+ quantized_dimension=quantized_dim,
999
+ )
1000
+ if tensor_content is None:
1001
+ return quant_params
1002
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
1003
+ quantized_vars = (
1004
+ uniform_quantize_tensor.uniform_quantize_for_emulated_subchannel(
1005
+ tensor_content, quant_params, tensor_quant_config.block_size
1006
+ )
1007
+ )
1008
+ else:
1009
+ quantized_vars = uniform_quantize_tensor.uniform_quantize(
1010
+ tensor_content, quant_params
1011
+ )
1012
+ # Update with quantized values.
1013
+ return qtyping.UniformQuantParams(
1014
+ scale=scale,
1015
+ zero_point=zp,
1016
+ num_bits=tensor_quant_config.num_bits,
1017
+ symmetric=tensor_quant_config.symmetric,
1018
+ quantized_dimension=quantized_dim,
1019
+ quantized_data=quantized_vars,
1020
+ )
1021
+
1022
+
1023
+ def _get_reduce_dims(
1024
+ quantized_dim: Optional[int],
1025
+ tensor_shape: list[int],
1026
+ ) -> Optional[tuple[int, ...]]:
1027
+ """Get the reduce dims of a tensor for the given quantized dimension."""
1028
+ if quantized_dim is None:
1029
+ return None
1030
+ reduce_dims = []
1031
+ for rank_idx in range(len(tensor_shape)):
1032
+ if rank_idx != quantized_dim:
1033
+ reduce_dims.append(rank_idx)
1034
+ return tuple(reduce_dims)
1035
+
1036
+
1037
+ def _get_bmm_weight_quantized_dim(
1038
+ weight_tensor_data: np.ndarray, adj_y: bool
1039
+ ) -> int:
1040
+ """Get the quantized dimension for batch matmul."""
1041
+ rank = len(weight_tensor_data.shape)
1042
+ # If adj_y is true, the weight tensor is transposed.
1043
+ if adj_y:
1044
+ return rank - 2
1045
+ return rank - 1
1046
+
1047
+
1048
+ def _get_min_max_from_quant_params(
1049
+ num_bits: int,
1050
+ symmetric: bool,
1051
+ tensor_params: qtyping.UniformQuantParams,
1052
+ ) -> tuple[float, float]:
1053
+ """Recalculate min/max from tensor quantization params."""
1054
+ q_min, q_max = uniform_quantize_tensor.get_quantized_range(
1055
+ _IntType(num_bits, True)
1056
+ )
1057
+ float_min = uniform_quantize_tensor.uniform_dequantize(
1058
+ np.array(q_min), tensor_params
1059
+ )
1060
+ float_max = uniform_quantize_tensor.uniform_dequantize(
1061
+ np.array(q_max), tensor_params
1062
+ )
1063
+ # We use qmax values to compute scale for symmetric quantization (see
1064
+ # uniform_quantize_tensor.tensor_zp_scale_from_min_max).
1065
+ if symmetric:
1066
+ float_min = -float_max
1067
+ return (float_min, float_max)