ai-edge-quantizer-nightly 0.0.1.dev20250210__py3-none-any.whl → 0.0.1.dev20250212__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +40 -61
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +637 -0
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +74 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +139 -533
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +0 -22
- ai_edge_quantizer/algorithms/utils/{min_max_quantize_utils.py → common_utils.py} +61 -175
- ai_edge_quantizer/algorithms/utils/{min_max_quantize_utils_test.py → common_utils_test.py} +20 -19
- ai_edge_quantizer/qtyping.py +12 -1
- {ai_edge_quantizer_nightly-0.0.1.dev20250210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.0.1.dev20250210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/RECORD +13 -11
- {ai_edge_quantizer_nightly-0.0.1.dev20250210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20250212.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,637 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Quantization helpers common to all uniform quantization algorithms.
|
17
|
+
|
18
|
+
This file contains quantization helpers common to all uniform quantization
|
19
|
+
algorithms. The materialize_op functions require algorithm-specific logic to
|
20
|
+
produce the quantization parameters (e.g. scale, zero point) for each tensor,
|
21
|
+
which is encapsulated in get_tensor_quant_params_fn. Each algorithm is required
|
22
|
+
to implement the get_tensor_quant_params_fn with the
|
23
|
+
qtyping.GetTensorQuantParamsFuncSignature signature.
|
24
|
+
"""
|
25
|
+
|
26
|
+
from typing import Any
|
27
|
+
import numpy as np
|
28
|
+
from ai_edge_quantizer import qtyping
|
29
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
|
30
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
31
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
32
|
+
|
33
|
+
_TFLOpName = qtyping.TFLOperationName
|
34
|
+
_QuantTransformation = qtyping.QuantTransformation
|
35
|
+
_OpQuantConstraint = common_utils.OpQuantConstraint
|
36
|
+
_ComputePrecision = qtyping.ComputePrecision
|
37
|
+
|
38
|
+
|
39
|
+
def check_op_quantization_config(
|
40
|
+
op_name: _TFLOpName,
|
41
|
+
op_quant_config: qtyping.OpQuantizationConfig,
|
42
|
+
config_check_policy: qtyping.ConfigCheckPolicyDict,
|
43
|
+
) -> None:
|
44
|
+
"""Checks the op quantization config.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
op_name: The name of the op.
|
48
|
+
op_quant_config: The quantization config for the op.
|
49
|
+
config_check_policy: The policy to check the op quantization config.
|
50
|
+
|
51
|
+
Raises:
|
52
|
+
ValueError: If the op quantization config is invalid.
|
53
|
+
"""
|
54
|
+
if op_quant_config.weight_tensor_config is None:
|
55
|
+
raise ValueError(
|
56
|
+
"Weight tensor quantization is required for min/max uniform"
|
57
|
+
" quantization."
|
58
|
+
)
|
59
|
+
if op_quant_config.weight_tensor_config.dtype != qtyping.TensorDataType.INT:
|
60
|
+
raise ValueError(
|
61
|
+
"Weights need to have integer type for min/max uniform quantization. If"
|
62
|
+
" you wish to perform float casting quantization (e.g., fp16 weight"
|
63
|
+
" only), please set algorithm key as 'float_casting'."
|
64
|
+
)
|
65
|
+
|
66
|
+
if op_quant_config.min_weight_elements < 0:
|
67
|
+
raise ValueError(
|
68
|
+
f"min_weight_elements must be non-negative for op: {op_name} with"
|
69
|
+
f" config: {op_quant_config}."
|
70
|
+
)
|
71
|
+
|
72
|
+
if op_quant_config.compute_precision in [
|
73
|
+
_ComputePrecision.INTEGER,
|
74
|
+
_ComputePrecision.FLOAT,
|
75
|
+
]:
|
76
|
+
# Use policy-based mechanism to validate op.
|
77
|
+
common_utils.check_if_valid_op_config(
|
78
|
+
op_name, op_quant_config, config_check_policy
|
79
|
+
)
|
80
|
+
common_utils.check_subchannel_config(op_name, op_quant_config)
|
81
|
+
|
82
|
+
|
83
|
+
def materialize_input(
|
84
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
85
|
+
op_info: qtyping.OpInfo,
|
86
|
+
graph_info: qtyping.GraphInfo,
|
87
|
+
tensor_name_to_qsv: dict[str, Any],
|
88
|
+
) -> list[qtyping.TensorTransformationParams]:
|
89
|
+
"""Materialize tensors in the virtual input op."""
|
90
|
+
return common_utils.materialize_standard_op(
|
91
|
+
op_info,
|
92
|
+
graph_info,
|
93
|
+
tensor_name_to_qsv,
|
94
|
+
get_tensor_quant_params_fn,
|
95
|
+
)
|
96
|
+
|
97
|
+
|
98
|
+
def materialize_output(
|
99
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
100
|
+
op_info: qtyping.OpInfo,
|
101
|
+
graph_info: qtyping.GraphInfo,
|
102
|
+
tensor_name_to_qsv: dict[str, Any],
|
103
|
+
) -> list[qtyping.TensorTransformationParams]:
|
104
|
+
"""Materialize tensors in the virtual output op."""
|
105
|
+
return common_utils.materialize_standard_op(
|
106
|
+
op_info,
|
107
|
+
graph_info,
|
108
|
+
tensor_name_to_qsv,
|
109
|
+
get_tensor_quant_params_fn,
|
110
|
+
)
|
111
|
+
|
112
|
+
|
113
|
+
def materialize_add(
|
114
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
115
|
+
op_info: qtyping.OpInfo,
|
116
|
+
graph_info: qtyping.GraphInfo,
|
117
|
+
tensor_name_to_qsv: dict[str, Any],
|
118
|
+
) -> list[qtyping.TensorTransformationParams]:
|
119
|
+
"""Materialize tensors in tfl.add."""
|
120
|
+
return common_utils.materialize_standard_op(
|
121
|
+
op_info,
|
122
|
+
graph_info,
|
123
|
+
tensor_name_to_qsv,
|
124
|
+
get_tensor_quant_params_fn,
|
125
|
+
)
|
126
|
+
|
127
|
+
|
128
|
+
def materialize_sub(
|
129
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
130
|
+
op_info: qtyping.OpInfo,
|
131
|
+
graph_info: qtyping.GraphInfo,
|
132
|
+
tensor_name_to_qsv: dict[str, Any],
|
133
|
+
) -> list[qtyping.TensorTransformationParams]:
|
134
|
+
"""Materialize tensors in tfl.sub."""
|
135
|
+
return common_utils.materialize_standard_op(
|
136
|
+
op_info,
|
137
|
+
graph_info,
|
138
|
+
tensor_name_to_qsv,
|
139
|
+
get_tensor_quant_params_fn,
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
def materialize_mul(
|
144
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
145
|
+
op_info: qtyping.OpInfo,
|
146
|
+
graph_info: qtyping.GraphInfo,
|
147
|
+
tensor_name_to_qsv: dict[str, Any],
|
148
|
+
) -> list[qtyping.TensorTransformationParams]:
|
149
|
+
"""Materialize tensors in tfl.mul."""
|
150
|
+
return common_utils.materialize_standard_op(
|
151
|
+
op_info,
|
152
|
+
graph_info,
|
153
|
+
tensor_name_to_qsv,
|
154
|
+
get_tensor_quant_params_fn,
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
def materialize_softmax_and_logistic(
|
159
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
160
|
+
op_info: qtyping.OpInfo,
|
161
|
+
graph_info: qtyping.GraphInfo,
|
162
|
+
tensor_name_to_qsv: dict[str, Any],
|
163
|
+
) -> list[qtyping.TensorTransformationParams]:
|
164
|
+
"""Materialize tensors in tfl.softmax and tfl.logistic."""
|
165
|
+
# Hard code scales and zp values as they are hard coded in TFL kernels.
|
166
|
+
# Softmax:
|
167
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L548
|
168
|
+
# Logistic:
|
169
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L421
|
170
|
+
output_activation_constraints = {
|
171
|
+
8: qtyping.UniformQuantParams(
|
172
|
+
num_bits=8,
|
173
|
+
quantized_dimension=None,
|
174
|
+
scale=np.array(1.0 / 256),
|
175
|
+
zero_point=np.array(-128),
|
176
|
+
symmetric=False,
|
177
|
+
),
|
178
|
+
16: qtyping.UniformQuantParams(
|
179
|
+
num_bits=16,
|
180
|
+
quantized_dimension=None,
|
181
|
+
scale=np.array(1.0 / 32768),
|
182
|
+
zero_point=np.array(0),
|
183
|
+
),
|
184
|
+
}
|
185
|
+
|
186
|
+
return common_utils.materialize_op_with_output_activation_constraint(
|
187
|
+
op_info,
|
188
|
+
graph_info,
|
189
|
+
tensor_name_to_qsv,
|
190
|
+
output_activation_constraints,
|
191
|
+
get_tensor_quant_params_fn,
|
192
|
+
)
|
193
|
+
|
194
|
+
|
195
|
+
def materialize_batch_matmul(
|
196
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
197
|
+
op_info: qtyping.OpInfo,
|
198
|
+
graph_info: qtyping.GraphInfo,
|
199
|
+
tensor_name_to_qsv: dict[str, Any],
|
200
|
+
) -> list[qtyping.TensorTransformationParams]:
|
201
|
+
"""Materialize tensors in tfl.batch_matmul."""
|
202
|
+
return common_utils.materialize_standard_op(
|
203
|
+
op_info,
|
204
|
+
graph_info,
|
205
|
+
tensor_name_to_qsv,
|
206
|
+
get_tensor_quant_params_fn,
|
207
|
+
)
|
208
|
+
|
209
|
+
|
210
|
+
def materialize_embedding_lookup(
|
211
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
212
|
+
op_info: qtyping.OpInfo,
|
213
|
+
graph_info: qtyping.GraphInfo,
|
214
|
+
tensor_name_to_qsv: dict[str, Any],
|
215
|
+
) -> list[qtyping.TensorTransformationParams]:
|
216
|
+
"""Materialize tensors in tfl.embedding_lookup."""
|
217
|
+
return common_utils.materialize_standard_op(
|
218
|
+
op_info,
|
219
|
+
graph_info,
|
220
|
+
tensor_name_to_qsv,
|
221
|
+
get_tensor_quant_params_fn,
|
222
|
+
inputs_to_ignore=[0], # Lookup index does not need to be quantized.
|
223
|
+
)
|
224
|
+
|
225
|
+
|
226
|
+
def materialize_reshape(
|
227
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
228
|
+
op_info: qtyping.OpInfo,
|
229
|
+
graph_info: qtyping.GraphInfo,
|
230
|
+
tensor_name_to_qsv: dict[str, Any],
|
231
|
+
) -> list[qtyping.TensorTransformationParams]:
|
232
|
+
"""Materialize tensors in tfl.reshape."""
|
233
|
+
return common_utils.materialize_standard_op(
|
234
|
+
op_info,
|
235
|
+
graph_info,
|
236
|
+
tensor_name_to_qsv,
|
237
|
+
get_tensor_quant_params_fn,
|
238
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
239
|
+
inputs_to_ignore=[1], # Shape tensor does not need to be quantized.
|
240
|
+
)
|
241
|
+
|
242
|
+
|
243
|
+
def materialize_average_pool_2d(
|
244
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
245
|
+
op_info: qtyping.OpInfo,
|
246
|
+
graph_info: qtyping.GraphInfo,
|
247
|
+
tensor_name_to_qsv: dict[str, Any],
|
248
|
+
) -> list[qtyping.TensorTransformationParams]:
|
249
|
+
"""Materialize tensors in tfl.average_pool_2d."""
|
250
|
+
return common_utils.materialize_standard_op(
|
251
|
+
op_info,
|
252
|
+
graph_info,
|
253
|
+
tensor_name_to_qsv,
|
254
|
+
get_tensor_quant_params_fn,
|
255
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
256
|
+
)
|
257
|
+
|
258
|
+
|
259
|
+
def _materialize_bias_for_conv_ops(
|
260
|
+
op_info: qtyping.OpInfo,
|
261
|
+
graph_info: qtyping.GraphInfo,
|
262
|
+
op_tensor_params: list[qtyping.TensorTransformationParams],
|
263
|
+
op_input_index: int = 0,
|
264
|
+
op_weight_index: int = 1,
|
265
|
+
op_bias_index: int = 2,
|
266
|
+
):
|
267
|
+
"""Materializes bias tensors in conv ops by updating `op_tensor_params`.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
271
|
+
graph_info: Graph information needed to perform quantization for the op.
|
272
|
+
op_tensor_params: Partially populated quantization configuration for the
|
273
|
+
tensors associated with the op in the order of input, weight, output.
|
274
|
+
op_input_index: Index for the input tensor in the op.
|
275
|
+
op_weight_index: Index for the weight tensor in the op.
|
276
|
+
op_bias_index: Index for the bias tensor in the op.
|
277
|
+
"""
|
278
|
+
_, _, bias_tensor, _ = tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
|
279
|
+
op_info.op,
|
280
|
+
graph_info.subgraph_tensors,
|
281
|
+
op_input_index,
|
282
|
+
op_weight_index,
|
283
|
+
op_bias_index,
|
284
|
+
)
|
285
|
+
if bias_tensor is not None:
|
286
|
+
bias_quant_params = None
|
287
|
+
# Fused bias needs to be quantized for SRQ.
|
288
|
+
# Check if SRQ.
|
289
|
+
if (
|
290
|
+
op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER
|
291
|
+
and op_info.op_quant_config.activation_tensor_config is not None
|
292
|
+
):
|
293
|
+
bias_content = tfl_flatbuffer_utils.get_tensor_data(
|
294
|
+
bias_tensor,
|
295
|
+
graph_info.buffers,
|
296
|
+
)
|
297
|
+
bias_quant_params = (
|
298
|
+
uniform_quantize_tensor.symmetric_quantize_bias_tensor(
|
299
|
+
bias_content,
|
300
|
+
op_tensor_params[op_input_index].consumers[0].parameters,
|
301
|
+
op_tensor_params[op_weight_index].consumers[0].parameters,
|
302
|
+
)
|
303
|
+
)
|
304
|
+
# We only quantize bias under SRQ. Setting is_constant=True for SRQ only
|
305
|
+
# to avoid quantize bias for DRQ and weight-only cases.
|
306
|
+
is_constant = (
|
307
|
+
# Check if SRQ.
|
308
|
+
op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER
|
309
|
+
and op_info.op_quant_config.activation_tensor_config is not None
|
310
|
+
)
|
311
|
+
op_tensor_params[op_bias_index] = (
|
312
|
+
common_utils.get_tensor_transformation_params(
|
313
|
+
tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
|
314
|
+
op_info,
|
315
|
+
is_inbounding_tensor=True,
|
316
|
+
quant_params=bias_quant_params,
|
317
|
+
is_constant=is_constant,
|
318
|
+
)
|
319
|
+
)
|
320
|
+
|
321
|
+
|
322
|
+
def _are_weights_too_small(
|
323
|
+
op_info: qtyping.OpInfo,
|
324
|
+
graph_info: qtyping.GraphInfo,
|
325
|
+
weight_index: int,
|
326
|
+
) -> bool:
|
327
|
+
"""Checks if weights are too small to be quantized."""
|
328
|
+
tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]]
|
329
|
+
tensor_data = tfl_flatbuffer_utils.get_tensor_data(
|
330
|
+
tensor,
|
331
|
+
graph_info.buffers,
|
332
|
+
)
|
333
|
+
return (
|
334
|
+
tensor_data is not None
|
335
|
+
and np.size(tensor_data) < op_info.op_quant_config.min_weight_elements
|
336
|
+
)
|
337
|
+
|
338
|
+
|
339
|
+
def materialize_slice(
|
340
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
341
|
+
op_info: qtyping.OpInfo,
|
342
|
+
graph_info: qtyping.GraphInfo,
|
343
|
+
tensor_name_to_qsv: dict[str, Any],
|
344
|
+
) -> list[qtyping.TensorTransformationParams]:
|
345
|
+
"""Materialize tensors in tfl.slice."""
|
346
|
+
return common_utils.materialize_standard_op(
|
347
|
+
op_info,
|
348
|
+
graph_info,
|
349
|
+
tensor_name_to_qsv,
|
350
|
+
get_tensor_quant_params_fn,
|
351
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
352
|
+
inputs_to_ignore=[
|
353
|
+
1,
|
354
|
+
2,
|
355
|
+
], # Begin and size indices do not need to be quantized.
|
356
|
+
)
|
357
|
+
|
358
|
+
|
359
|
+
def materialize_select_v2(
|
360
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
361
|
+
op_info: qtyping.OpInfo,
|
362
|
+
graph_info: qtyping.GraphInfo,
|
363
|
+
tensor_name_to_qsv: dict[str, Any],
|
364
|
+
) -> list[qtyping.TensorTransformationParams]:
|
365
|
+
"""Materialize tensors in tfl.select_v2."""
|
366
|
+
return common_utils.materialize_standard_op(
|
367
|
+
op_info,
|
368
|
+
graph_info,
|
369
|
+
tensor_name_to_qsv,
|
370
|
+
get_tensor_quant_params_fn,
|
371
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
372
|
+
inputs_to_ignore=[
|
373
|
+
0,
|
374
|
+
], # Condition tensor does not need to be quantized.
|
375
|
+
)
|
376
|
+
|
377
|
+
|
378
|
+
def materialize_sum(
|
379
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
380
|
+
op_info: qtyping.OpInfo,
|
381
|
+
graph_info: qtyping.GraphInfo,
|
382
|
+
tensor_name_to_qsv: dict[str, Any],
|
383
|
+
) -> list[qtyping.TensorTransformationParams]:
|
384
|
+
"""Materialize tensors in tfl.sum."""
|
385
|
+
return common_utils.materialize_standard_op(
|
386
|
+
op_info,
|
387
|
+
graph_info,
|
388
|
+
tensor_name_to_qsv,
|
389
|
+
get_tensor_quant_params_fn,
|
390
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
391
|
+
inputs_to_ignore=[1], # Axis index does not need to be quantized.
|
392
|
+
)
|
393
|
+
|
394
|
+
|
395
|
+
def materialize_fc_conv(
|
396
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
397
|
+
op_info: qtyping.OpInfo,
|
398
|
+
graph_info: qtyping.GraphInfo,
|
399
|
+
tensor_name_to_qsv: dict[str, Any],
|
400
|
+
input_index: int = 0,
|
401
|
+
weight_index: int = 1,
|
402
|
+
bias_index: int = 2,
|
403
|
+
) -> list[qtyping.TensorTransformationParams]:
|
404
|
+
"""Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d.
|
405
|
+
|
406
|
+
Args:
|
407
|
+
get_tensor_quant_params_fn: A function to get the quantization parameters
|
408
|
+
for a tensor.
|
409
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
410
|
+
graph_info: Graph information needed to perform quantization for the op.
|
411
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
412
|
+
input_index: Index for the input tensor in the op.
|
413
|
+
weight_index: Index for the weight tensor in the op.
|
414
|
+
bias_index: Index for the bias tensor in the op.
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
Quantization configuration for the tensors associated with the op (e.g.,
|
418
|
+
weights, bias).
|
419
|
+
"""
|
420
|
+
ignored_inputs = [bias_index] # Bias tensor is quantized separately.
|
421
|
+
if _are_weights_too_small(op_info, graph_info, weight_index):
|
422
|
+
ignored_inputs.append(weight_index)
|
423
|
+
|
424
|
+
op_tensor_params = common_utils.materialize_standard_op(
|
425
|
+
op_info,
|
426
|
+
graph_info,
|
427
|
+
tensor_name_to_qsv,
|
428
|
+
get_tensor_quant_params_fn,
|
429
|
+
inputs_to_ignore=ignored_inputs,
|
430
|
+
)
|
431
|
+
|
432
|
+
_materialize_bias_for_conv_ops(
|
433
|
+
op_info,
|
434
|
+
graph_info,
|
435
|
+
op_tensor_params,
|
436
|
+
op_input_index=input_index,
|
437
|
+
op_weight_index=weight_index,
|
438
|
+
op_bias_index=bias_index,
|
439
|
+
)
|
440
|
+
|
441
|
+
return op_tensor_params
|
442
|
+
|
443
|
+
|
444
|
+
def materialize_conv2d_transpose(
|
445
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
446
|
+
op_info: qtyping.OpInfo,
|
447
|
+
graph_info: qtyping.GraphInfo,
|
448
|
+
tensor_name_to_qsv: dict[str, Any],
|
449
|
+
) -> list[qtyping.TensorTransformationParams]:
|
450
|
+
"""Materialize tensors in tfl.conv2d_transpose.
|
451
|
+
|
452
|
+
Args:
|
453
|
+
get_tensor_quant_params_fn: A function to get the quantization parameters
|
454
|
+
for a tensor.
|
455
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
456
|
+
graph_info: Graph information needed to perform quantization for the op.
|
457
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
458
|
+
|
459
|
+
Returns:
|
460
|
+
Quantization configuration for the tensors associated with the op (e.g.,
|
461
|
+
weights, bias).
|
462
|
+
"""
|
463
|
+
ignored_shape_index = 0
|
464
|
+
weight_index = 1
|
465
|
+
input_index = 2
|
466
|
+
bias_index = 3
|
467
|
+
|
468
|
+
ignored_inputs = [
|
469
|
+
ignored_shape_index,
|
470
|
+
bias_index, # Bias tensor is quantized separately.
|
471
|
+
]
|
472
|
+
if _are_weights_too_small(op_info, graph_info, weight_index):
|
473
|
+
ignored_inputs.append(weight_index)
|
474
|
+
|
475
|
+
op_tensor_params = common_utils.materialize_standard_op(
|
476
|
+
op_info,
|
477
|
+
graph_info,
|
478
|
+
tensor_name_to_qsv,
|
479
|
+
get_tensor_quant_params_fn,
|
480
|
+
inputs_to_ignore=ignored_inputs,
|
481
|
+
)
|
482
|
+
if len(op_tensor_params) < 2:
|
483
|
+
raise ValueError(
|
484
|
+
"Materialize standard op should return at least two tensors for"
|
485
|
+
" conv2d_transpose."
|
486
|
+
)
|
487
|
+
_materialize_bias_for_conv_ops(
|
488
|
+
op_info,
|
489
|
+
graph_info,
|
490
|
+
op_tensor_params,
|
491
|
+
op_input_index=input_index,
|
492
|
+
op_weight_index=weight_index,
|
493
|
+
op_bias_index=bias_index,
|
494
|
+
)
|
495
|
+
|
496
|
+
return op_tensor_params
|
497
|
+
|
498
|
+
|
499
|
+
def materialize_tanh(
|
500
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
501
|
+
op_info: qtyping.OpInfo,
|
502
|
+
graph_info: qtyping.GraphInfo,
|
503
|
+
tensor_name_to_qsv: dict[str, Any],
|
504
|
+
) -> list[qtyping.TensorTransformationParams]:
|
505
|
+
"""Materialize tensors in tfl.tanh."""
|
506
|
+
# Hard code scales and zero point values as they are hard coded in:
|
507
|
+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L3430
|
508
|
+
output_activation_constraints = {}
|
509
|
+
for num_bits in [8, 16]:
|
510
|
+
output_activation_constraints[num_bits] = qtyping.UniformQuantParams(
|
511
|
+
num_bits=num_bits,
|
512
|
+
quantized_dimension=None,
|
513
|
+
scale=np.array(1.0 / (1 << (num_bits - 1))),
|
514
|
+
zero_point=np.array(0),
|
515
|
+
# Activation is always asymmetric for 8 bit and symmetric for 16 bits.
|
516
|
+
symmetric=num_bits == 16,
|
517
|
+
)
|
518
|
+
return common_utils.materialize_op_with_output_activation_constraint(
|
519
|
+
op_info,
|
520
|
+
graph_info,
|
521
|
+
tensor_name_to_qsv,
|
522
|
+
output_activation_constraints,
|
523
|
+
get_tensor_quant_params_fn,
|
524
|
+
)
|
525
|
+
|
526
|
+
|
527
|
+
def materialize_transpose(
|
528
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
529
|
+
op_info: qtyping.OpInfo,
|
530
|
+
graph_info: qtyping.GraphInfo,
|
531
|
+
tensor_name_to_qsv: dict[str, Any],
|
532
|
+
) -> list[qtyping.TensorTransformationParams]:
|
533
|
+
"""Materialize tensors in tfl.transpose."""
|
534
|
+
return common_utils.materialize_standard_op(
|
535
|
+
op_info,
|
536
|
+
graph_info,
|
537
|
+
tensor_name_to_qsv,
|
538
|
+
get_tensor_quant_params_fn,
|
539
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
540
|
+
inputs_to_ignore=[1], # Permutation tensor does not need to be quantized.
|
541
|
+
)
|
542
|
+
|
543
|
+
|
544
|
+
def materialize_gelu(
|
545
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
546
|
+
op_info: qtyping.OpInfo,
|
547
|
+
graph_info: qtyping.GraphInfo,
|
548
|
+
tensor_name_to_qsv: dict[str, Any],
|
549
|
+
) -> list[qtyping.TensorTransformationParams]:
|
550
|
+
"""Materialize tensors in tfl.gelu."""
|
551
|
+
return common_utils.materialize_standard_op(
|
552
|
+
op_info,
|
553
|
+
graph_info,
|
554
|
+
tensor_name_to_qsv,
|
555
|
+
get_tensor_quant_params_fn,
|
556
|
+
)
|
557
|
+
|
558
|
+
|
559
|
+
def materialize_strided_slice(
|
560
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
561
|
+
op_info: qtyping.OpInfo,
|
562
|
+
graph_info: qtyping.GraphInfo,
|
563
|
+
tensor_name_to_qsv: dict[str, Any],
|
564
|
+
) -> list[qtyping.TensorTransformationParams]:
|
565
|
+
"""Materialize tensors in tfl.strided_slice."""
|
566
|
+
return common_utils.materialize_standard_op(
|
567
|
+
op_info,
|
568
|
+
graph_info,
|
569
|
+
tensor_name_to_qsv,
|
570
|
+
get_tensor_quant_params_fn,
|
571
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
572
|
+
inputs_to_ignore=[1, 2, 3], # Ignore the begin, end, and strides tensors.
|
573
|
+
)
|
574
|
+
|
575
|
+
|
576
|
+
def materialize_mean(
|
577
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
578
|
+
op_info: qtyping.OpInfo,
|
579
|
+
graph_info: qtyping.GraphInfo,
|
580
|
+
tensor_name_to_qsv: dict[str, Any],
|
581
|
+
) -> list[qtyping.TensorTransformationParams]:
|
582
|
+
"""Materialize tensors in tfl.mean."""
|
583
|
+
return common_utils.materialize_standard_op(
|
584
|
+
op_info,
|
585
|
+
graph_info,
|
586
|
+
tensor_name_to_qsv,
|
587
|
+
get_tensor_quant_params_fn,
|
588
|
+
inputs_to_ignore=[1], # Axis tensor does not need to be quantized.
|
589
|
+
)
|
590
|
+
|
591
|
+
|
592
|
+
def materialize_rsqrt(
|
593
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
594
|
+
op_info: qtyping.OpInfo,
|
595
|
+
graph_info: qtyping.GraphInfo,
|
596
|
+
tensor_name_to_qsv: dict[str, Any],
|
597
|
+
) -> list[qtyping.TensorTransformationParams]:
|
598
|
+
"""Materialize tensors in tfl.rsqrt."""
|
599
|
+
return common_utils.materialize_standard_op(
|
600
|
+
op_info,
|
601
|
+
graph_info,
|
602
|
+
tensor_name_to_qsv,
|
603
|
+
get_tensor_quant_params_fn,
|
604
|
+
)
|
605
|
+
|
606
|
+
|
607
|
+
def materialize_concatenation(
|
608
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
609
|
+
op_info: qtyping.OpInfo,
|
610
|
+
graph_info: qtyping.GraphInfo,
|
611
|
+
tensor_name_to_qsv: dict[str, Any],
|
612
|
+
) -> list[qtyping.TensorTransformationParams]:
|
613
|
+
"""Materialize tensors in tfl.concatenation."""
|
614
|
+
return common_utils.materialize_standard_op(
|
615
|
+
op_info,
|
616
|
+
graph_info,
|
617
|
+
tensor_name_to_qsv,
|
618
|
+
get_tensor_quant_params_fn,
|
619
|
+
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
|
620
|
+
)
|
621
|
+
|
622
|
+
|
623
|
+
def materialize_split(
|
624
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
625
|
+
op_info: qtyping.OpInfo,
|
626
|
+
graph_info: qtyping.GraphInfo,
|
627
|
+
tensor_name_to_qsv: dict[str, Any],
|
628
|
+
) -> list[qtyping.TensorTransformationParams]:
|
629
|
+
"""Materialize tensors in tfl.split."""
|
630
|
+
return common_utils.materialize_standard_op(
|
631
|
+
op_info,
|
632
|
+
graph_info,
|
633
|
+
tensor_name_to_qsv,
|
634
|
+
get_tensor_quant_params_fn,
|
635
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
636
|
+
inputs_to_ignore=[0], # Split dimension does not need to be quantized.
|
637
|
+
)
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import os
|
17
|
+
|
18
|
+
from absl.testing import parameterized
|
19
|
+
import numpy as np
|
20
|
+
|
21
|
+
from tensorflow.python.platform import googletest
|
22
|
+
from ai_edge_quantizer import default_policy
|
23
|
+
from ai_edge_quantizer import qtyping
|
24
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
|
25
|
+
from ai_edge_quantizer.utils import test_utils
|
26
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
27
|
+
|
28
|
+
_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../../tests/models")
|
29
|
+
_TFLOpName = qtyping.TFLOperationName
|
30
|
+
_TensorQuantConfig = qtyping.TensorQuantizationConfig
|
31
|
+
|
32
|
+
|
33
|
+
class CommonQuantizeTest(parameterized.TestCase):
|
34
|
+
"""Tests for general quantize functions.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def setUp(self):
|
38
|
+
super().setUp()
|
39
|
+
np.random.seed(666)
|
40
|
+
self._test_model_path = os.path.join(
|
41
|
+
_TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
|
42
|
+
)
|
43
|
+
self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
|
44
|
+
# The test model has one subgraph for now.
|
45
|
+
self._graph_info = qtyping.GraphInfo(
|
46
|
+
subgraph_tensors=self._test_model.subgraphs[0].tensors,
|
47
|
+
buffers=self._test_model.buffers,
|
48
|
+
)
|
49
|
+
self._tensor_name_to_qsv = {}
|
50
|
+
|
51
|
+
def test_check_op_quantization_config_with_negative_min_weight_elements_raises_error(
|
52
|
+
self,
|
53
|
+
):
|
54
|
+
op_quant_config = qtyping.OpQuantizationConfig(
|
55
|
+
weight_tensor_config=_TensorQuantConfig(
|
56
|
+
num_bits=8,
|
57
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
58
|
+
),
|
59
|
+
compute_precision=qtyping.ComputePrecision.INTEGER, # DRQ.
|
60
|
+
min_weight_elements=-1,
|
61
|
+
)
|
62
|
+
with self.assertRaisesWithPredicateMatch(
|
63
|
+
ValueError,
|
64
|
+
lambda err: "min_weight_elements must be non-negative" in str(err),
|
65
|
+
):
|
66
|
+
common_quantize.check_op_quantization_config(
|
67
|
+
_TFLOpName.FULLY_CONNECTED,
|
68
|
+
op_quant_config,
|
69
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
if __name__ == "__main__":
|
74
|
+
googletest.main()
|