ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,273 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Performs float casting quantization."""
17
+
18
+ from typing import Any, Optional
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
22
+
23
+ ALGORITHM_KEY = "float_casting"
24
+ _TFLOpName = qtyping.TFLOperationName
25
+ _QuantTransformation = qtyping.QuantTransformation
26
+
27
+ # Ops that support weight quantization config (e.g., support Weight-only).
28
+ SUPPORTED_WEIGHT_QUANT_OPS = frozenset([
29
+ _TFLOpName.FULLY_CONNECTED,
30
+ _TFLOpName.CONV_2D,
31
+ _TFLOpName.DEPTHWISE_CONV_2D,
32
+ _TFLOpName.CONV_2D_TRANSPOSE,
33
+ _TFLOpName.EMBEDDING_LOOKUP,
34
+ ])
35
+
36
+
37
+ def check_op_quantization_config(
38
+ op_name: _TFLOpName,
39
+ op_quant_config: qtyping.OpQuantizationConfig,
40
+ config_check_policy: Optional[qtyping.ConfigCheckPolicyDict] = None,
41
+ ) -> None:
42
+ """Checks if the op is valid for float casting quantization.
43
+
44
+ Args:
45
+ op_name: The name of the op.
46
+ op_quant_config: The quantization config for the op.
47
+ config_check_policy: The policy to check the quantization config.
48
+
49
+ Raises:
50
+ ValueError: If the op is not supported or the compute_precision is not
51
+ FLOAT.
52
+ """
53
+ # TODO: b/353780772 - Add config check policy for float casting quantization.
54
+ if config_check_policy is not None and config_check_policy:
55
+ raise ValueError(f"Config check isn't implemented yet for op: {op_name}.")
56
+
57
+ # Check if WEIGHT_ONLY.
58
+ if op_quant_config.compute_precision != qtyping.ComputePrecision.FLOAT:
59
+ raise ValueError(
60
+ "Currently, only Weight-Only is supported for float casting"
61
+ " quantization. Got unsupported execution mode:"
62
+ f" {op_quant_config.compute_precision} for op: {op_name}"
63
+ )
64
+ if op_quant_config.activation_tensor_config is not None:
65
+ raise ValueError(
66
+ "Activation tensor quantization is not supported for float casting"
67
+ " quantization."
68
+ )
69
+ if op_name not in SUPPORTED_WEIGHT_QUANT_OPS:
70
+ raise ValueError(
71
+ f"Unsupported op: {op_name} for float casting quantization."
72
+ )
73
+ if op_quant_config.weight_tensor_config is None:
74
+ raise ValueError(
75
+ "Weight tensor quantization config is required for float casting"
76
+ " quantization."
77
+ )
78
+ if (
79
+ op_quant_config.weight_tensor_config.num_bits != 16
80
+ or op_quant_config.weight_tensor_config.dtype
81
+ != qtyping.TensorDataType.FLOAT
82
+ ):
83
+ raise ValueError(
84
+ "Currently, float casting quantization config requires number of bits"
85
+ " to be set as 16, dtype as float, got"
86
+ f" {op_quant_config.weight_tensor_config.num_bits} and"
87
+ f" {op_quant_config.weight_tensor_config.dtype} ."
88
+ )
89
+
90
+
91
+ def materialize_fc_conv(
92
+ op_info: qtyping.OpInfo,
93
+ graph_info: qtyping.GraphInfo,
94
+ _: dict[str, Any],
95
+ ) -> list[qtyping.TensorTransformationParams]:
96
+ """Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d ops.
97
+
98
+ This function is called by the quantization pipeline to materialize
99
+ quantization parameters for the weight tensor of the op.
100
+
101
+ Args:
102
+ op_info: Aggregated information about the op (e.g., quantization config).
103
+ graph_info: Graph information needed to perform quantization for the op.
104
+ _: A map of tensor name to quantization parameters (unused).
105
+
106
+ Returns:
107
+ Quantization configuration for the weight tensor of the op.
108
+
109
+ Raises:
110
+ ValueError: If the op is not supported or the compute precision is not
111
+ FLOAT.
112
+ """
113
+ input_tensor, weight_tensor, bias_tensor, output_tensor = (
114
+ tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
115
+ op_info.op, graph_info.subgraph_tensors
116
+ )
117
+ )
118
+ op_tensor_params = []
119
+ # Input tensor.
120
+ input_quant_params = _config_no_quantize_tensor(
121
+ op_info, input_tensor, is_inbounding_tensor=True
122
+ )
123
+ op_tensor_params.append(input_quant_params)
124
+ # Weight tensor.
125
+ weight_content = tfl_flatbuffer_utils.get_tensor_data(
126
+ weight_tensor,
127
+ graph_info.buffers,
128
+ )
129
+ quant_params = qtyping.NonLinearQuantParams(
130
+ num_bits=16, quantized_data=weight_content.astype(np.float16) # pytype: disable=attribute-error
131
+ )
132
+ op2weight_params = qtyping.OpToTensorParams(
133
+ subgraph_op_id=op_info.subgraph_op_index,
134
+ parameters=quant_params,
135
+ transformations=[_QuantTransformation.ADD_DEQUANTIZE],
136
+ )
137
+ op_tensor_params.append(
138
+ qtyping.TensorTransformationParams(
139
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
140
+ consumers=[op2weight_params],
141
+ )
142
+ )
143
+ # Output tensor.
144
+ output_quant_params = _config_no_quantize_tensor(
145
+ op_info, output_tensor, is_inbounding_tensor=False
146
+ )
147
+ op_tensor_params.append(output_quant_params)
148
+ # Bias tensor.
149
+ if bias_tensor is not None:
150
+ bias_quant_params = _config_no_quantize_tensor(
151
+ op_info, bias_tensor, is_inbounding_tensor=True
152
+ )
153
+ op_tensor_params.append(bias_quant_params)
154
+ return op_tensor_params
155
+
156
+
157
+ def materialize_embedding_lookup(
158
+ op_info: qtyping.OpInfo,
159
+ graph_info: qtyping.GraphInfo,
160
+ _: dict[str, Any],
161
+ ) -> list[qtyping.TensorTransformationParams]:
162
+ return materialize_fc_conv(op_info, graph_info, _)
163
+
164
+
165
+ def materialize_conv2d_transpose(
166
+ op_info: qtyping.OpInfo,
167
+ graph_info: qtyping.GraphInfo,
168
+ _: dict[str, Any],
169
+ ) -> list[qtyping.TensorTransformationParams]:
170
+ """Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d ops.
171
+
172
+ This function is called by the quantization pipeline to materialize
173
+ quantization parameters for the weight tensor of the op.
174
+
175
+ Args:
176
+ op_info: Aggregated information about the op (e.g., quantization config).
177
+ graph_info: Graph information needed to perform quantization for the op.
178
+ _: A map of tensor name to quantization parameters (unused).
179
+
180
+ Returns:
181
+ Quantization configuration for the weight tensor of the op.
182
+
183
+ Raises:
184
+ ValueError: If the op is not supported or the execution mode is not
185
+ WEIGHT_ONLY.
186
+ """
187
+ input_tensor, weight_tensor, bias_tensor, output_tensor = (
188
+ tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
189
+ op_info.op,
190
+ graph_info.subgraph_tensors,
191
+ input_index=2,
192
+ weight_index=1,
193
+ bias_index=3,
194
+ output_index=0,
195
+ )
196
+ )
197
+ op_tensor_params = []
198
+ # Input tensor.
199
+ input_quant_params = _config_no_quantize_tensor(
200
+ op_info, input_tensor, is_inbounding_tensor=True
201
+ )
202
+ op_tensor_params.append(input_quant_params)
203
+ # Weight tensor.
204
+ weight_content = tfl_flatbuffer_utils.get_tensor_data(
205
+ weight_tensor,
206
+ graph_info.buffers,
207
+ )
208
+ quant_params = qtyping.NonLinearQuantParams(
209
+ num_bits=16, quantized_data=weight_content.astype(np.float16) # pytype: disable=attribute-error
210
+ )
211
+ op2weight_params = qtyping.OpToTensorParams(
212
+ subgraph_op_id=op_info.subgraph_op_index,
213
+ parameters=quant_params,
214
+ transformations=[_QuantTransformation.ADD_DEQUANTIZE],
215
+ )
216
+ op_tensor_params.append(
217
+ qtyping.TensorTransformationParams(
218
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
219
+ consumers=[op2weight_params],
220
+ )
221
+ )
222
+ # Output tensor.
223
+ output_quant_params = _config_no_quantize_tensor(
224
+ op_info, output_tensor, is_inbounding_tensor=False
225
+ )
226
+ op_tensor_params.append(output_quant_params)
227
+ # Bias tensor.
228
+ if bias_tensor is not None:
229
+ bias_quant_params = _config_no_quantize_tensor(
230
+ op_info, bias_tensor, is_inbounding_tensor=True
231
+ )
232
+ op_tensor_params.append(bias_quant_params)
233
+ return op_tensor_params
234
+
235
+
236
+ def _config_no_quantize_tensor(
237
+ op_info: qtyping.OpInfo,
238
+ tensor: Any,
239
+ is_inbounding_tensor: bool,
240
+ ) -> qtyping.TensorTransformationParams:
241
+ """Configures a tensor to be not quantized.
242
+
243
+ Args:
244
+ op_info: Aggregated information about the op (e.g., quantization config).
245
+ tensor: The tensor to be configured.
246
+ is_inbounding_tensor: Whether the tensor is an inbounding tensor.
247
+
248
+ Returns:
249
+ TensorTransformationParams for the tensor.
250
+ """
251
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
252
+ op2tensor_params = qtyping.OpToTensorParams(
253
+ subgraph_op_id=op_info.subgraph_op_index,
254
+ transformations=[_QuantTransformation.NO_QUANTIZE],
255
+ )
256
+ if is_inbounding_tensor:
257
+ return qtyping.TensorTransformationParams(
258
+ tensor_name=tensor_name,
259
+ consumers=[op2tensor_params],
260
+ )
261
+ return qtyping.TensorTransformationParams(
262
+ tensor_name=tensor_name, producer=op2tensor_params
263
+ )
264
+
265
+
266
+ def init_qsvs(*_) -> qtyping.QSV:
267
+ """Currently calibration free. Placeholder for AlgorithmManager."""
268
+ return {}
269
+
270
+
271
+ def calibrate(*_) -> dict[str, qtyping.QSV]:
272
+ """Currently calibration free. Placeholder for AlgorithmManager."""
273
+ return {}