ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,666 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Performs naive min/max uniform quantization."""
17
+
18
+ from typing import Any, Optional
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
22
+ from ai_edge_quantizer.algorithms.utils import min_max_quantize_utils as utils
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
+
25
+ ALGORITHM_KEY = "min_max_uniform_quantize"
26
+ _TFLOpName = qtyping.TFLOperationName
27
+ _QuantTransformation = qtyping.QuantTransformation
28
+ _OpQuantConstraint = utils.OpQuantConstraint
29
+ _ComputePrecision = qtyping.ComputePrecision
30
+
31
+
32
+ def check_op_quantization_config(
33
+ op_name: _TFLOpName,
34
+ op_quant_config: qtyping.OpQuantizationConfig,
35
+ config_check_policy: qtyping.ConfigCheckPolicyDict,
36
+ ) -> None:
37
+ """Checks the op quantization config.
38
+
39
+ Args:
40
+ op_name: The name of the op.
41
+ op_quant_config: The quantization config for the op.
42
+ config_check_policy: The policy to check the op quantization config.
43
+
44
+ Raises:
45
+ ValueError: If the op quantization config is invalid.
46
+ """
47
+ if op_quant_config.weight_tensor_config is None:
48
+ raise ValueError(
49
+ "Weight tensor quantization is required for min/max uniform"
50
+ " quantization."
51
+ )
52
+ if op_quant_config.weight_tensor_config.dtype != qtyping.TensorDataType.INT:
53
+ raise ValueError(
54
+ "Weights need to have integer type for min/max uniform quantization. If"
55
+ " you wish to perform float casting quantization (e.g., fp16 weight"
56
+ " only), please set algorithm key as 'float_casting'."
57
+ )
58
+
59
+ if op_quant_config.min_weight_elements < 0:
60
+ raise ValueError(
61
+ f"min_weight_elements must be non-negative for op: {op_name} with"
62
+ f" config: {op_quant_config}."
63
+ )
64
+
65
+ if op_quant_config.compute_precision in [
66
+ _ComputePrecision.INTEGER,
67
+ _ComputePrecision.FLOAT,
68
+ ]:
69
+ # Use policy-based mechanism to validate op.
70
+ utils.check_if_valid_op_config(
71
+ op_name, op_quant_config, config_check_policy
72
+ )
73
+ utils.check_subchannel_config(op_name, op_quant_config)
74
+
75
+
76
+ def materialize_input(
77
+ op_info: qtyping.OpInfo,
78
+ graph_info: qtyping.GraphInfo,
79
+ tensor_name_to_qsv: dict[str, Any],
80
+ ) -> list[qtyping.TensorTransformationParams]:
81
+ """Materialize tensors in the virtual input op."""
82
+ return utils.materialize_standard_op(
83
+ op_info,
84
+ graph_info,
85
+ tensor_name_to_qsv,
86
+ )
87
+
88
+
89
+ def materialize_output(
90
+ op_info: qtyping.OpInfo,
91
+ graph_info: qtyping.GraphInfo,
92
+ tensor_name_to_qsv: dict[str, Any],
93
+ ) -> list[qtyping.TensorTransformationParams]:
94
+ """Materialize tensors in the virtual output op."""
95
+ return utils.materialize_standard_op(
96
+ op_info,
97
+ graph_info,
98
+ tensor_name_to_qsv,
99
+ )
100
+
101
+
102
+ def materialize_add(
103
+ op_info: qtyping.OpInfo,
104
+ graph_info: qtyping.GraphInfo,
105
+ tensor_name_to_qsv: dict[str, Any],
106
+ ) -> list[qtyping.TensorTransformationParams]:
107
+ """Materialize tensors in tfl.add."""
108
+ return utils.materialize_standard_op(
109
+ op_info,
110
+ graph_info,
111
+ tensor_name_to_qsv,
112
+ )
113
+
114
+
115
+ def materialize_sub(
116
+ op_info: qtyping.OpInfo,
117
+ graph_info: qtyping.GraphInfo,
118
+ tensor_name_to_qsv: dict[str, Any],
119
+ ) -> list[qtyping.TensorTransformationParams]:
120
+ """Materialize tensors in tfl.sub."""
121
+ return utils.materialize_standard_op(
122
+ op_info,
123
+ graph_info,
124
+ tensor_name_to_qsv,
125
+ )
126
+
127
+
128
+ def materialize_mul(
129
+ op_info: qtyping.OpInfo,
130
+ graph_info: qtyping.GraphInfo,
131
+ tensor_name_to_qsv: dict[str, Any],
132
+ ) -> list[qtyping.TensorTransformationParams]:
133
+ """Materialize tensors in tfl.mul."""
134
+ return utils.materialize_standard_op(
135
+ op_info,
136
+ graph_info,
137
+ tensor_name_to_qsv,
138
+ )
139
+
140
+
141
+ def materialize_softmax_and_logistic(
142
+ op_info: qtyping.OpInfo,
143
+ graph_info: qtyping.GraphInfo,
144
+ tensor_name_to_qsv: dict[str, Any],
145
+ ) -> list[qtyping.TensorTransformationParams]:
146
+ """Materialize tensors in tfl.softmax and tfl.logistic."""
147
+ # Hard code scales and zp values as they are hard coded in TFL kernels.
148
+ # Softmax:
149
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L548
150
+ # Logistic:
151
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L421
152
+ output_activation_constraints = {
153
+ 8: qtyping.UniformQuantParams(
154
+ num_bits=8,
155
+ quantized_dimension=None,
156
+ scale=np.array(1.0 / 256),
157
+ zero_point=np.array(-128),
158
+ symmetric=False,
159
+ ),
160
+ 16: qtyping.UniformQuantParams(
161
+ num_bits=16,
162
+ quantized_dimension=None,
163
+ scale=np.array(1.0 / 32768),
164
+ zero_point=np.array(0),
165
+ ),
166
+ }
167
+
168
+ return utils.materialize_op_with_output_activation_constraint(
169
+ op_info,
170
+ graph_info,
171
+ tensor_name_to_qsv,
172
+ output_activation_constraints,
173
+ )
174
+
175
+
176
+ def materialize_batch_matmul(
177
+ op_info: qtyping.OpInfo,
178
+ graph_info: qtyping.GraphInfo,
179
+ tensor_name_to_qsv: dict[str, Any],
180
+ ) -> list[qtyping.TensorTransformationParams]:
181
+ """Materialize tensors in tfl.batch_matmul."""
182
+ return utils.materialize_standard_op(
183
+ op_info,
184
+ graph_info,
185
+ tensor_name_to_qsv,
186
+ )
187
+
188
+
189
+ def materialize_embedding_lookup(
190
+ op_info: qtyping.OpInfo,
191
+ graph_info: qtyping.GraphInfo,
192
+ tensor_name_to_qsv: dict[str, Any],
193
+ ) -> list[qtyping.TensorTransformationParams]:
194
+ """Materialize tensors in tfl.embedding_lookup."""
195
+ return utils.materialize_standard_op(
196
+ op_info,
197
+ graph_info,
198
+ tensor_name_to_qsv,
199
+ inputs_to_ignore=[0], # Lookup index does not need to be quantized.
200
+ )
201
+
202
+
203
+ def materialize_reshape(
204
+ op_info: qtyping.OpInfo,
205
+ graph_info: qtyping.GraphInfo,
206
+ tensor_name_to_qsv: dict[str, Any],
207
+ ) -> list[qtyping.TensorTransformationParams]:
208
+ """Materialize tensors in tfl.reshape."""
209
+ return utils.materialize_standard_op(
210
+ op_info,
211
+ graph_info,
212
+ tensor_name_to_qsv,
213
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
214
+ inputs_to_ignore=[1], # Shape tensor does not need to be quantized.
215
+ )
216
+
217
+
218
+ def materialize_average_pool_2d(
219
+ op_info: qtyping.OpInfo,
220
+ graph_info: qtyping.GraphInfo,
221
+ tensor_name_to_qsv: dict[str, Any],
222
+ ) -> list[qtyping.TensorTransformationParams]:
223
+ """Materialize tensors in tfl.average_pool_2d."""
224
+ return utils.materialize_standard_op(
225
+ op_info,
226
+ graph_info,
227
+ tensor_name_to_qsv,
228
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
229
+ )
230
+
231
+
232
+ def _materialize_bias_for_conv_ops(
233
+ op_info: qtyping.OpInfo,
234
+ graph_info: qtyping.GraphInfo,
235
+ op_tensor_params: list[qtyping.TensorTransformationParams],
236
+ op_input_index: int = 0,
237
+ op_weight_index: int = 1,
238
+ op_bias_index: int = 2,
239
+ ):
240
+ """Materializes bias tensors in conv ops by updating `op_tensor_params`.
241
+
242
+ Args:
243
+ op_info: Aggregated information about the op (e.g., quantization config).
244
+ graph_info: Graph information needed to perform quantization for the op.
245
+ op_tensor_params: Partially populated quantization configuration for the
246
+ tensors associated with the op in the order of input, weight, output.
247
+ op_input_index: Index for the input tensor in the op.
248
+ op_weight_index: Index for the weight tensor in the op.
249
+ op_bias_index: Index for the bias tensor in the op.
250
+ """
251
+ _, _, bias_tensor, _ = tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
252
+ op_info.op,
253
+ graph_info.subgraph_tensors,
254
+ op_input_index,
255
+ op_weight_index,
256
+ op_bias_index,
257
+ )
258
+ if bias_tensor is not None:
259
+ bias_quant_params = None
260
+ # Fused bias needs to be quantized for SRQ.
261
+ # Check if SRQ.
262
+ if (
263
+ op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER
264
+ and op_info.op_quant_config.activation_tensor_config is not None
265
+ ):
266
+ bias_content = tfl_flatbuffer_utils.get_tensor_data(
267
+ bias_tensor,
268
+ graph_info.buffers,
269
+ )
270
+ bias_quant_params = (
271
+ uniform_quantize_tensor.symmetric_quantize_bias_tensor(
272
+ bias_content,
273
+ op_tensor_params[op_input_index].consumers[0].parameters,
274
+ op_tensor_params[op_weight_index].consumers[0].parameters,
275
+ )
276
+ )
277
+ # We only quantize bias under SRQ. Setting is_constant=True for SRQ only
278
+ # to avoid quantize bias for DRQ and weight-only cases.
279
+ is_constant = (
280
+ # Check if SRQ.
281
+ op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER
282
+ and op_info.op_quant_config.activation_tensor_config is not None
283
+ )
284
+ op_tensor_params[op_bias_index] = utils.get_tensor_transformation_params(
285
+ tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
286
+ op_info,
287
+ is_inbounding_tensor=True,
288
+ quant_params=bias_quant_params,
289
+ is_constant=is_constant,
290
+ )
291
+
292
+
293
+ def _are_weights_too_small(
294
+ op_info: qtyping.OpInfo,
295
+ graph_info: qtyping.GraphInfo,
296
+ weight_index: int,
297
+ ) -> bool:
298
+ """Checks if weights are too small to be quantized."""
299
+ tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]]
300
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
301
+ tensor,
302
+ graph_info.buffers,
303
+ )
304
+ return (
305
+ tensor_data is not None
306
+ and np.size(tensor_data) < op_info.op_quant_config.min_weight_elements
307
+ )
308
+
309
+
310
+ def materialize_slice(
311
+ op_info: qtyping.OpInfo,
312
+ graph_info: qtyping.GraphInfo,
313
+ tensor_name_to_qsv: dict[str, Any],
314
+ ) -> list[qtyping.TensorTransformationParams]:
315
+ """Materialize tensors in tfl.slice."""
316
+ return utils.materialize_standard_op(
317
+ op_info,
318
+ graph_info,
319
+ tensor_name_to_qsv,
320
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
321
+ inputs_to_ignore=[
322
+ 1,
323
+ 2,
324
+ ], # Begin and size indices do not need to be quantized.
325
+ )
326
+
327
+
328
+ def materialize_select_v2(
329
+ op_info: qtyping.OpInfo,
330
+ graph_info: qtyping.GraphInfo,
331
+ tensor_name_to_qsv: dict[str, Any],
332
+ ) -> list[qtyping.TensorTransformationParams]:
333
+ """Materialize tensors in tfl.select_v2."""
334
+ return utils.materialize_standard_op(
335
+ op_info,
336
+ graph_info,
337
+ tensor_name_to_qsv,
338
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
339
+ inputs_to_ignore=[
340
+ 0,
341
+ ], # Condition tensor does not need to be quantized.
342
+ )
343
+
344
+
345
+ def materialize_sum(
346
+ op_info: qtyping.OpInfo,
347
+ graph_info: qtyping.GraphInfo,
348
+ tensor_name_to_qsv: dict[str, Any],
349
+ ) -> list[qtyping.TensorTransformationParams]:
350
+ """Materialize tensors in tfl.sum."""
351
+ return utils.materialize_standard_op(
352
+ op_info,
353
+ graph_info,
354
+ tensor_name_to_qsv,
355
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
356
+ inputs_to_ignore=[1], # Axis index does not need to be quantized.
357
+ )
358
+
359
+
360
+ def materialize_fc_conv(
361
+ op_info: qtyping.OpInfo,
362
+ graph_info: qtyping.GraphInfo,
363
+ tensor_name_to_qsv: dict[str, Any],
364
+ input_index: int = 0,
365
+ weight_index: int = 1,
366
+ bias_index: int = 2,
367
+ ) -> list[qtyping.TensorTransformationParams]:
368
+ """Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d.
369
+
370
+ Args:
371
+ op_info: Aggregated information about the op (e.g., quantization config).
372
+ graph_info: Graph information needed to perform quantization for the op.
373
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
374
+ input_index: Index for the input tensor in the op.
375
+ weight_index: Index for the weight tensor in the op.
376
+ bias_index: Index for the bias tensor in the op.
377
+
378
+ Returns:
379
+ Quantization configuration for the tensors associated with the op (e.g.,
380
+ weights, bias).
381
+ """
382
+ ignored_inputs = [bias_index] # Bias tensor is quantized separately.
383
+ if _are_weights_too_small(op_info, graph_info, weight_index):
384
+ ignored_inputs.append(weight_index)
385
+
386
+ op_tensor_params = utils.materialize_standard_op(
387
+ op_info,
388
+ graph_info,
389
+ tensor_name_to_qsv,
390
+ inputs_to_ignore=ignored_inputs,
391
+ )
392
+
393
+ _materialize_bias_for_conv_ops(
394
+ op_info,
395
+ graph_info,
396
+ op_tensor_params,
397
+ op_input_index=input_index,
398
+ op_weight_index=weight_index,
399
+ op_bias_index=bias_index,
400
+ )
401
+
402
+ return op_tensor_params
403
+
404
+
405
+ def materialize_conv2d_transpose(
406
+ op_info: qtyping.OpInfo,
407
+ graph_info: qtyping.GraphInfo,
408
+ tensor_name_to_qsv: dict[str, Any],
409
+ ) -> list[qtyping.TensorTransformationParams]:
410
+ """Materialize tensors in tfl.conv2d_transpose.
411
+
412
+ Args:
413
+ op_info: Aggregated information about the op (e.g., quantization config).
414
+ graph_info: Graph information needed to perform quantization for the op.
415
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
416
+
417
+ Returns:
418
+ Quantization configuration for the tensors associated with the op (e.g.,
419
+ weights, bias).
420
+ """
421
+ ignored_shape_index = 0
422
+ weight_index = 1
423
+ input_index = 2
424
+ bias_index = 3
425
+
426
+ ignored_inputs = [
427
+ ignored_shape_index,
428
+ bias_index, # Bias tensor is quantized separately.
429
+ ]
430
+ if _are_weights_too_small(op_info, graph_info, weight_index):
431
+ ignored_inputs.append(weight_index)
432
+
433
+ op_tensor_params = utils.materialize_standard_op(
434
+ op_info,
435
+ graph_info,
436
+ tensor_name_to_qsv,
437
+ inputs_to_ignore=ignored_inputs,
438
+ )
439
+ if len(op_tensor_params) < 2:
440
+ raise ValueError(
441
+ "Materialize standard op should return at least two tensors for"
442
+ " conv2d_transpose."
443
+ )
444
+ _materialize_bias_for_conv_ops(
445
+ op_info,
446
+ graph_info,
447
+ op_tensor_params,
448
+ op_input_index=input_index,
449
+ op_weight_index=weight_index,
450
+ op_bias_index=bias_index,
451
+ )
452
+
453
+ return op_tensor_params
454
+
455
+
456
+ def materialize_tanh(
457
+ op_info: qtyping.OpInfo,
458
+ graph_info: qtyping.GraphInfo,
459
+ tensor_name_to_qsv: dict[str, Any],
460
+ ) -> list[qtyping.TensorTransformationParams]:
461
+ """Materialize tensors in tfl.tanh."""
462
+ # Hard code scales and zero point values as they are hard coded in:
463
+ # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L3430
464
+ output_activation_constraints = {}
465
+ for num_bits in [8, 16]:
466
+ output_activation_constraints[num_bits] = qtyping.UniformQuantParams(
467
+ num_bits=num_bits,
468
+ quantized_dimension=None,
469
+ scale=np.array(1.0 / (1 << (num_bits - 1))),
470
+ zero_point=np.array(0),
471
+ # Activation is always asymmetric for 8 bit and symmetric for 16 bits.
472
+ symmetric=num_bits == 16,
473
+ )
474
+ return utils.materialize_op_with_output_activation_constraint(
475
+ op_info, graph_info, tensor_name_to_qsv, output_activation_constraints
476
+ )
477
+
478
+
479
+ def materialize_transpose(
480
+ op_info: qtyping.OpInfo,
481
+ graph_info: qtyping.GraphInfo,
482
+ tensor_name_to_qsv: dict[str, Any],
483
+ ) -> list[qtyping.TensorTransformationParams]:
484
+ """Materialize tensors in tfl.transpose."""
485
+ return utils.materialize_standard_op(
486
+ op_info,
487
+ graph_info,
488
+ tensor_name_to_qsv,
489
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
490
+ inputs_to_ignore=[1], # Permutation tensor does not need to be quantized.
491
+ )
492
+
493
+
494
+ def materialize_gelu(
495
+ op_info: qtyping.OpInfo,
496
+ graph_info: qtyping.GraphInfo,
497
+ tensor_name_to_qsv: dict[str, Any],
498
+ ) -> list[qtyping.TensorTransformationParams]:
499
+ """Materialize tensors in tfl.gelu."""
500
+ return utils.materialize_standard_op(
501
+ op_info,
502
+ graph_info,
503
+ tensor_name_to_qsv,
504
+ )
505
+
506
+
507
+ def materialize_strided_slice(
508
+ op_info: qtyping.OpInfo,
509
+ graph_info: qtyping.GraphInfo,
510
+ tensor_name_to_qsv: dict[str, Any],
511
+ ) -> list[qtyping.TensorTransformationParams]:
512
+ """Materialize tensors in tfl.strided_slice."""
513
+ return utils.materialize_standard_op(
514
+ op_info,
515
+ graph_info,
516
+ tensor_name_to_qsv,
517
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
518
+ inputs_to_ignore=[1, 2, 3], # Ignore the begin, end, and strides tensors.
519
+ )
520
+
521
+
522
+ def materialize_mean(
523
+ op_info: qtyping.OpInfo,
524
+ graph_info: qtyping.GraphInfo,
525
+ tensor_name_to_qsv: dict[str, Any],
526
+ ) -> list[qtyping.TensorTransformationParams]:
527
+ """Materialize tensors in tfl.mean."""
528
+ return utils.materialize_standard_op(
529
+ op_info,
530
+ graph_info,
531
+ tensor_name_to_qsv,
532
+ inputs_to_ignore=[1], # Axis tensor does not need to be quantized.
533
+ )
534
+
535
+
536
+ def materialize_rsqrt(
537
+ op_info: qtyping.OpInfo,
538
+ graph_info: qtyping.GraphInfo,
539
+ tensor_name_to_qsv: dict[str, Any],
540
+ ) -> list[qtyping.TensorTransformationParams]:
541
+ """Materialize tensors in tfl.rsqrt."""
542
+ return utils.materialize_standard_op(
543
+ op_info,
544
+ graph_info,
545
+ tensor_name_to_qsv,
546
+ )
547
+
548
+
549
+ def materialize_concatenation(
550
+ op_info: qtyping.OpInfo,
551
+ graph_info: qtyping.GraphInfo,
552
+ tensor_name_to_qsv: dict[str, Any],
553
+ ) -> list[qtyping.TensorTransformationParams]:
554
+ """Materialize tensors in tfl.concatenation."""
555
+ return utils.materialize_standard_op(
556
+ op_info,
557
+ graph_info,
558
+ tensor_name_to_qsv,
559
+ constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
560
+ )
561
+
562
+
563
+ def materialize_split(
564
+ op_info: qtyping.OpInfo,
565
+ graph_info: qtyping.GraphInfo,
566
+ tensor_name_to_qsv: dict[str, Any],
567
+ ) -> list[qtyping.TensorTransformationParams]:
568
+ """Materialize tensors in tfl.split."""
569
+ return utils.materialize_standard_op(
570
+ op_info,
571
+ graph_info,
572
+ tensor_name_to_qsv,
573
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
574
+ inputs_to_ignore=[0], # Split dimension does not need to be quantized.
575
+ )
576
+
577
+
578
+ # TODO: b/333731147 - Use named tuple to store min/max.
579
+ def init_qsvs(
580
+ op_info: qtyping.OpInfo,
581
+ graph_info: qtyping.GraphInfo,
582
+ inputs_to_ignore: Optional[list[int]] = None,
583
+ outputs_to_ignore: Optional[list[int]] = None,
584
+ ) -> qtyping.QSV:
585
+ """Initialize the QSVs.
586
+
587
+ Args:
588
+ op_info: Aggregated information about the op (e.g., quantization config).
589
+ graph_info: Graph information needed to perform quantization for the op.
590
+ inputs_to_ignore: Input tensor indices to ignore.
591
+ outputs_to_ignore: Output tensor indices to ignore.
592
+
593
+ Returns:
594
+ QSVs.
595
+ """
596
+ op_qsvs = {}
597
+
598
+ inputs_to_ignore = inputs_to_ignore or []
599
+ outputs_to_ignore = outputs_to_ignore or []
600
+ for i, tensor_idx in enumerate(op_info.op.inputs):
601
+ if tensor_idx != -1 and i not in inputs_to_ignore:
602
+ tensor = graph_info.subgraph_tensors[tensor_idx]
603
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
604
+ op_qsvs[tensor_name] = utils.init_tensor_min_max(
605
+ tensor,
606
+ graph_info,
607
+ op_info,
608
+ )
609
+ for i, tensor_idx in enumerate(op_info.op.outputs):
610
+ if tensor_idx != -1 and i not in outputs_to_ignore:
611
+ tensor = graph_info.subgraph_tensors[tensor_idx]
612
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
613
+ op_qsvs[tensor_name] = utils.init_tensor_min_max(
614
+ tensor,
615
+ graph_info,
616
+ op_info,
617
+ )
618
+ return op_qsvs
619
+
620
+
621
+ def min_max_calibrate(
622
+ tfl_op: Any,
623
+ graph_info: qtyping.GraphInfo,
624
+ tensor_content_map: dict[str, np.ndarray],
625
+ inputs_to_ignore: Optional[list[int]] = None,
626
+ outputs_to_ignore: Optional[list[int]] = None,
627
+ ) -> dict[str, qtyping.QSV]:
628
+ """Collect quantization statistics variable (QSV, e.g., min/max) for the op.
629
+
630
+ Args:
631
+ tfl_op: The tfl operation.
632
+ graph_info: Graph information needed to perform quantization for the op.
633
+ tensor_content_map: A map of tensor name to tensor content.
634
+ inputs_to_ignore: Input tensor indices to ignore.
635
+ outputs_to_ignore: Output tensor indices to ignore.
636
+
637
+ Returns:
638
+ A dictionary with key as tensor name and value as the collected QSV.
639
+ """
640
+ op_qsvs = {}
641
+
642
+ def _collect_activation_tensor_min_max(tensor_idx):
643
+ tensor = graph_info.subgraph_tensors[tensor_idx]
644
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
645
+ tensor, graph_info.buffers
646
+ )
647
+ # Skip constant tensors.
648
+ if tensor_data is not None:
649
+ return
650
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
651
+ tensor_content = tensor_content_map[tensor_name]
652
+ op_qsvs[tensor_name] = {
653
+ "min": np.min(tensor_content, axis=None, keepdims=True),
654
+ "max": np.max(tensor_content, axis=None, keepdims=True),
655
+ }
656
+
657
+ inputs_to_ignore = inputs_to_ignore or []
658
+ outputs_to_ignore = outputs_to_ignore or []
659
+ for i, tensor_idx in enumerate(tfl_op.inputs):
660
+ if tensor_idx != -1 and i not in inputs_to_ignore:
661
+ _collect_activation_tensor_min_max(tensor_idx)
662
+ for i, tensor_idx in enumerate(tfl_op.outputs):
663
+ if tensor_idx != -1 and i not in outputs_to_ignore:
664
+ _collect_activation_tensor_min_max(tensor_idx)
665
+
666
+ return op_qsvs