ai-edge-quantizer-nightly 0.0.1.dev20250211__py3-none-any.whl → 0.0.1.dev20250212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,637 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Quantization helpers common to all uniform quantization algorithms.
17
+
18
+ This file contains quantization helpers common to all uniform quantization
19
+ algorithms. The materialize_op functions require algorithm-specific logic to
20
+ produce the quantization parameters (e.g. scale, zero point) for each tensor,
21
+ which is encapsulated in get_tensor_quant_params_fn. Each algorithm is required
22
+ to implement the get_tensor_quant_params_fn with the
23
+ qtyping.GetTensorQuantParamsFuncSignature signature.
24
+ """
25
+
26
+ from typing import Any
27
+ import numpy as np
28
+ from ai_edge_quantizer import qtyping
29
+ from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
30
+ from ai_edge_quantizer.algorithms.utils import common_utils
31
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
32
+
33
+ _TFLOpName = qtyping.TFLOperationName
34
+ _QuantTransformation = qtyping.QuantTransformation
35
+ _OpQuantConstraint = common_utils.OpQuantConstraint
36
+ _ComputePrecision = qtyping.ComputePrecision
37
+
38
+
39
+ def check_op_quantization_config(
40
+ op_name: _TFLOpName,
41
+ op_quant_config: qtyping.OpQuantizationConfig,
42
+ config_check_policy: qtyping.ConfigCheckPolicyDict,
43
+ ) -> None:
44
+ """Checks the op quantization config.
45
+
46
+ Args:
47
+ op_name: The name of the op.
48
+ op_quant_config: The quantization config for the op.
49
+ config_check_policy: The policy to check the op quantization config.
50
+
51
+ Raises:
52
+ ValueError: If the op quantization config is invalid.
53
+ """
54
+ if op_quant_config.weight_tensor_config is None:
55
+ raise ValueError(
56
+ "Weight tensor quantization is required for min/max uniform"
57
+ " quantization."
58
+ )
59
+ if op_quant_config.weight_tensor_config.dtype != qtyping.TensorDataType.INT:
60
+ raise ValueError(
61
+ "Weights need to have integer type for min/max uniform quantization. If"
62
+ " you wish to perform float casting quantization (e.g., fp16 weight"
63
+ " only), please set algorithm key as 'float_casting'."
64
+ )
65
+
66
+ if op_quant_config.min_weight_elements < 0:
67
+ raise ValueError(
68
+ f"min_weight_elements must be non-negative for op: {op_name} with"
69
+ f" config: {op_quant_config}."
70
+ )
71
+
72
+ if op_quant_config.compute_precision in [
73
+ _ComputePrecision.INTEGER,
74
+ _ComputePrecision.FLOAT,
75
+ ]:
76
+ # Use policy-based mechanism to validate op.
77
+ common_utils.check_if_valid_op_config(
78
+ op_name, op_quant_config, config_check_policy
79
+ )
80
+ common_utils.check_subchannel_config(op_name, op_quant_config)
81
+
82
+
83
+ def materialize_input(
84
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
85
+ op_info: qtyping.OpInfo,
86
+ graph_info: qtyping.GraphInfo,
87
+ tensor_name_to_qsv: dict[str, Any],
88
+ ) -> list[qtyping.TensorTransformationParams]:
89
+ """Materialize tensors in the virtual input op."""
90
+ return common_utils.materialize_standard_op(
91
+ op_info,
92
+ graph_info,
93
+ tensor_name_to_qsv,
94
+ get_tensor_quant_params_fn,
95
+ )
96
+
97
+
98
+ def materialize_output(
99
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
100
+ op_info: qtyping.OpInfo,
101
+ graph_info: qtyping.GraphInfo,
102
+ tensor_name_to_qsv: dict[str, Any],
103
+ ) -> list[qtyping.TensorTransformationParams]:
104
+ """Materialize tensors in the virtual output op."""
105
+ return common_utils.materialize_standard_op(
106
+ op_info,
107
+ graph_info,
108
+ tensor_name_to_qsv,
109
+ get_tensor_quant_params_fn,
110
+ )
111
+
112
+
113
+ def materialize_add(
114
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
115
+ op_info: qtyping.OpInfo,
116
+ graph_info: qtyping.GraphInfo,
117
+ tensor_name_to_qsv: dict[str, Any],
118
+ ) -> list[qtyping.TensorTransformationParams]:
119
+ """Materialize tensors in tfl.add."""
120
+ return common_utils.materialize_standard_op(
121
+ op_info,
122
+ graph_info,
123
+ tensor_name_to_qsv,
124
+ get_tensor_quant_params_fn,
125
+ )
126
+
127
+
128
+ def materialize_sub(
129
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
130
+ op_info: qtyping.OpInfo,
131
+ graph_info: qtyping.GraphInfo,
132
+ tensor_name_to_qsv: dict[str, Any],
133
+ ) -> list[qtyping.TensorTransformationParams]:
134
+ """Materialize tensors in tfl.sub."""
135
+ return common_utils.materialize_standard_op(
136
+ op_info,
137
+ graph_info,
138
+ tensor_name_to_qsv,
139
+ get_tensor_quant_params_fn,
140
+ )
141
+
142
+
143
+ def materialize_mul(
144
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
145
+ op_info: qtyping.OpInfo,
146
+ graph_info: qtyping.GraphInfo,
147
+ tensor_name_to_qsv: dict[str, Any],
148
+ ) -> list[qtyping.TensorTransformationParams]:
149
+ """Materialize tensors in tfl.mul."""
150
+ return common_utils.materialize_standard_op(
151
+ op_info,
152
+ graph_info,
153
+ tensor_name_to_qsv,
154
+ get_tensor_quant_params_fn,
155
+ )
156
+
157
+
158
+ def materialize_softmax_and_logistic(
159
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
160
+ op_info: qtyping.OpInfo,
161
+ graph_info: qtyping.GraphInfo,
162
+ tensor_name_to_qsv: dict[str, Any],
163
+ ) -> list[qtyping.TensorTransformationParams]:
164
+ """Materialize tensors in tfl.softmax and tfl.logistic."""
165
+ # Hard code scales and zp values as they are hard coded in TFL kernels.
166
+ # Softmax:
167
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L548
168
+ # Logistic:
169
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L421
170
+ output_activation_constraints = {
171
+ 8: qtyping.UniformQuantParams(
172
+ num_bits=8,
173
+ quantized_dimension=None,
174
+ scale=np.array(1.0 / 256),
175
+ zero_point=np.array(-128),
176
+ symmetric=False,
177
+ ),
178
+ 16: qtyping.UniformQuantParams(
179
+ num_bits=16,
180
+ quantized_dimension=None,
181
+ scale=np.array(1.0 / 32768),
182
+ zero_point=np.array(0),
183
+ ),
184
+ }
185
+
186
+ return common_utils.materialize_op_with_output_activation_constraint(
187
+ op_info,
188
+ graph_info,
189
+ tensor_name_to_qsv,
190
+ output_activation_constraints,
191
+ get_tensor_quant_params_fn,
192
+ )
193
+
194
+
195
+ def materialize_batch_matmul(
196
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
197
+ op_info: qtyping.OpInfo,
198
+ graph_info: qtyping.GraphInfo,
199
+ tensor_name_to_qsv: dict[str, Any],
200
+ ) -> list[qtyping.TensorTransformationParams]:
201
+ """Materialize tensors in tfl.batch_matmul."""
202
+ return common_utils.materialize_standard_op(
203
+ op_info,
204
+ graph_info,
205
+ tensor_name_to_qsv,
206
+ get_tensor_quant_params_fn,
207
+ )
208
+
209
+
210
+ def materialize_embedding_lookup(
211
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
212
+ op_info: qtyping.OpInfo,
213
+ graph_info: qtyping.GraphInfo,
214
+ tensor_name_to_qsv: dict[str, Any],
215
+ ) -> list[qtyping.TensorTransformationParams]:
216
+ """Materialize tensors in tfl.embedding_lookup."""
217
+ return common_utils.materialize_standard_op(
218
+ op_info,
219
+ graph_info,
220
+ tensor_name_to_qsv,
221
+ get_tensor_quant_params_fn,
222
+ inputs_to_ignore=[0], # Lookup index does not need to be quantized.
223
+ )
224
+
225
+
226
+ def materialize_reshape(
227
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
228
+ op_info: qtyping.OpInfo,
229
+ graph_info: qtyping.GraphInfo,
230
+ tensor_name_to_qsv: dict[str, Any],
231
+ ) -> list[qtyping.TensorTransformationParams]:
232
+ """Materialize tensors in tfl.reshape."""
233
+ return common_utils.materialize_standard_op(
234
+ op_info,
235
+ graph_info,
236
+ tensor_name_to_qsv,
237
+ get_tensor_quant_params_fn,
238
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
239
+ inputs_to_ignore=[1], # Shape tensor does not need to be quantized.
240
+ )
241
+
242
+
243
+ def materialize_average_pool_2d(
244
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
245
+ op_info: qtyping.OpInfo,
246
+ graph_info: qtyping.GraphInfo,
247
+ tensor_name_to_qsv: dict[str, Any],
248
+ ) -> list[qtyping.TensorTransformationParams]:
249
+ """Materialize tensors in tfl.average_pool_2d."""
250
+ return common_utils.materialize_standard_op(
251
+ op_info,
252
+ graph_info,
253
+ tensor_name_to_qsv,
254
+ get_tensor_quant_params_fn,
255
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
256
+ )
257
+
258
+
259
+ def _materialize_bias_for_conv_ops(
260
+ op_info: qtyping.OpInfo,
261
+ graph_info: qtyping.GraphInfo,
262
+ op_tensor_params: list[qtyping.TensorTransformationParams],
263
+ op_input_index: int = 0,
264
+ op_weight_index: int = 1,
265
+ op_bias_index: int = 2,
266
+ ):
267
+ """Materializes bias tensors in conv ops by updating `op_tensor_params`.
268
+
269
+ Args:
270
+ op_info: Aggregated information about the op (e.g., quantization config).
271
+ graph_info: Graph information needed to perform quantization for the op.
272
+ op_tensor_params: Partially populated quantization configuration for the
273
+ tensors associated with the op in the order of input, weight, output.
274
+ op_input_index: Index for the input tensor in the op.
275
+ op_weight_index: Index for the weight tensor in the op.
276
+ op_bias_index: Index for the bias tensor in the op.
277
+ """
278
+ _, _, bias_tensor, _ = tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
279
+ op_info.op,
280
+ graph_info.subgraph_tensors,
281
+ op_input_index,
282
+ op_weight_index,
283
+ op_bias_index,
284
+ )
285
+ if bias_tensor is not None:
286
+ bias_quant_params = None
287
+ # Fused bias needs to be quantized for SRQ.
288
+ # Check if SRQ.
289
+ if (
290
+ op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER
291
+ and op_info.op_quant_config.activation_tensor_config is not None
292
+ ):
293
+ bias_content = tfl_flatbuffer_utils.get_tensor_data(
294
+ bias_tensor,
295
+ graph_info.buffers,
296
+ )
297
+ bias_quant_params = (
298
+ uniform_quantize_tensor.symmetric_quantize_bias_tensor(
299
+ bias_content,
300
+ op_tensor_params[op_input_index].consumers[0].parameters,
301
+ op_tensor_params[op_weight_index].consumers[0].parameters,
302
+ )
303
+ )
304
+ # We only quantize bias under SRQ. Setting is_constant=True for SRQ only
305
+ # to avoid quantize bias for DRQ and weight-only cases.
306
+ is_constant = (
307
+ # Check if SRQ.
308
+ op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER
309
+ and op_info.op_quant_config.activation_tensor_config is not None
310
+ )
311
+ op_tensor_params[op_bias_index] = (
312
+ common_utils.get_tensor_transformation_params(
313
+ tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
314
+ op_info,
315
+ is_inbounding_tensor=True,
316
+ quant_params=bias_quant_params,
317
+ is_constant=is_constant,
318
+ )
319
+ )
320
+
321
+
322
+ def _are_weights_too_small(
323
+ op_info: qtyping.OpInfo,
324
+ graph_info: qtyping.GraphInfo,
325
+ weight_index: int,
326
+ ) -> bool:
327
+ """Checks if weights are too small to be quantized."""
328
+ tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]]
329
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
330
+ tensor,
331
+ graph_info.buffers,
332
+ )
333
+ return (
334
+ tensor_data is not None
335
+ and np.size(tensor_data) < op_info.op_quant_config.min_weight_elements
336
+ )
337
+
338
+
339
+ def materialize_slice(
340
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
341
+ op_info: qtyping.OpInfo,
342
+ graph_info: qtyping.GraphInfo,
343
+ tensor_name_to_qsv: dict[str, Any],
344
+ ) -> list[qtyping.TensorTransformationParams]:
345
+ """Materialize tensors in tfl.slice."""
346
+ return common_utils.materialize_standard_op(
347
+ op_info,
348
+ graph_info,
349
+ tensor_name_to_qsv,
350
+ get_tensor_quant_params_fn,
351
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
352
+ inputs_to_ignore=[
353
+ 1,
354
+ 2,
355
+ ], # Begin and size indices do not need to be quantized.
356
+ )
357
+
358
+
359
+ def materialize_select_v2(
360
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
361
+ op_info: qtyping.OpInfo,
362
+ graph_info: qtyping.GraphInfo,
363
+ tensor_name_to_qsv: dict[str, Any],
364
+ ) -> list[qtyping.TensorTransformationParams]:
365
+ """Materialize tensors in tfl.select_v2."""
366
+ return common_utils.materialize_standard_op(
367
+ op_info,
368
+ graph_info,
369
+ tensor_name_to_qsv,
370
+ get_tensor_quant_params_fn,
371
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
372
+ inputs_to_ignore=[
373
+ 0,
374
+ ], # Condition tensor does not need to be quantized.
375
+ )
376
+
377
+
378
+ def materialize_sum(
379
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
380
+ op_info: qtyping.OpInfo,
381
+ graph_info: qtyping.GraphInfo,
382
+ tensor_name_to_qsv: dict[str, Any],
383
+ ) -> list[qtyping.TensorTransformationParams]:
384
+ """Materialize tensors in tfl.sum."""
385
+ return common_utils.materialize_standard_op(
386
+ op_info,
387
+ graph_info,
388
+ tensor_name_to_qsv,
389
+ get_tensor_quant_params_fn,
390
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
391
+ inputs_to_ignore=[1], # Axis index does not need to be quantized.
392
+ )
393
+
394
+
395
+ def materialize_fc_conv(
396
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
397
+ op_info: qtyping.OpInfo,
398
+ graph_info: qtyping.GraphInfo,
399
+ tensor_name_to_qsv: dict[str, Any],
400
+ input_index: int = 0,
401
+ weight_index: int = 1,
402
+ bias_index: int = 2,
403
+ ) -> list[qtyping.TensorTransformationParams]:
404
+ """Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d.
405
+
406
+ Args:
407
+ get_tensor_quant_params_fn: A function to get the quantization parameters
408
+ for a tensor.
409
+ op_info: Aggregated information about the op (e.g., quantization config).
410
+ graph_info: Graph information needed to perform quantization for the op.
411
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
412
+ input_index: Index for the input tensor in the op.
413
+ weight_index: Index for the weight tensor in the op.
414
+ bias_index: Index for the bias tensor in the op.
415
+
416
+ Returns:
417
+ Quantization configuration for the tensors associated with the op (e.g.,
418
+ weights, bias).
419
+ """
420
+ ignored_inputs = [bias_index] # Bias tensor is quantized separately.
421
+ if _are_weights_too_small(op_info, graph_info, weight_index):
422
+ ignored_inputs.append(weight_index)
423
+
424
+ op_tensor_params = common_utils.materialize_standard_op(
425
+ op_info,
426
+ graph_info,
427
+ tensor_name_to_qsv,
428
+ get_tensor_quant_params_fn,
429
+ inputs_to_ignore=ignored_inputs,
430
+ )
431
+
432
+ _materialize_bias_for_conv_ops(
433
+ op_info,
434
+ graph_info,
435
+ op_tensor_params,
436
+ op_input_index=input_index,
437
+ op_weight_index=weight_index,
438
+ op_bias_index=bias_index,
439
+ )
440
+
441
+ return op_tensor_params
442
+
443
+
444
+ def materialize_conv2d_transpose(
445
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
446
+ op_info: qtyping.OpInfo,
447
+ graph_info: qtyping.GraphInfo,
448
+ tensor_name_to_qsv: dict[str, Any],
449
+ ) -> list[qtyping.TensorTransformationParams]:
450
+ """Materialize tensors in tfl.conv2d_transpose.
451
+
452
+ Args:
453
+ get_tensor_quant_params_fn: A function to get the quantization parameters
454
+ for a tensor.
455
+ op_info: Aggregated information about the op (e.g., quantization config).
456
+ graph_info: Graph information needed to perform quantization for the op.
457
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
458
+
459
+ Returns:
460
+ Quantization configuration for the tensors associated with the op (e.g.,
461
+ weights, bias).
462
+ """
463
+ ignored_shape_index = 0
464
+ weight_index = 1
465
+ input_index = 2
466
+ bias_index = 3
467
+
468
+ ignored_inputs = [
469
+ ignored_shape_index,
470
+ bias_index, # Bias tensor is quantized separately.
471
+ ]
472
+ if _are_weights_too_small(op_info, graph_info, weight_index):
473
+ ignored_inputs.append(weight_index)
474
+
475
+ op_tensor_params = common_utils.materialize_standard_op(
476
+ op_info,
477
+ graph_info,
478
+ tensor_name_to_qsv,
479
+ get_tensor_quant_params_fn,
480
+ inputs_to_ignore=ignored_inputs,
481
+ )
482
+ if len(op_tensor_params) < 2:
483
+ raise ValueError(
484
+ "Materialize standard op should return at least two tensors for"
485
+ " conv2d_transpose."
486
+ )
487
+ _materialize_bias_for_conv_ops(
488
+ op_info,
489
+ graph_info,
490
+ op_tensor_params,
491
+ op_input_index=input_index,
492
+ op_weight_index=weight_index,
493
+ op_bias_index=bias_index,
494
+ )
495
+
496
+ return op_tensor_params
497
+
498
+
499
+ def materialize_tanh(
500
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
501
+ op_info: qtyping.OpInfo,
502
+ graph_info: qtyping.GraphInfo,
503
+ tensor_name_to_qsv: dict[str, Any],
504
+ ) -> list[qtyping.TensorTransformationParams]:
505
+ """Materialize tensors in tfl.tanh."""
506
+ # Hard code scales and zero point values as they are hard coded in:
507
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L3430
508
+ output_activation_constraints = {}
509
+ for num_bits in [8, 16]:
510
+ output_activation_constraints[num_bits] = qtyping.UniformQuantParams(
511
+ num_bits=num_bits,
512
+ quantized_dimension=None,
513
+ scale=np.array(1.0 / (1 << (num_bits - 1))),
514
+ zero_point=np.array(0),
515
+ # Activation is always asymmetric for 8 bit and symmetric for 16 bits.
516
+ symmetric=num_bits == 16,
517
+ )
518
+ return common_utils.materialize_op_with_output_activation_constraint(
519
+ op_info,
520
+ graph_info,
521
+ tensor_name_to_qsv,
522
+ output_activation_constraints,
523
+ get_tensor_quant_params_fn,
524
+ )
525
+
526
+
527
+ def materialize_transpose(
528
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
529
+ op_info: qtyping.OpInfo,
530
+ graph_info: qtyping.GraphInfo,
531
+ tensor_name_to_qsv: dict[str, Any],
532
+ ) -> list[qtyping.TensorTransformationParams]:
533
+ """Materialize tensors in tfl.transpose."""
534
+ return common_utils.materialize_standard_op(
535
+ op_info,
536
+ graph_info,
537
+ tensor_name_to_qsv,
538
+ get_tensor_quant_params_fn,
539
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
540
+ inputs_to_ignore=[1], # Permutation tensor does not need to be quantized.
541
+ )
542
+
543
+
544
+ def materialize_gelu(
545
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
546
+ op_info: qtyping.OpInfo,
547
+ graph_info: qtyping.GraphInfo,
548
+ tensor_name_to_qsv: dict[str, Any],
549
+ ) -> list[qtyping.TensorTransformationParams]:
550
+ """Materialize tensors in tfl.gelu."""
551
+ return common_utils.materialize_standard_op(
552
+ op_info,
553
+ graph_info,
554
+ tensor_name_to_qsv,
555
+ get_tensor_quant_params_fn,
556
+ )
557
+
558
+
559
+ def materialize_strided_slice(
560
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
561
+ op_info: qtyping.OpInfo,
562
+ graph_info: qtyping.GraphInfo,
563
+ tensor_name_to_qsv: dict[str, Any],
564
+ ) -> list[qtyping.TensorTransformationParams]:
565
+ """Materialize tensors in tfl.strided_slice."""
566
+ return common_utils.materialize_standard_op(
567
+ op_info,
568
+ graph_info,
569
+ tensor_name_to_qsv,
570
+ get_tensor_quant_params_fn,
571
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
572
+ inputs_to_ignore=[1, 2, 3], # Ignore the begin, end, and strides tensors.
573
+ )
574
+
575
+
576
+ def materialize_mean(
577
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
578
+ op_info: qtyping.OpInfo,
579
+ graph_info: qtyping.GraphInfo,
580
+ tensor_name_to_qsv: dict[str, Any],
581
+ ) -> list[qtyping.TensorTransformationParams]:
582
+ """Materialize tensors in tfl.mean."""
583
+ return common_utils.materialize_standard_op(
584
+ op_info,
585
+ graph_info,
586
+ tensor_name_to_qsv,
587
+ get_tensor_quant_params_fn,
588
+ inputs_to_ignore=[1], # Axis tensor does not need to be quantized.
589
+ )
590
+
591
+
592
+ def materialize_rsqrt(
593
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
594
+ op_info: qtyping.OpInfo,
595
+ graph_info: qtyping.GraphInfo,
596
+ tensor_name_to_qsv: dict[str, Any],
597
+ ) -> list[qtyping.TensorTransformationParams]:
598
+ """Materialize tensors in tfl.rsqrt."""
599
+ return common_utils.materialize_standard_op(
600
+ op_info,
601
+ graph_info,
602
+ tensor_name_to_qsv,
603
+ get_tensor_quant_params_fn,
604
+ )
605
+
606
+
607
+ def materialize_concatenation(
608
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
609
+ op_info: qtyping.OpInfo,
610
+ graph_info: qtyping.GraphInfo,
611
+ tensor_name_to_qsv: dict[str, Any],
612
+ ) -> list[qtyping.TensorTransformationParams]:
613
+ """Materialize tensors in tfl.concatenation."""
614
+ return common_utils.materialize_standard_op(
615
+ op_info,
616
+ graph_info,
617
+ tensor_name_to_qsv,
618
+ get_tensor_quant_params_fn,
619
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
620
+ )
621
+
622
+
623
+ def materialize_split(
624
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
625
+ op_info: qtyping.OpInfo,
626
+ graph_info: qtyping.GraphInfo,
627
+ tensor_name_to_qsv: dict[str, Any],
628
+ ) -> list[qtyping.TensorTransformationParams]:
629
+ """Materialize tensors in tfl.split."""
630
+ return common_utils.materialize_standard_op(
631
+ op_info,
632
+ graph_info,
633
+ tensor_name_to_qsv,
634
+ get_tensor_quant_params_fn,
635
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
636
+ inputs_to_ignore=[0], # Split dimension does not need to be quantized.
637
+ )
@@ -0,0 +1,74 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+
18
+ from absl.testing import parameterized
19
+ import numpy as np
20
+
21
+ from tensorflow.python.platform import googletest
22
+ from ai_edge_quantizer import default_policy
23
+ from ai_edge_quantizer import qtyping
24
+ from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
25
+ from ai_edge_quantizer.utils import test_utils
26
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
27
+
28
+ _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../../tests/models")
29
+ _TFLOpName = qtyping.TFLOperationName
30
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
31
+
32
+
33
+ class CommonQuantizeTest(parameterized.TestCase):
34
+ """Tests for general quantize functions.
35
+ """
36
+
37
+ def setUp(self):
38
+ super().setUp()
39
+ np.random.seed(666)
40
+ self._test_model_path = os.path.join(
41
+ _TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
42
+ )
43
+ self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
44
+ # The test model has one subgraph for now.
45
+ self._graph_info = qtyping.GraphInfo(
46
+ subgraph_tensors=self._test_model.subgraphs[0].tensors,
47
+ buffers=self._test_model.buffers,
48
+ )
49
+ self._tensor_name_to_qsv = {}
50
+
51
+ def test_check_op_quantization_config_with_negative_min_weight_elements_raises_error(
52
+ self,
53
+ ):
54
+ op_quant_config = qtyping.OpQuantizationConfig(
55
+ weight_tensor_config=_TensorQuantConfig(
56
+ num_bits=8,
57
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
58
+ ),
59
+ compute_precision=qtyping.ComputePrecision.INTEGER, # DRQ.
60
+ min_weight_elements=-1,
61
+ )
62
+ with self.assertRaisesWithPredicateMatch(
63
+ ValueError,
64
+ lambda err: "min_weight_elements must be non-negative" in str(err),
65
+ ):
66
+ common_quantize.check_op_quantization_config(
67
+ _TFLOpName.FULLY_CONNECTED,
68
+ op_quant_config,
69
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
70
+ )
71
+
72
+
73
+ if __name__ == "__main__":
74
+ googletest.main()