ai-edge-quantizer-nightly 0.0.1.dev20250211__py3-none-any.whl → 0.0.1.dev20250212__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +40 -61
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +637 -0
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +74 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +139 -533
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +0 -22
- ai_edge_quantizer/algorithms/utils/{min_max_quantize_utils.py → common_utils.py} +61 -175
- ai_edge_quantizer/algorithms/utils/{min_max_quantize_utils_test.py → common_utils_test.py} +20 -19
- ai_edge_quantizer/qtyping.py +12 -1
- {ai_edge_quantizer_nightly-0.0.1.dev20250211.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.0.1.dev20250211.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/RECORD +13 -11
- {ai_edge_quantizer_nightly-0.0.1.dev20250211.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250211.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250211.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,6 @@ from absl.testing import parameterized
|
|
19
19
|
import numpy as np
|
20
20
|
|
21
21
|
from tensorflow.python.platform import googletest
|
22
|
-
from ai_edge_quantizer import default_policy
|
23
22
|
from ai_edge_quantizer import qtyping
|
24
23
|
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
|
25
24
|
from ai_edge_quantizer.utils import test_utils
|
@@ -158,27 +157,6 @@ class NaiveMinMaxQuantizeTest(parameterized.TestCase):
|
|
158
157
|
self.assertNotIn("arith.constant1", op_qsvs)
|
159
158
|
self.assertNotIn("arith.constant2", op_qsvs)
|
160
159
|
|
161
|
-
def test_check_op_quantization_config_with_negative_min_weight_elements_raises_error(
|
162
|
-
self,
|
163
|
-
):
|
164
|
-
op_quant_config = qtyping.OpQuantizationConfig(
|
165
|
-
weight_tensor_config=_TensorQuantConfig(
|
166
|
-
num_bits=8,
|
167
|
-
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
168
|
-
),
|
169
|
-
compute_precision=qtyping.ComputePrecision.INTEGER, # DRQ.
|
170
|
-
min_weight_elements=-1,
|
171
|
-
)
|
172
|
-
with self.assertRaisesWithPredicateMatch(
|
173
|
-
ValueError,
|
174
|
-
lambda err: "min_weight_elements must be non-negative" in str(err),
|
175
|
-
):
|
176
|
-
naive_min_max_quantize.check_op_quantization_config(
|
177
|
-
_TFLOpName.FULLY_CONNECTED,
|
178
|
-
op_quant_config,
|
179
|
-
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
180
|
-
)
|
181
|
-
|
182
160
|
|
183
161
|
if __name__ == "__main__":
|
184
162
|
googletest.main()
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""
|
16
|
+
"""Common utils for uniform quantization algorithms."""
|
17
17
|
|
18
18
|
from collections.abc import Sequence
|
19
19
|
import dataclasses
|
@@ -29,16 +29,8 @@ _TFLOpName = qtyping.TFLOperationName
|
|
29
29
|
_QuantTransformation = qtyping.QuantTransformation
|
30
30
|
_IntType = uniform_quantize_tensor.IntType
|
31
31
|
|
32
|
-
_SUPPORTED_WEIGHT_ONLY_OPS = frozenset([
|
33
|
-
_TFLOpName.FULLY_CONNECTED,
|
34
|
-
_TFLOpName.CONV_2D,
|
35
|
-
_TFLOpName.BATCH_MATMUL,
|
36
|
-
_TFLOpName.EMBEDDING_LOOKUP,
|
37
|
-
_TFLOpName.DEPTHWISE_CONV_2D,
|
38
|
-
_TFLOpName.CONV_2D_TRANSPOSE,
|
39
|
-
])
|
40
32
|
|
41
|
-
|
33
|
+
_DRQ_OR_WEIGHT_ONLY_OPS = frozenset([
|
42
34
|
_TFLOpName.FULLY_CONNECTED,
|
43
35
|
_TFLOpName.CONV_2D,
|
44
36
|
_TFLOpName.BATCH_MATMUL,
|
@@ -46,6 +38,7 @@ _SUPPORTED_DRQ_OPS = frozenset([
|
|
46
38
|
_TFLOpName.DEPTHWISE_CONV_2D,
|
47
39
|
_TFLOpName.CONV_2D_TRANSPOSE,
|
48
40
|
])
|
41
|
+
|
49
42
|
_SUPPORTED_SUBCHANNEL_OPS = frozenset([
|
50
43
|
_TFLOpName.FULLY_CONNECTED,
|
51
44
|
])
|
@@ -139,73 +132,13 @@ class OpQuantConstraint(enum.Enum):
|
|
139
132
|
SAME_AS_OUTPUT_SCALE = 2
|
140
133
|
|
141
134
|
|
142
|
-
def init_tensor_min_max(
|
143
|
-
tensor: Any,
|
144
|
-
graph_info: qtyping.GraphInfo,
|
145
|
-
op_info: qtyping.OpInfo,
|
146
|
-
):
|
147
|
-
"""Initialize the min/max for a tensor."""
|
148
|
-
tensor_data = tfl_flatbuffer_utils.get_tensor_data(tensor, graph_info.buffers)
|
149
|
-
# Initial values for non-constant tensors.
|
150
|
-
if tensor_data is None:
|
151
|
-
return {}
|
152
|
-
# Real min/max for constant tensors.
|
153
|
-
else:
|
154
|
-
quantized_dim = None
|
155
|
-
if (
|
156
|
-
op_info.op_quant_config.weight_tensor_config is not None
|
157
|
-
and op_info.op_quant_config.weight_tensor_config.granularity
|
158
|
-
== qtyping.QuantGranularity.BLOCKWISE
|
159
|
-
):
|
160
|
-
# TODO(b/346612503): emulate subchannel only supports fully connected,
|
161
|
-
# will skip special handling. Once we have a spec, we can change this.
|
162
|
-
block_size = op_info.op_quant_config.weight_tensor_config.block_size
|
163
|
-
# assuming tensor is 2D, which is correct for FULLY_CONNECTED
|
164
|
-
transposed_tensor_data = np.transpose(tensor_data, (1, 0))
|
165
|
-
if transposed_tensor_data.shape[0] % block_size:
|
166
|
-
raise ValueError(
|
167
|
-
f"Block size {block_size} does not divide channel dimension"
|
168
|
-
f" {transposed_tensor_data.shape[0]}."
|
169
|
-
)
|
170
|
-
reshaped_tensor_data = np.reshape(
|
171
|
-
transposed_tensor_data,
|
172
|
-
(
|
173
|
-
1,
|
174
|
-
int(transposed_tensor_data.shape[0] / block_size),
|
175
|
-
block_size,
|
176
|
-
transposed_tensor_data.shape[1],
|
177
|
-
),
|
178
|
-
)
|
179
|
-
return {
|
180
|
-
"min": np.min(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
|
181
|
-
"max": np.max(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
|
182
|
-
}
|
183
|
-
if (
|
184
|
-
op_info.op_quant_config.weight_tensor_config is not None
|
185
|
-
and op_info.op_quant_config.weight_tensor_config.granularity
|
186
|
-
== qtyping.QuantGranularity.CHANNELWISE
|
187
|
-
):
|
188
|
-
if op_info.op_name == _TFLOpName.BATCH_MATMUL:
|
189
|
-
quantized_dim = _get_bmm_weight_quantized_dim(
|
190
|
-
tensor_data, adj_y=op_info.op.builtinOptions.adjY
|
191
|
-
)
|
192
|
-
else:
|
193
|
-
quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
|
194
|
-
op_info.op_name, None
|
195
|
-
)
|
196
|
-
reduce_dims = _get_reduce_dims(quantized_dim, tensor.shape)
|
197
|
-
return {
|
198
|
-
"min": np.min(tensor_data, axis=reduce_dims, keepdims=True),
|
199
|
-
"max": np.max(tensor_data, axis=reduce_dims, keepdims=True),
|
200
|
-
}
|
201
|
-
|
202
|
-
|
203
135
|
def _get_tensor_transformation_params_wrapper(
|
204
136
|
tensor: Any,
|
205
137
|
is_inbounding_tensor: bool,
|
206
138
|
op_info: qtyping.OpInfo,
|
207
139
|
graph_info: qtyping.GraphInfo,
|
208
140
|
tensor_name_to_qsv: dict[str, Any],
|
141
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
209
142
|
quant_params=None,
|
210
143
|
) -> qtyping.TensorTransformationParams:
|
211
144
|
"""Util to get tensor transformation params.
|
@@ -216,6 +149,8 @@ def _get_tensor_transformation_params_wrapper(
|
|
216
149
|
op_info: Aggregated information about the op (e.g., quantization config).
|
217
150
|
graph_info: Graph information needed to perform quantization for the op.
|
218
151
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
152
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
153
|
+
tensor.
|
219
154
|
quant_params: Quantization parameters for the tensor.
|
220
155
|
|
221
156
|
Returns:
|
@@ -229,37 +164,15 @@ def _get_tensor_transformation_params_wrapper(
|
|
229
164
|
tensor_quant_config = op_info.op_quant_config.activation_tensor_config
|
230
165
|
is_constant = tensor_data is not None
|
231
166
|
# Use weight configuration if it is supported.
|
232
|
-
if is_constant and op_info.op_name in
|
233
|
-
_SUPPORTED_WEIGHT_ONLY_OPS, _SUPPORTED_DRQ_OPS
|
234
|
-
):
|
167
|
+
if is_constant and op_info.op_name in _DRQ_OR_WEIGHT_ONLY_OPS:
|
235
168
|
tensor_quant_config = op_info.op_quant_config.weight_tensor_config
|
236
169
|
# Get quant params.
|
237
170
|
if quant_params is None and tensor_quant_config is not None:
|
238
|
-
|
239
|
-
if is_constant:
|
240
|
-
# We need min/max to calculate quantization parameters, which
|
241
|
-
# should be collected during the calibration process. However,
|
242
|
-
# weight-only and DRQ do not require calibration, thus it is
|
243
|
-
# possible that this information is missing here. In that case we
|
244
|
-
# collect min/max on the spot.
|
245
|
-
tensor_min_max = init_tensor_min_max(
|
246
|
-
tensor,
|
247
|
-
graph_info,
|
248
|
-
op_info,
|
249
|
-
)
|
250
|
-
else:
|
251
|
-
raise ValueError(
|
252
|
-
f"Tensor {tensor_name} not found in tensor_name_to_qsv. Check"
|
253
|
-
" if the correct calibration results are passed into the"
|
254
|
-
" ParamsGenerator."
|
255
|
-
)
|
256
|
-
else:
|
257
|
-
tensor_min_max = tensor_name_to_qsv[tensor_name]
|
258
|
-
quant_params = _get_tensor_quant_params(
|
171
|
+
quant_params = get_tensor_quant_params_fn(
|
259
172
|
op_info,
|
260
|
-
tensor_min_max,
|
261
173
|
tensor_quant_config,
|
262
|
-
|
174
|
+
tensor_data,
|
175
|
+
tensor_name_to_qsv.get(tensor_name),
|
263
176
|
)
|
264
177
|
return get_tensor_transformation_params(
|
265
178
|
tensor_name,
|
@@ -277,6 +190,7 @@ def _materialize_op_tensors(
|
|
277
190
|
op_info: qtyping.OpInfo,
|
278
191
|
graph_info: qtyping.GraphInfo,
|
279
192
|
tensor_name_to_qsv: dict[str, Any],
|
193
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
280
194
|
quant_params=None,
|
281
195
|
) -> None:
|
282
196
|
"""Util to materialize op tensors. Appends the results to op_tensor_params.
|
@@ -289,6 +203,8 @@ def _materialize_op_tensors(
|
|
289
203
|
op_info: Aggregated information about the op (e.g., quantization config).
|
290
204
|
graph_info: Graph information needed to perform quantization for the op.
|
291
205
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
206
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
207
|
+
tensor.
|
292
208
|
quant_params: Quantization parameters for the tensor.
|
293
209
|
"""
|
294
210
|
for tensor in op_tensors:
|
@@ -298,7 +214,8 @@ def _materialize_op_tensors(
|
|
298
214
|
op_info,
|
299
215
|
graph_info,
|
300
216
|
tensor_name_to_qsv,
|
301
|
-
|
217
|
+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
|
218
|
+
quant_params=quant_params,
|
302
219
|
)
|
303
220
|
op_tensor_params.append(tensor_params)
|
304
221
|
|
@@ -309,6 +226,7 @@ def _get_single_tensor_params(
|
|
309
226
|
op_info: qtyping.OpInfo,
|
310
227
|
graph_info: qtyping.GraphInfo,
|
311
228
|
tensor_name_to_qsv: dict[str, Any],
|
229
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
312
230
|
) -> qtyping.TensorTransformationParams:
|
313
231
|
"""Util to get single tensor params.
|
314
232
|
|
@@ -318,6 +236,8 @@ def _get_single_tensor_params(
|
|
318
236
|
op_info: Aggregated information about the op (e.g., quantization config).
|
319
237
|
graph_info: Graph information needed to perform quantization for the op.
|
320
238
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
239
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
240
|
+
tensor.
|
321
241
|
|
322
242
|
Returns:
|
323
243
|
Transformation parameters for the tensor.
|
@@ -336,6 +256,7 @@ def _get_single_tensor_params(
|
|
336
256
|
op_info,
|
337
257
|
graph_info,
|
338
258
|
tensor_name_to_qsv,
|
259
|
+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
|
339
260
|
)
|
340
261
|
|
341
262
|
|
@@ -345,6 +266,7 @@ def _materialize_standard_op_with_same_as_input_scale(
|
|
345
266
|
op_info: qtyping.OpInfo,
|
346
267
|
graph_info: qtyping.GraphInfo,
|
347
268
|
tensor_name_to_qsv: dict[str, Any],
|
269
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
348
270
|
) -> list[qtyping.TensorTransformationParams]:
|
349
271
|
"""Materialize tensors in an op with same as input scale constraint.
|
350
272
|
|
@@ -354,6 +276,8 @@ def _materialize_standard_op_with_same_as_input_scale(
|
|
354
276
|
op_info: Aggregated information about the op (e.g., quantization config).
|
355
277
|
graph_info: Graph information needed to perform quantization for the op.
|
356
278
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
279
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
280
|
+
tensor.
|
357
281
|
|
358
282
|
Returns:
|
359
283
|
Quantization configuration for the tensors associated with the op (e.g.,
|
@@ -367,6 +291,7 @@ def _materialize_standard_op_with_same_as_input_scale(
|
|
367
291
|
op_info=op_info,
|
368
292
|
graph_info=graph_info,
|
369
293
|
tensor_name_to_qsv=tensor_name_to_qsv,
|
294
|
+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
|
370
295
|
)
|
371
296
|
op_tensor_params.append(input_tensor_params)
|
372
297
|
# Use input quantization params for all output tensors.
|
@@ -377,6 +302,7 @@ def _materialize_standard_op_with_same_as_input_scale(
|
|
377
302
|
op_info=op_info,
|
378
303
|
graph_info=graph_info,
|
379
304
|
tensor_name_to_qsv=tensor_name_to_qsv,
|
305
|
+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
|
380
306
|
quant_params=input_tensor_params.consumers[0].parameters,
|
381
307
|
)
|
382
308
|
# Change output qsv to be the same as input qsv. This is safe since TFL
|
@@ -396,6 +322,7 @@ def _materialize_standard_op_with_same_as_output_scale(
|
|
396
322
|
op_info: qtyping.OpInfo,
|
397
323
|
graph_info: qtyping.GraphInfo,
|
398
324
|
tensor_name_to_qsv: dict[str, Any],
|
325
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
399
326
|
) -> list[qtyping.TensorTransformationParams]:
|
400
327
|
"""Materialize tensors in an op with same as output scale constraint.
|
401
328
|
|
@@ -405,6 +332,8 @@ def _materialize_standard_op_with_same_as_output_scale(
|
|
405
332
|
op_info: Aggregated information about the op (e.g., quantization config).
|
406
333
|
graph_info: Graph information needed to perform quantization for the op.
|
407
334
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
335
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
336
|
+
tensor.
|
408
337
|
|
409
338
|
Returns:
|
410
339
|
Quantization configuration for the tensors associated with the op (e.g.,
|
@@ -418,6 +347,7 @@ def _materialize_standard_op_with_same_as_output_scale(
|
|
418
347
|
op_info=op_info,
|
419
348
|
graph_info=graph_info,
|
420
349
|
tensor_name_to_qsv=tensor_name_to_qsv,
|
350
|
+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
|
421
351
|
)
|
422
352
|
# Use output quantization params for all input tensors.
|
423
353
|
if output_tensor_params.producer is None:
|
@@ -431,6 +361,7 @@ def _materialize_standard_op_with_same_as_output_scale(
|
|
431
361
|
op_info=op_info,
|
432
362
|
graph_info=graph_info,
|
433
363
|
tensor_name_to_qsv=tensor_name_to_qsv,
|
364
|
+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
|
434
365
|
quant_params=quant_params,
|
435
366
|
)
|
436
367
|
op_tensor_params.append(output_tensor_params)
|
@@ -444,6 +375,7 @@ def _materialize_standard_op_no_constraint(
|
|
444
375
|
op_info: qtyping.OpInfo,
|
445
376
|
graph_info: qtyping.GraphInfo,
|
446
377
|
tensor_name_to_qsv: dict[str, Any],
|
378
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
447
379
|
) -> list[qtyping.TensorTransformationParams]:
|
448
380
|
"""Materialize tensors in an op with no constraint.
|
449
381
|
|
@@ -453,6 +385,8 @@ def _materialize_standard_op_no_constraint(
|
|
453
385
|
op_info: Aggregated information about the op (e.g., quantization config).
|
454
386
|
graph_info: Graph information needed to perform quantization for the op.
|
455
387
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
388
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
389
|
+
tensor.
|
456
390
|
|
457
391
|
Returns:
|
458
392
|
Quantization configuration for the tensors associated with the op (e.g.,
|
@@ -466,6 +400,7 @@ def _materialize_standard_op_no_constraint(
|
|
466
400
|
op_info=op_info,
|
467
401
|
graph_info=graph_info,
|
468
402
|
tensor_name_to_qsv=tensor_name_to_qsv,
|
403
|
+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
|
469
404
|
)
|
470
405
|
_materialize_op_tensors(
|
471
406
|
op_tensor_params,
|
@@ -474,6 +409,7 @@ def _materialize_standard_op_no_constraint(
|
|
474
409
|
op_info=op_info,
|
475
410
|
graph_info=graph_info,
|
476
411
|
tensor_name_to_qsv=tensor_name_to_qsv,
|
412
|
+
get_tensor_quant_params_fn=get_tensor_quant_params_fn,
|
477
413
|
)
|
478
414
|
|
479
415
|
return op_tensor_params
|
@@ -696,6 +632,7 @@ def materialize_standard_op(
|
|
696
632
|
op_info: qtyping.OpInfo,
|
697
633
|
graph_info: qtyping.GraphInfo,
|
698
634
|
tensor_name_to_qsv: dict[str, Any],
|
635
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
699
636
|
constraint: OpQuantConstraint = OpQuantConstraint.NO_CONSTRAIN,
|
700
637
|
inputs_to_ignore: Optional[Sequence[int]] = None,
|
701
638
|
outputs_to_ignore: Optional[Sequence[int]] = None,
|
@@ -709,6 +646,8 @@ def materialize_standard_op(
|
|
709
646
|
op_info: Aggregated information about the op (e.g., quantization config).
|
710
647
|
graph_info: Graph information needed to perform quantization for the op.
|
711
648
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
649
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
650
|
+
tensor.
|
712
651
|
constraint: The constraint for materializing the op.
|
713
652
|
inputs_to_ignore: Input tensor indices to ignore.
|
714
653
|
outputs_to_ignore: Output tensor indices to ignore.
|
@@ -747,15 +686,30 @@ def materialize_standard_op(
|
|
747
686
|
tensor_params = [] # Every tensor is ignored.
|
748
687
|
elif constraint == OpQuantConstraint.SAME_AS_INPUT_SCALE:
|
749
688
|
tensor_params = _materialize_standard_op_with_same_as_input_scale(
|
750
|
-
input_tensors,
|
689
|
+
input_tensors,
|
690
|
+
output_tensors,
|
691
|
+
op_info,
|
692
|
+
graph_info,
|
693
|
+
tensor_name_to_qsv,
|
694
|
+
get_tensor_quant_params_fn,
|
751
695
|
)
|
752
696
|
elif constraint == OpQuantConstraint.SAME_AS_OUTPUT_SCALE:
|
753
697
|
tensor_params = _materialize_standard_op_with_same_as_output_scale(
|
754
|
-
input_tensors,
|
698
|
+
input_tensors,
|
699
|
+
output_tensors,
|
700
|
+
op_info,
|
701
|
+
graph_info,
|
702
|
+
tensor_name_to_qsv,
|
703
|
+
get_tensor_quant_params_fn,
|
755
704
|
)
|
756
705
|
else:
|
757
706
|
tensor_params = _materialize_standard_op_no_constraint(
|
758
|
-
input_tensors,
|
707
|
+
input_tensors,
|
708
|
+
output_tensors,
|
709
|
+
op_info,
|
710
|
+
graph_info,
|
711
|
+
tensor_name_to_qsv,
|
712
|
+
get_tensor_quant_params_fn,
|
759
713
|
)
|
760
714
|
|
761
715
|
# Materialize ignored tensors.
|
@@ -781,6 +735,7 @@ def materialize_op_with_output_activation_constraint(
|
|
781
735
|
graph_info: qtyping.GraphInfo,
|
782
736
|
tensor_name_to_qsv: dict[str, Any],
|
783
737
|
output_activation_constraints: dict[int, qtyping.UniformQuantParams],
|
738
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
784
739
|
) -> list[qtyping.TensorTransformationParams]:
|
785
740
|
"""Materialize tensors in an op with output activation constraint.
|
786
741
|
|
@@ -795,6 +750,8 @@ def materialize_op_with_output_activation_constraint(
|
|
795
750
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
796
751
|
output_activation_constraints: A map of output activation num bits to
|
797
752
|
quantization parameters.
|
753
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
754
|
+
tensor.
|
798
755
|
|
799
756
|
Returns:
|
800
757
|
Quantization configuration for the tensors associated with the op (e.g.,
|
@@ -812,9 +769,7 @@ def materialize_op_with_output_activation_constraint(
|
|
812
769
|
)
|
813
770
|
|
814
771
|
tensor_params = materialize_standard_op(
|
815
|
-
op_info,
|
816
|
-
graph_info,
|
817
|
-
tensor_name_to_qsv,
|
772
|
+
op_info, graph_info, tensor_name_to_qsv, get_tensor_quant_params_fn
|
818
773
|
)
|
819
774
|
output_tensor_params = tensor_params[-1]
|
820
775
|
|
@@ -951,76 +906,7 @@ def get_tensor_transformation_params(
|
|
951
906
|
)
|
952
907
|
|
953
908
|
|
954
|
-
def
|
955
|
-
op_info: qtyping.OpInfo,
|
956
|
-
tensor_min_max: dict[str, Any],
|
957
|
-
tensor_quant_config: qtyping.TensorQuantizationConfig,
|
958
|
-
tensor_content: Optional[np.ndarray] = None,
|
959
|
-
) -> qtyping.UniformQuantParams:
|
960
|
-
"""Get the quantization parameters for a tensor.
|
961
|
-
|
962
|
-
Args:
|
963
|
-
op_info: aggregated information about the op (e.g., quantization config).
|
964
|
-
tensor_min_max: the min/max of the tensor.
|
965
|
-
tensor_quant_config: the quantization config for the tensor.
|
966
|
-
tensor_content: the content of the tensor.
|
967
|
-
|
968
|
-
Returns:
|
969
|
-
The quantization parameters for the tensor.
|
970
|
-
"""
|
971
|
-
if "min" not in tensor_min_max or "max" not in tensor_min_max:
|
972
|
-
raise ValueError(
|
973
|
-
"min and max must be provided to produce tensor quantization"
|
974
|
-
" parameters. Check if the correct calibration results are passed into"
|
975
|
-
" the ParamsGenerator."
|
976
|
-
)
|
977
|
-
zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
|
978
|
-
tensor_min_max["min"],
|
979
|
-
tensor_min_max["max"],
|
980
|
-
tensor_quant_config.num_bits,
|
981
|
-
tensor_quant_config.symmetric,
|
982
|
-
)
|
983
|
-
quantized_dim = None
|
984
|
-
if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
|
985
|
-
if op_info.op_name == _TFLOpName.BATCH_MATMUL:
|
986
|
-
quantized_dim = _get_bmm_weight_quantized_dim(
|
987
|
-
tensor_content, adj_y=op_info.op.builtinOptions.adjY
|
988
|
-
)
|
989
|
-
else:
|
990
|
-
quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM[
|
991
|
-
op_info.op_name
|
992
|
-
]
|
993
|
-
quant_params = qtyping.UniformQuantParams(
|
994
|
-
scale=scale,
|
995
|
-
zero_point=zp,
|
996
|
-
num_bits=tensor_quant_config.num_bits,
|
997
|
-
symmetric=tensor_quant_config.symmetric,
|
998
|
-
quantized_dimension=quantized_dim,
|
999
|
-
)
|
1000
|
-
if tensor_content is None:
|
1001
|
-
return quant_params
|
1002
|
-
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
|
1003
|
-
quantized_vars = (
|
1004
|
-
uniform_quantize_tensor.uniform_quantize_for_emulated_subchannel(
|
1005
|
-
tensor_content, quant_params, tensor_quant_config.block_size
|
1006
|
-
)
|
1007
|
-
)
|
1008
|
-
else:
|
1009
|
-
quantized_vars = uniform_quantize_tensor.uniform_quantize(
|
1010
|
-
tensor_content, quant_params
|
1011
|
-
)
|
1012
|
-
# Update with quantized values.
|
1013
|
-
return qtyping.UniformQuantParams(
|
1014
|
-
scale=scale,
|
1015
|
-
zero_point=zp,
|
1016
|
-
num_bits=tensor_quant_config.num_bits,
|
1017
|
-
symmetric=tensor_quant_config.symmetric,
|
1018
|
-
quantized_dimension=quantized_dim,
|
1019
|
-
quantized_data=quantized_vars,
|
1020
|
-
)
|
1021
|
-
|
1022
|
-
|
1023
|
-
def _get_reduce_dims(
|
909
|
+
def get_reduce_dims(
|
1024
910
|
quantized_dim: Optional[int],
|
1025
911
|
tensor_shape: list[int],
|
1026
912
|
) -> Optional[tuple[int, ...]]:
|
@@ -1034,7 +920,7 @@ def _get_reduce_dims(
|
|
1034
920
|
return tuple(reduce_dims)
|
1035
921
|
|
1036
922
|
|
1037
|
-
def
|
923
|
+
def get_bmm_weight_quantized_dim(
|
1038
924
|
weight_tensor_data: np.ndarray, adj_y: bool
|
1039
925
|
) -> int:
|
1040
926
|
"""Get the quantized dimension for batch matmul."""
|
@@ -18,7 +18,7 @@ from absl.testing import parameterized
|
|
18
18
|
from tensorflow.python.platform import googletest
|
19
19
|
from ai_edge_quantizer import default_policy
|
20
20
|
from ai_edge_quantizer import qtyping
|
21
|
-
from ai_edge_quantizer.algorithms.utils import
|
21
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
22
22
|
|
23
23
|
_ComputePrecision = qtyping.ComputePrecision
|
24
24
|
_QuantTransformation = qtyping.QuantTransformation
|
@@ -52,7 +52,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
52
52
|
compute_precision=compute_precision,
|
53
53
|
explicit_dequantize=explicit_dequantize,
|
54
54
|
)
|
55
|
-
transformations =
|
55
|
+
transformations = common_utils.get_tensor_transformations(
|
56
56
|
op_quant_config, is_inbounding_tensor, is_constant
|
57
57
|
)
|
58
58
|
# Check if WEIGHT_ONLY.
|
@@ -120,7 +120,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
120
120
|
with self.assertRaisesWithPredicateMatch(
|
121
121
|
ValueError, lambda err: error_message in str(err)
|
122
122
|
):
|
123
|
-
|
123
|
+
common_utils.check_if_valid_op_config(
|
124
124
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
125
125
|
)
|
126
126
|
|
@@ -146,7 +146,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
146
146
|
),
|
147
147
|
compute_precision=_ComputePrecision.INTEGER, # DRQ.
|
148
148
|
)
|
149
|
-
|
149
|
+
common_utils.check_if_valid_op_config(
|
150
150
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
151
151
|
)
|
152
152
|
|
@@ -166,7 +166,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
166
166
|
with self.assertRaisesWithPredicateMatch(
|
167
167
|
ValueError, lambda err: error_message in str(err)
|
168
168
|
):
|
169
|
-
|
169
|
+
common_utils.check_if_valid_op_config(
|
170
170
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
171
171
|
)
|
172
172
|
|
@@ -186,7 +186,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
186
186
|
with self.assertRaisesWithPredicateMatch(
|
187
187
|
ValueError, lambda err: error_message in str(err)
|
188
188
|
):
|
189
|
-
|
189
|
+
common_utils.check_if_valid_op_config(
|
190
190
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
191
191
|
)
|
192
192
|
|
@@ -207,7 +207,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
207
207
|
with self.assertRaisesWithPredicateMatch(
|
208
208
|
ValueError, lambda err: error_message in str(err)
|
209
209
|
):
|
210
|
-
|
210
|
+
common_utils.check_if_valid_op_config(
|
211
211
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
212
212
|
)
|
213
213
|
|
@@ -220,7 +220,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
220
220
|
compute_precision=_ComputePrecision.INTEGER, # DRQ.
|
221
221
|
min_weight_elements=100,
|
222
222
|
)
|
223
|
-
|
223
|
+
common_utils.check_if_valid_op_config(
|
224
224
|
_TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
225
225
|
)
|
226
226
|
|
@@ -255,7 +255,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
255
255
|
),
|
256
256
|
compute_precision=_ComputePrecision.INTEGER, # SRQ.
|
257
257
|
)
|
258
|
-
|
258
|
+
common_utils.check_if_valid_op_config(
|
259
259
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
260
260
|
)
|
261
261
|
|
@@ -275,7 +275,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
275
275
|
with self.assertRaisesWithPredicateMatch(
|
276
276
|
ValueError, lambda err: error_message in str(err)
|
277
277
|
):
|
278
|
-
|
278
|
+
common_utils.check_if_valid_op_config(
|
279
279
|
_TFLOpName.CUSTOM_OP, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
280
280
|
)
|
281
281
|
|
@@ -297,7 +297,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
297
297
|
with self.assertRaisesWithPredicateMatch(
|
298
298
|
ValueError, lambda err: error_message in str(err)
|
299
299
|
):
|
300
|
-
|
300
|
+
common_utils.check_if_valid_op_config(
|
301
301
|
_TFLOpName.FULLY_CONNECTED,
|
302
302
|
op_quant_config,
|
303
303
|
_DEFAULT_CONFIG_CHECK_POLICY,
|
@@ -321,7 +321,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
321
321
|
with self.assertRaisesWithPredicateMatch(
|
322
322
|
ValueError, lambda err: error_message in str(err)
|
323
323
|
):
|
324
|
-
|
324
|
+
common_utils.check_if_valid_op_config(
|
325
325
|
_TFLOpName.FULLY_CONNECTED,
|
326
326
|
op_quant_config,
|
327
327
|
_DEFAULT_CONFIG_CHECK_POLICY,
|
@@ -345,7 +345,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
345
345
|
with self.assertRaisesWithPredicateMatch(
|
346
346
|
ValueError, lambda err: error_message in str(err)
|
347
347
|
):
|
348
|
-
|
348
|
+
common_utils.check_if_valid_op_config(
|
349
349
|
_TFLOpName.FULLY_CONNECTED,
|
350
350
|
op_quant_config,
|
351
351
|
_DEFAULT_CONFIG_CHECK_POLICY,
|
@@ -368,7 +368,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
368
368
|
with self.assertRaisesWithPredicateMatch(
|
369
369
|
ValueError, lambda err: error_message in str(err)
|
370
370
|
):
|
371
|
-
|
371
|
+
common_utils.check_if_valid_op_config(
|
372
372
|
_TFLOpName.FULLY_CONNECTED,
|
373
373
|
op_quant_config,
|
374
374
|
_DEFAULT_CONFIG_CHECK_POLICY,
|
@@ -425,7 +425,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
425
425
|
compute_precision == _ComputePrecision.INTEGER
|
426
426
|
and op_quant_config.activation_tensor_config is None
|
427
427
|
):
|
428
|
-
|
428
|
+
common_utils.check_if_valid_op_config(
|
429
429
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
430
430
|
)
|
431
431
|
# Check if WEIGHT_ONLY.
|
@@ -439,7 +439,7 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
439
439
|
compute_precision == _ComputePrecision.INTEGER
|
440
440
|
and op_quant_config.activation_tensor_config is not None
|
441
441
|
):
|
442
|
-
|
442
|
+
common_utils.check_if_valid_op_config(
|
443
443
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
444
444
|
)
|
445
445
|
|
@@ -477,11 +477,11 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
477
477
|
|
478
478
|
with self.assertRaises(ValueError):
|
479
479
|
if is_drq:
|
480
|
-
|
480
|
+
common_utils.check_if_valid_op_config(
|
481
481
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
482
482
|
)
|
483
483
|
elif not is_drq:
|
484
|
-
|
484
|
+
common_utils.check_if_valid_op_config(
|
485
485
|
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
|
486
486
|
)
|
487
487
|
|
@@ -500,11 +500,12 @@ class MinMaxQuantizeUtilsTest(parameterized.TestCase):
|
|
500
500
|
with self.assertRaisesRegex(
|
501
501
|
ValueError, "only supports ops with a single output tensor"
|
502
502
|
):
|
503
|
-
|
503
|
+
common_utils.materialize_op_with_output_activation_constraint(
|
504
504
|
op_info=mock_op_info,
|
505
505
|
graph_info=qtyping.GraphInfo([], []),
|
506
506
|
tensor_name_to_qsv={},
|
507
507
|
output_activation_constraints={},
|
508
|
+
get_tensor_quant_params_fn=lambda *args: [],
|
508
509
|
)
|
509
510
|
|
510
511
|
|
ai_edge_quantizer/qtyping.py
CHANGED
@@ -20,7 +20,7 @@ from collections.abc import MutableMapping
|
|
20
20
|
import copy
|
21
21
|
import dataclasses
|
22
22
|
import enum
|
23
|
-
from typing import Any, Optional, Union
|
23
|
+
from typing import Any, Optional, Union, Callable
|
24
24
|
|
25
25
|
import numpy as np
|
26
26
|
from typing_extensions import TypeAlias
|
@@ -485,3 +485,14 @@ class IOOperator:
|
|
485
485
|
inputs: list[int]
|
486
486
|
outputs: list[int]
|
487
487
|
op_key: TFLOperationName
|
488
|
+
|
489
|
+
# The function signature for `get_tensor_quant_params_fn`.
|
490
|
+
GetTensorQuantParamsFuncSignature = Callable[
|
491
|
+
[
|
492
|
+
OpInfo, # op_info
|
493
|
+
TensorQuantizationConfig, # tensor_quant_config
|
494
|
+
Optional[np.ndarray], # tensor_data
|
495
|
+
Optional[dict[str, Any]], # tensor qsv
|
496
|
+
],
|
497
|
+
UniformQuantParams,
|
498
|
+
]
|