ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,483 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Type hinting support for AI Edge Quantizer."""
|
17
|
+
|
18
|
+
import collections
|
19
|
+
from collections.abc import MutableMapping
|
20
|
+
import copy
|
21
|
+
import dataclasses
|
22
|
+
import enum
|
23
|
+
from typing import Any, Optional, Union
|
24
|
+
|
25
|
+
import numpy as np
|
26
|
+
from typing_extensions import TypeAlias
|
27
|
+
|
28
|
+
|
29
|
+
QSV: TypeAlias = MutableMapping[str, Any]
|
30
|
+
|
31
|
+
|
32
|
+
class TFLOperationName(str, enum.Enum):
|
33
|
+
"""TF Lite operation names."""
|
34
|
+
|
35
|
+
ALL_SUPPORTED = '*'
|
36
|
+
INPUT = 'INPUT'
|
37
|
+
OUTPUT = 'OUTPUT'
|
38
|
+
FULLY_CONNECTED = 'FULLY_CONNECTED'
|
39
|
+
BATCH_MATMUL = 'BATCH_MATMUL'
|
40
|
+
DEPTHWISE_CONV_2D = 'DEPTHWISE_CONV_2D'
|
41
|
+
CONV_2D = 'CONV_2D'
|
42
|
+
CONV_2D_TRANSPOSE = 'CONV_2D_TRANSPOSE'
|
43
|
+
AVERAGE_POOL_2D = 'AVERAGE_POOL_2D'
|
44
|
+
RESHAPE = 'RESHAPE'
|
45
|
+
CUSTOM_OP = 'CUSTOM_OP'
|
46
|
+
EMBEDDING_LOOKUP = 'EMBEDDING_LOOKUP'
|
47
|
+
SOFTMAX = 'SOFTMAX'
|
48
|
+
TANH = 'TANH'
|
49
|
+
TRANSPOSE = 'TRANSPOSE'
|
50
|
+
GELU = 'GELU'
|
51
|
+
ADD = 'ADD'
|
52
|
+
SUB = 'SUB'
|
53
|
+
MUL = 'MUL'
|
54
|
+
MEAN = 'MEAN'
|
55
|
+
RSQRT = 'RSQRT'
|
56
|
+
CONCATENATION = 'CONCATENATION'
|
57
|
+
STRIDED_SLICE = 'STRIDED_SLICE'
|
58
|
+
SPLIT = 'SPLIT'
|
59
|
+
LOGISTIC = 'LOGISTIC'
|
60
|
+
SLICE = 'SLICE'
|
61
|
+
SUM = 'SUM'
|
62
|
+
SELECT_V2 = 'SELECT_V2'
|
63
|
+
|
64
|
+
|
65
|
+
class QuantizeMode(enum.Enum):
|
66
|
+
CALIBRATE = 2
|
67
|
+
MATERIALIZE = 3
|
68
|
+
|
69
|
+
|
70
|
+
class OpExecutionMode(str, enum.Enum):
|
71
|
+
"""How to execute the op."""
|
72
|
+
|
73
|
+
WEIGHT_ONLY = 'WEIGHT_ONLY'
|
74
|
+
DRQ = 'DRQ' # Dynamic range quantization.
|
75
|
+
SRQ = 'SRQ' # Static range quantization.
|
76
|
+
|
77
|
+
|
78
|
+
class ComputePrecision(str, enum.Enum):
|
79
|
+
"""The precision of the compute operation."""
|
80
|
+
|
81
|
+
INTEGER = 'INTEGER'
|
82
|
+
FLOAT = 'FLOAT'
|
83
|
+
|
84
|
+
|
85
|
+
class TensorDataType(str, enum.Enum):
|
86
|
+
INT = 'INT'
|
87
|
+
FLOAT = 'FLOAT'
|
88
|
+
|
89
|
+
|
90
|
+
class QuantGranularity(str, enum.Enum):
|
91
|
+
TENSORWISE = 'TENSORWISE'
|
92
|
+
CHANNELWISE = 'CHANNELWISE'
|
93
|
+
BLOCKWISE = 'BLOCKWISE'
|
94
|
+
|
95
|
+
|
96
|
+
class QuantTransformation(enum.Enum):
|
97
|
+
"""Operations associated with quantization for a tensor."""
|
98
|
+
|
99
|
+
# Do nothing: float_tensor -> float_tensor.
|
100
|
+
NO_QUANTIZE = 0
|
101
|
+
# Add a quantize op: float_tensor -> Quantize Op -> quantized_tensor.
|
102
|
+
ADD_QUANTIZE = 1
|
103
|
+
# Add a dequantize op: quantized_tensor -> Dequantize Op -> float_tensor.
|
104
|
+
ADD_DEQUANTIZE = 2
|
105
|
+
# Quantize the float tensor: float_tensor -> quantized_tensor.
|
106
|
+
QUANTIZE_TENSOR = 3
|
107
|
+
# Create pattern for emulated subchannel quantization, only support fully
|
108
|
+
# connected op.
|
109
|
+
EMULATED_SUBCHANNEL = 4
|
110
|
+
|
111
|
+
|
112
|
+
@dataclasses.dataclass(frozen=True)
|
113
|
+
class UniformQuantParams:
|
114
|
+
"""Parameters for uniform quantization.
|
115
|
+
|
116
|
+
Attributes:
|
117
|
+
num_bits: Number of bits to quantize to (e.g. 8 for int8).
|
118
|
+
quantized_dimension: The dimension to quantize.
|
119
|
+
scale: The scale of the quantization.
|
120
|
+
zero_point: The zero point of the quantization.
|
121
|
+
symmetric: Whether the quantization is symmetric (force zero_point to be 0).
|
122
|
+
quantized_data: The quantized data.
|
123
|
+
"""
|
124
|
+
|
125
|
+
num_bits: int
|
126
|
+
quantized_dimension: Optional[int]
|
127
|
+
scale: np.ndarray
|
128
|
+
zero_point: np.ndarray
|
129
|
+
symmetric: bool = True
|
130
|
+
quantized_data: Optional[np.ndarray] = None
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams':
|
134
|
+
"""Creates UniformQuantParams from TFLite tensor details.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
tensor_detail: The tensor details from TFLite.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
UniformQuantParams.
|
141
|
+
"""
|
142
|
+
quant_params = tensor_detail['quantization_parameters']
|
143
|
+
data_type = tensor_detail['dtype']
|
144
|
+
if data_type == np.int8:
|
145
|
+
num_bits = 8
|
146
|
+
elif data_type == np.int16:
|
147
|
+
num_bits = 16
|
148
|
+
elif data_type == np.int32:
|
149
|
+
num_bits = 32
|
150
|
+
elif data_type == np.int64:
|
151
|
+
num_bits = 64
|
152
|
+
else:
|
153
|
+
raise ValueError(f'Unsupported data type: {data_type}')
|
154
|
+
symmetric = sum(abs(quant_params['zero_points'])) == 0
|
155
|
+
return cls(
|
156
|
+
quantized_dimension=quant_params['quantized_dimension'],
|
157
|
+
num_bits=num_bits,
|
158
|
+
scale=quant_params['scales'],
|
159
|
+
zero_point=quant_params['zero_points'],
|
160
|
+
symmetric=symmetric,
|
161
|
+
)
|
162
|
+
|
163
|
+
def __eq__(self, other):
|
164
|
+
if other.__class__ is not self.__class__:
|
165
|
+
return NotImplemented
|
166
|
+
return (
|
167
|
+
self.num_bits == other.num_bits
|
168
|
+
and self.quantized_dimension == other.quantized_dimension
|
169
|
+
and np.array_equal(self.scale, other.scale)
|
170
|
+
and np.array_equal(self.zero_point, other.zero_point)
|
171
|
+
and self.symmetric == other.symmetric
|
172
|
+
and _compare_array_or_none(self.quantized_data, other.quantized_data)
|
173
|
+
)
|
174
|
+
|
175
|
+
|
176
|
+
@dataclasses.dataclass(frozen=True)
|
177
|
+
class NonLinearQuantParams:
|
178
|
+
"""Parameters for nonlinear quantization.
|
179
|
+
|
180
|
+
Currently only used for fp16 quantization.
|
181
|
+
|
182
|
+
Attributes:
|
183
|
+
num_bits: Number of bits to quantize to (e.g. 16 for fp16).
|
184
|
+
quantized_data: The quantized data.
|
185
|
+
data_type: The data type of the tensor.
|
186
|
+
"""
|
187
|
+
|
188
|
+
num_bits: int
|
189
|
+
quantized_data: Optional[np.ndarray]
|
190
|
+
data_type: TensorDataType = TensorDataType.FLOAT
|
191
|
+
|
192
|
+
def __eq__(self, other):
|
193
|
+
if other.__class__ is not self.__class__:
|
194
|
+
return NotImplemented
|
195
|
+
return (
|
196
|
+
self.num_bits == other.num_bits
|
197
|
+
and self.data_type == other.data_type
|
198
|
+
and _compare_array_or_none(self.quantized_data, other.quantized_data)
|
199
|
+
)
|
200
|
+
|
201
|
+
|
202
|
+
@dataclasses.dataclass(frozen=True)
|
203
|
+
class OpToTensorParams:
|
204
|
+
"""Tensor params authored from an associated op.
|
205
|
+
|
206
|
+
Attributes:
|
207
|
+
subgraph_op_id: The position of the op in the subgraph.
|
208
|
+
transformations: The transformations to be applied to the tensor.
|
209
|
+
parameters: The quantization parameters for the tensor.
|
210
|
+
"""
|
211
|
+
|
212
|
+
subgraph_op_id: int
|
213
|
+
transformations: list[QuantTransformation]
|
214
|
+
parameters: Union[None, UniformQuantParams, NonLinearQuantParams] = None
|
215
|
+
|
216
|
+
|
217
|
+
@dataclasses.dataclass
|
218
|
+
class TensorTransformationParams:
|
219
|
+
"""Transformation info for a tensor.
|
220
|
+
|
221
|
+
Every tensor in .tflite has the following property:
|
222
|
+
* Produced by one source op (producer), except constant tensor or model
|
223
|
+
input.
|
224
|
+
* Consumed by one or many destination ops (consumer), except model output.
|
225
|
+
|
226
|
+
Because users configure quantization settings in Op level
|
227
|
+
`OpQuantizationConfig`, each tensor will receive transformation parameters
|
228
|
+
* from the source op
|
229
|
+
* from the destination ops
|
230
|
+
"""
|
231
|
+
|
232
|
+
tensor_name: str
|
233
|
+
producer: Optional[OpToTensorParams] = None
|
234
|
+
consumers: Optional[list[OpToTensorParams]] = None
|
235
|
+
|
236
|
+
|
237
|
+
@dataclasses.dataclass(frozen=True)
|
238
|
+
class TensorQuantizationConfig:
|
239
|
+
"""Quantization configuration for a tensor.
|
240
|
+
|
241
|
+
Attributes:
|
242
|
+
num_bits: Number of bits to quantize to (e.g. 8 for int8).
|
243
|
+
symmetric: Whether to perform symmetric or asymmetric quantization. In the
|
244
|
+
symmetric quantization mode, the zero point is always 0.
|
245
|
+
granularity: Whether to perform per-tensor, per-channel or per-block
|
246
|
+
quantization.
|
247
|
+
dtype: The data type of the tensor.
|
248
|
+
block_size: The block size for blockwise quantization, ignored otherwise.
|
249
|
+
"""
|
250
|
+
|
251
|
+
num_bits: int
|
252
|
+
symmetric: bool = True
|
253
|
+
granularity: QuantGranularity = QuantGranularity.TENSORWISE
|
254
|
+
dtype: TensorDataType = TensorDataType.INT
|
255
|
+
block_size: int = 0
|
256
|
+
|
257
|
+
def to_dict(self) -> dict[str, Any]:
|
258
|
+
"""Converts ActivationQuantizationConfig to dict."""
|
259
|
+
return dataclasses.asdict(
|
260
|
+
self,
|
261
|
+
dict_factory=lambda x: { # pylint: disable=g-long-lambda
|
262
|
+
k: v
|
263
|
+
for (k, v) in x
|
264
|
+
# Skip None and empty dict values.
|
265
|
+
if v is not None and not (isinstance(v, dict) and not v)
|
266
|
+
},
|
267
|
+
)
|
268
|
+
|
269
|
+
@classmethod
|
270
|
+
def from_dict(cls, params: dict[str, Any]) -> 'TensorQuantizationConfig':
|
271
|
+
"""Converts a given dict to TensorQuantizationConfig."""
|
272
|
+
params_copy = copy.deepcopy(params)
|
273
|
+
return cls(**params_copy)
|
274
|
+
|
275
|
+
|
276
|
+
@dataclasses.dataclass(frozen=True)
|
277
|
+
class OpQuantizationConfig:
|
278
|
+
"""Configuration class to control the quantization process behavior.
|
279
|
+
|
280
|
+
Default to float activations and weights.
|
281
|
+
|
282
|
+
Attributes:
|
283
|
+
activation_tensor_config: The quantization configuration for activation
|
284
|
+
tensors in the op (i.e., runtime tensors).
|
285
|
+
weight_tensor_config: The quantization configuration for weight tensor in
|
286
|
+
the op.
|
287
|
+
compute_precision: The precision of the compute operation.
|
288
|
+
explicit_dequantize: Whether to add explicit dequantize op if compute
|
289
|
+
precision is FLOAT, but weight is quantized.
|
290
|
+
skip_checks: Whether to skip op quantization config checks. For advanced
|
291
|
+
users only. If set, the quantizer will ignore all op configuration checks
|
292
|
+
and forcefully quantize this op according to the user instructions even if
|
293
|
+
it's not supported in the TFLite runtime.
|
294
|
+
min_weight_elements: The minimum number of elements in the weight tensor to
|
295
|
+
be quantized.
|
296
|
+
"""
|
297
|
+
|
298
|
+
activation_tensor_config: Optional[TensorQuantizationConfig] = None
|
299
|
+
# Bias tensor quantization is deduced from activation/weight config.
|
300
|
+
# e.g., int8A X int8W => int32 bias.
|
301
|
+
weight_tensor_config: Optional[TensorQuantizationConfig] = None
|
302
|
+
compute_precision: ComputePrecision = ComputePrecision.FLOAT
|
303
|
+
# TODO: b/359647578 - Set default to True.
|
304
|
+
explicit_dequantize: bool = False
|
305
|
+
skip_checks: bool = False
|
306
|
+
min_weight_elements: int = 0
|
307
|
+
|
308
|
+
def __post_init__(self):
|
309
|
+
if (
|
310
|
+
self.activation_tensor_config is None
|
311
|
+
or self.weight_tensor_config is None
|
312
|
+
):
|
313
|
+
return
|
314
|
+
# Make sure the setting is valid.
|
315
|
+
if (
|
316
|
+
self.activation_tensor_config.dtype == TensorDataType.INT
|
317
|
+
and self.weight_tensor_config.dtype == TensorDataType.FLOAT
|
318
|
+
):
|
319
|
+
raise ValueError(
|
320
|
+
'An op can not be set to have integer activation but float weights!'
|
321
|
+
)
|
322
|
+
if (
|
323
|
+
# SRQ compliance check for the config.
|
324
|
+
self.activation_tensor_config.dtype == TensorDataType.INT
|
325
|
+
and self.weight_tensor_config.dtype == TensorDataType.INT
|
326
|
+
and self.compute_precision != ComputePrecision.INTEGER
|
327
|
+
):
|
328
|
+
raise ValueError(
|
329
|
+
'Op execution mode must be SRQ (static range quantization) if both'
|
330
|
+
' activation and weight tensors are quantized!'
|
331
|
+
)
|
332
|
+
|
333
|
+
def to_dict(self) -> dict[str, Any]:
|
334
|
+
"""Converts OpQuantizationConfig to dict."""
|
335
|
+
return dataclasses.asdict(
|
336
|
+
self,
|
337
|
+
dict_factory=lambda x: { # pylint: disable=g-long-lambda
|
338
|
+
k: v
|
339
|
+
for (k, v) in x
|
340
|
+
# Skip None and empty dict values.
|
341
|
+
if v is not None and not (isinstance(v, dict) and not v)
|
342
|
+
},
|
343
|
+
)
|
344
|
+
|
345
|
+
@classmethod
|
346
|
+
def from_dict(cls, params: dict[str, Any]) -> 'OpQuantizationConfig':
|
347
|
+
"""Converts a given dict to OpQuantizationConfig."""
|
348
|
+
params_copy = copy.deepcopy(params)
|
349
|
+
params_copy['weight_tensor_config'] = TensorQuantizationConfig.from_dict(
|
350
|
+
params_copy['weight_tensor_config']
|
351
|
+
)
|
352
|
+
if 'activation_tensor_config' in params_copy:
|
353
|
+
params_copy['activation_tensor_config'] = (
|
354
|
+
TensorQuantizationConfig.from_dict(
|
355
|
+
params_copy['activation_tensor_config']
|
356
|
+
)
|
357
|
+
)
|
358
|
+
return cls(**params_copy)
|
359
|
+
|
360
|
+
|
361
|
+
@dataclasses.dataclass(frozen=True)
|
362
|
+
class GraphInfo:
|
363
|
+
"""Aggregates graph information needed to perform quantization for an op.
|
364
|
+
|
365
|
+
Attributes:
|
366
|
+
subgraph_tensors: Tensors in the subgraph.
|
367
|
+
buffers: Buffers in the subgraph.
|
368
|
+
"""
|
369
|
+
|
370
|
+
subgraph_tensors: list[Any]
|
371
|
+
buffers: list[Any]
|
372
|
+
|
373
|
+
|
374
|
+
@dataclasses.dataclass(frozen=True)
|
375
|
+
class OpInfo:
|
376
|
+
"""Aggregates op information needed to perform quantization for an op.
|
377
|
+
|
378
|
+
Attributes:
|
379
|
+
op: The op to be quantized.
|
380
|
+
op_name: The name of the op.
|
381
|
+
subgraph_op_index: The position of the op in the subgraph.
|
382
|
+
op_quant_config: The quantization configuration for the op.
|
383
|
+
"""
|
384
|
+
|
385
|
+
op: Any
|
386
|
+
op_name: TFLOperationName
|
387
|
+
subgraph_op_index: int # Position of the op in the subgraph.
|
388
|
+
op_quant_config: OpQuantizationConfig
|
389
|
+
|
390
|
+
|
391
|
+
# Data classes used by model modifier.
|
392
|
+
|
393
|
+
|
394
|
+
# TODO: b/335530570 - This needs to support more than one parameters.
|
395
|
+
@dataclasses.dataclass
|
396
|
+
class TransformationInst:
|
397
|
+
"""Transformation instruction for a tensor.
|
398
|
+
|
399
|
+
Attributes:
|
400
|
+
transformation: The transformation to be applied to the tensor.
|
401
|
+
tensor_id: The id of the tensor.
|
402
|
+
producer: The id of the producer op.
|
403
|
+
consumers: The ids of the consumer ops.
|
404
|
+
parameters: The quantization parameters for the tensor.
|
405
|
+
"""
|
406
|
+
|
407
|
+
transformation: QuantTransformation
|
408
|
+
tensor_id: int
|
409
|
+
producer: Optional[int]
|
410
|
+
consumers: list[int]
|
411
|
+
parameters: Union[None, UniformQuantParams, NonLinearQuantParams] = None
|
412
|
+
|
413
|
+
|
414
|
+
@dataclasses.dataclass
|
415
|
+
class TensorTransformationInsts:
|
416
|
+
"""Transformation instructions for a tensor.
|
417
|
+
|
418
|
+
Attributes:
|
419
|
+
tensor_name: The name of the tensor.
|
420
|
+
subgraph_id: The id of the subgraph.
|
421
|
+
instructions: The transformation instructions for the tensor.
|
422
|
+
"""
|
423
|
+
|
424
|
+
tensor_name: str
|
425
|
+
subgraph_id: int
|
426
|
+
instructions: Optional[list[TransformationInst]]
|
427
|
+
|
428
|
+
|
429
|
+
@dataclasses.dataclass(frozen=True)
|
430
|
+
class TransformationInfo:
|
431
|
+
"""Transformation information for an op.
|
432
|
+
|
433
|
+
Attributes:
|
434
|
+
op_id: The id where op replacement/insertion begins.
|
435
|
+
num_ops_added: The number of ops added during the transformation.
|
436
|
+
output_tensor_id: The id of the output tensor.
|
437
|
+
"""
|
438
|
+
|
439
|
+
op_id: int
|
440
|
+
num_ops_added: int
|
441
|
+
output_tensor_id: int
|
442
|
+
|
443
|
+
|
444
|
+
# Policy is represented as a dict to check the op quantization config.
|
445
|
+
# Normally the policy is loaded from a json file.
|
446
|
+
ConfigCheckPolicyDict = collections.OrderedDict[
|
447
|
+
TFLOperationName, list[OpQuantizationConfig]
|
448
|
+
]
|
449
|
+
|
450
|
+
|
451
|
+
def _compare_array_or_none(
|
452
|
+
obj1: Optional[np.ndarray], obj2: Optional[np.ndarray]
|
453
|
+
):
|
454
|
+
"""Compares two arrays or None.
|
455
|
+
|
456
|
+
Args:
|
457
|
+
obj1: The first object to compare.
|
458
|
+
obj2: The second object to compare.
|
459
|
+
|
460
|
+
Returns:
|
461
|
+
True if both objects are None or both objects are equal.
|
462
|
+
"""
|
463
|
+
if obj1 is None and obj2 is None:
|
464
|
+
return True # Both None, so they're equal.
|
465
|
+
elif obj1 is None or obj2 is None:
|
466
|
+
return False # Only one is None, so they're different.
|
467
|
+
else:
|
468
|
+
return np.array_equal(obj1, obj2)
|
469
|
+
|
470
|
+
|
471
|
+
@dataclasses.dataclass(frozen=True)
|
472
|
+
class IOOperator:
|
473
|
+
"""IOOperator class to represent the input and output for a subgraph.
|
474
|
+
|
475
|
+
Attributes:
|
476
|
+
inputs: The input tensor ids of the op.
|
477
|
+
outputs: The output tensor ids of the op.
|
478
|
+
op_key: The op key of the op (input or output).
|
479
|
+
"""
|
480
|
+
|
481
|
+
inputs: list[int]
|
482
|
+
outputs: list[int]
|
483
|
+
op_key: TFLOperationName
|