ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,310 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Default quantization policy."""
|
17
|
+
|
18
|
+
import collections
|
19
|
+
import copy
|
20
|
+
import json
|
21
|
+
from typing import Any
|
22
|
+
from ai_edge_quantizer import qtyping
|
23
|
+
|
24
|
+
_TFLOpName = qtyping.TFLOperationName
|
25
|
+
_OpQuantizationConfig = qtyping.OpQuantizationConfig
|
26
|
+
_TensorQuantizationConfig = qtyping.TensorQuantizationConfig
|
27
|
+
_ComputePrecision = qtyping.ComputePrecision
|
28
|
+
_Granularity = qtyping.QuantGranularity
|
29
|
+
_INT = qtyping.TensorDataType.INT
|
30
|
+
|
31
|
+
# Default config check policy in JSON format. It has two keys: "configs" and
|
32
|
+
# "ops_per_config". "configs" is a dictionary mapping from config name to config
|
33
|
+
# content. "ops_per_config" is a dictionary mapping from config name to a list
|
34
|
+
# of op names.
|
35
|
+
DEFAULT_JSON_POLICY = """
|
36
|
+
{
|
37
|
+
"configs": {
|
38
|
+
"dynamic_wi8_afp32": {
|
39
|
+
"weight_tensor_config": {
|
40
|
+
"num_bits": 8,
|
41
|
+
"symmetric": [true],
|
42
|
+
"granularity": ["CHANNELWISE", "TENSORWISE"],
|
43
|
+
"dtype": "INT"
|
44
|
+
},
|
45
|
+
"explicit_dequantize": false,
|
46
|
+
"compute_precision": "INTEGER"
|
47
|
+
},
|
48
|
+
"dynamic_wi4_afp32": {
|
49
|
+
"weight_tensor_config": {
|
50
|
+
"num_bits": 4,
|
51
|
+
"symmetric": [true],
|
52
|
+
"granularity": ["CHANNELWISE", "TENSORWISE"],
|
53
|
+
"dtype": "INT"
|
54
|
+
},
|
55
|
+
"explicit_dequantize": false,
|
56
|
+
"compute_precision": "INTEGER"
|
57
|
+
},
|
58
|
+
"static_wi8_ai16": {
|
59
|
+
"activation_tensor_config": {
|
60
|
+
"num_bits": 16,
|
61
|
+
"symmetric": [true],
|
62
|
+
"granularity": ["TENSORWISE"],
|
63
|
+
"dtype": "INT"
|
64
|
+
},
|
65
|
+
"weight_tensor_config": {
|
66
|
+
"num_bits": 8,
|
67
|
+
"symmetric": [true],
|
68
|
+
"granularity": ["CHANNELWISE", "TENSORWISE"],
|
69
|
+
"dtype": "INT"
|
70
|
+
},
|
71
|
+
"explicit_dequantize": false,
|
72
|
+
"compute_precision": "INTEGER"
|
73
|
+
},
|
74
|
+
"static_wi4_ai16": {
|
75
|
+
"activation_tensor_config": {
|
76
|
+
"num_bits": 16,
|
77
|
+
"symmetric": [true],
|
78
|
+
"granularity": ["TENSORWISE"],
|
79
|
+
"dtype": "INT"
|
80
|
+
},
|
81
|
+
"weight_tensor_config": {
|
82
|
+
"num_bits": 4,
|
83
|
+
"symmetric": [true],
|
84
|
+
"granularity": ["CHANNELWISE", "TENSORWISE"],
|
85
|
+
"dtype": "INT"
|
86
|
+
},
|
87
|
+
"explicit_dequantize": false,
|
88
|
+
"compute_precision": "INTEGER"
|
89
|
+
},
|
90
|
+
"static_wi8_ai8": {
|
91
|
+
"activation_tensor_config": {
|
92
|
+
"num_bits": 8,
|
93
|
+
"symmetric": [true, false],
|
94
|
+
"granularity": ["TENSORWISE"],
|
95
|
+
"dtype": "INT"
|
96
|
+
},
|
97
|
+
"weight_tensor_config": {
|
98
|
+
"num_bits": 8,
|
99
|
+
"symmetric": [true],
|
100
|
+
"granularity": ["CHANNELWISE", "TENSORWISE"],
|
101
|
+
"dtype": "INT"
|
102
|
+
},
|
103
|
+
"explicit_dequantize": false,
|
104
|
+
"compute_precision": "INTEGER"
|
105
|
+
},
|
106
|
+
"static_wi4_ai8": {
|
107
|
+
"activation_tensor_config": {
|
108
|
+
"num_bits": 8,
|
109
|
+
"symmetric": [true, false],
|
110
|
+
"granularity": ["TENSORWISE"],
|
111
|
+
"dtype": "INT"
|
112
|
+
},
|
113
|
+
"weight_tensor_config": {
|
114
|
+
"num_bits": 4,
|
115
|
+
"symmetric": [true],
|
116
|
+
"granularity": ["CHANNELWISE", "TENSORWISE"],
|
117
|
+
"dtype": "INT"
|
118
|
+
},
|
119
|
+
"explicit_dequantize": false,
|
120
|
+
"compute_precision": "INTEGER"
|
121
|
+
},
|
122
|
+
"weightonly_wi8_afp32": {
|
123
|
+
"weight_tensor_config": {
|
124
|
+
"num_bits": 8,
|
125
|
+
"symmetric": [true, false],
|
126
|
+
"granularity": ["CHANNELWISE", "TENSORWISE"],
|
127
|
+
"dtype": "INT"
|
128
|
+
},
|
129
|
+
"explicit_dequantize": true,
|
130
|
+
"compute_precision": "FLOAT"
|
131
|
+
},
|
132
|
+
"weightonly_wi4_afp32": {
|
133
|
+
"weight_tensor_config": {
|
134
|
+
"num_bits": 4,
|
135
|
+
"symmetric": [true, false],
|
136
|
+
"granularity": ["CHANNELWISE", "TENSORWISE"],
|
137
|
+
"dtype": "INT"
|
138
|
+
},
|
139
|
+
"explicit_dequantize": true,
|
140
|
+
"compute_precision": "FLOAT"
|
141
|
+
}
|
142
|
+
},
|
143
|
+
"ops_per_config": {
|
144
|
+
"static_wi8_ai16": [
|
145
|
+
"ADD",
|
146
|
+
"AVERAGE_POOL_2D",
|
147
|
+
"BATCH_MATMUL",
|
148
|
+
"CONCATENATION",
|
149
|
+
"CONV_2D",
|
150
|
+
"CONV_2D_TRANSPOSE",
|
151
|
+
"DEPTHWISE_CONV_2D",
|
152
|
+
"FULLY_CONNECTED",
|
153
|
+
"GELU",
|
154
|
+
"LOGISTIC",
|
155
|
+
"MEAN",
|
156
|
+
"MUL",
|
157
|
+
"RESHAPE",
|
158
|
+
"RSQRT",
|
159
|
+
"SOFTMAX",
|
160
|
+
"SPLIT",
|
161
|
+
"STRIDED_SLICE",
|
162
|
+
"SUB",
|
163
|
+
"TANH",
|
164
|
+
"TRANSPOSE",
|
165
|
+
"INPUT",
|
166
|
+
"OUTPUT",
|
167
|
+
"SLICE",
|
168
|
+
"EMBEDDING_LOOKUP",
|
169
|
+
"SUM",
|
170
|
+
"SELECT_V2"
|
171
|
+
],
|
172
|
+
"static_wi8_ai8": [
|
173
|
+
"ADD",
|
174
|
+
"AVERAGE_POOL_2D",
|
175
|
+
"BATCH_MATMUL",
|
176
|
+
"CONCATENATION",
|
177
|
+
"FULLY_CONNECTED",
|
178
|
+
"CONV_2D",
|
179
|
+
"CONV_2D_TRANSPOSE",
|
180
|
+
"DEPTHWISE_CONV_2D",
|
181
|
+
"GELU",
|
182
|
+
"LOGISTIC",
|
183
|
+
"MEAN",
|
184
|
+
"MUL",
|
185
|
+
"RESHAPE",
|
186
|
+
"RSQRT",
|
187
|
+
"SOFTMAX",
|
188
|
+
"SPLIT",
|
189
|
+
"STRIDED_SLICE",
|
190
|
+
"SUB",
|
191
|
+
"TANH",
|
192
|
+
"TRANSPOSE",
|
193
|
+
"INPUT",
|
194
|
+
"OUTPUT",
|
195
|
+
"SLICE",
|
196
|
+
"EMBEDDING_LOOKUP",
|
197
|
+
"SUM",
|
198
|
+
"SELECT_V2"
|
199
|
+
],
|
200
|
+
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
|
201
|
+
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
|
202
|
+
"dynamic_wi8_afp32": [
|
203
|
+
"BATCH_MATMUL",
|
204
|
+
"CONV_2D",
|
205
|
+
"CONV_2D_TRANSPOSE",
|
206
|
+
"DEPTHWISE_CONV_2D",
|
207
|
+
"EMBEDDING_LOOKUP",
|
208
|
+
"FULLY_CONNECTED"
|
209
|
+
],
|
210
|
+
"dynamic_wi4_afp32": ["FULLY_CONNECTED", "EMBEDDING_LOOKUP", "CONV_2D"],
|
211
|
+
"weightonly_wi8_afp32": [
|
212
|
+
"BATCH_MATMUL",
|
213
|
+
"CONV_2D",
|
214
|
+
"CONV_2D_TRANSPOSE",
|
215
|
+
"DEPTHWISE_CONV_2D",
|
216
|
+
"EMBEDDING_LOOKUP",
|
217
|
+
"FULLY_CONNECTED"
|
218
|
+
],
|
219
|
+
"weightonly_wi4_afp32": ["BATCH_MATMUL", "FULLY_CONNECTED", "EMBEDDING_LOOKUP", "CONV_2D"]
|
220
|
+
}
|
221
|
+
}
|
222
|
+
"""
|
223
|
+
|
224
|
+
|
225
|
+
def _unroll_json_config(
|
226
|
+
json_config: dict[str, Any],
|
227
|
+
) -> list[_OpQuantizationConfig]:
|
228
|
+
"""Unrolls the config.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
json_config: JSON config to be unrolled.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
quant_configs: A list of unrolled configs.
|
235
|
+
"""
|
236
|
+
|
237
|
+
# Unroll activation configs first.
|
238
|
+
activation_configs = []
|
239
|
+
if "activation_tensor_config" in json_config:
|
240
|
+
for symmetric in json_config["activation_tensor_config"]["symmetric"]:
|
241
|
+
for granularity in json_config["activation_tensor_config"]["granularity"]:
|
242
|
+
tensor_config = {
|
243
|
+
"num_bits": json_config["activation_tensor_config"]["num_bits"],
|
244
|
+
"symmetric": symmetric,
|
245
|
+
"granularity": granularity,
|
246
|
+
"dtype": json_config["activation_tensor_config"]["dtype"],
|
247
|
+
}
|
248
|
+
activation_configs.append(
|
249
|
+
qtyping.TensorQuantizationConfig.from_dict(tensor_config)
|
250
|
+
)
|
251
|
+
|
252
|
+
# Then unroll weight configs and turn them into quantization configs.
|
253
|
+
quant_configs = []
|
254
|
+
for symmetric in json_config["weight_tensor_config"]["symmetric"]:
|
255
|
+
for granularity in json_config["weight_tensor_config"]["granularity"]:
|
256
|
+
tensor_config = {
|
257
|
+
"num_bits": json_config["weight_tensor_config"]["num_bits"],
|
258
|
+
"symmetric": symmetric,
|
259
|
+
"granularity": granularity,
|
260
|
+
"dtype": json_config["weight_tensor_config"]["dtype"],
|
261
|
+
}
|
262
|
+
|
263
|
+
if activation_configs:
|
264
|
+
for activation_config in activation_configs:
|
265
|
+
quant_configs.append(
|
266
|
+
qtyping.OpQuantizationConfig(
|
267
|
+
activation_tensor_config=activation_config,
|
268
|
+
weight_tensor_config=qtyping.TensorQuantizationConfig.from_dict(
|
269
|
+
tensor_config
|
270
|
+
),
|
271
|
+
compute_precision=json_config["compute_precision"],
|
272
|
+
explicit_dequantize=json_config["explicit_dequantize"],
|
273
|
+
)
|
274
|
+
)
|
275
|
+
else:
|
276
|
+
quant_configs.append(
|
277
|
+
qtyping.OpQuantizationConfig(
|
278
|
+
weight_tensor_config=qtyping.TensorQuantizationConfig.from_dict(
|
279
|
+
tensor_config
|
280
|
+
),
|
281
|
+
compute_precision=json_config["compute_precision"],
|
282
|
+
explicit_dequantize=json_config["explicit_dequantize"],
|
283
|
+
)
|
284
|
+
)
|
285
|
+
|
286
|
+
return quant_configs
|
287
|
+
|
288
|
+
|
289
|
+
def update_default_config_policy(raw_json_policy: str):
|
290
|
+
"""Updates the default config check policy."""
|
291
|
+
json_policy_content = json.loads(raw_json_policy)
|
292
|
+
|
293
|
+
# Unroll the json config and add the configs to the policy.
|
294
|
+
policy = collections.OrderedDict()
|
295
|
+
for json_policy_config in json_policy_content["ops_per_config"]:
|
296
|
+
unrolled_configs = _unroll_json_config(
|
297
|
+
json_policy_content["configs"][json_policy_config]
|
298
|
+
)
|
299
|
+
|
300
|
+
for op in json_policy_content["ops_per_config"][json_policy_config]:
|
301
|
+
op_name = _TFLOpName(op)
|
302
|
+
quant_configs = copy.deepcopy(unrolled_configs)
|
303
|
+
if op in policy.keys():
|
304
|
+
quant_configs += policy[op_name]
|
305
|
+
policy[op_name] = quant_configs
|
306
|
+
|
307
|
+
return policy
|
308
|
+
|
309
|
+
|
310
|
+
DEFAULT_CONFIG_CHECK_POLICY = update_default_config_policy(DEFAULT_JSON_POLICY)
|
@@ -0,0 +1,176 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Model Modifier class that produce the final quantized TFlite model."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
|
22
|
+
from ai_edge_quantizer import qtyping
|
23
|
+
from ai_edge_quantizer import transformation_instruction_generator
|
24
|
+
from ai_edge_quantizer import transformation_performer
|
25
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
26
|
+
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
|
27
|
+
|
28
|
+
|
29
|
+
class ModelModifier:
|
30
|
+
"""Model Modifier class that produce the final quantized TFlite model."""
|
31
|
+
|
32
|
+
def __init__(self, float_tflite: bytes):
|
33
|
+
"""Constructor.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
float_tflite: the original TFlite model in bytearray or file path
|
37
|
+
"""
|
38
|
+
|
39
|
+
self._model_content = float_tflite
|
40
|
+
|
41
|
+
self._constant_map = []
|
42
|
+
self._transformation_instruction_generator = (
|
43
|
+
transformation_instruction_generator.TransformationInstructionsGenerator()
|
44
|
+
)
|
45
|
+
self._transformation_performer = (
|
46
|
+
transformation_performer.TransformationPerformer()
|
47
|
+
)
|
48
|
+
|
49
|
+
def modify_model(
|
50
|
+
self, params: dict[str, qtyping.TensorTransformationParams]
|
51
|
+
) -> bytearray:
|
52
|
+
"""Modify the model.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
params: a dictionary with tensor name and a list of tensor transformation
|
56
|
+
params
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
a byte buffer that represents the serialized tflite model
|
60
|
+
"""
|
61
|
+
quantized_model = copy.deepcopy(
|
62
|
+
flatbuffer_utils.read_model_from_bytearray(self._model_content)
|
63
|
+
)
|
64
|
+
|
65
|
+
instructions = self._transformation_instruction_generator.quant_params_to_transformation_insts(
|
66
|
+
params, quantized_model
|
67
|
+
)
|
68
|
+
|
69
|
+
self._transformation_performer.transform_graph(
|
70
|
+
instructions, quantized_model
|
71
|
+
)
|
72
|
+
constant_buffer_size = self._process_constant_map(quantized_model)
|
73
|
+
# we leave 64MB for the model architecture.
|
74
|
+
if constant_buffer_size > 2**31 - 2**26:
|
75
|
+
return self._serialize_large_model(quantized_model)
|
76
|
+
else:
|
77
|
+
return self._serialize_small_model(quantized_model)
|
78
|
+
|
79
|
+
def _process_constant_map(
|
80
|
+
self, quantized_model: schema_py_generated.ModelT
|
81
|
+
) -> int:
|
82
|
+
"""Process the constant map after all transformations are applied.
|
83
|
+
|
84
|
+
If the resulting model is > 2GB then we would need to serialize constants
|
85
|
+
separately, as such, we collect all the constant buffers using this
|
86
|
+
function.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
quantized_model: a quantized TFlite ModelT
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
an integer representing the total size of the constant buffers
|
93
|
+
"""
|
94
|
+
buffer_size = 0
|
95
|
+
for buffer in quantized_model.buffers:
|
96
|
+
if buffer.data is None:
|
97
|
+
self._constant_map.append(buffer.data)
|
98
|
+
elif isinstance(buffer.data, np.ndarray):
|
99
|
+
self._constant_map.append(buffer.data.tobytes())
|
100
|
+
buffer_size += len(buffer.data.tobytes())
|
101
|
+
else:
|
102
|
+
self._constant_map.append(buffer.data)
|
103
|
+
buffer_size += len(buffer.data)
|
104
|
+
return buffer_size
|
105
|
+
|
106
|
+
def _pad_bytearray(self, bytearr: bytearray):
|
107
|
+
"""Pad the bytearray to 16 bytes."""
|
108
|
+
remainder = len(bytearr) % 16
|
109
|
+
if remainder != 0:
|
110
|
+
padding_size = 16 - remainder
|
111
|
+
bytearr.extend(b'\0' * padding_size)
|
112
|
+
|
113
|
+
# TODO: b/333797307 - support > 2GB output model
|
114
|
+
def _serialize_large_model(
|
115
|
+
self, quantized_model: schema_py_generated.ModelT
|
116
|
+
) -> bytearray:
|
117
|
+
"""serialize models > 2GB.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
quantized_model: a quantized TFlite ModelT
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
a byte buffer that represents the serialized tflite model
|
124
|
+
"""
|
125
|
+
# TODO: b/338244867 - we can have more efficient way to calculate the
|
126
|
+
# buffer offsets.
|
127
|
+
|
128
|
+
# remove all the constant from the model.
|
129
|
+
for buffer in quantized_model.buffers:
|
130
|
+
if buffer.data is not None:
|
131
|
+
buffer.data = None
|
132
|
+
buffer.offset = 1
|
133
|
+
buffer.size = 1
|
134
|
+
dummy_bytearray = bytearray(
|
135
|
+
flatbuffer_utils.convert_object_to_bytearray(quantized_model)
|
136
|
+
)
|
137
|
+
# calculate the correct buffer size and offset
|
138
|
+
self._pad_bytearray(dummy_bytearray)
|
139
|
+
for buffer_idx, buffer in enumerate(quantized_model.buffers):
|
140
|
+
buffer_data = self._constant_map[buffer_idx]
|
141
|
+
if buffer_data is None:
|
142
|
+
continue
|
143
|
+
buffer.offset = len(dummy_bytearray)
|
144
|
+
buffer.size = len(buffer_data)
|
145
|
+
dummy_bytearray += buffer_data
|
146
|
+
self._pad_bytearray(dummy_bytearray)
|
147
|
+
del dummy_bytearray
|
148
|
+
|
149
|
+
# build new tflite file with correct buffer offset
|
150
|
+
model_bytearray = bytearray(
|
151
|
+
flatbuffer_utils.convert_object_to_bytearray(quantized_model)
|
152
|
+
)
|
153
|
+
self._pad_bytearray(model_bytearray)
|
154
|
+
for buffer_idx, _ in enumerate(quantized_model.buffers):
|
155
|
+
buffer_data = self._constant_map[buffer_idx]
|
156
|
+
if buffer_data is None:
|
157
|
+
continue
|
158
|
+
model_bytearray += buffer_data
|
159
|
+
self._pad_bytearray(model_bytearray)
|
160
|
+
return model_bytearray
|
161
|
+
|
162
|
+
def _serialize_small_model(
|
163
|
+
self, quantized_model: schema_py_generated.ModelT
|
164
|
+
) -> bytearray:
|
165
|
+
"""serialize models < 2GB.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
quantized_model: a quantized TFlite ModelT
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
a byte buffer that represents the serialized tflite model
|
172
|
+
"""
|
173
|
+
model_bytearray = flatbuffer_utils.convert_object_to_bytearray(
|
174
|
+
quantized_model
|
175
|
+
)
|
176
|
+
return model_bytearray
|
@@ -0,0 +1,130 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Tests for model_modifier."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import tracemalloc
|
20
|
+
from tensorflow.python.platform import googletest
|
21
|
+
from absl.testing import parameterized
|
22
|
+
from ai_edge_quantizer import model_modifier
|
23
|
+
from ai_edge_quantizer import params_generator
|
24
|
+
from ai_edge_quantizer import qtyping
|
25
|
+
from ai_edge_quantizer import recipe_manager
|
26
|
+
from ai_edge_quantizer.utils import test_utils
|
27
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
28
|
+
from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
|
29
|
+
|
30
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
|
31
|
+
|
32
|
+
|
33
|
+
class ModelModifierTest(parameterized.TestCase):
|
34
|
+
|
35
|
+
def setUp(self):
|
36
|
+
super().setUp()
|
37
|
+
self._model_path = os.path.join(
|
38
|
+
TEST_DATA_PREFIX_PATH, 'tests/models/conv_fc_mnist.tflite'
|
39
|
+
)
|
40
|
+
|
41
|
+
self._model_content: bytes = tfl_flatbuffer_utils.get_model_content(
|
42
|
+
self._model_path
|
43
|
+
)
|
44
|
+
self._model_modifier = model_modifier.ModelModifier(self._model_content)
|
45
|
+
self._global_recipe = [
|
46
|
+
{
|
47
|
+
'regex': '.*',
|
48
|
+
'operation': 'FULLY_CONNECTED',
|
49
|
+
'algorithm_key': 'min_max_uniform_quantize',
|
50
|
+
'op_config': {
|
51
|
+
'weight_tensor_config': {
|
52
|
+
'dtype': qtyping.TensorDataType.INT,
|
53
|
+
'num_bits': 8,
|
54
|
+
'symmetric': False,
|
55
|
+
'granularity': qtyping.QuantGranularity.CHANNELWISE,
|
56
|
+
'block_size': 0,
|
57
|
+
},
|
58
|
+
# Equivalent to WEIGHT_ONLY.
|
59
|
+
'compute_precision': qtyping.ComputePrecision.FLOAT,
|
60
|
+
'explicit_dequantize': True,
|
61
|
+
},
|
62
|
+
},
|
63
|
+
]
|
64
|
+
|
65
|
+
def test_process_constant_map_succeeds(self):
|
66
|
+
model_bytearray = flatbuffer_utils.read_model_from_bytearray(
|
67
|
+
self._model_content
|
68
|
+
)
|
69
|
+
constant_size = self._model_modifier._process_constant_map(model_bytearray)
|
70
|
+
self.assertEqual(constant_size, 202540)
|
71
|
+
|
72
|
+
def test_modify_model_succeeds_with_recipe(self):
|
73
|
+
recipe_manager_instance = recipe_manager.RecipeManager()
|
74
|
+
params_generator_instance = params_generator.ParamsGenerator(
|
75
|
+
self._model_path
|
76
|
+
)
|
77
|
+
|
78
|
+
recipe_manager_instance.load_quantization_recipe(self._global_recipe)
|
79
|
+
tensor_quantization_params = (
|
80
|
+
params_generator_instance.generate_quantization_parameters(
|
81
|
+
recipe_manager_instance
|
82
|
+
)
|
83
|
+
)
|
84
|
+
new_model_binary = self._model_modifier.modify_model(
|
85
|
+
tensor_quantization_params
|
86
|
+
)
|
87
|
+
flatbuffer_utils.convert_bytearray_to_object(new_model_binary)
|
88
|
+
self.assertLess(new_model_binary, self._model_content)
|
89
|
+
|
90
|
+
def test_modify_model_preserves_original_model(self):
|
91
|
+
recipe_manager_instance = recipe_manager.RecipeManager()
|
92
|
+
params_generator_instance = params_generator.ParamsGenerator(
|
93
|
+
self._model_path
|
94
|
+
)
|
95
|
+
|
96
|
+
recipe_manager_instance.load_quantization_recipe(self._global_recipe)
|
97
|
+
tensor_quantization_params = (
|
98
|
+
params_generator_instance.generate_quantization_parameters(
|
99
|
+
recipe_manager_instance
|
100
|
+
)
|
101
|
+
)
|
102
|
+
self.assertEqual(self._model_modifier._model_content, self._model_content)
|
103
|
+
self._model_modifier.modify_model(tensor_quantization_params)
|
104
|
+
self.assertEqual(self._model_modifier._model_content, self._model_content)
|
105
|
+
|
106
|
+
def test_modify_model_peak_memory_usage_in_acceptable_range(self):
|
107
|
+
"""Test ModifyModel peak memory usage."""
|
108
|
+
|
109
|
+
recipe_manager_instance = recipe_manager.RecipeManager()
|
110
|
+
params_generator_instance = params_generator.ParamsGenerator(
|
111
|
+
self._model_path
|
112
|
+
)
|
113
|
+
|
114
|
+
recipe_manager_instance.load_quantization_recipe(self._global_recipe)
|
115
|
+
tensor_quantization_params = (
|
116
|
+
params_generator_instance.generate_quantization_parameters(
|
117
|
+
recipe_manager_instance
|
118
|
+
)
|
119
|
+
)
|
120
|
+
|
121
|
+
tracemalloc.start()
|
122
|
+
self._model_modifier.modify_model(tensor_quantization_params)
|
123
|
+
_, mem_peak = tracemalloc.get_traced_memory()
|
124
|
+
|
125
|
+
loosen_mem_use_factor = 4.5
|
126
|
+
self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor)
|
127
|
+
|
128
|
+
|
129
|
+
if __name__ == '__main__':
|
130
|
+
googletest.main()
|