ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,310 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Default quantization policy."""
17
+
18
+ import collections
19
+ import copy
20
+ import json
21
+ from typing import Any
22
+ from ai_edge_quantizer import qtyping
23
+
24
+ _TFLOpName = qtyping.TFLOperationName
25
+ _OpQuantizationConfig = qtyping.OpQuantizationConfig
26
+ _TensorQuantizationConfig = qtyping.TensorQuantizationConfig
27
+ _ComputePrecision = qtyping.ComputePrecision
28
+ _Granularity = qtyping.QuantGranularity
29
+ _INT = qtyping.TensorDataType.INT
30
+
31
+ # Default config check policy in JSON format. It has two keys: "configs" and
32
+ # "ops_per_config". "configs" is a dictionary mapping from config name to config
33
+ # content. "ops_per_config" is a dictionary mapping from config name to a list
34
+ # of op names.
35
+ DEFAULT_JSON_POLICY = """
36
+ {
37
+ "configs": {
38
+ "dynamic_wi8_afp32": {
39
+ "weight_tensor_config": {
40
+ "num_bits": 8,
41
+ "symmetric": [true],
42
+ "granularity": ["CHANNELWISE", "TENSORWISE"],
43
+ "dtype": "INT"
44
+ },
45
+ "explicit_dequantize": false,
46
+ "compute_precision": "INTEGER"
47
+ },
48
+ "dynamic_wi4_afp32": {
49
+ "weight_tensor_config": {
50
+ "num_bits": 4,
51
+ "symmetric": [true],
52
+ "granularity": ["CHANNELWISE", "TENSORWISE"],
53
+ "dtype": "INT"
54
+ },
55
+ "explicit_dequantize": false,
56
+ "compute_precision": "INTEGER"
57
+ },
58
+ "static_wi8_ai16": {
59
+ "activation_tensor_config": {
60
+ "num_bits": 16,
61
+ "symmetric": [true],
62
+ "granularity": ["TENSORWISE"],
63
+ "dtype": "INT"
64
+ },
65
+ "weight_tensor_config": {
66
+ "num_bits": 8,
67
+ "symmetric": [true],
68
+ "granularity": ["CHANNELWISE", "TENSORWISE"],
69
+ "dtype": "INT"
70
+ },
71
+ "explicit_dequantize": false,
72
+ "compute_precision": "INTEGER"
73
+ },
74
+ "static_wi4_ai16": {
75
+ "activation_tensor_config": {
76
+ "num_bits": 16,
77
+ "symmetric": [true],
78
+ "granularity": ["TENSORWISE"],
79
+ "dtype": "INT"
80
+ },
81
+ "weight_tensor_config": {
82
+ "num_bits": 4,
83
+ "symmetric": [true],
84
+ "granularity": ["CHANNELWISE", "TENSORWISE"],
85
+ "dtype": "INT"
86
+ },
87
+ "explicit_dequantize": false,
88
+ "compute_precision": "INTEGER"
89
+ },
90
+ "static_wi8_ai8": {
91
+ "activation_tensor_config": {
92
+ "num_bits": 8,
93
+ "symmetric": [true, false],
94
+ "granularity": ["TENSORWISE"],
95
+ "dtype": "INT"
96
+ },
97
+ "weight_tensor_config": {
98
+ "num_bits": 8,
99
+ "symmetric": [true],
100
+ "granularity": ["CHANNELWISE", "TENSORWISE"],
101
+ "dtype": "INT"
102
+ },
103
+ "explicit_dequantize": false,
104
+ "compute_precision": "INTEGER"
105
+ },
106
+ "static_wi4_ai8": {
107
+ "activation_tensor_config": {
108
+ "num_bits": 8,
109
+ "symmetric": [true, false],
110
+ "granularity": ["TENSORWISE"],
111
+ "dtype": "INT"
112
+ },
113
+ "weight_tensor_config": {
114
+ "num_bits": 4,
115
+ "symmetric": [true],
116
+ "granularity": ["CHANNELWISE", "TENSORWISE"],
117
+ "dtype": "INT"
118
+ },
119
+ "explicit_dequantize": false,
120
+ "compute_precision": "INTEGER"
121
+ },
122
+ "weightonly_wi8_afp32": {
123
+ "weight_tensor_config": {
124
+ "num_bits": 8,
125
+ "symmetric": [true, false],
126
+ "granularity": ["CHANNELWISE", "TENSORWISE"],
127
+ "dtype": "INT"
128
+ },
129
+ "explicit_dequantize": true,
130
+ "compute_precision": "FLOAT"
131
+ },
132
+ "weightonly_wi4_afp32": {
133
+ "weight_tensor_config": {
134
+ "num_bits": 4,
135
+ "symmetric": [true, false],
136
+ "granularity": ["CHANNELWISE", "TENSORWISE"],
137
+ "dtype": "INT"
138
+ },
139
+ "explicit_dequantize": true,
140
+ "compute_precision": "FLOAT"
141
+ }
142
+ },
143
+ "ops_per_config": {
144
+ "static_wi8_ai16": [
145
+ "ADD",
146
+ "AVERAGE_POOL_2D",
147
+ "BATCH_MATMUL",
148
+ "CONCATENATION",
149
+ "CONV_2D",
150
+ "CONV_2D_TRANSPOSE",
151
+ "DEPTHWISE_CONV_2D",
152
+ "FULLY_CONNECTED",
153
+ "GELU",
154
+ "LOGISTIC",
155
+ "MEAN",
156
+ "MUL",
157
+ "RESHAPE",
158
+ "RSQRT",
159
+ "SOFTMAX",
160
+ "SPLIT",
161
+ "STRIDED_SLICE",
162
+ "SUB",
163
+ "TANH",
164
+ "TRANSPOSE",
165
+ "INPUT",
166
+ "OUTPUT",
167
+ "SLICE",
168
+ "EMBEDDING_LOOKUP",
169
+ "SUM",
170
+ "SELECT_V2"
171
+ ],
172
+ "static_wi8_ai8": [
173
+ "ADD",
174
+ "AVERAGE_POOL_2D",
175
+ "BATCH_MATMUL",
176
+ "CONCATENATION",
177
+ "FULLY_CONNECTED",
178
+ "CONV_2D",
179
+ "CONV_2D_TRANSPOSE",
180
+ "DEPTHWISE_CONV_2D",
181
+ "GELU",
182
+ "LOGISTIC",
183
+ "MEAN",
184
+ "MUL",
185
+ "RESHAPE",
186
+ "RSQRT",
187
+ "SOFTMAX",
188
+ "SPLIT",
189
+ "STRIDED_SLICE",
190
+ "SUB",
191
+ "TANH",
192
+ "TRANSPOSE",
193
+ "INPUT",
194
+ "OUTPUT",
195
+ "SLICE",
196
+ "EMBEDDING_LOOKUP",
197
+ "SUM",
198
+ "SELECT_V2"
199
+ ],
200
+ "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
201
+ "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
202
+ "dynamic_wi8_afp32": [
203
+ "BATCH_MATMUL",
204
+ "CONV_2D",
205
+ "CONV_2D_TRANSPOSE",
206
+ "DEPTHWISE_CONV_2D",
207
+ "EMBEDDING_LOOKUP",
208
+ "FULLY_CONNECTED"
209
+ ],
210
+ "dynamic_wi4_afp32": ["FULLY_CONNECTED", "EMBEDDING_LOOKUP", "CONV_2D"],
211
+ "weightonly_wi8_afp32": [
212
+ "BATCH_MATMUL",
213
+ "CONV_2D",
214
+ "CONV_2D_TRANSPOSE",
215
+ "DEPTHWISE_CONV_2D",
216
+ "EMBEDDING_LOOKUP",
217
+ "FULLY_CONNECTED"
218
+ ],
219
+ "weightonly_wi4_afp32": ["BATCH_MATMUL", "FULLY_CONNECTED", "EMBEDDING_LOOKUP", "CONV_2D"]
220
+ }
221
+ }
222
+ """
223
+
224
+
225
+ def _unroll_json_config(
226
+ json_config: dict[str, Any],
227
+ ) -> list[_OpQuantizationConfig]:
228
+ """Unrolls the config.
229
+
230
+ Args:
231
+ json_config: JSON config to be unrolled.
232
+
233
+ Returns:
234
+ quant_configs: A list of unrolled configs.
235
+ """
236
+
237
+ # Unroll activation configs first.
238
+ activation_configs = []
239
+ if "activation_tensor_config" in json_config:
240
+ for symmetric in json_config["activation_tensor_config"]["symmetric"]:
241
+ for granularity in json_config["activation_tensor_config"]["granularity"]:
242
+ tensor_config = {
243
+ "num_bits": json_config["activation_tensor_config"]["num_bits"],
244
+ "symmetric": symmetric,
245
+ "granularity": granularity,
246
+ "dtype": json_config["activation_tensor_config"]["dtype"],
247
+ }
248
+ activation_configs.append(
249
+ qtyping.TensorQuantizationConfig.from_dict(tensor_config)
250
+ )
251
+
252
+ # Then unroll weight configs and turn them into quantization configs.
253
+ quant_configs = []
254
+ for symmetric in json_config["weight_tensor_config"]["symmetric"]:
255
+ for granularity in json_config["weight_tensor_config"]["granularity"]:
256
+ tensor_config = {
257
+ "num_bits": json_config["weight_tensor_config"]["num_bits"],
258
+ "symmetric": symmetric,
259
+ "granularity": granularity,
260
+ "dtype": json_config["weight_tensor_config"]["dtype"],
261
+ }
262
+
263
+ if activation_configs:
264
+ for activation_config in activation_configs:
265
+ quant_configs.append(
266
+ qtyping.OpQuantizationConfig(
267
+ activation_tensor_config=activation_config,
268
+ weight_tensor_config=qtyping.TensorQuantizationConfig.from_dict(
269
+ tensor_config
270
+ ),
271
+ compute_precision=json_config["compute_precision"],
272
+ explicit_dequantize=json_config["explicit_dequantize"],
273
+ )
274
+ )
275
+ else:
276
+ quant_configs.append(
277
+ qtyping.OpQuantizationConfig(
278
+ weight_tensor_config=qtyping.TensorQuantizationConfig.from_dict(
279
+ tensor_config
280
+ ),
281
+ compute_precision=json_config["compute_precision"],
282
+ explicit_dequantize=json_config["explicit_dequantize"],
283
+ )
284
+ )
285
+
286
+ return quant_configs
287
+
288
+
289
+ def update_default_config_policy(raw_json_policy: str):
290
+ """Updates the default config check policy."""
291
+ json_policy_content = json.loads(raw_json_policy)
292
+
293
+ # Unroll the json config and add the configs to the policy.
294
+ policy = collections.OrderedDict()
295
+ for json_policy_config in json_policy_content["ops_per_config"]:
296
+ unrolled_configs = _unroll_json_config(
297
+ json_policy_content["configs"][json_policy_config]
298
+ )
299
+
300
+ for op in json_policy_content["ops_per_config"][json_policy_config]:
301
+ op_name = _TFLOpName(op)
302
+ quant_configs = copy.deepcopy(unrolled_configs)
303
+ if op in policy.keys():
304
+ quant_configs += policy[op_name]
305
+ policy[op_name] = quant_configs
306
+
307
+ return policy
308
+
309
+
310
+ DEFAULT_CONFIG_CHECK_POLICY = update_default_config_policy(DEFAULT_JSON_POLICY)
@@ -0,0 +1,176 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Model Modifier class that produce the final quantized TFlite model."""
17
+
18
+ import copy
19
+
20
+ import numpy as np
21
+
22
+ from ai_edge_quantizer import qtyping
23
+ from ai_edge_quantizer import transformation_instruction_generator
24
+ from ai_edge_quantizer import transformation_performer
25
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
26
+ from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
27
+
28
+
29
+ class ModelModifier:
30
+ """Model Modifier class that produce the final quantized TFlite model."""
31
+
32
+ def __init__(self, float_tflite: bytes):
33
+ """Constructor.
34
+
35
+ Args:
36
+ float_tflite: the original TFlite model in bytearray or file path
37
+ """
38
+
39
+ self._model_content = float_tflite
40
+
41
+ self._constant_map = []
42
+ self._transformation_instruction_generator = (
43
+ transformation_instruction_generator.TransformationInstructionsGenerator()
44
+ )
45
+ self._transformation_performer = (
46
+ transformation_performer.TransformationPerformer()
47
+ )
48
+
49
+ def modify_model(
50
+ self, params: dict[str, qtyping.TensorTransformationParams]
51
+ ) -> bytearray:
52
+ """Modify the model.
53
+
54
+ Args:
55
+ params: a dictionary with tensor name and a list of tensor transformation
56
+ params
57
+
58
+ Returns:
59
+ a byte buffer that represents the serialized tflite model
60
+ """
61
+ quantized_model = copy.deepcopy(
62
+ flatbuffer_utils.read_model_from_bytearray(self._model_content)
63
+ )
64
+
65
+ instructions = self._transformation_instruction_generator.quant_params_to_transformation_insts(
66
+ params, quantized_model
67
+ )
68
+
69
+ self._transformation_performer.transform_graph(
70
+ instructions, quantized_model
71
+ )
72
+ constant_buffer_size = self._process_constant_map(quantized_model)
73
+ # we leave 64MB for the model architecture.
74
+ if constant_buffer_size > 2**31 - 2**26:
75
+ return self._serialize_large_model(quantized_model)
76
+ else:
77
+ return self._serialize_small_model(quantized_model)
78
+
79
+ def _process_constant_map(
80
+ self, quantized_model: schema_py_generated.ModelT
81
+ ) -> int:
82
+ """Process the constant map after all transformations are applied.
83
+
84
+ If the resulting model is > 2GB then we would need to serialize constants
85
+ separately, as such, we collect all the constant buffers using this
86
+ function.
87
+
88
+ Args:
89
+ quantized_model: a quantized TFlite ModelT
90
+
91
+ Returns:
92
+ an integer representing the total size of the constant buffers
93
+ """
94
+ buffer_size = 0
95
+ for buffer in quantized_model.buffers:
96
+ if buffer.data is None:
97
+ self._constant_map.append(buffer.data)
98
+ elif isinstance(buffer.data, np.ndarray):
99
+ self._constant_map.append(buffer.data.tobytes())
100
+ buffer_size += len(buffer.data.tobytes())
101
+ else:
102
+ self._constant_map.append(buffer.data)
103
+ buffer_size += len(buffer.data)
104
+ return buffer_size
105
+
106
+ def _pad_bytearray(self, bytearr: bytearray):
107
+ """Pad the bytearray to 16 bytes."""
108
+ remainder = len(bytearr) % 16
109
+ if remainder != 0:
110
+ padding_size = 16 - remainder
111
+ bytearr.extend(b'\0' * padding_size)
112
+
113
+ # TODO: b/333797307 - support > 2GB output model
114
+ def _serialize_large_model(
115
+ self, quantized_model: schema_py_generated.ModelT
116
+ ) -> bytearray:
117
+ """serialize models > 2GB.
118
+
119
+ Args:
120
+ quantized_model: a quantized TFlite ModelT
121
+
122
+ Returns:
123
+ a byte buffer that represents the serialized tflite model
124
+ """
125
+ # TODO: b/338244867 - we can have more efficient way to calculate the
126
+ # buffer offsets.
127
+
128
+ # remove all the constant from the model.
129
+ for buffer in quantized_model.buffers:
130
+ if buffer.data is not None:
131
+ buffer.data = None
132
+ buffer.offset = 1
133
+ buffer.size = 1
134
+ dummy_bytearray = bytearray(
135
+ flatbuffer_utils.convert_object_to_bytearray(quantized_model)
136
+ )
137
+ # calculate the correct buffer size and offset
138
+ self._pad_bytearray(dummy_bytearray)
139
+ for buffer_idx, buffer in enumerate(quantized_model.buffers):
140
+ buffer_data = self._constant_map[buffer_idx]
141
+ if buffer_data is None:
142
+ continue
143
+ buffer.offset = len(dummy_bytearray)
144
+ buffer.size = len(buffer_data)
145
+ dummy_bytearray += buffer_data
146
+ self._pad_bytearray(dummy_bytearray)
147
+ del dummy_bytearray
148
+
149
+ # build new tflite file with correct buffer offset
150
+ model_bytearray = bytearray(
151
+ flatbuffer_utils.convert_object_to_bytearray(quantized_model)
152
+ )
153
+ self._pad_bytearray(model_bytearray)
154
+ for buffer_idx, _ in enumerate(quantized_model.buffers):
155
+ buffer_data = self._constant_map[buffer_idx]
156
+ if buffer_data is None:
157
+ continue
158
+ model_bytearray += buffer_data
159
+ self._pad_bytearray(model_bytearray)
160
+ return model_bytearray
161
+
162
+ def _serialize_small_model(
163
+ self, quantized_model: schema_py_generated.ModelT
164
+ ) -> bytearray:
165
+ """serialize models < 2GB.
166
+
167
+ Args:
168
+ quantized_model: a quantized TFlite ModelT
169
+
170
+ Returns:
171
+ a byte buffer that represents the serialized tflite model
172
+ """
173
+ model_bytearray = flatbuffer_utils.convert_object_to_bytearray(
174
+ quantized_model
175
+ )
176
+ return model_bytearray
@@ -0,0 +1,130 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for model_modifier."""
17
+
18
+ import os
19
+ import tracemalloc
20
+ from tensorflow.python.platform import googletest
21
+ from absl.testing import parameterized
22
+ from ai_edge_quantizer import model_modifier
23
+ from ai_edge_quantizer import params_generator
24
+ from ai_edge_quantizer import qtyping
25
+ from ai_edge_quantizer import recipe_manager
26
+ from ai_edge_quantizer.utils import test_utils
27
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
+ from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
29
+
30
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
31
+
32
+
33
+ class ModelModifierTest(parameterized.TestCase):
34
+
35
+ def setUp(self):
36
+ super().setUp()
37
+ self._model_path = os.path.join(
38
+ TEST_DATA_PREFIX_PATH, 'tests/models/conv_fc_mnist.tflite'
39
+ )
40
+
41
+ self._model_content: bytes = tfl_flatbuffer_utils.get_model_content(
42
+ self._model_path
43
+ )
44
+ self._model_modifier = model_modifier.ModelModifier(self._model_content)
45
+ self._global_recipe = [
46
+ {
47
+ 'regex': '.*',
48
+ 'operation': 'FULLY_CONNECTED',
49
+ 'algorithm_key': 'min_max_uniform_quantize',
50
+ 'op_config': {
51
+ 'weight_tensor_config': {
52
+ 'dtype': qtyping.TensorDataType.INT,
53
+ 'num_bits': 8,
54
+ 'symmetric': False,
55
+ 'granularity': qtyping.QuantGranularity.CHANNELWISE,
56
+ 'block_size': 0,
57
+ },
58
+ # Equivalent to WEIGHT_ONLY.
59
+ 'compute_precision': qtyping.ComputePrecision.FLOAT,
60
+ 'explicit_dequantize': True,
61
+ },
62
+ },
63
+ ]
64
+
65
+ def test_process_constant_map_succeeds(self):
66
+ model_bytearray = flatbuffer_utils.read_model_from_bytearray(
67
+ self._model_content
68
+ )
69
+ constant_size = self._model_modifier._process_constant_map(model_bytearray)
70
+ self.assertEqual(constant_size, 202540)
71
+
72
+ def test_modify_model_succeeds_with_recipe(self):
73
+ recipe_manager_instance = recipe_manager.RecipeManager()
74
+ params_generator_instance = params_generator.ParamsGenerator(
75
+ self._model_path
76
+ )
77
+
78
+ recipe_manager_instance.load_quantization_recipe(self._global_recipe)
79
+ tensor_quantization_params = (
80
+ params_generator_instance.generate_quantization_parameters(
81
+ recipe_manager_instance
82
+ )
83
+ )
84
+ new_model_binary = self._model_modifier.modify_model(
85
+ tensor_quantization_params
86
+ )
87
+ flatbuffer_utils.convert_bytearray_to_object(new_model_binary)
88
+ self.assertLess(new_model_binary, self._model_content)
89
+
90
+ def test_modify_model_preserves_original_model(self):
91
+ recipe_manager_instance = recipe_manager.RecipeManager()
92
+ params_generator_instance = params_generator.ParamsGenerator(
93
+ self._model_path
94
+ )
95
+
96
+ recipe_manager_instance.load_quantization_recipe(self._global_recipe)
97
+ tensor_quantization_params = (
98
+ params_generator_instance.generate_quantization_parameters(
99
+ recipe_manager_instance
100
+ )
101
+ )
102
+ self.assertEqual(self._model_modifier._model_content, self._model_content)
103
+ self._model_modifier.modify_model(tensor_quantization_params)
104
+ self.assertEqual(self._model_modifier._model_content, self._model_content)
105
+
106
+ def test_modify_model_peak_memory_usage_in_acceptable_range(self):
107
+ """Test ModifyModel peak memory usage."""
108
+
109
+ recipe_manager_instance = recipe_manager.RecipeManager()
110
+ params_generator_instance = params_generator.ParamsGenerator(
111
+ self._model_path
112
+ )
113
+
114
+ recipe_manager_instance.load_quantization_recipe(self._global_recipe)
115
+ tensor_quantization_params = (
116
+ params_generator_instance.generate_quantization_parameters(
117
+ recipe_manager_instance
118
+ )
119
+ )
120
+
121
+ tracemalloc.start()
122
+ self._model_modifier.modify_model(tensor_quantization_params)
123
+ _, mem_peak = tracemalloc.get_traced_memory()
124
+
125
+ loosen_mem_use_factor = 4.5
126
+ self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor)
127
+
128
+
129
+ if __name__ == '__main__':
130
+ googletest.main()