ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,97 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+
18
+ from absl.testing import parameterized
19
+
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import quantizer
22
+ from ai_edge_quantizer import recipe
23
+ from ai_edge_quantizer.utils import test_utils
24
+
25
+
26
+ _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')
27
+
28
+
29
+ class RecipeTest(parameterized.TestCase):
30
+
31
+ def setUp(self):
32
+ super().setUp()
33
+ self._test_model_path = os.path.join(
34
+ _TEST_DATA_PREFIX_PATH,
35
+ 'tests/models/single_conv2d_transpose_bias.tflite',
36
+ )
37
+
38
+ def _quantize_with_recipe_func(self, recipe_func):
39
+ qt = quantizer.Quantizer(self._test_model_path)
40
+ qt.load_quantization_recipe(recipe_func())
41
+ self.assertIsNone(qt._result.quantized_model)
42
+ quant_result = qt.quantize()
43
+ self.assertIsNotNone(quant_result.quantized_model)
44
+ return quant_result
45
+
46
+ def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self):
47
+ quant_result = self._quantize_with_recipe_func(recipe.dynamic_wi8_afp32)
48
+ self.assertLess(
49
+ len(quant_result.quantized_model),
50
+ os.path.getsize(self._test_model_path),
51
+ )
52
+
53
+ def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self):
54
+ quant_result = self._quantize_with_recipe_func(
55
+ recipe.dynamic_legacy_wi8_afp32
56
+ )
57
+ self.assertLen(
58
+ quant_result.quantized_model,
59
+ os.path.getsize(self._test_model_path),
60
+ )
61
+
62
+ @parameterized.named_parameters(
63
+ dict(
64
+ testcase_name='dynamic_wi8_afp32',
65
+ recipe_json_path='recipes/dynamic_wi8_afp32_recipe.json',
66
+ recipe_func=recipe.dynamic_wi8_afp32,
67
+ ),
68
+ dict(
69
+ testcase_name='dynamic_legacy_wi8_afp32',
70
+ recipe_json_path='recipes/dynamic_legacy_wi8_afp32_recipe.json',
71
+ recipe_func=recipe.dynamic_legacy_wi8_afp32,
72
+ ),
73
+ )
74
+ def test_recipe_func_and_json_matches(self, recipe_json_path, recipe_func):
75
+ # Quantize with recipe from function in recipe module.
76
+ quant_result_from_func = self._quantize_with_recipe_func(recipe_func)
77
+
78
+ # Quantize with recipe from json file.
79
+ qt_json = quantizer.Quantizer(self._test_model_path)
80
+ json_recipe_path = os.path.join(_TEST_DATA_PREFIX_PATH, recipe_json_path)
81
+ qt_json.load_quantization_recipe(json_recipe_path)
82
+ quant_result_from_json = qt_json.quantize()
83
+ self.assertIsNotNone(quant_result_from_json.quantized_model)
84
+
85
+ # Check if the recipes and quantized models match.
86
+ self.assertEqual(
87
+ quant_result_from_func.recipe,
88
+ quant_result_from_json.recipe,
89
+ )
90
+ self.assertEqual(
91
+ len(quant_result_from_func.quantized_model),
92
+ len(quant_result_from_json.quantized_model),
93
+ )
94
+
95
+
96
+ if __name__ == '__main__':
97
+ googletest.main()
@@ -0,0 +1,584 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Create transformation instructions for transformation_performer.
17
+
18
+ Given quantization parameters, create a list of transformation instructions that
19
+ can then be used by transformation_performer. Includes necessary optimizations
20
+ """
21
+
22
+ from collections.abc import Iterator
23
+ import dataclasses
24
+ from typing import Optional
25
+ from ai_edge_quantizer import qtyping
26
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
27
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
28
+
29
+
30
+ # When a tensor has no producer, we'll assign -1 to the producer field
31
+ # When a tensor is a graph output, we'll also include a -1 in the consumer list
32
+ def check_horizontal_optimization(
33
+ param1: qtyping.OpToTensorParams,
34
+ param2: qtyping.OpToTensorParams,
35
+ index: int,
36
+ ) -> bool:
37
+ """check if horizontal optimization can be applied.
38
+
39
+ check if two transformations at the same index (which belongs to two
40
+ different
41
+ OpToTensorParams) can be merged together.
42
+
43
+ Args:
44
+ param1: first parameters to be compared
45
+ param2: second parameters to be compared
46
+ index: the index for which the transformation will be compared
47
+
48
+ Returns:
49
+ True if the two transformations can be merged, False otherwise
50
+ """
51
+ return (
52
+ param1.parameters == param2.parameters
53
+ and len(param1.transformations) > index
54
+ and len(param2.transformations) > index
55
+ and param1.transformations[index] == param2.transformations[index]
56
+ )
57
+
58
+
59
+ def check_dq_q_elimination(
60
+ producer_inst: qtyping.TransformationInst,
61
+ consumer_inst: qtyping.TransformationInst,
62
+ ) -> bool:
63
+ """Check if a pair of dequantize & quantize transformation can be eliminated.
64
+
65
+ This can only happen when the dequantize & quantize have the same quant
66
+ parameters and dequantize belongs to producer and quantize belongs to a
67
+ consumer.
68
+
69
+ Args:
70
+ producer_inst: TransformationInst from producer.
71
+ consumer_inst: TransformationInst from consumer.
72
+
73
+ Returns:
74
+ True if dequantize & quantize can be eliminated, False otherwise.
75
+ """
76
+ is_dequantize_in_producer = (
77
+ producer_inst.transformation == qtyping.QuantTransformation.ADD_DEQUANTIZE
78
+ )
79
+ is_quantize_in_consumer = (
80
+ consumer_inst.transformation == qtyping.QuantTransformation.ADD_QUANTIZE
81
+ )
82
+ is_same_parameters = producer_inst.parameters == consumer_inst.parameters
83
+ return (
84
+ is_dequantize_in_producer
85
+ and is_quantize_in_consumer
86
+ and is_same_parameters
87
+ )
88
+
89
+
90
+ def check_replace_dq_q_with_rq(
91
+ producer_inst: qtyping.TransformationInst,
92
+ consumer_inst: qtyping.TransformationInst,
93
+ ) -> bool:
94
+ """Check if a pair of dequantize & quantize can be replaced by a requantize.
95
+
96
+ This can only happen when the dequantize belongs to producer and quantize
97
+ belongs to a consumer.
98
+
99
+ Args:
100
+ producer_inst: TransformationInst from producer.
101
+ consumer_inst: TransformationInst from consumer.
102
+
103
+ Returns:
104
+ True if dequantize & quantize can be replaced, False otherwise.
105
+ Note that we consider the case where DQ & Q can be eliminated as a false
106
+ case.
107
+ """
108
+ is_dequantize_in_producer = (
109
+ producer_inst.transformation == qtyping.QuantTransformation.ADD_DEQUANTIZE
110
+ )
111
+ is_quantize_in_consumer = (
112
+ consumer_inst.transformation == qtyping.QuantTransformation.ADD_QUANTIZE
113
+ )
114
+ is_same_parameters = producer_inst.parameters == consumer_inst.parameters
115
+
116
+ return (
117
+ is_dequantize_in_producer
118
+ and is_quantize_in_consumer
119
+ and not is_same_parameters
120
+ )
121
+
122
+
123
+ def check_dq_no_quant_elimination(
124
+ producer_inst: qtyping.TransformationInst,
125
+ consumer_inst: qtyping.TransformationInst,
126
+ ) -> bool:
127
+ """Check if a pair of dequantize & no quantize transformation can be eliminated.
128
+
129
+ This can only happen when the dequantize belongs to producer and no quantize
130
+ belongs to a consumer.
131
+
132
+ Args:
133
+ producer_inst: TransformationInst from producer.
134
+ consumer_inst: TransformationInst from consumer.
135
+
136
+ Returns:
137
+ True if dequantize & no quantize can be eliminated, False otherwise.
138
+ """
139
+ is_dequantize_in_producer = (
140
+ producer_inst.transformation == qtyping.QuantTransformation.ADD_DEQUANTIZE
141
+ )
142
+ is_no_quant_in_consumer = (
143
+ consumer_inst.transformation == qtyping.QuantTransformation.NO_QUANTIZE
144
+ )
145
+ return is_dequantize_in_producer and is_no_quant_in_consumer
146
+
147
+
148
+ class TransformationInstructionsGenerator:
149
+ """Generates transformation instructions from tensor quant params."""
150
+
151
+ def __init__(self, float_tflite: Optional[str] = None):
152
+ """Constructor.
153
+
154
+ Args:
155
+ float_tflite: the original TFlite model in bytearray or file path.
156
+ """
157
+ if float_tflite is None:
158
+ self._tensor_name_to_graph_info: dict[
159
+ str, TransformationInstructionsGenerator.TensorGraphInfo
160
+ ] = {}
161
+ self.flatbuffer_model: schema_py_generated.ModelT = ()
162
+ else:
163
+ self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
164
+ self._create_tensor_name_to_graph_info_map()
165
+
166
+ @dataclasses.dataclass(frozen=True)
167
+ class TensorGraphInfo:
168
+ tensor_id: int
169
+ subgraph_id: int
170
+ producer: int
171
+ consumers: list[int]
172
+
173
+ def _tensor_info_generator(
174
+ self, subgraph_id: int, subgraph: schema_py_generated.SubGraphT
175
+ ) -> Iterator[tuple[str, TensorGraphInfo]]:
176
+ """Generator function for tensor info.
177
+
178
+ Args:
179
+ subgraph_id: Index for the given subgraph,
180
+ subgraph: Subgraph struct to generate tensor info on.
181
+
182
+ Yields:
183
+ A tuple of tensor_name and TensorGraphInfo.
184
+ """
185
+ for tensor_id, tensor in enumerate(subgraph.tensors):
186
+ consumers = [
187
+ op_id
188
+ for (op_id, op) in enumerate(subgraph.operators)
189
+ if tensor_id in op.inputs
190
+ ]
191
+ producer = -1
192
+ for op_id, op in enumerate(subgraph.operators):
193
+ if tensor_id in op.outputs:
194
+ producer = op_id
195
+ break
196
+ if tensor_id in subgraph.outputs:
197
+ consumers.insert(0, -1)
198
+ tensor_info = self.TensorGraphInfo(
199
+ tensor_id, subgraph_id, producer, consumers
200
+ )
201
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
202
+ yield tensor_name, tensor_info
203
+
204
+ def _create_tensor_name_to_graph_info_map(self):
205
+ """Create a mapping between tensor name and tensor info."""
206
+ self._tensor_name_to_graph_info = {}
207
+ # TODO: b/333607428 - support graph input & output
208
+ for subgraph_id, subgraph in enumerate(self.flatbuffer_model.subgraphs):
209
+ for tensor_name, tensor_info in self._tensor_info_generator(
210
+ subgraph_id, subgraph
211
+ ):
212
+ self._tensor_name_to_graph_info[tensor_name] = tensor_info
213
+
214
+ def _group_consumer_transformations(
215
+ self, param: qtyping.TensorTransformationParams
216
+ ) -> list[list[set[int]]]:
217
+ """Group transformations between consumers into common groups.
218
+
219
+ Args:
220
+ param: TensorTransformationParams for a tensor
221
+
222
+ Returns:
223
+ A list of list of sets where the set represents indices of transformations
224
+ that can be merged horizontally
225
+ E.g:
226
+ For the following consumer:
227
+ [(1, [ADD_QUANTIZE, ADD_DEQUANTIZE], param1),
228
+ (2, [ADD_QUANTIZE], param2),
229
+ (3, [ADD_QUANTIZE], param1)]
230
+ this function returns:
231
+ [[{1, 2, 3}],
232
+ [{1, 3}, {2}],
233
+ [{1}]]
234
+
235
+ Where the 0 depth list is the initial state, since all consumer comes
236
+ from the same producer.
237
+ In depth 1, the ADD_QUANTIZE in 1 & 3 can be merged, so they are in the
238
+ same group
239
+ In depth 2, there is only one transformation from 1, so there is only
240
+ one group with 1 in there
241
+ """
242
+ if not param or not param.consumers:
243
+ return []
244
+
245
+ # consumer group contains indices of operations that can be horizontally
246
+ # optimized together. The outermost list is the depth of the transformation
247
+ # and the second list contains sets that represents the consumer indices
248
+ # that can be grouped together at the given depth
249
+ consumer_groups = [[set()]]
250
+ # the max number of transformations applied before a particular consumer
251
+ longest_trans_chain = 0
252
+ for i, consumer_param in enumerate(param.consumers):
253
+ consumer_groups[0][0].add(i)
254
+ longest_trans_chain = max(
255
+ longest_trans_chain, len(consumer_param.transformations)
256
+ )
257
+
258
+ # looping over transformations of the same depth
259
+ for transformation_depth in range(longest_trans_chain):
260
+ next_depth_groups = []
261
+ for consumer_param_index, consumer_param in enumerate(param.consumers):
262
+ if len(consumer_param.transformations) > transformation_depth:
263
+ for current_depth_groups in consumer_groups[transformation_depth]:
264
+ if consumer_param_index in current_depth_groups:
265
+ # if the transformation of the particular edge has been processed
266
+ trans_assigned = False
267
+ for new_group in next_depth_groups:
268
+ # get an index in the existing group, any of them work since
269
+ # they have the same quantization
270
+ index = next(iter(new_group))
271
+ if (
272
+ index in current_depth_groups
273
+ and check_horizontal_optimization(
274
+ param.consumers[index],
275
+ consumer_param,
276
+ transformation_depth,
277
+ )
278
+ ):
279
+ new_group.add(consumer_param_index)
280
+ trans_assigned = True
281
+ break
282
+ if not trans_assigned:
283
+ next_depth_groups.append(set([consumer_param_index]))
284
+ consumer_groups.append(next_depth_groups)
285
+ return consumer_groups
286
+
287
+ def _produce_transformation_for_vertical_opt(
288
+ self,
289
+ consumer_group: list[list[set[int]]],
290
+ param: qtyping.TensorTransformationParams,
291
+ ) -> list[qtyping.TransformationInst]:
292
+ """Create a list of transformation rules available for vertical optimization.
293
+
294
+ A consumer transformation is available to vertical transformation IFF it's
295
+ the first transformation for a given consumer.
296
+
297
+ This function relies on the consumer_group argument already being optimized
298
+ for horizontal transformations.
299
+
300
+ Args:
301
+ consumer_group: a list of grouped indices for consumer transformationns
302
+ param: a TensorTransformationParams for the tensor
303
+
304
+ Returns:
305
+ A list of transformation rules available for vertical optimization
306
+ """
307
+ tensor_info = self._tensor_name_to_graph_info[param.tensor_name]
308
+ transformations_available_for_vertical_optimization = []
309
+ # we start at 1 because consumer groups in index 0 is the inital state
310
+ # and does not contain actual information
311
+ if len(consumer_group) > 1:
312
+ for group in consumer_group[1]:
313
+ op_list = list(group)
314
+ op_idx_list = []
315
+ for index in op_list:
316
+ op_idx_list.append(param.consumers[index].subgraph_op_id)
317
+ transformations_available_for_vertical_optimization.append(
318
+ qtyping.TransformationInst(
319
+ param.consumers[op_list[0]].transformations[0],
320
+ tensor_info.tensor_id,
321
+ tensor_info.producer,
322
+ op_idx_list,
323
+ param.consumers[op_list[0]].parameters,
324
+ )
325
+ )
326
+ return transformations_available_for_vertical_optimization
327
+
328
+ def _produce_consumer_transformations_unavailable_for_vertical_opt(
329
+ self,
330
+ consumer_group: list[list[set[int]]],
331
+ param: qtyping.TensorTransformationParams,
332
+ ) -> list[qtyping.TransformationInst]:
333
+ """Produce a list of consumer transformation that can't be used for vertical optimization.
334
+
335
+ A consumer transformation is available to vertical optimization if and only
336
+ if it's the first transformation for a given consumer.
337
+
338
+ This function relies on the consumer_group argument already being optimized
339
+ for horizontal transformations
340
+
341
+ Args:
342
+ consumer_group: a list of grouped indices for consumer transformationns
343
+ param: a TensorTransformationParams for the tensor
344
+
345
+ Returns:
346
+ A list of transformation rules unavailable for vertical optimization
347
+ """
348
+ tensor_info = self._tensor_name_to_graph_info[param.tensor_name]
349
+ other_consumer_transformations = []
350
+ for transformation_idx in range(2, len(consumer_group)):
351
+ for group in consumer_group[transformation_idx]:
352
+ op_list = list(group)
353
+ op_idx_list = []
354
+ if (
355
+ len(param.consumers[op_list[0]].transformations)
356
+ <= transformation_idx - 1
357
+ ):
358
+ continue
359
+ for index in op_list:
360
+ op_idx_list.append(param.consumers[index].subgraph_op_id)
361
+ other_consumer_transformations.append(
362
+ qtyping.TransformationInst(
363
+ param.consumers[op_list[0]].transformations[
364
+ transformation_idx - 1
365
+ ],
366
+ tensor_info.tensor_id,
367
+ tensor_info.producer,
368
+ op_idx_list,
369
+ param.consumers[op_list[0]].parameters,
370
+ )
371
+ )
372
+ return other_consumer_transformations
373
+
374
+ def _apply_vertical_optimization(
375
+ self,
376
+ producer_trans_rule: qtyping.TransformationInst,
377
+ consumer_trans_rules: list[qtyping.TransformationInst],
378
+ ) -> list[qtyping.TransformationInst]:
379
+ """Apply vertical optimization.
380
+
381
+ There are two types of transformations we consider:
382
+ 1. when DQ & Q has the same parameter eliminate the operators and quantize
383
+ the tensor only
384
+ 2. when DQ & Q has different parameters, then replace the DQ & Q with an
385
+ RQ op
386
+
387
+ vertical optimization can only happen with the last producer rules and the
388
+ first consumer rules that are on the first.
389
+
390
+ Args:
391
+ producer_trans_rule: the last producer transformation rules.
392
+ consumer_trans_rules: a list of consumer transformation rules that are
393
+ avilable for vertical transformations.
394
+
395
+ Returns:
396
+ A list of transformations after vertical optimization has been applied,
397
+ note producer transformation is included.
398
+ """
399
+ transformations = []
400
+ for trans_rule in consumer_trans_rules:
401
+ if check_dq_q_elimination(producer_trans_rule, trans_rule):
402
+ for consumer_id in trans_rule.consumers:
403
+ if consumer_id in producer_trans_rule.consumers:
404
+ producer_trans_rule.consumers.remove(consumer_id)
405
+ transformations.append(
406
+ qtyping.TransformationInst(
407
+ qtyping.QuantTransformation.QUANTIZE_TENSOR,
408
+ trans_rule.tensor_id,
409
+ trans_rule.producer,
410
+ trans_rule.consumers,
411
+ trans_rule.parameters,
412
+ )
413
+ )
414
+ continue
415
+ elif check_replace_dq_q_with_rq(producer_trans_rule, trans_rule):
416
+ for consumer_id in trans_rule.consumers:
417
+ producer_trans_rule.consumers.remove(consumer_id)
418
+ transformations.append(
419
+ qtyping.TransformationInst(
420
+ qtyping.QuantTransformation.QUANTIZE_TENSOR,
421
+ trans_rule.tensor_id,
422
+ trans_rule.producer,
423
+ trans_rule.consumers,
424
+ producer_trans_rule.parameters,
425
+ )
426
+ )
427
+ transformations.append(
428
+ qtyping.TransformationInst(
429
+ qtyping.QuantTransformation.ADD_QUANTIZE,
430
+ trans_rule.tensor_id,
431
+ trans_rule.producer,
432
+ trans_rule.consumers,
433
+ trans_rule.parameters,
434
+ )
435
+ )
436
+ continue
437
+ elif check_dq_no_quant_elimination(producer_trans_rule, trans_rule):
438
+ for consumer_id in trans_rule.consumers:
439
+ if consumer_id in producer_trans_rule.consumers:
440
+ producer_trans_rule.consumers.remove(consumer_id)
441
+ transformations.append(
442
+ qtyping.TransformationInst(
443
+ qtyping.QuantTransformation.ADD_DEQUANTIZE,
444
+ trans_rule.tensor_id,
445
+ trans_rule.producer,
446
+ trans_rule.consumers,
447
+ producer_trans_rule.parameters,
448
+ )
449
+ )
450
+ continue
451
+ else:
452
+ transformations.append(trans_rule)
453
+ if producer_trans_rule.consumers:
454
+ transformations.insert(0, producer_trans_rule)
455
+ return transformations
456
+
457
+ def _quant_params_to_transformation_insts(
458
+ self,
459
+ param: qtyping.TensorTransformationParams,
460
+ ) -> qtyping.TensorTransformationInsts:
461
+ """Converts a single quantization params to transformation instructions.
462
+
463
+ Args:
464
+ param: quantization parameter of a tensor in the graph
465
+
466
+ Returns:
467
+ a list of transformations to be applied to the same tensor
468
+ """
469
+ # setup the structure
470
+ tensor_info = self._tensor_name_to_graph_info[param.tensor_name]
471
+ tensor_trans_insts = qtyping.TensorTransformationInsts(
472
+ param.tensor_name, tensor_info.subgraph_id, []
473
+ )
474
+
475
+ # horizontal optimization
476
+ consumer_group = self._group_consumer_transformations(param)
477
+ # at this point, starting from index 1 of consumer_group, we're having sets
478
+ # that represents transformations that can be grouped together
479
+ transformations_available_for_vertical_optimization = (
480
+ self._produce_transformation_for_vertical_opt(consumer_group, param)
481
+ )
482
+ other_consumer_transformations = (
483
+ self._produce_consumer_transformations_unavailable_for_vertical_opt(
484
+ consumer_group, param
485
+ )
486
+ )
487
+
488
+ transformations = []
489
+ # adding all producer rules
490
+ producer_params = param.producer
491
+ if producer_params:
492
+ for transformation in producer_params.transformations:
493
+ transformations.append(
494
+ qtyping.TransformationInst(
495
+ transformation,
496
+ tensor_info.tensor_id,
497
+ tensor_info.producer,
498
+ tensor_info.consumers,
499
+ producer_params.parameters,
500
+ )
501
+ )
502
+
503
+ # apply vertical optimization
504
+ last_producer_rule_idx = len(transformations) - 1
505
+ if last_producer_rule_idx >= 0:
506
+ transformations += self._apply_vertical_optimization(
507
+ transformations.pop(),
508
+ transformations_available_for_vertical_optimization,
509
+ )
510
+ else:
511
+ transformations += transformations_available_for_vertical_optimization
512
+ # Adding other consumers rules
513
+ transformations += other_consumer_transformations
514
+ tensor_trans_insts.instructions = transformations
515
+ # Check the generated transformation instructions are valid, the function
516
+ # will raise an error if the instructions are not valid
517
+ self._check_tensor_transformation_instructions_valid(tensor_trans_insts)
518
+
519
+ return tensor_trans_insts
520
+
521
+ def _check_tensor_transformation_instructions_valid(
522
+ self, instructions: qtyping.TensorTransformationInsts
523
+ ):
524
+ """Check if the tensor transformation instructions are valid.
525
+
526
+ Args:
527
+ instructions: Transformation instructions for a tensor.
528
+
529
+ Raises:
530
+ ValueError: If the instructions are not valid.
531
+ """
532
+ is_tensor_unquantized = False
533
+ is_tensor_quantized = False
534
+ is_operator_emulated = False
535
+ for instruction in instructions.instructions:
536
+ transform_type = instruction.transformation
537
+ if transform_type == qtyping.QuantTransformation.NO_QUANTIZE:
538
+ is_tensor_unquantized = True
539
+ elif (
540
+ transform_type == qtyping.QuantTransformation.QUANTIZE_TENSOR
541
+ or transform_type == qtyping.QuantTransformation.ADD_DEQUANTIZE
542
+ ):
543
+ is_tensor_quantized = True
544
+ elif transform_type == qtyping.QuantTransformation.EMULATED_SUBCHANNEL:
545
+ is_operator_emulated = True
546
+ if is_tensor_unquantized and is_tensor_quantized:
547
+ raise ValueError(
548
+ "Tensor %s can not be both quantized and unquantized"
549
+ % instructions.tensor_name
550
+ )
551
+ if is_operator_emulated and len(instructions.instructions) > 1:
552
+ raise ValueError(
553
+ "Tensor %s : op replacement transformation can not be combined with"
554
+ " other transformations."
555
+ % instructions.tensor_name
556
+ )
557
+
558
+ def quant_params_to_transformation_insts(
559
+ self,
560
+ params: dict[str, qtyping.TensorTransformationParams],
561
+ flatbuffer_model: Optional[schema_py_generated.ModelT] = None,
562
+ ) -> dict[str, qtyping.TensorTransformationInsts]:
563
+ """Converts quantization params to transformation instructions.
564
+
565
+ Args:
566
+ params: quantization parameters generated by params_generator. The data
567
+ type is designed to be the same as the output of
568
+ generate_quantization_parameters.
569
+ flatbuffer_model: the flatbuffer model to be quantized.
570
+
571
+ Returns:
572
+ a dictionary with tensor name as key and transformation instructions as
573
+ value
574
+ """
575
+ if flatbuffer_model is not None:
576
+ self.flatbuffer_model = flatbuffer_model
577
+ self._create_tensor_name_to_graph_info_map()
578
+
579
+ insts = {}
580
+ for tensor_name in params:
581
+ insts[tensor_name] = self._quant_params_to_transformation_insts(
582
+ params[tensor_name]
583
+ )
584
+ return insts