ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -18,12 +18,55 @@
|
|
|
18
18
|
import inspect as _inspect
|
|
19
19
|
import os.path as _os_path
|
|
20
20
|
import sys as _sys
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import Optional, Union
|
|
22
22
|
|
|
23
|
-
|
|
23
|
+
from absl.testing import parameterized
|
|
24
24
|
|
|
25
|
+
from ai_edge_quantizer import model_validator
|
|
26
|
+
from ai_edge_quantizer import qtyping
|
|
27
|
+
from ai_edge_quantizer import quantizer
|
|
25
28
|
from ai_edge_quantizer.utils import tfl_interpreter_utils
|
|
26
29
|
|
|
30
|
+
_ComputePrecision = qtyping.ComputePrecision
|
|
31
|
+
_OpName = qtyping.TFLOperationName
|
|
32
|
+
_TensorQuantConfig = qtyping.TensorQuantizationConfig
|
|
33
|
+
_OpQuantConfig = qtyping.OpQuantizationConfig
|
|
34
|
+
_AlgorithmName = quantizer.AlgorithmName
|
|
35
|
+
_Numeric = Union[int, float]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
DEFAULT_ACTIVATION_QUANT_SETTING = _TensorQuantConfig(
|
|
39
|
+
num_bits=8,
|
|
40
|
+
symmetric=False,
|
|
41
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
|
42
|
+
)
|
|
43
|
+
DEFAULT_WEIGHT_QUANT_SETTING = _TensorQuantConfig(
|
|
44
|
+
num_bits=8,
|
|
45
|
+
symmetric=True,
|
|
46
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_static_activation_quant_setting(
|
|
51
|
+
num_bits: int, symmetric: bool
|
|
52
|
+
) -> _TensorQuantConfig:
|
|
53
|
+
return _TensorQuantConfig(
|
|
54
|
+
num_bits=num_bits,
|
|
55
|
+
symmetric=symmetric,
|
|
56
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_static_op_quant_config(
|
|
61
|
+
activation_config: _TensorQuantConfig = DEFAULT_ACTIVATION_QUANT_SETTING,
|
|
62
|
+
weight_config: _TensorQuantConfig = DEFAULT_WEIGHT_QUANT_SETTING,
|
|
63
|
+
) -> _OpQuantConfig:
|
|
64
|
+
return qtyping.OpQuantizationConfig(
|
|
65
|
+
activation_tensor_config=activation_config,
|
|
66
|
+
weight_tensor_config=weight_config,
|
|
67
|
+
compute_precision=_ComputePrecision.INTEGER,
|
|
68
|
+
)
|
|
69
|
+
|
|
27
70
|
|
|
28
71
|
def get_path_to_datafile(path):
|
|
29
72
|
"""Get the path to the specified file in the data dependencies.
|
|
@@ -46,62 +89,152 @@ def get_path_to_datafile(path):
|
|
|
46
89
|
return path
|
|
47
90
|
|
|
48
91
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
92
|
+
class BaseOpTestCase(parameterized.TestCase):
|
|
93
|
+
"""Base class for op-level tests."""
|
|
94
|
+
|
|
95
|
+
def quantize_and_validate(
|
|
96
|
+
self,
|
|
97
|
+
model_path: str,
|
|
98
|
+
algorithm_key: _AlgorithmName,
|
|
99
|
+
op_name: _OpName,
|
|
100
|
+
op_config: _OpQuantConfig,
|
|
101
|
+
num_validation_samples: int = 4,
|
|
102
|
+
num_calibration_samples: Optional[int] = None,
|
|
103
|
+
error_metric: str = 'mse',
|
|
104
|
+
min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
|
|
105
|
+
) -> model_validator.ComparisonResult:
|
|
106
|
+
"""Quantizes and validates the given model with the given configurations.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
model_path: The path to the model to be quantized.
|
|
110
|
+
algorithm_key: The algorithm to be used for quantization.
|
|
111
|
+
op_name: The name of the operation to be quantized.
|
|
112
|
+
op_config: The configuration for the operation to be quantized.
|
|
113
|
+
num_validation_samples: The number of samples to use for validation.
|
|
114
|
+
num_calibration_samples: The number of samples to use for calibration. If
|
|
115
|
+
None then it will be set to num_validation_samples * 8.
|
|
116
|
+
error_metric: The error error_metric to use for validation.
|
|
117
|
+
min_max_range: The min and max of the input range.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
The comparison result of the validation.
|
|
121
|
+
"""
|
|
122
|
+
quantizer_instance = quantizer.Quantizer(model_path)
|
|
123
|
+
quantizer_instance.update_quantization_recipe(
|
|
124
|
+
algorithm_key=algorithm_key,
|
|
125
|
+
regex='.*',
|
|
126
|
+
operation_name=op_name,
|
|
127
|
+
op_config=op_config,
|
|
128
|
+
)
|
|
129
|
+
if quantizer_instance.need_calibration:
|
|
130
|
+
if num_calibration_samples is None:
|
|
131
|
+
num_calibration_samples = num_validation_samples * 8
|
|
132
|
+
calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
|
|
133
|
+
quantizer_instance.float_model,
|
|
134
|
+
num_samples=num_calibration_samples,
|
|
135
|
+
min_max_range=min_max_range,
|
|
136
|
+
)
|
|
137
|
+
calibration_result = quantizer_instance.calibrate(calibration_data)
|
|
138
|
+
quantization_result = quantizer_instance.quantize(calibration_result)
|
|
139
|
+
else:
|
|
140
|
+
quantization_result = quantizer_instance.quantize()
|
|
141
|
+
test_data = tfl_interpreter_utils.create_random_normal_input_data(
|
|
142
|
+
quantization_result.quantized_model,
|
|
143
|
+
num_samples=num_validation_samples,
|
|
144
|
+
min_max_range=min_max_range,
|
|
145
|
+
)
|
|
146
|
+
return quantizer_instance.validate(test_data, error_metric)
|
|
147
|
+
|
|
148
|
+
def assert_model_size_reduction_above_min_pct(
|
|
149
|
+
self,
|
|
150
|
+
validation_result: model_validator.ComparisonResult,
|
|
151
|
+
min_pct: float,
|
|
152
|
+
):
|
|
153
|
+
"""Checks the model size reduction (percentage) against the given expectation."""
|
|
154
|
+
_, reduction_pct = validation_result.get_model_size_reduction()
|
|
155
|
+
self.assertGreater(reduction_pct, min_pct)
|
|
156
|
+
|
|
157
|
+
def assert_weights_errors_below_tolerance(
|
|
158
|
+
self,
|
|
159
|
+
validation_result: model_validator.ComparisonResult,
|
|
160
|
+
weight_tolerance: float,
|
|
161
|
+
):
|
|
162
|
+
"""Checks the weight tensors' numerical behavior against the given tolerance."""
|
|
163
|
+
self.assertNotEmpty(validation_result.available_signature_keys())
|
|
164
|
+
for signature_key in validation_result.available_signature_keys():
|
|
165
|
+
signature_result = validation_result.get_signature_comparison_result(
|
|
166
|
+
signature_key
|
|
167
|
+
)
|
|
168
|
+
for result in signature_result.constant_tensors.values():
|
|
169
|
+
self.assertLess(result, weight_tolerance)
|
|
170
|
+
|
|
171
|
+
def assert_output_errors_below_tolerance(
|
|
172
|
+
self,
|
|
173
|
+
validation_result: model_validator.ComparisonResult,
|
|
174
|
+
output_tolerance: float,
|
|
175
|
+
):
|
|
176
|
+
"""Checks the output tensor numerical behavior against the given tolerance."""
|
|
177
|
+
self.assertNotEmpty(validation_result.available_signature_keys())
|
|
178
|
+
for signature_key in validation_result.available_signature_keys():
|
|
179
|
+
signature_result = validation_result.get_signature_comparison_result(
|
|
180
|
+
signature_key
|
|
181
|
+
)
|
|
182
|
+
for result in signature_result.output_tensors.values():
|
|
183
|
+
self.assertLess(result, output_tolerance)
|
|
184
|
+
|
|
185
|
+
def assert_quantization_accuracy_and_size(
|
|
186
|
+
self,
|
|
187
|
+
algorithm_key: _AlgorithmName,
|
|
188
|
+
model_path: str,
|
|
189
|
+
op_name: _OpName,
|
|
190
|
+
op_config: _OpQuantConfig,
|
|
191
|
+
expected_model_size_reduction: float,
|
|
192
|
+
weight_tolerance: float = 1e-4,
|
|
193
|
+
output_tolerance: float = 1e-4,
|
|
194
|
+
min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
|
|
195
|
+
):
|
|
196
|
+
"""Check if the quantization is successful and the result is valid."""
|
|
197
|
+
validation_result = self.quantize_and_validate(
|
|
198
|
+
model_path=model_path,
|
|
199
|
+
algorithm_key=algorithm_key,
|
|
200
|
+
op_name=op_name,
|
|
201
|
+
op_config=op_config,
|
|
202
|
+
min_max_range=min_max_range,
|
|
203
|
+
)
|
|
204
|
+
with self.subTest(name='ModelSizeReduction'):
|
|
205
|
+
self.assert_model_size_reduction_above_min_pct(
|
|
206
|
+
validation_result, expected_model_size_reduction
|
|
207
|
+
)
|
|
208
|
+
with self.subTest(name='WeightsErrors'):
|
|
209
|
+
self.assert_weights_errors_below_tolerance(
|
|
210
|
+
validation_result, weight_tolerance
|
|
211
|
+
)
|
|
212
|
+
with self.subTest(name='OutputErrors'):
|
|
213
|
+
self.assert_output_errors_below_tolerance(
|
|
214
|
+
validation_result, output_tolerance
|
|
74
215
|
)
|
|
75
|
-
input_data[arg_name] = new_data
|
|
76
|
-
dataset.append(input_data)
|
|
77
|
-
return dataset
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def create_random_normal_input_data(
|
|
81
|
-
tflite_model: Union[str, bytes],
|
|
82
|
-
num_samples: int = 4,
|
|
83
|
-
random_seed: int = 666,
|
|
84
|
-
) -> dict[str, list[dict[str, Any]]]:
|
|
85
|
-
"""create random dataset following random distribution for signature runner.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
tflite_model: TFLite model path or bytearray
|
|
89
|
-
num_samples: number of input samples to be generated
|
|
90
|
-
random_seed: random seed to be used for function
|
|
91
216
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
217
|
+
def assert_quantization_accuracy(
|
|
218
|
+
self,
|
|
219
|
+
algorithm_key: _AlgorithmName,
|
|
220
|
+
model_path: str,
|
|
221
|
+
op_name: _OpName,
|
|
222
|
+
op_config: _OpQuantConfig,
|
|
223
|
+
num_validation_samples: int = 4,
|
|
224
|
+
num_calibration_samples: Optional[int] = None,
|
|
225
|
+
output_tolerance: float = 1e-4,
|
|
226
|
+
min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
|
|
227
|
+
):
|
|
228
|
+
"""Checks if the output errors after quantization are within the tolerance."""
|
|
229
|
+
validation_result = self.quantize_and_validate(
|
|
230
|
+
model_path=model_path,
|
|
231
|
+
algorithm_key=algorithm_key,
|
|
232
|
+
num_validation_samples=num_validation_samples,
|
|
233
|
+
num_calibration_samples=num_calibration_samples,
|
|
234
|
+
op_name=op_name,
|
|
235
|
+
op_config=op_config,
|
|
236
|
+
min_max_range=min_max_range,
|
|
237
|
+
)
|
|
238
|
+
self.assert_output_errors_below_tolerance(
|
|
239
|
+
validation_result, output_tolerance
|
|
106
240
|
)
|
|
107
|
-
return test_data
|
|
@@ -20,48 +20,66 @@ from typing import Any, Optional, Union
|
|
|
20
20
|
import immutabledict
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
23
|
+
from ai_edge_litert.tools import flatbuffer_utils
|
|
23
24
|
from ai_edge_quantizer import qtyping
|
|
24
|
-
from ai_edge_litert import schema_py_generated # pylint:disable=g-direct-tensorflow-import
|
|
25
|
-
|
|
26
|
-
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
|
|
25
|
+
from ai_edge_litert import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
|
|
26
|
+
import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
|
|
27
27
|
|
|
28
28
|
_TFLOpName = qtyping.TFLOperationName
|
|
29
29
|
|
|
30
30
|
TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
|
|
31
|
-
_TFLOpName.FULLY_CONNECTED:
|
|
32
|
-
|
|
31
|
+
_TFLOpName.FULLY_CONNECTED: schema.BuiltinOperator.FULLY_CONNECTED,
|
|
32
|
+
_TFLOpName.BATCH_MATMUL: schema.BuiltinOperator.BATCH_MATMUL,
|
|
33
|
+
_TFLOpName.CONV_2D: schema.BuiltinOperator.CONV_2D,
|
|
34
|
+
_TFLOpName.DEPTHWISE_CONV_2D: schema.BuiltinOperator.DEPTHWISE_CONV_2D,
|
|
35
|
+
_TFLOpName.CONV_2D_TRANSPOSE: schema.BuiltinOperator.TRANSPOSE_CONV,
|
|
36
|
+
_TFLOpName.EMBEDDING_LOOKUP: schema.BuiltinOperator.EMBEDDING_LOOKUP,
|
|
37
|
+
_TFLOpName.SOFTMAX: schema.BuiltinOperator.SOFTMAX,
|
|
38
|
+
_TFLOpName.AVERAGE_POOL_2D: schema.BuiltinOperator.AVERAGE_POOL_2D,
|
|
39
|
+
_TFLOpName.RESHAPE: schema.BuiltinOperator.RESHAPE,
|
|
40
|
+
_TFLOpName.TANH: schema.BuiltinOperator.TANH,
|
|
41
|
+
_TFLOpName.TRANSPOSE: schema.BuiltinOperator.TRANSPOSE,
|
|
42
|
+
_TFLOpName.GELU: schema.BuiltinOperator.GELU,
|
|
43
|
+
_TFLOpName.ADD: schema.BuiltinOperator.ADD,
|
|
44
|
+
_TFLOpName.SUB: schema.BuiltinOperator.SUB,
|
|
45
|
+
_TFLOpName.MUL: schema.BuiltinOperator.MUL,
|
|
46
|
+
_TFLOpName.MEAN: schema.BuiltinOperator.MEAN,
|
|
47
|
+
_TFLOpName.RSQRT: schema.BuiltinOperator.RSQRT,
|
|
48
|
+
_TFLOpName.CONCATENATION: schema.BuiltinOperator.CONCATENATION,
|
|
49
|
+
_TFLOpName.STRIDED_SLICE: schema.BuiltinOperator.STRIDED_SLICE,
|
|
50
|
+
_TFLOpName.SPLIT: schema.BuiltinOperator.SPLIT,
|
|
51
|
+
_TFLOpName.LOGISTIC: schema.BuiltinOperator.LOGISTIC,
|
|
52
|
+
_TFLOpName.SLICE: schema.BuiltinOperator.SLICE,
|
|
53
|
+
_TFLOpName.SUM: schema.BuiltinOperator.SUM,
|
|
54
|
+
_TFLOpName.SELECT: schema.BuiltinOperator.SELECT,
|
|
55
|
+
_TFLOpName.SELECT_V2: schema.BuiltinOperator.SELECT_V2,
|
|
56
|
+
_TFLOpName.STABLEHLO_COMPOSITE: schema.BuiltinOperator.STABLEHLO_COMPOSITE,
|
|
57
|
+
_TFLOpName.DYNAMIC_UPDATE_SLICE: (
|
|
58
|
+
schema.BuiltinOperator.DYNAMIC_UPDATE_SLICE
|
|
33
59
|
),
|
|
34
|
-
_TFLOpName.
|
|
35
|
-
_TFLOpName.
|
|
36
|
-
_TFLOpName.
|
|
37
|
-
|
|
60
|
+
_TFLOpName.PAD: schema.BuiltinOperator.PAD,
|
|
61
|
+
_TFLOpName.SQUARED_DIFFERENCE: schema.BuiltinOperator.SQUARED_DIFFERENCE,
|
|
62
|
+
_TFLOpName.MAX_POOL_2D: schema.BuiltinOperator.MAX_POOL_2D,
|
|
63
|
+
_TFLOpName.RESIZE_BILINEAR: schema.BuiltinOperator.RESIZE_BILINEAR,
|
|
64
|
+
_TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
|
|
65
|
+
schema.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR
|
|
38
66
|
),
|
|
39
|
-
_TFLOpName.
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
_TFLOpName.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
_TFLOpName.
|
|
46
|
-
_TFLOpName.
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
_TFLOpName.
|
|
50
|
-
_TFLOpName.
|
|
51
|
-
_TFLOpName.
|
|
52
|
-
_TFLOpName.
|
|
53
|
-
_TFLOpName.
|
|
54
|
-
_TFLOpName.
|
|
55
|
-
_TFLOpName.MUL: schema_py_generated.BuiltinOperator.MUL,
|
|
56
|
-
_TFLOpName.MEAN: schema_py_generated.BuiltinOperator.MEAN,
|
|
57
|
-
_TFLOpName.RSQRT: schema_py_generated.BuiltinOperator.RSQRT,
|
|
58
|
-
_TFLOpName.CONCATENATION: schema_py_generated.BuiltinOperator.CONCATENATION,
|
|
59
|
-
_TFLOpName.STRIDED_SLICE: schema_py_generated.BuiltinOperator.STRIDED_SLICE,
|
|
60
|
-
_TFLOpName.SPLIT: schema_py_generated.BuiltinOperator.SPLIT,
|
|
61
|
-
_TFLOpName.LOGISTIC: schema_py_generated.BuiltinOperator.LOGISTIC,
|
|
62
|
-
_TFLOpName.SLICE: schema_py_generated.BuiltinOperator.SLICE,
|
|
63
|
-
_TFLOpName.SUM: schema_py_generated.BuiltinOperator.SUM,
|
|
64
|
-
_TFLOpName.SELECT_V2: schema_py_generated.BuiltinOperator.SELECT_V2,
|
|
67
|
+
_TFLOpName.GATHER_ND: schema.BuiltinOperator.GATHER_ND,
|
|
68
|
+
_TFLOpName.PACK: schema.BuiltinOperator.PACK,
|
|
69
|
+
_TFLOpName.UNPACK: schema.BuiltinOperator.UNPACK,
|
|
70
|
+
_TFLOpName.DIV: schema.BuiltinOperator.DIV,
|
|
71
|
+
_TFLOpName.BROADCAST_TO: schema.BuiltinOperator.BROADCAST_TO,
|
|
72
|
+
_TFLOpName.SQRT: schema.BuiltinOperator.SQRT,
|
|
73
|
+
_TFLOpName.GATHER: schema.BuiltinOperator.GATHER,
|
|
74
|
+
_TFLOpName.HARD_SWISH: schema.BuiltinOperator.HARD_SWISH,
|
|
75
|
+
_TFLOpName.MAXIMUM: schema.BuiltinOperator.MAXIMUM,
|
|
76
|
+
_TFLOpName.PADV2: schema.BuiltinOperator.PADV2,
|
|
77
|
+
_TFLOpName.REDUCE_MIN: schema.BuiltinOperator.REDUCE_MIN,
|
|
78
|
+
_TFLOpName.EQUAL: schema.BuiltinOperator.EQUAL,
|
|
79
|
+
_TFLOpName.NOT_EQUAL: schema.BuiltinOperator.NOT_EQUAL,
|
|
80
|
+
_TFLOpName.MIRROR_PAD: schema.BuiltinOperator.MIRROR_PAD,
|
|
81
|
+
_TFLOpName.SPACE_TO_DEPTH: schema.BuiltinOperator.SPACE_TO_DEPTH,
|
|
82
|
+
_TFLOpName.RELU: schema.BuiltinOperator.RELU,
|
|
65
83
|
})
|
|
66
84
|
|
|
67
85
|
TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
|
|
@@ -78,6 +96,11 @@ TFL_OP_TO_WEIGHT_QUANTIZED_DIM = immutabledict.immutabledict({
|
|
|
78
96
|
_TFLOpName.CONV_2D_TRANSPOSE: 0,
|
|
79
97
|
})
|
|
80
98
|
|
|
99
|
+
TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM = immutabledict.immutabledict({
|
|
100
|
+
_TFLOpName.FULLY_CONNECTED: 1,
|
|
101
|
+
_TFLOpName.EMBEDDING_LOOKUP: 1,
|
|
102
|
+
})
|
|
103
|
+
|
|
81
104
|
NUM_TFL_DATATYPES = 18
|
|
82
105
|
TENSOR_CODE_TO_TYPE = {}
|
|
83
106
|
for dtype_code in range(NUM_TFL_DATATYPES):
|
|
@@ -87,7 +110,7 @@ TENSOR_TYPE_TO_CODE = immutabledict.immutabledict(
|
|
|
87
110
|
(reversed(item) for item in TENSOR_CODE_TO_TYPE.items())
|
|
88
111
|
)
|
|
89
112
|
|
|
90
|
-
# Expose functions in
|
|
113
|
+
# Expose functions in litert.python.tools.flatbuffer_utils
|
|
91
114
|
write_model = flatbuffer_utils.write_model
|
|
92
115
|
|
|
93
116
|
|
|
@@ -122,7 +145,7 @@ def get_model_content(tflite_path: str) -> bytes:
|
|
|
122
145
|
Returns:
|
|
123
146
|
The model bytes.
|
|
124
147
|
"""
|
|
125
|
-
with
|
|
148
|
+
with open(tflite_path, "rb") as tflite_file:
|
|
126
149
|
return tflite_file.read()
|
|
127
150
|
|
|
128
151
|
|
|
@@ -135,7 +158,7 @@ def get_model_buffer(tflite_path: str) -> bytearray:
|
|
|
135
158
|
Returns:
|
|
136
159
|
model_buffer: the model buffer.
|
|
137
160
|
"""
|
|
138
|
-
with
|
|
161
|
+
with open(tflite_path, "rb") as tflite_file:
|
|
139
162
|
return bytearray(tflite_file.read())
|
|
140
163
|
|
|
141
164
|
|
|
@@ -188,25 +211,18 @@ def parse_fc_bmm_conv_tensors(
|
|
|
188
211
|
return input_tensor, weight_tensor, bias_tensor, output_tensor
|
|
189
212
|
|
|
190
213
|
|
|
191
|
-
# flatbuffer_model has Any type since
|
|
214
|
+
# flatbuffer_model has Any type since litert.python.tools.flatbuffer_utils
|
|
192
215
|
# is not type annotated.
|
|
193
216
|
def buffer_to_tensors(flatbuffer_model: Any) -> dict[int, list[Any]]:
|
|
194
|
-
"""
|
|
195
|
-
|
|
196
|
-
Args:
|
|
197
|
-
flatbuffer_model: the flatbuffer_model.
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
buffer_to_tensor_map: key as buffer index, value as list of tensors share
|
|
201
|
-
the buffer
|
|
202
|
-
"""
|
|
217
|
+
"""Returns a map from buffer id to tensors that use it."""
|
|
203
218
|
buffer_to_tensor_map = {}
|
|
204
219
|
for subgraph in flatbuffer_model.subgraphs:
|
|
205
220
|
for op in subgraph.operators:
|
|
206
221
|
for tensor in parse_op_tensors(op, subgraph.tensors):
|
|
207
222
|
if tensor.buffer not in buffer_to_tensor_map:
|
|
208
223
|
buffer_to_tensor_map[tensor.buffer] = []
|
|
209
|
-
buffer_to_tensor_map[tensor.buffer]
|
|
224
|
+
if tensor not in buffer_to_tensor_map[tensor.buffer]:
|
|
225
|
+
buffer_to_tensor_map[tensor.buffer].append(tensor)
|
|
210
226
|
return buffer_to_tensor_map
|
|
211
227
|
|
|
212
228
|
|
|
@@ -239,7 +255,8 @@ def get_tensor_data(tensor: Any, buffers: list[Any]) -> Optional[np.ndarray]:
|
|
|
239
255
|
data = np.frombuffer(
|
|
240
256
|
buffer_data, dtype=TENSOR_CODE_TO_TYPE[tensor.type].lower()
|
|
241
257
|
)
|
|
242
|
-
|
|
258
|
+
if tensor.shape is not None:
|
|
259
|
+
data = np.reshape(data, tensor.shape)
|
|
243
260
|
return data
|
|
244
261
|
|
|
245
262
|
|
|
@@ -315,3 +332,32 @@ def get_subgraph_input_output_operators(
|
|
|
315
332
|
op_key=qtyping.TFLOperationName.OUTPUT,
|
|
316
333
|
)
|
|
317
334
|
return [input_op, output_op]
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def get_op_side_effect_subgraphs(
|
|
338
|
+
op: Union[schema.Operator, schema.OperatorT],
|
|
339
|
+
) -> list[int]:
|
|
340
|
+
"""Get indices of any subgraphs invoked as a side effect of the operator.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
op: The operator object.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
A list of subgraph indices invoked by the operator. Empty if the operator
|
|
347
|
+
does not invoke any subgraphs.
|
|
348
|
+
"""
|
|
349
|
+
if opts := flatbuffer_utils.get_options_as(
|
|
350
|
+
op, schema.StableHLOCompositeOptionsT
|
|
351
|
+
):
|
|
352
|
+
return [opts.decompositionSubgraphIndex]
|
|
353
|
+
# Can add other nested ops here (control flow ops, etc).
|
|
354
|
+
return []
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def get_op_name_by_index(
|
|
358
|
+
flatbuffer_model: Any, subgraph_id: int, op_index: int
|
|
359
|
+
) -> str:
|
|
360
|
+
"""Get the op name from the flatbuffer model."""
|
|
361
|
+
op = flatbuffer_model.subgraphs[subgraph_id].operators[op_index]
|
|
362
|
+
builtin_code = flatbuffer_model.operatorCodes[op.opcodeIndex].builtinCode
|
|
363
|
+
return TFL_OP_CODE_TO_NAME[builtin_code]
|
|
@@ -105,6 +105,26 @@ class FlatbufferUtilsTest(googletest.TestCase):
|
|
|
105
105
|
conv2d_filter_tensor = tensors[0]
|
|
106
106
|
self.assertEqual(tuple(conv2d_filter_tensor.shape), (8, 3, 3, 1))
|
|
107
107
|
|
|
108
|
+
def test_buffer_to_tensors_has_unique_values(self):
|
|
109
|
+
test_model_path = os.path.join(
|
|
110
|
+
TEST_DATA_PREFIX_PATH,
|
|
111
|
+
"constant_tensor_and_buffer_only_sharing_weight_fcs.tflite",
|
|
112
|
+
)
|
|
113
|
+
test_model = tfl_flatbuffer_utils.read_model(test_model_path)
|
|
114
|
+
buffer_to_tensor_map = tfl_flatbuffer_utils.buffer_to_tensors(test_model)
|
|
115
|
+
self.assertLen(buffer_to_tensor_map, 7)
|
|
116
|
+
# The following buffer is shared by two tensors, each shared by two FC ops.
|
|
117
|
+
# This is where before we had multiple enrties for the same tensor.
|
|
118
|
+
self.assertLen(buffer_to_tensor_map[2], 2)
|
|
119
|
+
got_tensor_names = [
|
|
120
|
+
tfl_flatbuffer_utils.get_tensor_name(tensor)
|
|
121
|
+
for tensor in buffer_to_tensor_map[2]
|
|
122
|
+
]
|
|
123
|
+
self.assertEqual(
|
|
124
|
+
got_tensor_names,
|
|
125
|
+
["arith.constant", "arith.constant1"],
|
|
126
|
+
)
|
|
127
|
+
|
|
108
128
|
def test_get_tensor_name(self):
|
|
109
129
|
subgraph0 = self._test_model.subgraphs[0]
|
|
110
130
|
subgraph_tensors = subgraph0.tensors
|