ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,363 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Transformation pattern for emulated subchannel quantization."""
17
+
18
+ from typing import cast
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.transformations import quantize_tensor
22
+ from ai_edge_quantizer.transformations import transformation_utils
23
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
24
+
25
+
26
+ def emulated_subchannel(
27
+ transformation_input: transformation_utils.TransformationInput,
28
+ ) -> qtyping.TransformationInfo:
29
+ """Emulated subchannel quantization for fully_connected op.
30
+
31
+ The input tensor must also be the weight tensor of the fully_connected op.
32
+
33
+ after the transformation, the fully connected op will be replaced by:
34
+ reshape -> batch_matmul -> mul -> sum -> add (if bias is present) ->
35
+ activation (if fused activation function exist, only support ReLU for now)
36
+
37
+ Args:
38
+ transformation_input: input structure that contains all information needed
39
+ for the transformation.
40
+
41
+ Returns:
42
+ The transformation info.
43
+ """
44
+ # only apply to a single fully_connected op
45
+ if len(transformation_input.consumers) > 1:
46
+ raise ValueError('Emulated Subchannel transformation only support one op')
47
+ if isinstance(
48
+ transformation_input.quant_params, qtyping.NonLinearQuantParams
49
+ ):
50
+ raise ValueError(
51
+ 'Emulated Subchannel transformation only support uniform quantization'
52
+ )
53
+ if (
54
+ transformation_input.op_codes[
55
+ transformation_input.subgraph.operators[
56
+ transformation_input.consumers[0]
57
+ ].opcodeIndex
58
+ ].builtinCode
59
+ != schema_py_generated.BuiltinOperator.FULLY_CONNECTED
60
+ ):
61
+ raise ValueError(
62
+ 'Emulated Subchannel transformation only support fully_connected op'
63
+ )
64
+ if transformation_input.producer != -1:
65
+ raise ValueError(
66
+ 'Emulated Subchannel transformation only support constant tensor'
67
+ )
68
+
69
+ # insert all tne necessary op codes into the model
70
+ reshape_op_code_idx = transformation_utils.add_op_code(
71
+ schema_py_generated.BuiltinOperator.RESHAPE, transformation_input.op_codes
72
+ )
73
+ bmm_op_code_idx = transformation_utils.add_op_code(
74
+ schema_py_generated.BuiltinOperator.BATCH_MATMUL,
75
+ transformation_input.op_codes,
76
+ )
77
+ mul_op_code_idx = transformation_utils.add_op_code(
78
+ schema_py_generated.BuiltinOperator.MUL, transformation_input.op_codes
79
+ )
80
+ sum_op_code_idx = transformation_utils.add_op_code(
81
+ schema_py_generated.BuiltinOperator.SUM, transformation_input.op_codes
82
+ )
83
+
84
+ original_fc_op_idx = transformation_input.consumers[0]
85
+ if cast(
86
+ schema_py_generated.FullyConnectedOptionsT,
87
+ transformation_input.subgraph.operators[
88
+ original_fc_op_idx
89
+ ].builtinOptions,
90
+ ).fusedActivationFunction not in (
91
+ schema_py_generated.ActivationFunctionType.NONE,
92
+ schema_py_generated.ActivationFunctionType.RELU,
93
+ ):
94
+ raise ValueError(
95
+ 'Emulated Subchannel transformation only support'
96
+ ' fusedActivationFunction NONE and RELU for now'
97
+ )
98
+
99
+ weight_tensor = transformation_input.subgraph.tensors[
100
+ transformation_input.tensor_id
101
+ ]
102
+ weight_tensor.type = quantize_tensor.quant_params_to_tflite_type(
103
+ transformation_input.quant_params.num_bits
104
+ )
105
+
106
+ # modify the weight tensor with the correct quantization parameters
107
+ transformation_input.buffers[weight_tensor.buffer].data = np.frombuffer(
108
+ cast(
109
+ np.ndarray, transformation_input.quant_params.quantized_data
110
+ ).tobytes(),
111
+ dtype=np.uint8,
112
+ )
113
+ weight_tensor.shape = cast(
114
+ np.ndarray, transformation_input.quant_params.quantized_data
115
+ ).shape
116
+ weight_tensor.quantization.scale = np.ones(shape=[1], dtype=np.float32)
117
+ weight_tensor.quantization.zeroPoint = np.zeros(
118
+ shape=[1], dtype=np.int64
119
+ ).flatten()
120
+
121
+ # assuming zero point is 0, so no need to add a zero point tensor
122
+ for val in transformation_input.quant_params.zero_point.flatten():
123
+ if val != 0:
124
+ raise ValueError(
125
+ 'Emulated Subchannel transformation only support zero point 0 for now'
126
+ )
127
+
128
+ scale_tensor_id = transformation_utils.add_new_constant_tensor(
129
+ weight_tensor.name + b'_scale',
130
+ transformation_input.quant_params.scale,
131
+ schema_py_generated.TensorType.FLOAT32,
132
+ transformation_input.subgraph,
133
+ transformation_input.buffers,
134
+ )
135
+
136
+ # for fully connected op, the reduce axis is always 1
137
+ reduce_axes_data = np.array([1], dtype=np.int32)
138
+ reduce_axes_tensor_id = transformation_utils.add_new_constant_tensor(
139
+ weight_tensor.name + b'_reduce_axes',
140
+ reduce_axes_data,
141
+ schema_py_generated.TensorType.INT32,
142
+ transformation_input.subgraph,
143
+ transformation_input.buffers,
144
+ )
145
+
146
+ # find the input and output tensor of the fully connected op
147
+ activation_input_id = transformation_input.subgraph.operators[
148
+ transformation_input.consumers[0]
149
+ ].inputs[0]
150
+ activation_output_id = transformation_input.subgraph.operators[
151
+ transformation_input.consumers[0]
152
+ ].outputs[0]
153
+ activation_input = transformation_input.subgraph.tensors[activation_input_id]
154
+ activation_output = transformation_input.subgraph.tensors[
155
+ activation_output_id
156
+ ]
157
+
158
+ if len(activation_input.shape) != 3:
159
+ raise ValueError(
160
+ 'Emulated Subchannel transformation only support 3D input tensor'
161
+ )
162
+ bmm_input_shape = [
163
+ activation_input.shape[0] * activation_input.shape[1],
164
+ weight_tensor.shape[1],
165
+ 1,
166
+ weight_tensor.shape[2],
167
+ ]
168
+ intermediate_tensor_shape = [
169
+ activation_input.shape[0] * activation_input.shape[1],
170
+ weight_tensor.shape[1],
171
+ 1,
172
+ weight_tensor.shape[3],
173
+ ]
174
+ sum_output_shape = [
175
+ activation_input.shape[0] * activation_input.shape[1],
176
+ 1,
177
+ 1,
178
+ weight_tensor.shape[3],
179
+ ]
180
+
181
+ # create constant tensors for reshape
182
+ reshape1_shape_id = transformation_utils.add_new_constant_tensor(
183
+ activation_output.name + b'_reshape_op1_shape',
184
+ np.array(bmm_input_shape, dtype=np.int32),
185
+ schema_py_generated.TensorType.INT32,
186
+ transformation_input.subgraph,
187
+ transformation_input.buffers,
188
+ )
189
+ reshape2_shape_id = transformation_utils.add_new_constant_tensor(
190
+ activation_output.name + b'_reshape_op2_shape',
191
+ np.array(activation_output.shape, dtype=np.int32),
192
+ schema_py_generated.TensorType.INT32,
193
+ transformation_input.subgraph,
194
+ transformation_input.buffers,
195
+ )
196
+
197
+ # create all intermediate tensors
198
+ bmm_input_id = transformation_utils.add_new_activation_tensor(
199
+ activation_output.name + b'_bmm_input',
200
+ bmm_input_shape,
201
+ schema_py_generated.TensorType.FLOAT32,
202
+ transformation_input.subgraph,
203
+ )
204
+ mul_input_id = transformation_utils.add_new_activation_tensor(
205
+ activation_output.name + b'_mul_input',
206
+ intermediate_tensor_shape,
207
+ schema_py_generated.TensorType.FLOAT32,
208
+ transformation_input.subgraph,
209
+ )
210
+ sum_input_id = transformation_utils.add_new_activation_tensor(
211
+ activation_output.name + b'_reduce_sum_input',
212
+ intermediate_tensor_shape,
213
+ schema_py_generated.TensorType.FLOAT32,
214
+ transformation_input.subgraph,
215
+ )
216
+ reshape_op2_input_id = transformation_utils.add_new_activation_tensor(
217
+ activation_output.name + b'_reshape_op2_input',
218
+ sum_output_shape,
219
+ schema_py_generated.TensorType.FLOAT32,
220
+ transformation_input.subgraph,
221
+ )
222
+
223
+ # reshape
224
+ reshape_op1 = schema_py_generated.OperatorT()
225
+ reshape_op1.opcodeIndex = reshape_op_code_idx
226
+ reshape_op1_option = schema_py_generated.ReshapeOptionsT()
227
+ reshape_op1_option.newShape = bmm_input_shape
228
+ reshape_op1.inputs = [activation_input_id, reshape1_shape_id]
229
+ reshape_op1.outputs = [bmm_input_id]
230
+ reshape_op1.builtinOptionsType = (
231
+ schema_py_generated.BuiltinOptions.ReshapeOptions
232
+ ) # reshape option index
233
+ reshape_op1.builtinOptions = reshape_op1_option
234
+
235
+ # batch_matmul
236
+ bmm_op = schema_py_generated.OperatorT()
237
+ bmm_op.opcodeIndex = bmm_op_code_idx
238
+ bmm_op.inputs = [bmm_input_id, transformation_input.tensor_id]
239
+ bmm_op.outputs = [mul_input_id]
240
+ bmm_op.builtinOptionsType = (
241
+ schema_py_generated.BuiltinOptions.BatchMatMulOptions
242
+ )
243
+ bmm_op.builtinOptions = schema_py_generated.BatchMatMulOptionsT()
244
+
245
+ # mul
246
+ mul_op = schema_py_generated.OperatorT()
247
+ mul_op.opcodeIndex = mul_op_code_idx
248
+ mul_option = schema_py_generated.MulOptionsT()
249
+ mul_option.fusedActivationFunction = (
250
+ schema_py_generated.ActivationFunctionType.NONE
251
+ )
252
+ mul_op.inputs = [mul_input_id, scale_tensor_id]
253
+ mul_op.outputs = [sum_input_id]
254
+ mul_op.builtinOptionsType = schema_py_generated.BuiltinOptions.MulOptions
255
+ mul_op.builtinOptions = mul_option
256
+
257
+ # sum
258
+ sum_op = schema_py_generated.OperatorT()
259
+ sum_op.opcodeIndex = sum_op_code_idx
260
+ sum_op.inputs = [sum_input_id, reduce_axes_tensor_id]
261
+ sum_op.outputs = [reshape_op2_input_id]
262
+ sum_op.builtinOptionsType = schema_py_generated.BuiltinOptions.ReducerOptions
263
+ sum_op.builtinOptions = schema_py_generated.ReducerOptionsT()
264
+ sum_op.builtinOptions.keepDims = True
265
+
266
+ # reshape
267
+ reshape_op2 = schema_py_generated.OperatorT()
268
+ reshape_op2.opcodeIndex = reshape_op_code_idx
269
+ reshape_op2_option = schema_py_generated.ReshapeOptionsT()
270
+ reshape_op2_option.newShape = activation_output.shape
271
+ reshape_op2.inputs = [reshape_op2_input_id, reshape2_shape_id]
272
+ reshape_op2.outputs = [activation_output_id]
273
+ reshape_op2.builtinOptionsType = (
274
+ schema_py_generated.BuiltinOptions.ReshapeOptions
275
+ )
276
+ reshape_op2.builtinOptions = reshape_op2_option
277
+
278
+ transformation_input.subgraph.operators.insert(
279
+ original_fc_op_idx, reshape_op1
280
+ )
281
+ transformation_input.subgraph.operators.insert(original_fc_op_idx + 1, bmm_op)
282
+ transformation_input.subgraph.operators.insert(original_fc_op_idx + 2, mul_op)
283
+ transformation_input.subgraph.operators.insert(original_fc_op_idx + 3, sum_op)
284
+ transformation_input.subgraph.operators.insert(
285
+ original_fc_op_idx + 4, reshape_op2
286
+ )
287
+ ops_added = 5
288
+ last_op = reshape_op2
289
+
290
+ # If there is a bias tensor (the third input to the original fc op),
291
+ # we need an add to process it. The current fc op id need to be recalculated
292
+ # because we added operators in front of it.
293
+ current_fc_op_id = original_fc_op_idx + ops_added
294
+ if (
295
+ len(transformation_input.subgraph.operators[current_fc_op_id].inputs) > 2
296
+ and transformation_input.subgraph.operators[current_fc_op_id].inputs[2]
297
+ != -1
298
+ ):
299
+ add_op_code_idx = transformation_utils.add_op_code(
300
+ schema_py_generated.BuiltinOperator.ADD, transformation_input.op_codes
301
+ )
302
+ reshape_op2_output_id = transformation_utils.add_new_activation_tensor(
303
+ activation_output.name + b'_reshape_op2_output',
304
+ activation_output.shape,
305
+ schema_py_generated.TensorType.FLOAT32,
306
+ transformation_input.subgraph,
307
+ )
308
+ last_op.outputs = [reshape_op2_output_id]
309
+ add_op = schema_py_generated.OperatorT()
310
+ add_op.opcodeIndex = add_op_code_idx
311
+ add_option = schema_py_generated.AddOptionsT()
312
+ add_op.builtinOptionsType = schema_py_generated.BuiltinOptions.AddOptions
313
+ add_op.builtinOptions = add_option
314
+ add_op.inputs = [
315
+ reshape_op2_output_id,
316
+ transformation_input.subgraph.operators[
317
+ original_fc_op_idx + ops_added
318
+ ].inputs[2],
319
+ ]
320
+ add_op.outputs = [activation_output_id]
321
+ transformation_input.subgraph.operators.insert(
322
+ original_fc_op_idx + ops_added, add_op
323
+ )
324
+ ops_added += 1
325
+ last_op = add_op
326
+
327
+ # If the fused activation function is RELU, we need to add a relu op.
328
+ # The current fc op id need to be recalculated because we added operators
329
+ # in front of it.
330
+ fc_fused_activation_function = cast(
331
+ schema_py_generated.FullyConnectedOptionsT,
332
+ transformation_input.subgraph.operators[
333
+ original_fc_op_idx + ops_added
334
+ ].builtinOptions,
335
+ ).fusedActivationFunction
336
+ if (
337
+ fc_fused_activation_function
338
+ == schema_py_generated.ActivationFunctionType.RELU
339
+ ):
340
+ activation_output.name += b'_relu'
341
+ relu_input_id = transformation_utils.add_new_activation_tensor(
342
+ activation_output.name + b'_relu_input',
343
+ activation_output.shape,
344
+ schema_py_generated.TensorType.FLOAT32,
345
+ transformation_input.subgraph,
346
+ )
347
+ last_op.outputs = [relu_input_id]
348
+ relu_op = schema_py_generated.OperatorT()
349
+ relu_op.opcodeIndex = transformation_utils.add_op_code(
350
+ schema_py_generated.BuiltinOperator.RELU, transformation_input.op_codes
351
+ )
352
+ relu_op.inputs = [relu_input_id]
353
+ relu_op.outputs = [activation_output_id]
354
+ transformation_input.subgraph.operators.insert(
355
+ original_fc_op_idx + ops_added, relu_op
356
+ )
357
+ ops_added += 1
358
+ last_op = relu_op
359
+
360
+ del transformation_input.subgraph.operators[original_fc_op_idx + ops_added]
361
+ return qtyping.TransformationInfo(
362
+ original_fc_op_idx, ops_added - 1, activation_output_id
363
+ )
@@ -0,0 +1,212 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for emulated_subchannel."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.transformations import emulated_subchannel
23
+ from ai_edge_quantizer.transformations import transformation_utils
24
+ from ai_edge_quantizer.utils import test_utils
25
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
27
+
28
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")
29
+
30
+
31
+ class EmulatedSubchannelTest(googletest.TestCase):
32
+ """Tests for emulated_subchannel."""
33
+
34
+ def setUp(self):
35
+ super().setUp()
36
+ self.params = qtyping.UniformQuantParams(
37
+ num_bits=8,
38
+ quantized_dimension=None,
39
+ scale=np.ones([1, 1, 1, 4], dtype=np.float32),
40
+ zero_point=np.zeros([1, 1, 1, 4], dtype=np.int64),
41
+ symmetric=True,
42
+ quantized_data=np.ones([1, 4, 2, 4], dtype=np.int8),
43
+ )
44
+
45
+ def test_emulate_subchannel_without_bias_succeeds(self):
46
+ """Tests the emulated_subchannel function."""
47
+ self._model_path = os.path.join(
48
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_no_bias.tflite"
49
+ )
50
+ self._model = tfl_flatbuffer_utils.read_model(self._model_path)
51
+ subgraph = self._model.subgraphs[0]
52
+ model = self._model
53
+ ret = emulated_subchannel.emulated_subchannel(
54
+ transformation_utils.TransformationInput(
55
+ tensor_id=1,
56
+ op_codes=model.operatorCodes,
57
+ buffers=model.buffers,
58
+ subgraph=subgraph,
59
+ producer=-1,
60
+ consumers=[0],
61
+ quant_params=self.params,
62
+ )
63
+ )
64
+ self.assertEqual(ret.op_id, 0)
65
+ self.assertEqual(ret.num_ops_added, 4)
66
+ self.assertEqual(ret.output_tensor_id, 2)
67
+ self.assertEqual(
68
+ model.operatorCodes[subgraph.operators[0].opcodeIndex].builtinCode,
69
+ schema_py_generated.BuiltinOperator.RESHAPE,
70
+ )
71
+ self.assertEqual(
72
+ model.operatorCodes[subgraph.operators[1].opcodeIndex].builtinCode,
73
+ schema_py_generated.BuiltinOperator.BATCH_MATMUL,
74
+ )
75
+ self.assertEqual(
76
+ model.operatorCodes[subgraph.operators[2].opcodeIndex].builtinCode,
77
+ schema_py_generated.BuiltinOperator.MUL,
78
+ )
79
+ self.assertEqual(
80
+ model.operatorCodes[subgraph.operators[3].opcodeIndex].builtinCode,
81
+ schema_py_generated.BuiltinOperator.SUM,
82
+ )
83
+ self.assertEqual(
84
+ model.operatorCodes[subgraph.operators[4].opcodeIndex].builtinCode,
85
+ schema_py_generated.BuiltinOperator.RESHAPE,
86
+ )
87
+ self.assertEqual(
88
+ subgraph.tensors[subgraph.operators[2].inputs[1]].name,
89
+ b"arith.constant_scale",
90
+ )
91
+ self.assertListEqual(
92
+ np.frombuffer(
93
+ model.buffers[
94
+ subgraph.tensors[subgraph.operators[2].inputs[1]].buffer
95
+ ].data,
96
+ dtype=np.float32,
97
+ ).tolist(),
98
+ np.ones([1, 1, 1, 4]).flatten().tolist(),
99
+ )
100
+
101
+ def test_emulate_subchannel_with_bias_succeeds(self):
102
+ """Tests the emulated_subchannel function."""
103
+ self._model_path = os.path.join(
104
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
105
+ )
106
+ self._model = tfl_flatbuffer_utils.read_model(self._model_path)
107
+ subgraph = self._model.subgraphs[0]
108
+ model = self._model
109
+ ret = emulated_subchannel.emulated_subchannel(
110
+ transformation_utils.TransformationInput(
111
+ tensor_id=1,
112
+ op_codes=model.operatorCodes,
113
+ buffers=model.buffers,
114
+ subgraph=subgraph,
115
+ producer=-1,
116
+ consumers=[0],
117
+ quant_params=self.params,
118
+ )
119
+ )
120
+ self.assertEqual(ret.op_id, 0)
121
+ self.assertEqual(ret.num_ops_added, 5)
122
+ self.assertEqual(ret.output_tensor_id, 3)
123
+ self.assertEqual(
124
+ model.operatorCodes[subgraph.operators[0].opcodeIndex].builtinCode,
125
+ schema_py_generated.BuiltinOperator.RESHAPE,
126
+ )
127
+ self.assertEqual(
128
+ model.operatorCodes[subgraph.operators[1].opcodeIndex].builtinCode,
129
+ schema_py_generated.BuiltinOperator.BATCH_MATMUL,
130
+ )
131
+ self.assertEqual(
132
+ model.operatorCodes[subgraph.operators[2].opcodeIndex].builtinCode,
133
+ schema_py_generated.BuiltinOperator.MUL,
134
+ )
135
+ self.assertEqual(
136
+ model.operatorCodes[subgraph.operators[3].opcodeIndex].builtinCode,
137
+ schema_py_generated.BuiltinOperator.SUM,
138
+ )
139
+ self.assertEqual(
140
+ model.operatorCodes[subgraph.operators[4].opcodeIndex].builtinCode,
141
+ schema_py_generated.BuiltinOperator.RESHAPE,
142
+ )
143
+ self.assertEqual(
144
+ model.operatorCodes[subgraph.operators[5].opcodeIndex].builtinCode,
145
+ schema_py_generated.BuiltinOperator.ADD,
146
+ )
147
+ self.assertEqual(
148
+ subgraph.tensors[subgraph.operators[2].inputs[1]].name,
149
+ b"arith.constant_scale",
150
+ )
151
+ self.assertListEqual(
152
+ np.frombuffer(
153
+ model.buffers[
154
+ subgraph.tensors[subgraph.operators[2].inputs[1]].buffer
155
+ ].data,
156
+ dtype=np.float32,
157
+ ).tolist(),
158
+ np.ones([1, 1, 1, 4]).flatten().tolist(),
159
+ )
160
+
161
+ def test_emulated_subchannel_with_fused_relu_succeeds(self):
162
+ """Tests the emulated_subchannel function with fused relu."""
163
+ self._model_path = os.path.join(
164
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias_relu.tflite"
165
+ )
166
+ self._model = tfl_flatbuffer_utils.read_model(self._model_path)
167
+ self._model = tfl_flatbuffer_utils.read_model(self._model_path)
168
+ subgraph = self._model.subgraphs[0]
169
+ model = self._model
170
+ ret = emulated_subchannel.emulated_subchannel(
171
+ transformation_utils.TransformationInput(
172
+ tensor_id=1,
173
+ op_codes=model.operatorCodes,
174
+ buffers=model.buffers,
175
+ subgraph=subgraph,
176
+ producer=-1,
177
+ consumers=[0],
178
+ quant_params=self.params,
179
+ )
180
+ )
181
+ self.assertEqual(ret.op_id, 0)
182
+ self.assertEqual(ret.num_ops_added, 6)
183
+ self.assertEqual(ret.output_tensor_id, 3)
184
+ self.assertEqual(
185
+ model.operatorCodes[subgraph.operators[6].opcodeIndex].builtinCode,
186
+ schema_py_generated.BuiltinOperator.RELU,
187
+ )
188
+
189
+ def test_emulated_subchannel_raises_when_unsupported_activation(self):
190
+ """Tests the emulated_subchannel function with unsupported activation."""
191
+ self._model_path = os.path.join(
192
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias_relu6.tflite"
193
+ )
194
+ self._model = tfl_flatbuffer_utils.read_model(self._model_path)
195
+ subgraph = self._model.subgraphs[0]
196
+ model = self._model
197
+ with self.assertRaises(ValueError):
198
+ emulated_subchannel.emulated_subchannel(
199
+ transformation_utils.TransformationInput(
200
+ tensor_id=1,
201
+ op_codes=model.operatorCodes,
202
+ buffers=model.buffers,
203
+ subgraph=subgraph,
204
+ producer=-1,
205
+ consumers=[0],
206
+ quant_params=self.params,
207
+ )
208
+ )
209
+
210
+
211
+ if __name__ == "__main__":
212
+ googletest.main()
@@ -0,0 +1,100 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Apply dequantization transformations to the given op/tensor.
17
+
18
+ Inserts dequantize node after the given tensor to enable float execution of
19
+ the tensor consumer
20
+ """
21
+
22
+ from ai_edge_quantizer import qtyping
23
+ from ai_edge_quantizer.transformations import quantize_tensor
24
+ from ai_edge_quantizer.transformations import transformation_utils
25
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
26
+
27
+
28
+ def insert_quant(
29
+ transformation_input: transformation_utils.TransformationInput,
30
+ ) -> qtyping.TransformationInfo:
31
+ """Insert quant op after the given tensor in the subgraph.
32
+
33
+ Args:
34
+ transformation_input: input structure that contains all information needed
35
+ for the transformation.
36
+
37
+ Returns:
38
+ TransformationInfo:
39
+ op_id: the index where the dequant op is added
40
+ num_ops_added: the total number of ops inserted by this operation, which
41
+ is 1
42
+ """
43
+ quant_op_code_idx = transformation_utils.add_op_code(
44
+ schema_py_generated.BuiltinOperator.QUANTIZE,
45
+ transformation_input.op_codes,
46
+ )
47
+
48
+ # create output tensor for the quantize op
49
+ tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
50
+ new_tensor_id = transformation_utils.add_new_activation_tensor(
51
+ tensor.name + b'_quantized',
52
+ tensor.shape,
53
+ schema_py_generated.TensorType.FLOAT32,
54
+ transformation_input.subgraph,
55
+ )
56
+
57
+ # quantize the output tensor
58
+ ## we need a new transformation input because we don't want to modify the
59
+ ## original input
60
+ quantize_tensor.quantize_tensor(
61
+ transformation_utils.TransformationInput(
62
+ new_tensor_id,
63
+ transformation_input.op_codes,
64
+ transformation_input.buffers,
65
+ transformation_input.subgraph,
66
+ transformation_input.producer,
67
+ transformation_input.consumers,
68
+ transformation_input.quant_params,
69
+ )
70
+ )
71
+
72
+ # create quantize_op
73
+ quant_op = schema_py_generated.OperatorT()
74
+ quant_op.opcodeIndex = quant_op_code_idx
75
+ quant_op.outputs = [new_tensor_id]
76
+ quant_op.inputs = [transformation_input.tensor_id]
77
+
78
+ # update the original consumers of the op to take the dequant op,
79
+ # and find the first consumer of the new tensor
80
+ first_consumer_id = min(transformation_input.consumers)
81
+ for consumer_id in transformation_input.consumers:
82
+ op = transformation_input.subgraph.operators[consumer_id]
83
+ for input_idx in range(len(op.inputs)):
84
+ if op.inputs[input_idx] == transformation_input.tensor_id:
85
+ op.inputs[input_idx] = new_tensor_id
86
+
87
+ # if the output is also an output to the graph, we need to update that as well
88
+ for output_idx, output in enumerate(transformation_input.subgraph.outputs):
89
+ if output == transformation_input.tensor_id:
90
+ transformation_input.subgraph.outputs[output_idx] = new_tensor_id
91
+
92
+ # add dequant into the subgraph op list,
93
+ # must insert the op right before it's first consumer
94
+ # in the case of output goes to graph output, we need to ensure the dequant
95
+ # op is inserted after the producer
96
+ op_id = max(transformation_input.producer + 1, first_consumer_id)
97
+ transformation_input.subgraph.operators.insert(op_id, quant_op)
98
+ return qtyping.TransformationInfo(
99
+ op_id=op_id, num_ops_added=1, output_tensor_id=new_tensor_id
100
+ )