ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,273 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Performs float casting quantization."""
|
17
|
+
|
18
|
+
from typing import Any, Optional
|
19
|
+
import numpy as np
|
20
|
+
from ai_edge_quantizer import qtyping
|
21
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
22
|
+
|
23
|
+
ALGORITHM_KEY = "float_casting"
|
24
|
+
_TFLOpName = qtyping.TFLOperationName
|
25
|
+
_QuantTransformation = qtyping.QuantTransformation
|
26
|
+
|
27
|
+
# Ops that support weight quantization config (e.g., support Weight-only).
|
28
|
+
SUPPORTED_WEIGHT_QUANT_OPS = frozenset([
|
29
|
+
_TFLOpName.FULLY_CONNECTED,
|
30
|
+
_TFLOpName.CONV_2D,
|
31
|
+
_TFLOpName.DEPTHWISE_CONV_2D,
|
32
|
+
_TFLOpName.CONV_2D_TRANSPOSE,
|
33
|
+
_TFLOpName.EMBEDDING_LOOKUP,
|
34
|
+
])
|
35
|
+
|
36
|
+
|
37
|
+
def check_op_quantization_config(
|
38
|
+
op_name: _TFLOpName,
|
39
|
+
op_quant_config: qtyping.OpQuantizationConfig,
|
40
|
+
config_check_policy: Optional[qtyping.ConfigCheckPolicyDict] = None,
|
41
|
+
) -> None:
|
42
|
+
"""Checks if the op is valid for float casting quantization.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
op_name: The name of the op.
|
46
|
+
op_quant_config: The quantization config for the op.
|
47
|
+
config_check_policy: The policy to check the quantization config.
|
48
|
+
|
49
|
+
Raises:
|
50
|
+
ValueError: If the op is not supported or the compute_precision is not
|
51
|
+
FLOAT.
|
52
|
+
"""
|
53
|
+
# TODO: b/353780772 - Add config check policy for float casting quantization.
|
54
|
+
if config_check_policy is not None and config_check_policy:
|
55
|
+
raise ValueError(f"Config check isn't implemented yet for op: {op_name}.")
|
56
|
+
|
57
|
+
# Check if WEIGHT_ONLY.
|
58
|
+
if op_quant_config.compute_precision != qtyping.ComputePrecision.FLOAT:
|
59
|
+
raise ValueError(
|
60
|
+
"Currently, only Weight-Only is supported for float casting"
|
61
|
+
" quantization. Got unsupported execution mode:"
|
62
|
+
f" {op_quant_config.compute_precision} for op: {op_name}"
|
63
|
+
)
|
64
|
+
if op_quant_config.activation_tensor_config is not None:
|
65
|
+
raise ValueError(
|
66
|
+
"Activation tensor quantization is not supported for float casting"
|
67
|
+
" quantization."
|
68
|
+
)
|
69
|
+
if op_name not in SUPPORTED_WEIGHT_QUANT_OPS:
|
70
|
+
raise ValueError(
|
71
|
+
f"Unsupported op: {op_name} for float casting quantization."
|
72
|
+
)
|
73
|
+
if op_quant_config.weight_tensor_config is None:
|
74
|
+
raise ValueError(
|
75
|
+
"Weight tensor quantization config is required for float casting"
|
76
|
+
" quantization."
|
77
|
+
)
|
78
|
+
if (
|
79
|
+
op_quant_config.weight_tensor_config.num_bits != 16
|
80
|
+
or op_quant_config.weight_tensor_config.dtype
|
81
|
+
!= qtyping.TensorDataType.FLOAT
|
82
|
+
):
|
83
|
+
raise ValueError(
|
84
|
+
"Currently, float casting quantization config requires number of bits"
|
85
|
+
" to be set as 16, dtype as float, got"
|
86
|
+
f" {op_quant_config.weight_tensor_config.num_bits} and"
|
87
|
+
f" {op_quant_config.weight_tensor_config.dtype} ."
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
def materialize_fc_conv(
|
92
|
+
op_info: qtyping.OpInfo,
|
93
|
+
graph_info: qtyping.GraphInfo,
|
94
|
+
_: dict[str, Any],
|
95
|
+
) -> list[qtyping.TensorTransformationParams]:
|
96
|
+
"""Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d ops.
|
97
|
+
|
98
|
+
This function is called by the quantization pipeline to materialize
|
99
|
+
quantization parameters for the weight tensor of the op.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
103
|
+
graph_info: Graph information needed to perform quantization for the op.
|
104
|
+
_: A map of tensor name to quantization parameters (unused).
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Quantization configuration for the weight tensor of the op.
|
108
|
+
|
109
|
+
Raises:
|
110
|
+
ValueError: If the op is not supported or the compute precision is not
|
111
|
+
FLOAT.
|
112
|
+
"""
|
113
|
+
input_tensor, weight_tensor, bias_tensor, output_tensor = (
|
114
|
+
tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
|
115
|
+
op_info.op, graph_info.subgraph_tensors
|
116
|
+
)
|
117
|
+
)
|
118
|
+
op_tensor_params = []
|
119
|
+
# Input tensor.
|
120
|
+
input_quant_params = _config_no_quantize_tensor(
|
121
|
+
op_info, input_tensor, is_inbounding_tensor=True
|
122
|
+
)
|
123
|
+
op_tensor_params.append(input_quant_params)
|
124
|
+
# Weight tensor.
|
125
|
+
weight_content = tfl_flatbuffer_utils.get_tensor_data(
|
126
|
+
weight_tensor,
|
127
|
+
graph_info.buffers,
|
128
|
+
)
|
129
|
+
quant_params = qtyping.NonLinearQuantParams(
|
130
|
+
num_bits=16, quantized_data=weight_content.astype(np.float16) # pytype: disable=attribute-error
|
131
|
+
)
|
132
|
+
op2weight_params = qtyping.OpToTensorParams(
|
133
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
134
|
+
parameters=quant_params,
|
135
|
+
transformations=[_QuantTransformation.ADD_DEQUANTIZE],
|
136
|
+
)
|
137
|
+
op_tensor_params.append(
|
138
|
+
qtyping.TensorTransformationParams(
|
139
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
|
140
|
+
consumers=[op2weight_params],
|
141
|
+
)
|
142
|
+
)
|
143
|
+
# Output tensor.
|
144
|
+
output_quant_params = _config_no_quantize_tensor(
|
145
|
+
op_info, output_tensor, is_inbounding_tensor=False
|
146
|
+
)
|
147
|
+
op_tensor_params.append(output_quant_params)
|
148
|
+
# Bias tensor.
|
149
|
+
if bias_tensor is not None:
|
150
|
+
bias_quant_params = _config_no_quantize_tensor(
|
151
|
+
op_info, bias_tensor, is_inbounding_tensor=True
|
152
|
+
)
|
153
|
+
op_tensor_params.append(bias_quant_params)
|
154
|
+
return op_tensor_params
|
155
|
+
|
156
|
+
|
157
|
+
def materialize_embedding_lookup(
|
158
|
+
op_info: qtyping.OpInfo,
|
159
|
+
graph_info: qtyping.GraphInfo,
|
160
|
+
_: dict[str, Any],
|
161
|
+
) -> list[qtyping.TensorTransformationParams]:
|
162
|
+
return materialize_fc_conv(op_info, graph_info, _)
|
163
|
+
|
164
|
+
|
165
|
+
def materialize_conv2d_transpose(
|
166
|
+
op_info: qtyping.OpInfo,
|
167
|
+
graph_info: qtyping.GraphInfo,
|
168
|
+
_: dict[str, Any],
|
169
|
+
) -> list[qtyping.TensorTransformationParams]:
|
170
|
+
"""Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d ops.
|
171
|
+
|
172
|
+
This function is called by the quantization pipeline to materialize
|
173
|
+
quantization parameters for the weight tensor of the op.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
177
|
+
graph_info: Graph information needed to perform quantization for the op.
|
178
|
+
_: A map of tensor name to quantization parameters (unused).
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
Quantization configuration for the weight tensor of the op.
|
182
|
+
|
183
|
+
Raises:
|
184
|
+
ValueError: If the op is not supported or the execution mode is not
|
185
|
+
WEIGHT_ONLY.
|
186
|
+
"""
|
187
|
+
input_tensor, weight_tensor, bias_tensor, output_tensor = (
|
188
|
+
tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
|
189
|
+
op_info.op,
|
190
|
+
graph_info.subgraph_tensors,
|
191
|
+
input_index=2,
|
192
|
+
weight_index=1,
|
193
|
+
bias_index=3,
|
194
|
+
output_index=0,
|
195
|
+
)
|
196
|
+
)
|
197
|
+
op_tensor_params = []
|
198
|
+
# Input tensor.
|
199
|
+
input_quant_params = _config_no_quantize_tensor(
|
200
|
+
op_info, input_tensor, is_inbounding_tensor=True
|
201
|
+
)
|
202
|
+
op_tensor_params.append(input_quant_params)
|
203
|
+
# Weight tensor.
|
204
|
+
weight_content = tfl_flatbuffer_utils.get_tensor_data(
|
205
|
+
weight_tensor,
|
206
|
+
graph_info.buffers,
|
207
|
+
)
|
208
|
+
quant_params = qtyping.NonLinearQuantParams(
|
209
|
+
num_bits=16, quantized_data=weight_content.astype(np.float16) # pytype: disable=attribute-error
|
210
|
+
)
|
211
|
+
op2weight_params = qtyping.OpToTensorParams(
|
212
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
213
|
+
parameters=quant_params,
|
214
|
+
transformations=[_QuantTransformation.ADD_DEQUANTIZE],
|
215
|
+
)
|
216
|
+
op_tensor_params.append(
|
217
|
+
qtyping.TensorTransformationParams(
|
218
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
|
219
|
+
consumers=[op2weight_params],
|
220
|
+
)
|
221
|
+
)
|
222
|
+
# Output tensor.
|
223
|
+
output_quant_params = _config_no_quantize_tensor(
|
224
|
+
op_info, output_tensor, is_inbounding_tensor=False
|
225
|
+
)
|
226
|
+
op_tensor_params.append(output_quant_params)
|
227
|
+
# Bias tensor.
|
228
|
+
if bias_tensor is not None:
|
229
|
+
bias_quant_params = _config_no_quantize_tensor(
|
230
|
+
op_info, bias_tensor, is_inbounding_tensor=True
|
231
|
+
)
|
232
|
+
op_tensor_params.append(bias_quant_params)
|
233
|
+
return op_tensor_params
|
234
|
+
|
235
|
+
|
236
|
+
def _config_no_quantize_tensor(
|
237
|
+
op_info: qtyping.OpInfo,
|
238
|
+
tensor: Any,
|
239
|
+
is_inbounding_tensor: bool,
|
240
|
+
) -> qtyping.TensorTransformationParams:
|
241
|
+
"""Configures a tensor to be not quantized.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
245
|
+
tensor: The tensor to be configured.
|
246
|
+
is_inbounding_tensor: Whether the tensor is an inbounding tensor.
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
TensorTransformationParams for the tensor.
|
250
|
+
"""
|
251
|
+
tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
|
252
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
253
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
254
|
+
transformations=[_QuantTransformation.NO_QUANTIZE],
|
255
|
+
)
|
256
|
+
if is_inbounding_tensor:
|
257
|
+
return qtyping.TensorTransformationParams(
|
258
|
+
tensor_name=tensor_name,
|
259
|
+
consumers=[op2tensor_params],
|
260
|
+
)
|
261
|
+
return qtyping.TensorTransformationParams(
|
262
|
+
tensor_name=tensor_name, producer=op2tensor_params
|
263
|
+
)
|
264
|
+
|
265
|
+
|
266
|
+
def init_qsvs(*_) -> qtyping.QSV:
|
267
|
+
"""Currently calibration free. Placeholder for AlgorithmManager."""
|
268
|
+
return {}
|
269
|
+
|
270
|
+
|
271
|
+
def calibrate(*_) -> dict[str, qtyping.QSV]:
|
272
|
+
"""Currently calibration free. Placeholder for AlgorithmManager."""
|
273
|
+
return {}
|