ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import os
|
17
|
+
|
18
|
+
from absl.testing import parameterized
|
19
|
+
|
20
|
+
from tensorflow.python.platform import googletest
|
21
|
+
from ai_edge_quantizer import quantizer
|
22
|
+
from ai_edge_quantizer import recipe
|
23
|
+
from ai_edge_quantizer.utils import test_utils
|
24
|
+
|
25
|
+
|
26
|
+
_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
|
27
|
+
|
28
|
+
|
29
|
+
class RecipeTest(parameterized.TestCase):
|
30
|
+
|
31
|
+
def setUp(self):
|
32
|
+
super().setUp()
|
33
|
+
self._test_model_path = os.path.join(
|
34
|
+
_TEST_DATA_PREFIX_PATH,
|
35
|
+
'tests/models/single_conv2d_transpose_bias.tflite',
|
36
|
+
)
|
37
|
+
|
38
|
+
def _quantize_with_recipe_func(self, recipe_func):
|
39
|
+
qt = quantizer.Quantizer(self._test_model_path)
|
40
|
+
qt.load_quantization_recipe(recipe_func())
|
41
|
+
self.assertIsNone(qt._result.quantized_model)
|
42
|
+
quant_result = qt.quantize()
|
43
|
+
self.assertIsNotNone(quant_result.quantized_model)
|
44
|
+
return quant_result
|
45
|
+
|
46
|
+
def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self):
|
47
|
+
quant_result = self._quantize_with_recipe_func(recipe.dynamic_wi8_afp32)
|
48
|
+
self.assertLess(
|
49
|
+
len(quant_result.quantized_model),
|
50
|
+
os.path.getsize(self._test_model_path),
|
51
|
+
)
|
52
|
+
|
53
|
+
def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self):
|
54
|
+
quant_result = self._quantize_with_recipe_func(
|
55
|
+
recipe.dynamic_legacy_wi8_afp32
|
56
|
+
)
|
57
|
+
self.assertLen(
|
58
|
+
quant_result.quantized_model,
|
59
|
+
os.path.getsize(self._test_model_path),
|
60
|
+
)
|
61
|
+
|
62
|
+
@parameterized.named_parameters(
|
63
|
+
dict(
|
64
|
+
testcase_name='dynamic_wi8_afp32',
|
65
|
+
recipe_json_path='recipes/dynamic_wi8_afp32_recipe.json',
|
66
|
+
recipe_func=recipe.dynamic_wi8_afp32,
|
67
|
+
),
|
68
|
+
dict(
|
69
|
+
testcase_name='dynamic_legacy_wi8_afp32',
|
70
|
+
recipe_json_path='recipes/dynamic_legacy_wi8_afp32_recipe.json',
|
71
|
+
recipe_func=recipe.dynamic_legacy_wi8_afp32,
|
72
|
+
),
|
73
|
+
)
|
74
|
+
def test_recipe_func_and_json_matches(self, recipe_json_path, recipe_func):
|
75
|
+
# Quantize with recipe from function in recipe module.
|
76
|
+
quant_result_from_func = self._quantize_with_recipe_func(recipe_func)
|
77
|
+
|
78
|
+
# Quantize with recipe from json file.
|
79
|
+
qt_json = quantizer.Quantizer(self._test_model_path)
|
80
|
+
json_recipe_path = os.path.join(_TEST_DATA_PREFIX_PATH, recipe_json_path)
|
81
|
+
qt_json.load_quantization_recipe(json_recipe_path)
|
82
|
+
quant_result_from_json = qt_json.quantize()
|
83
|
+
self.assertIsNotNone(quant_result_from_json.quantized_model)
|
84
|
+
|
85
|
+
# Check if the recipes and quantized models match.
|
86
|
+
self.assertEqual(
|
87
|
+
quant_result_from_func.recipe,
|
88
|
+
quant_result_from_json.recipe,
|
89
|
+
)
|
90
|
+
self.assertEqual(
|
91
|
+
len(quant_result_from_func.quantized_model),
|
92
|
+
len(quant_result_from_json.quantized_model),
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
if __name__ == '__main__':
|
97
|
+
googletest.main()
|
@@ -0,0 +1,584 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Create transformation instructions for transformation_performer.
|
17
|
+
|
18
|
+
Given quantization parameters, create a list of transformation instructions that
|
19
|
+
can then be used by transformation_performer. Includes necessary optimizations
|
20
|
+
"""
|
21
|
+
|
22
|
+
from collections.abc import Iterator
|
23
|
+
import dataclasses
|
24
|
+
from typing import Optional
|
25
|
+
from ai_edge_quantizer import qtyping
|
26
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
27
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
28
|
+
|
29
|
+
|
30
|
+
# When a tensor has no producer, we'll assign -1 to the producer field
|
31
|
+
# When a tensor is a graph output, we'll also include a -1 in the consumer list
|
32
|
+
def check_horizontal_optimization(
|
33
|
+
param1: qtyping.OpToTensorParams,
|
34
|
+
param2: qtyping.OpToTensorParams,
|
35
|
+
index: int,
|
36
|
+
) -> bool:
|
37
|
+
"""check if horizontal optimization can be applied.
|
38
|
+
|
39
|
+
check if two transformations at the same index (which belongs to two
|
40
|
+
different
|
41
|
+
OpToTensorParams) can be merged together.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
param1: first parameters to be compared
|
45
|
+
param2: second parameters to be compared
|
46
|
+
index: the index for which the transformation will be compared
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
True if the two transformations can be merged, False otherwise
|
50
|
+
"""
|
51
|
+
return (
|
52
|
+
param1.parameters == param2.parameters
|
53
|
+
and len(param1.transformations) > index
|
54
|
+
and len(param2.transformations) > index
|
55
|
+
and param1.transformations[index] == param2.transformations[index]
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
def check_dq_q_elimination(
|
60
|
+
producer_inst: qtyping.TransformationInst,
|
61
|
+
consumer_inst: qtyping.TransformationInst,
|
62
|
+
) -> bool:
|
63
|
+
"""Check if a pair of dequantize & quantize transformation can be eliminated.
|
64
|
+
|
65
|
+
This can only happen when the dequantize & quantize have the same quant
|
66
|
+
parameters and dequantize belongs to producer and quantize belongs to a
|
67
|
+
consumer.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
producer_inst: TransformationInst from producer.
|
71
|
+
consumer_inst: TransformationInst from consumer.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
True if dequantize & quantize can be eliminated, False otherwise.
|
75
|
+
"""
|
76
|
+
is_dequantize_in_producer = (
|
77
|
+
producer_inst.transformation == qtyping.QuantTransformation.ADD_DEQUANTIZE
|
78
|
+
)
|
79
|
+
is_quantize_in_consumer = (
|
80
|
+
consumer_inst.transformation == qtyping.QuantTransformation.ADD_QUANTIZE
|
81
|
+
)
|
82
|
+
is_same_parameters = producer_inst.parameters == consumer_inst.parameters
|
83
|
+
return (
|
84
|
+
is_dequantize_in_producer
|
85
|
+
and is_quantize_in_consumer
|
86
|
+
and is_same_parameters
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
def check_replace_dq_q_with_rq(
|
91
|
+
producer_inst: qtyping.TransformationInst,
|
92
|
+
consumer_inst: qtyping.TransformationInst,
|
93
|
+
) -> bool:
|
94
|
+
"""Check if a pair of dequantize & quantize can be replaced by a requantize.
|
95
|
+
|
96
|
+
This can only happen when the dequantize belongs to producer and quantize
|
97
|
+
belongs to a consumer.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
producer_inst: TransformationInst from producer.
|
101
|
+
consumer_inst: TransformationInst from consumer.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
True if dequantize & quantize can be replaced, False otherwise.
|
105
|
+
Note that we consider the case where DQ & Q can be eliminated as a false
|
106
|
+
case.
|
107
|
+
"""
|
108
|
+
is_dequantize_in_producer = (
|
109
|
+
producer_inst.transformation == qtyping.QuantTransformation.ADD_DEQUANTIZE
|
110
|
+
)
|
111
|
+
is_quantize_in_consumer = (
|
112
|
+
consumer_inst.transformation == qtyping.QuantTransformation.ADD_QUANTIZE
|
113
|
+
)
|
114
|
+
is_same_parameters = producer_inst.parameters == consumer_inst.parameters
|
115
|
+
|
116
|
+
return (
|
117
|
+
is_dequantize_in_producer
|
118
|
+
and is_quantize_in_consumer
|
119
|
+
and not is_same_parameters
|
120
|
+
)
|
121
|
+
|
122
|
+
|
123
|
+
def check_dq_no_quant_elimination(
|
124
|
+
producer_inst: qtyping.TransformationInst,
|
125
|
+
consumer_inst: qtyping.TransformationInst,
|
126
|
+
) -> bool:
|
127
|
+
"""Check if a pair of dequantize & no quantize transformation can be eliminated.
|
128
|
+
|
129
|
+
This can only happen when the dequantize belongs to producer and no quantize
|
130
|
+
belongs to a consumer.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
producer_inst: TransformationInst from producer.
|
134
|
+
consumer_inst: TransformationInst from consumer.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
True if dequantize & no quantize can be eliminated, False otherwise.
|
138
|
+
"""
|
139
|
+
is_dequantize_in_producer = (
|
140
|
+
producer_inst.transformation == qtyping.QuantTransformation.ADD_DEQUANTIZE
|
141
|
+
)
|
142
|
+
is_no_quant_in_consumer = (
|
143
|
+
consumer_inst.transformation == qtyping.QuantTransformation.NO_QUANTIZE
|
144
|
+
)
|
145
|
+
return is_dequantize_in_producer and is_no_quant_in_consumer
|
146
|
+
|
147
|
+
|
148
|
+
class TransformationInstructionsGenerator:
|
149
|
+
"""Generates transformation instructions from tensor quant params."""
|
150
|
+
|
151
|
+
def __init__(self, float_tflite: Optional[str] = None):
|
152
|
+
"""Constructor.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
float_tflite: the original TFlite model in bytearray or file path.
|
156
|
+
"""
|
157
|
+
if float_tflite is None:
|
158
|
+
self._tensor_name_to_graph_info: dict[
|
159
|
+
str, TransformationInstructionsGenerator.TensorGraphInfo
|
160
|
+
] = {}
|
161
|
+
self.flatbuffer_model: schema_py_generated.ModelT = ()
|
162
|
+
else:
|
163
|
+
self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
164
|
+
self._create_tensor_name_to_graph_info_map()
|
165
|
+
|
166
|
+
@dataclasses.dataclass(frozen=True)
|
167
|
+
class TensorGraphInfo:
|
168
|
+
tensor_id: int
|
169
|
+
subgraph_id: int
|
170
|
+
producer: int
|
171
|
+
consumers: list[int]
|
172
|
+
|
173
|
+
def _tensor_info_generator(
|
174
|
+
self, subgraph_id: int, subgraph: schema_py_generated.SubGraphT
|
175
|
+
) -> Iterator[tuple[str, TensorGraphInfo]]:
|
176
|
+
"""Generator function for tensor info.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
subgraph_id: Index for the given subgraph,
|
180
|
+
subgraph: Subgraph struct to generate tensor info on.
|
181
|
+
|
182
|
+
Yields:
|
183
|
+
A tuple of tensor_name and TensorGraphInfo.
|
184
|
+
"""
|
185
|
+
for tensor_id, tensor in enumerate(subgraph.tensors):
|
186
|
+
consumers = [
|
187
|
+
op_id
|
188
|
+
for (op_id, op) in enumerate(subgraph.operators)
|
189
|
+
if tensor_id in op.inputs
|
190
|
+
]
|
191
|
+
producer = -1
|
192
|
+
for op_id, op in enumerate(subgraph.operators):
|
193
|
+
if tensor_id in op.outputs:
|
194
|
+
producer = op_id
|
195
|
+
break
|
196
|
+
if tensor_id in subgraph.outputs:
|
197
|
+
consumers.insert(0, -1)
|
198
|
+
tensor_info = self.TensorGraphInfo(
|
199
|
+
tensor_id, subgraph_id, producer, consumers
|
200
|
+
)
|
201
|
+
tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
|
202
|
+
yield tensor_name, tensor_info
|
203
|
+
|
204
|
+
def _create_tensor_name_to_graph_info_map(self):
|
205
|
+
"""Create a mapping between tensor name and tensor info."""
|
206
|
+
self._tensor_name_to_graph_info = {}
|
207
|
+
# TODO: b/333607428 - support graph input & output
|
208
|
+
for subgraph_id, subgraph in enumerate(self.flatbuffer_model.subgraphs):
|
209
|
+
for tensor_name, tensor_info in self._tensor_info_generator(
|
210
|
+
subgraph_id, subgraph
|
211
|
+
):
|
212
|
+
self._tensor_name_to_graph_info[tensor_name] = tensor_info
|
213
|
+
|
214
|
+
def _group_consumer_transformations(
|
215
|
+
self, param: qtyping.TensorTransformationParams
|
216
|
+
) -> list[list[set[int]]]:
|
217
|
+
"""Group transformations between consumers into common groups.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
param: TensorTransformationParams for a tensor
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
A list of list of sets where the set represents indices of transformations
|
224
|
+
that can be merged horizontally
|
225
|
+
E.g:
|
226
|
+
For the following consumer:
|
227
|
+
[(1, [ADD_QUANTIZE, ADD_DEQUANTIZE], param1),
|
228
|
+
(2, [ADD_QUANTIZE], param2),
|
229
|
+
(3, [ADD_QUANTIZE], param1)]
|
230
|
+
this function returns:
|
231
|
+
[[{1, 2, 3}],
|
232
|
+
[{1, 3}, {2}],
|
233
|
+
[{1}]]
|
234
|
+
|
235
|
+
Where the 0 depth list is the initial state, since all consumer comes
|
236
|
+
from the same producer.
|
237
|
+
In depth 1, the ADD_QUANTIZE in 1 & 3 can be merged, so they are in the
|
238
|
+
same group
|
239
|
+
In depth 2, there is only one transformation from 1, so there is only
|
240
|
+
one group with 1 in there
|
241
|
+
"""
|
242
|
+
if not param or not param.consumers:
|
243
|
+
return []
|
244
|
+
|
245
|
+
# consumer group contains indices of operations that can be horizontally
|
246
|
+
# optimized together. The outermost list is the depth of the transformation
|
247
|
+
# and the second list contains sets that represents the consumer indices
|
248
|
+
# that can be grouped together at the given depth
|
249
|
+
consumer_groups = [[set()]]
|
250
|
+
# the max number of transformations applied before a particular consumer
|
251
|
+
longest_trans_chain = 0
|
252
|
+
for i, consumer_param in enumerate(param.consumers):
|
253
|
+
consumer_groups[0][0].add(i)
|
254
|
+
longest_trans_chain = max(
|
255
|
+
longest_trans_chain, len(consumer_param.transformations)
|
256
|
+
)
|
257
|
+
|
258
|
+
# looping over transformations of the same depth
|
259
|
+
for transformation_depth in range(longest_trans_chain):
|
260
|
+
next_depth_groups = []
|
261
|
+
for consumer_param_index, consumer_param in enumerate(param.consumers):
|
262
|
+
if len(consumer_param.transformations) > transformation_depth:
|
263
|
+
for current_depth_groups in consumer_groups[transformation_depth]:
|
264
|
+
if consumer_param_index in current_depth_groups:
|
265
|
+
# if the transformation of the particular edge has been processed
|
266
|
+
trans_assigned = False
|
267
|
+
for new_group in next_depth_groups:
|
268
|
+
# get an index in the existing group, any of them work since
|
269
|
+
# they have the same quantization
|
270
|
+
index = next(iter(new_group))
|
271
|
+
if (
|
272
|
+
index in current_depth_groups
|
273
|
+
and check_horizontal_optimization(
|
274
|
+
param.consumers[index],
|
275
|
+
consumer_param,
|
276
|
+
transformation_depth,
|
277
|
+
)
|
278
|
+
):
|
279
|
+
new_group.add(consumer_param_index)
|
280
|
+
trans_assigned = True
|
281
|
+
break
|
282
|
+
if not trans_assigned:
|
283
|
+
next_depth_groups.append(set([consumer_param_index]))
|
284
|
+
consumer_groups.append(next_depth_groups)
|
285
|
+
return consumer_groups
|
286
|
+
|
287
|
+
def _produce_transformation_for_vertical_opt(
|
288
|
+
self,
|
289
|
+
consumer_group: list[list[set[int]]],
|
290
|
+
param: qtyping.TensorTransformationParams,
|
291
|
+
) -> list[qtyping.TransformationInst]:
|
292
|
+
"""Create a list of transformation rules available for vertical optimization.
|
293
|
+
|
294
|
+
A consumer transformation is available to vertical transformation IFF it's
|
295
|
+
the first transformation for a given consumer.
|
296
|
+
|
297
|
+
This function relies on the consumer_group argument already being optimized
|
298
|
+
for horizontal transformations.
|
299
|
+
|
300
|
+
Args:
|
301
|
+
consumer_group: a list of grouped indices for consumer transformationns
|
302
|
+
param: a TensorTransformationParams for the tensor
|
303
|
+
|
304
|
+
Returns:
|
305
|
+
A list of transformation rules available for vertical optimization
|
306
|
+
"""
|
307
|
+
tensor_info = self._tensor_name_to_graph_info[param.tensor_name]
|
308
|
+
transformations_available_for_vertical_optimization = []
|
309
|
+
# we start at 1 because consumer groups in index 0 is the inital state
|
310
|
+
# and does not contain actual information
|
311
|
+
if len(consumer_group) > 1:
|
312
|
+
for group in consumer_group[1]:
|
313
|
+
op_list = list(group)
|
314
|
+
op_idx_list = []
|
315
|
+
for index in op_list:
|
316
|
+
op_idx_list.append(param.consumers[index].subgraph_op_id)
|
317
|
+
transformations_available_for_vertical_optimization.append(
|
318
|
+
qtyping.TransformationInst(
|
319
|
+
param.consumers[op_list[0]].transformations[0],
|
320
|
+
tensor_info.tensor_id,
|
321
|
+
tensor_info.producer,
|
322
|
+
op_idx_list,
|
323
|
+
param.consumers[op_list[0]].parameters,
|
324
|
+
)
|
325
|
+
)
|
326
|
+
return transformations_available_for_vertical_optimization
|
327
|
+
|
328
|
+
def _produce_consumer_transformations_unavailable_for_vertical_opt(
|
329
|
+
self,
|
330
|
+
consumer_group: list[list[set[int]]],
|
331
|
+
param: qtyping.TensorTransformationParams,
|
332
|
+
) -> list[qtyping.TransformationInst]:
|
333
|
+
"""Produce a list of consumer transformation that can't be used for vertical optimization.
|
334
|
+
|
335
|
+
A consumer transformation is available to vertical optimization if and only
|
336
|
+
if it's the first transformation for a given consumer.
|
337
|
+
|
338
|
+
This function relies on the consumer_group argument already being optimized
|
339
|
+
for horizontal transformations
|
340
|
+
|
341
|
+
Args:
|
342
|
+
consumer_group: a list of grouped indices for consumer transformationns
|
343
|
+
param: a TensorTransformationParams for the tensor
|
344
|
+
|
345
|
+
Returns:
|
346
|
+
A list of transformation rules unavailable for vertical optimization
|
347
|
+
"""
|
348
|
+
tensor_info = self._tensor_name_to_graph_info[param.tensor_name]
|
349
|
+
other_consumer_transformations = []
|
350
|
+
for transformation_idx in range(2, len(consumer_group)):
|
351
|
+
for group in consumer_group[transformation_idx]:
|
352
|
+
op_list = list(group)
|
353
|
+
op_idx_list = []
|
354
|
+
if (
|
355
|
+
len(param.consumers[op_list[0]].transformations)
|
356
|
+
<= transformation_idx - 1
|
357
|
+
):
|
358
|
+
continue
|
359
|
+
for index in op_list:
|
360
|
+
op_idx_list.append(param.consumers[index].subgraph_op_id)
|
361
|
+
other_consumer_transformations.append(
|
362
|
+
qtyping.TransformationInst(
|
363
|
+
param.consumers[op_list[0]].transformations[
|
364
|
+
transformation_idx - 1
|
365
|
+
],
|
366
|
+
tensor_info.tensor_id,
|
367
|
+
tensor_info.producer,
|
368
|
+
op_idx_list,
|
369
|
+
param.consumers[op_list[0]].parameters,
|
370
|
+
)
|
371
|
+
)
|
372
|
+
return other_consumer_transformations
|
373
|
+
|
374
|
+
def _apply_vertical_optimization(
|
375
|
+
self,
|
376
|
+
producer_trans_rule: qtyping.TransformationInst,
|
377
|
+
consumer_trans_rules: list[qtyping.TransformationInst],
|
378
|
+
) -> list[qtyping.TransformationInst]:
|
379
|
+
"""Apply vertical optimization.
|
380
|
+
|
381
|
+
There are two types of transformations we consider:
|
382
|
+
1. when DQ & Q has the same parameter eliminate the operators and quantize
|
383
|
+
the tensor only
|
384
|
+
2. when DQ & Q has different parameters, then replace the DQ & Q with an
|
385
|
+
RQ op
|
386
|
+
|
387
|
+
vertical optimization can only happen with the last producer rules and the
|
388
|
+
first consumer rules that are on the first.
|
389
|
+
|
390
|
+
Args:
|
391
|
+
producer_trans_rule: the last producer transformation rules.
|
392
|
+
consumer_trans_rules: a list of consumer transformation rules that are
|
393
|
+
avilable for vertical transformations.
|
394
|
+
|
395
|
+
Returns:
|
396
|
+
A list of transformations after vertical optimization has been applied,
|
397
|
+
note producer transformation is included.
|
398
|
+
"""
|
399
|
+
transformations = []
|
400
|
+
for trans_rule in consumer_trans_rules:
|
401
|
+
if check_dq_q_elimination(producer_trans_rule, trans_rule):
|
402
|
+
for consumer_id in trans_rule.consumers:
|
403
|
+
if consumer_id in producer_trans_rule.consumers:
|
404
|
+
producer_trans_rule.consumers.remove(consumer_id)
|
405
|
+
transformations.append(
|
406
|
+
qtyping.TransformationInst(
|
407
|
+
qtyping.QuantTransformation.QUANTIZE_TENSOR,
|
408
|
+
trans_rule.tensor_id,
|
409
|
+
trans_rule.producer,
|
410
|
+
trans_rule.consumers,
|
411
|
+
trans_rule.parameters,
|
412
|
+
)
|
413
|
+
)
|
414
|
+
continue
|
415
|
+
elif check_replace_dq_q_with_rq(producer_trans_rule, trans_rule):
|
416
|
+
for consumer_id in trans_rule.consumers:
|
417
|
+
producer_trans_rule.consumers.remove(consumer_id)
|
418
|
+
transformations.append(
|
419
|
+
qtyping.TransformationInst(
|
420
|
+
qtyping.QuantTransformation.QUANTIZE_TENSOR,
|
421
|
+
trans_rule.tensor_id,
|
422
|
+
trans_rule.producer,
|
423
|
+
trans_rule.consumers,
|
424
|
+
producer_trans_rule.parameters,
|
425
|
+
)
|
426
|
+
)
|
427
|
+
transformations.append(
|
428
|
+
qtyping.TransformationInst(
|
429
|
+
qtyping.QuantTransformation.ADD_QUANTIZE,
|
430
|
+
trans_rule.tensor_id,
|
431
|
+
trans_rule.producer,
|
432
|
+
trans_rule.consumers,
|
433
|
+
trans_rule.parameters,
|
434
|
+
)
|
435
|
+
)
|
436
|
+
continue
|
437
|
+
elif check_dq_no_quant_elimination(producer_trans_rule, trans_rule):
|
438
|
+
for consumer_id in trans_rule.consumers:
|
439
|
+
if consumer_id in producer_trans_rule.consumers:
|
440
|
+
producer_trans_rule.consumers.remove(consumer_id)
|
441
|
+
transformations.append(
|
442
|
+
qtyping.TransformationInst(
|
443
|
+
qtyping.QuantTransformation.ADD_DEQUANTIZE,
|
444
|
+
trans_rule.tensor_id,
|
445
|
+
trans_rule.producer,
|
446
|
+
trans_rule.consumers,
|
447
|
+
producer_trans_rule.parameters,
|
448
|
+
)
|
449
|
+
)
|
450
|
+
continue
|
451
|
+
else:
|
452
|
+
transformations.append(trans_rule)
|
453
|
+
if producer_trans_rule.consumers:
|
454
|
+
transformations.insert(0, producer_trans_rule)
|
455
|
+
return transformations
|
456
|
+
|
457
|
+
def _quant_params_to_transformation_insts(
|
458
|
+
self,
|
459
|
+
param: qtyping.TensorTransformationParams,
|
460
|
+
) -> qtyping.TensorTransformationInsts:
|
461
|
+
"""Converts a single quantization params to transformation instructions.
|
462
|
+
|
463
|
+
Args:
|
464
|
+
param: quantization parameter of a tensor in the graph
|
465
|
+
|
466
|
+
Returns:
|
467
|
+
a list of transformations to be applied to the same tensor
|
468
|
+
"""
|
469
|
+
# setup the structure
|
470
|
+
tensor_info = self._tensor_name_to_graph_info[param.tensor_name]
|
471
|
+
tensor_trans_insts = qtyping.TensorTransformationInsts(
|
472
|
+
param.tensor_name, tensor_info.subgraph_id, []
|
473
|
+
)
|
474
|
+
|
475
|
+
# horizontal optimization
|
476
|
+
consumer_group = self._group_consumer_transformations(param)
|
477
|
+
# at this point, starting from index 1 of consumer_group, we're having sets
|
478
|
+
# that represents transformations that can be grouped together
|
479
|
+
transformations_available_for_vertical_optimization = (
|
480
|
+
self._produce_transformation_for_vertical_opt(consumer_group, param)
|
481
|
+
)
|
482
|
+
other_consumer_transformations = (
|
483
|
+
self._produce_consumer_transformations_unavailable_for_vertical_opt(
|
484
|
+
consumer_group, param
|
485
|
+
)
|
486
|
+
)
|
487
|
+
|
488
|
+
transformations = []
|
489
|
+
# adding all producer rules
|
490
|
+
producer_params = param.producer
|
491
|
+
if producer_params:
|
492
|
+
for transformation in producer_params.transformations:
|
493
|
+
transformations.append(
|
494
|
+
qtyping.TransformationInst(
|
495
|
+
transformation,
|
496
|
+
tensor_info.tensor_id,
|
497
|
+
tensor_info.producer,
|
498
|
+
tensor_info.consumers,
|
499
|
+
producer_params.parameters,
|
500
|
+
)
|
501
|
+
)
|
502
|
+
|
503
|
+
# apply vertical optimization
|
504
|
+
last_producer_rule_idx = len(transformations) - 1
|
505
|
+
if last_producer_rule_idx >= 0:
|
506
|
+
transformations += self._apply_vertical_optimization(
|
507
|
+
transformations.pop(),
|
508
|
+
transformations_available_for_vertical_optimization,
|
509
|
+
)
|
510
|
+
else:
|
511
|
+
transformations += transformations_available_for_vertical_optimization
|
512
|
+
# Adding other consumers rules
|
513
|
+
transformations += other_consumer_transformations
|
514
|
+
tensor_trans_insts.instructions = transformations
|
515
|
+
# Check the generated transformation instructions are valid, the function
|
516
|
+
# will raise an error if the instructions are not valid
|
517
|
+
self._check_tensor_transformation_instructions_valid(tensor_trans_insts)
|
518
|
+
|
519
|
+
return tensor_trans_insts
|
520
|
+
|
521
|
+
def _check_tensor_transformation_instructions_valid(
|
522
|
+
self, instructions: qtyping.TensorTransformationInsts
|
523
|
+
):
|
524
|
+
"""Check if the tensor transformation instructions are valid.
|
525
|
+
|
526
|
+
Args:
|
527
|
+
instructions: Transformation instructions for a tensor.
|
528
|
+
|
529
|
+
Raises:
|
530
|
+
ValueError: If the instructions are not valid.
|
531
|
+
"""
|
532
|
+
is_tensor_unquantized = False
|
533
|
+
is_tensor_quantized = False
|
534
|
+
is_operator_emulated = False
|
535
|
+
for instruction in instructions.instructions:
|
536
|
+
transform_type = instruction.transformation
|
537
|
+
if transform_type == qtyping.QuantTransformation.NO_QUANTIZE:
|
538
|
+
is_tensor_unquantized = True
|
539
|
+
elif (
|
540
|
+
transform_type == qtyping.QuantTransformation.QUANTIZE_TENSOR
|
541
|
+
or transform_type == qtyping.QuantTransformation.ADD_DEQUANTIZE
|
542
|
+
):
|
543
|
+
is_tensor_quantized = True
|
544
|
+
elif transform_type == qtyping.QuantTransformation.EMULATED_SUBCHANNEL:
|
545
|
+
is_operator_emulated = True
|
546
|
+
if is_tensor_unquantized and is_tensor_quantized:
|
547
|
+
raise ValueError(
|
548
|
+
"Tensor %s can not be both quantized and unquantized"
|
549
|
+
% instructions.tensor_name
|
550
|
+
)
|
551
|
+
if is_operator_emulated and len(instructions.instructions) > 1:
|
552
|
+
raise ValueError(
|
553
|
+
"Tensor %s : op replacement transformation can not be combined with"
|
554
|
+
" other transformations."
|
555
|
+
% instructions.tensor_name
|
556
|
+
)
|
557
|
+
|
558
|
+
def quant_params_to_transformation_insts(
|
559
|
+
self,
|
560
|
+
params: dict[str, qtyping.TensorTransformationParams],
|
561
|
+
flatbuffer_model: Optional[schema_py_generated.ModelT] = None,
|
562
|
+
) -> dict[str, qtyping.TensorTransformationInsts]:
|
563
|
+
"""Converts quantization params to transformation instructions.
|
564
|
+
|
565
|
+
Args:
|
566
|
+
params: quantization parameters generated by params_generator. The data
|
567
|
+
type is designed to be the same as the output of
|
568
|
+
generate_quantization_parameters.
|
569
|
+
flatbuffer_model: the flatbuffer model to be quantized.
|
570
|
+
|
571
|
+
Returns:
|
572
|
+
a dictionary with tensor name as key and transformation instructions as
|
573
|
+
value
|
574
|
+
"""
|
575
|
+
if flatbuffer_model is not None:
|
576
|
+
self.flatbuffer_model = flatbuffer_model
|
577
|
+
self._create_tensor_name_to_graph_info_map()
|
578
|
+
|
579
|
+
insts = {}
|
580
|
+
for tensor_name in params:
|
581
|
+
insts[tensor_name] = self._quant_params_to_transformation_insts(
|
582
|
+
params[tensor_name]
|
583
|
+
)
|
584
|
+
return insts
|