ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,483 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Type hinting support for AI Edge Quantizer."""
17
+
18
+ import collections
19
+ from collections.abc import MutableMapping
20
+ import copy
21
+ import dataclasses
22
+ import enum
23
+ from typing import Any, Optional, Union
24
+
25
+ import numpy as np
26
+ from typing_extensions import TypeAlias
27
+
28
+
29
+ QSV: TypeAlias = MutableMapping[str, Any]
30
+
31
+
32
+ class TFLOperationName(str, enum.Enum):
33
+ """TF Lite operation names."""
34
+
35
+ ALL_SUPPORTED = '*'
36
+ INPUT = 'INPUT'
37
+ OUTPUT = 'OUTPUT'
38
+ FULLY_CONNECTED = 'FULLY_CONNECTED'
39
+ BATCH_MATMUL = 'BATCH_MATMUL'
40
+ DEPTHWISE_CONV_2D = 'DEPTHWISE_CONV_2D'
41
+ CONV_2D = 'CONV_2D'
42
+ CONV_2D_TRANSPOSE = 'CONV_2D_TRANSPOSE'
43
+ AVERAGE_POOL_2D = 'AVERAGE_POOL_2D'
44
+ RESHAPE = 'RESHAPE'
45
+ CUSTOM_OP = 'CUSTOM_OP'
46
+ EMBEDDING_LOOKUP = 'EMBEDDING_LOOKUP'
47
+ SOFTMAX = 'SOFTMAX'
48
+ TANH = 'TANH'
49
+ TRANSPOSE = 'TRANSPOSE'
50
+ GELU = 'GELU'
51
+ ADD = 'ADD'
52
+ SUB = 'SUB'
53
+ MUL = 'MUL'
54
+ MEAN = 'MEAN'
55
+ RSQRT = 'RSQRT'
56
+ CONCATENATION = 'CONCATENATION'
57
+ STRIDED_SLICE = 'STRIDED_SLICE'
58
+ SPLIT = 'SPLIT'
59
+ LOGISTIC = 'LOGISTIC'
60
+ SLICE = 'SLICE'
61
+ SUM = 'SUM'
62
+ SELECT_V2 = 'SELECT_V2'
63
+
64
+
65
+ class QuantizeMode(enum.Enum):
66
+ CALIBRATE = 2
67
+ MATERIALIZE = 3
68
+
69
+
70
+ class OpExecutionMode(str, enum.Enum):
71
+ """How to execute the op."""
72
+
73
+ WEIGHT_ONLY = 'WEIGHT_ONLY'
74
+ DRQ = 'DRQ' # Dynamic range quantization.
75
+ SRQ = 'SRQ' # Static range quantization.
76
+
77
+
78
+ class ComputePrecision(str, enum.Enum):
79
+ """The precision of the compute operation."""
80
+
81
+ INTEGER = 'INTEGER'
82
+ FLOAT = 'FLOAT'
83
+
84
+
85
+ class TensorDataType(str, enum.Enum):
86
+ INT = 'INT'
87
+ FLOAT = 'FLOAT'
88
+
89
+
90
+ class QuantGranularity(str, enum.Enum):
91
+ TENSORWISE = 'TENSORWISE'
92
+ CHANNELWISE = 'CHANNELWISE'
93
+ BLOCKWISE = 'BLOCKWISE'
94
+
95
+
96
+ class QuantTransformation(enum.Enum):
97
+ """Operations associated with quantization for a tensor."""
98
+
99
+ # Do nothing: float_tensor -> float_tensor.
100
+ NO_QUANTIZE = 0
101
+ # Add a quantize op: float_tensor -> Quantize Op -> quantized_tensor.
102
+ ADD_QUANTIZE = 1
103
+ # Add a dequantize op: quantized_tensor -> Dequantize Op -> float_tensor.
104
+ ADD_DEQUANTIZE = 2
105
+ # Quantize the float tensor: float_tensor -> quantized_tensor.
106
+ QUANTIZE_TENSOR = 3
107
+ # Create pattern for emulated subchannel quantization, only support fully
108
+ # connected op.
109
+ EMULATED_SUBCHANNEL = 4
110
+
111
+
112
+ @dataclasses.dataclass(frozen=True)
113
+ class UniformQuantParams:
114
+ """Parameters for uniform quantization.
115
+
116
+ Attributes:
117
+ num_bits: Number of bits to quantize to (e.g. 8 for int8).
118
+ quantized_dimension: The dimension to quantize.
119
+ scale: The scale of the quantization.
120
+ zero_point: The zero point of the quantization.
121
+ symmetric: Whether the quantization is symmetric (force zero_point to be 0).
122
+ quantized_data: The quantized data.
123
+ """
124
+
125
+ num_bits: int
126
+ quantized_dimension: Optional[int]
127
+ scale: np.ndarray
128
+ zero_point: np.ndarray
129
+ symmetric: bool = True
130
+ quantized_data: Optional[np.ndarray] = None
131
+
132
+ @classmethod
133
+ def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams':
134
+ """Creates UniformQuantParams from TFLite tensor details.
135
+
136
+ Args:
137
+ tensor_detail: The tensor details from TFLite.
138
+
139
+ Returns:
140
+ UniformQuantParams.
141
+ """
142
+ quant_params = tensor_detail['quantization_parameters']
143
+ data_type = tensor_detail['dtype']
144
+ if data_type == np.int8:
145
+ num_bits = 8
146
+ elif data_type == np.int16:
147
+ num_bits = 16
148
+ elif data_type == np.int32:
149
+ num_bits = 32
150
+ elif data_type == np.int64:
151
+ num_bits = 64
152
+ else:
153
+ raise ValueError(f'Unsupported data type: {data_type}')
154
+ symmetric = sum(abs(quant_params['zero_points'])) == 0
155
+ return cls(
156
+ quantized_dimension=quant_params['quantized_dimension'],
157
+ num_bits=num_bits,
158
+ scale=quant_params['scales'],
159
+ zero_point=quant_params['zero_points'],
160
+ symmetric=symmetric,
161
+ )
162
+
163
+ def __eq__(self, other):
164
+ if other.__class__ is not self.__class__:
165
+ return NotImplemented
166
+ return (
167
+ self.num_bits == other.num_bits
168
+ and self.quantized_dimension == other.quantized_dimension
169
+ and np.array_equal(self.scale, other.scale)
170
+ and np.array_equal(self.zero_point, other.zero_point)
171
+ and self.symmetric == other.symmetric
172
+ and _compare_array_or_none(self.quantized_data, other.quantized_data)
173
+ )
174
+
175
+
176
+ @dataclasses.dataclass(frozen=True)
177
+ class NonLinearQuantParams:
178
+ """Parameters for nonlinear quantization.
179
+
180
+ Currently only used for fp16 quantization.
181
+
182
+ Attributes:
183
+ num_bits: Number of bits to quantize to (e.g. 16 for fp16).
184
+ quantized_data: The quantized data.
185
+ data_type: The data type of the tensor.
186
+ """
187
+
188
+ num_bits: int
189
+ quantized_data: Optional[np.ndarray]
190
+ data_type: TensorDataType = TensorDataType.FLOAT
191
+
192
+ def __eq__(self, other):
193
+ if other.__class__ is not self.__class__:
194
+ return NotImplemented
195
+ return (
196
+ self.num_bits == other.num_bits
197
+ and self.data_type == other.data_type
198
+ and _compare_array_or_none(self.quantized_data, other.quantized_data)
199
+ )
200
+
201
+
202
+ @dataclasses.dataclass(frozen=True)
203
+ class OpToTensorParams:
204
+ """Tensor params authored from an associated op.
205
+
206
+ Attributes:
207
+ subgraph_op_id: The position of the op in the subgraph.
208
+ transformations: The transformations to be applied to the tensor.
209
+ parameters: The quantization parameters for the tensor.
210
+ """
211
+
212
+ subgraph_op_id: int
213
+ transformations: list[QuantTransformation]
214
+ parameters: Union[None, UniformQuantParams, NonLinearQuantParams] = None
215
+
216
+
217
+ @dataclasses.dataclass
218
+ class TensorTransformationParams:
219
+ """Transformation info for a tensor.
220
+
221
+ Every tensor in .tflite has the following property:
222
+ * Produced by one source op (producer), except constant tensor or model
223
+ input.
224
+ * Consumed by one or many destination ops (consumer), except model output.
225
+
226
+ Because users configure quantization settings in Op level
227
+ `OpQuantizationConfig`, each tensor will receive transformation parameters
228
+ * from the source op
229
+ * from the destination ops
230
+ """
231
+
232
+ tensor_name: str
233
+ producer: Optional[OpToTensorParams] = None
234
+ consumers: Optional[list[OpToTensorParams]] = None
235
+
236
+
237
+ @dataclasses.dataclass(frozen=True)
238
+ class TensorQuantizationConfig:
239
+ """Quantization configuration for a tensor.
240
+
241
+ Attributes:
242
+ num_bits: Number of bits to quantize to (e.g. 8 for int8).
243
+ symmetric: Whether to perform symmetric or asymmetric quantization. In the
244
+ symmetric quantization mode, the zero point is always 0.
245
+ granularity: Whether to perform per-tensor, per-channel or per-block
246
+ quantization.
247
+ dtype: The data type of the tensor.
248
+ block_size: The block size for blockwise quantization, ignored otherwise.
249
+ """
250
+
251
+ num_bits: int
252
+ symmetric: bool = True
253
+ granularity: QuantGranularity = QuantGranularity.TENSORWISE
254
+ dtype: TensorDataType = TensorDataType.INT
255
+ block_size: int = 0
256
+
257
+ def to_dict(self) -> dict[str, Any]:
258
+ """Converts ActivationQuantizationConfig to dict."""
259
+ return dataclasses.asdict(
260
+ self,
261
+ dict_factory=lambda x: { # pylint: disable=g-long-lambda
262
+ k: v
263
+ for (k, v) in x
264
+ # Skip None and empty dict values.
265
+ if v is not None and not (isinstance(v, dict) and not v)
266
+ },
267
+ )
268
+
269
+ @classmethod
270
+ def from_dict(cls, params: dict[str, Any]) -> 'TensorQuantizationConfig':
271
+ """Converts a given dict to TensorQuantizationConfig."""
272
+ params_copy = copy.deepcopy(params)
273
+ return cls(**params_copy)
274
+
275
+
276
+ @dataclasses.dataclass(frozen=True)
277
+ class OpQuantizationConfig:
278
+ """Configuration class to control the quantization process behavior.
279
+
280
+ Default to float activations and weights.
281
+
282
+ Attributes:
283
+ activation_tensor_config: The quantization configuration for activation
284
+ tensors in the op (i.e., runtime tensors).
285
+ weight_tensor_config: The quantization configuration for weight tensor in
286
+ the op.
287
+ compute_precision: The precision of the compute operation.
288
+ explicit_dequantize: Whether to add explicit dequantize op if compute
289
+ precision is FLOAT, but weight is quantized.
290
+ skip_checks: Whether to skip op quantization config checks. For advanced
291
+ users only. If set, the quantizer will ignore all op configuration checks
292
+ and forcefully quantize this op according to the user instructions even if
293
+ it's not supported in the TFLite runtime.
294
+ min_weight_elements: The minimum number of elements in the weight tensor to
295
+ be quantized.
296
+ """
297
+
298
+ activation_tensor_config: Optional[TensorQuantizationConfig] = None
299
+ # Bias tensor quantization is deduced from activation/weight config.
300
+ # e.g., int8A X int8W => int32 bias.
301
+ weight_tensor_config: Optional[TensorQuantizationConfig] = None
302
+ compute_precision: ComputePrecision = ComputePrecision.FLOAT
303
+ # TODO: b/359647578 - Set default to True.
304
+ explicit_dequantize: bool = False
305
+ skip_checks: bool = False
306
+ min_weight_elements: int = 0
307
+
308
+ def __post_init__(self):
309
+ if (
310
+ self.activation_tensor_config is None
311
+ or self.weight_tensor_config is None
312
+ ):
313
+ return
314
+ # Make sure the setting is valid.
315
+ if (
316
+ self.activation_tensor_config.dtype == TensorDataType.INT
317
+ and self.weight_tensor_config.dtype == TensorDataType.FLOAT
318
+ ):
319
+ raise ValueError(
320
+ 'An op can not be set to have integer activation but float weights!'
321
+ )
322
+ if (
323
+ # SRQ compliance check for the config.
324
+ self.activation_tensor_config.dtype == TensorDataType.INT
325
+ and self.weight_tensor_config.dtype == TensorDataType.INT
326
+ and self.compute_precision != ComputePrecision.INTEGER
327
+ ):
328
+ raise ValueError(
329
+ 'Op execution mode must be SRQ (static range quantization) if both'
330
+ ' activation and weight tensors are quantized!'
331
+ )
332
+
333
+ def to_dict(self) -> dict[str, Any]:
334
+ """Converts OpQuantizationConfig to dict."""
335
+ return dataclasses.asdict(
336
+ self,
337
+ dict_factory=lambda x: { # pylint: disable=g-long-lambda
338
+ k: v
339
+ for (k, v) in x
340
+ # Skip None and empty dict values.
341
+ if v is not None and not (isinstance(v, dict) and not v)
342
+ },
343
+ )
344
+
345
+ @classmethod
346
+ def from_dict(cls, params: dict[str, Any]) -> 'OpQuantizationConfig':
347
+ """Converts a given dict to OpQuantizationConfig."""
348
+ params_copy = copy.deepcopy(params)
349
+ params_copy['weight_tensor_config'] = TensorQuantizationConfig.from_dict(
350
+ params_copy['weight_tensor_config']
351
+ )
352
+ if 'activation_tensor_config' in params_copy:
353
+ params_copy['activation_tensor_config'] = (
354
+ TensorQuantizationConfig.from_dict(
355
+ params_copy['activation_tensor_config']
356
+ )
357
+ )
358
+ return cls(**params_copy)
359
+
360
+
361
+ @dataclasses.dataclass(frozen=True)
362
+ class GraphInfo:
363
+ """Aggregates graph information needed to perform quantization for an op.
364
+
365
+ Attributes:
366
+ subgraph_tensors: Tensors in the subgraph.
367
+ buffers: Buffers in the subgraph.
368
+ """
369
+
370
+ subgraph_tensors: list[Any]
371
+ buffers: list[Any]
372
+
373
+
374
+ @dataclasses.dataclass(frozen=True)
375
+ class OpInfo:
376
+ """Aggregates op information needed to perform quantization for an op.
377
+
378
+ Attributes:
379
+ op: The op to be quantized.
380
+ op_name: The name of the op.
381
+ subgraph_op_index: The position of the op in the subgraph.
382
+ op_quant_config: The quantization configuration for the op.
383
+ """
384
+
385
+ op: Any
386
+ op_name: TFLOperationName
387
+ subgraph_op_index: int # Position of the op in the subgraph.
388
+ op_quant_config: OpQuantizationConfig
389
+
390
+
391
+ # Data classes used by model modifier.
392
+
393
+
394
+ # TODO: b/335530570 - This needs to support more than one parameters.
395
+ @dataclasses.dataclass
396
+ class TransformationInst:
397
+ """Transformation instruction for a tensor.
398
+
399
+ Attributes:
400
+ transformation: The transformation to be applied to the tensor.
401
+ tensor_id: The id of the tensor.
402
+ producer: The id of the producer op.
403
+ consumers: The ids of the consumer ops.
404
+ parameters: The quantization parameters for the tensor.
405
+ """
406
+
407
+ transformation: QuantTransformation
408
+ tensor_id: int
409
+ producer: Optional[int]
410
+ consumers: list[int]
411
+ parameters: Union[None, UniformQuantParams, NonLinearQuantParams] = None
412
+
413
+
414
+ @dataclasses.dataclass
415
+ class TensorTransformationInsts:
416
+ """Transformation instructions for a tensor.
417
+
418
+ Attributes:
419
+ tensor_name: The name of the tensor.
420
+ subgraph_id: The id of the subgraph.
421
+ instructions: The transformation instructions for the tensor.
422
+ """
423
+
424
+ tensor_name: str
425
+ subgraph_id: int
426
+ instructions: Optional[list[TransformationInst]]
427
+
428
+
429
+ @dataclasses.dataclass(frozen=True)
430
+ class TransformationInfo:
431
+ """Transformation information for an op.
432
+
433
+ Attributes:
434
+ op_id: The id where op replacement/insertion begins.
435
+ num_ops_added: The number of ops added during the transformation.
436
+ output_tensor_id: The id of the output tensor.
437
+ """
438
+
439
+ op_id: int
440
+ num_ops_added: int
441
+ output_tensor_id: int
442
+
443
+
444
+ # Policy is represented as a dict to check the op quantization config.
445
+ # Normally the policy is loaded from a json file.
446
+ ConfigCheckPolicyDict = collections.OrderedDict[
447
+ TFLOperationName, list[OpQuantizationConfig]
448
+ ]
449
+
450
+
451
+ def _compare_array_or_none(
452
+ obj1: Optional[np.ndarray], obj2: Optional[np.ndarray]
453
+ ):
454
+ """Compares two arrays or None.
455
+
456
+ Args:
457
+ obj1: The first object to compare.
458
+ obj2: The second object to compare.
459
+
460
+ Returns:
461
+ True if both objects are None or both objects are equal.
462
+ """
463
+ if obj1 is None and obj2 is None:
464
+ return True # Both None, so they're equal.
465
+ elif obj1 is None or obj2 is None:
466
+ return False # Only one is None, so they're different.
467
+ else:
468
+ return np.array_equal(obj1, obj2)
469
+
470
+
471
+ @dataclasses.dataclass(frozen=True)
472
+ class IOOperator:
473
+ """IOOperator class to represent the input and output for a subgraph.
474
+
475
+ Attributes:
476
+ inputs: The input tensor ids of the op.
477
+ outputs: The output tensor ids of the op.
478
+ op_key: The op key of the op (input or output).
479
+ """
480
+
481
+ inputs: list[int]
482
+ outputs: list[int]
483
+ op_key: TFLOperationName