ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/__init__.py +19 -0
- ai_edge_quantizer/algorithm_manager.py +167 -0
- ai_edge_quantizer/algorithm_manager_api.py +271 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
- ai_edge_quantizer/algorithms/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
- ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
- ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
- ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
- ai_edge_quantizer/calibrator.py +288 -0
- ai_edge_quantizer/calibrator_test.py +297 -0
- ai_edge_quantizer/conftest.py +22 -0
- ai_edge_quantizer/default_policy.py +310 -0
- ai_edge_quantizer/model_modifier.py +176 -0
- ai_edge_quantizer/model_modifier_test.py +130 -0
- ai_edge_quantizer/model_validator.py +357 -0
- ai_edge_quantizer/model_validator_test.py +354 -0
- ai_edge_quantizer/params_generator.py +361 -0
- ai_edge_quantizer/params_generator_test.py +1041 -0
- ai_edge_quantizer/qtyping.py +483 -0
- ai_edge_quantizer/quantizer.py +372 -0
- ai_edge_quantizer/quantizer_test.py +532 -0
- ai_edge_quantizer/recipe.py +67 -0
- ai_edge_quantizer/recipe_manager.py +245 -0
- ai_edge_quantizer/recipe_manager_test.py +815 -0
- ai_edge_quantizer/recipe_test.py +97 -0
- ai_edge_quantizer/transformation_instruction_generator.py +584 -0
- ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
- ai_edge_quantizer/transformation_performer.py +278 -0
- ai_edge_quantizer/transformation_performer_test.py +344 -0
- ai_edge_quantizer/transformations/__init__.py +15 -0
- ai_edge_quantizer/transformations/dequant_insert.py +87 -0
- ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
- ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
- ai_edge_quantizer/transformations/quant_insert.py +100 -0
- ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
- ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
- ai_edge_quantizer/transformations/transformation_utils.py +132 -0
- ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
- ai_edge_quantizer/utils/__init__.py +15 -0
- ai_edge_quantizer/utils/calibration_utils.py +86 -0
- ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
- ai_edge_quantizer/utils/test_utils.py +107 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
- ai_edge_quantizer/utils/validation_utils.py +125 -0
- ai_edge_quantizer/utils/validation_utils_test.py +87 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
- ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Apply dequantization transformations to the given op/tensor.
|
17
|
+
|
18
|
+
Inserts dequantize node after the given tensor to enable float execution of
|
19
|
+
the tensor consumer
|
20
|
+
"""
|
21
|
+
|
22
|
+
from ai_edge_quantizer import qtyping
|
23
|
+
from ai_edge_quantizer.transformations import quantize_tensor
|
24
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
25
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
26
|
+
|
27
|
+
|
28
|
+
def insert_dequant(
|
29
|
+
transformation_input: transformation_utils.TransformationInput,
|
30
|
+
) -> qtyping.TransformationInfo:
|
31
|
+
"""Insert dequant op after the given tensor in the subgraph.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
transformation_input: input structure that contains all information needed
|
35
|
+
for the transformation.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
TransformationInfo:
|
39
|
+
op_id: the index where the dequant op is added
|
40
|
+
num_ops_added: the total number of ops inserted by this operation, which
|
41
|
+
is 1
|
42
|
+
"""
|
43
|
+
dequant_op_code_idx = transformation_utils.add_op_code(
|
44
|
+
schema_py_generated.BuiltinOperator.DEQUANTIZE,
|
45
|
+
transformation_input.op_codes,
|
46
|
+
)
|
47
|
+
# create output tensor for the dequant op
|
48
|
+
tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
|
49
|
+
new_tensor_id = transformation_utils.add_new_activation_tensor(
|
50
|
+
tensor.name + b'_dequant',
|
51
|
+
tensor.shape,
|
52
|
+
schema_py_generated.TensorType.FLOAT32,
|
53
|
+
transformation_input.subgraph,
|
54
|
+
)
|
55
|
+
|
56
|
+
# create dequantize_op
|
57
|
+
dequant_op = schema_py_generated.OperatorT()
|
58
|
+
dequant_op.opcodeIndex = dequant_op_code_idx
|
59
|
+
dequant_op.outputs = [new_tensor_id]
|
60
|
+
dequant_op.inputs = [transformation_input.tensor_id]
|
61
|
+
|
62
|
+
# quantize the source tensor
|
63
|
+
quantize_tensor.quantize_tensor(transformation_input)
|
64
|
+
|
65
|
+
# update the original consumers of the op to take the dequant op,
|
66
|
+
# and find the first consumer of the new tensor
|
67
|
+
first_consumer_id = min(transformation_input.consumers)
|
68
|
+
for consumer_id in transformation_input.consumers:
|
69
|
+
op = transformation_input.subgraph.operators[consumer_id]
|
70
|
+
for input_idx in range(len(op.inputs)):
|
71
|
+
if op.inputs[input_idx] == transformation_input.tensor_id:
|
72
|
+
op.inputs[input_idx] = new_tensor_id
|
73
|
+
|
74
|
+
# if the output is also an output to the graph, we need to update that as well
|
75
|
+
for output_idx, output in enumerate(transformation_input.subgraph.outputs):
|
76
|
+
if output == transformation_input.tensor_id:
|
77
|
+
transformation_input.subgraph.outputs[output_idx] = new_tensor_id
|
78
|
+
|
79
|
+
# add dequant into the subgraph op list,
|
80
|
+
# must insert the op right before it's first consumer
|
81
|
+
# in the case of output goes to graph output, we need to ensure the dequant
|
82
|
+
# op is inserted after the producer
|
83
|
+
op_id = max(transformation_input.producer + 1, first_consumer_id)
|
84
|
+
transformation_input.subgraph.operators.insert(op_id, dequant_op)
|
85
|
+
return qtyping.TransformationInfo(
|
86
|
+
op_id=op_id, num_ops_added=1, output_tensor_id=new_tensor_id
|
87
|
+
)
|
@@ -0,0 +1,304 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Test for various transformations used by quantizer."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import numpy as np
|
20
|
+
from tensorflow.python.platform import googletest
|
21
|
+
from ai_edge_quantizer import qtyping
|
22
|
+
from ai_edge_quantizer.transformations import dequant_insert
|
23
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
24
|
+
from ai_edge_quantizer.utils import test_utils
|
25
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
26
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
27
|
+
|
28
|
+
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")
|
29
|
+
|
30
|
+
|
31
|
+
class DequantInsertTest(googletest.TestCase):
|
32
|
+
|
33
|
+
def setUp(self):
|
34
|
+
super().setUp()
|
35
|
+
self._orig_test_model_path = os.path.join(
|
36
|
+
TEST_DATA_PREFIX_PATH, "tests/models/insert_dequant_test.tflite"
|
37
|
+
)
|
38
|
+
self._model = tfl_flatbuffer_utils.read_model(self._orig_test_model_path)
|
39
|
+
|
40
|
+
def test_dequant_insert_constant(self):
|
41
|
+
"""Test dequant insert lib on a constant tensor."""
|
42
|
+
subgraph = self._model.subgraphs[0]
|
43
|
+
model = self._model
|
44
|
+
dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
|
45
|
+
# insert dequant on the constant before the add node
|
46
|
+
dequant_insert.insert_dequant(
|
47
|
+
transformation_utils.TransformationInput(
|
48
|
+
7,
|
49
|
+
model.operatorCodes,
|
50
|
+
model.buffers,
|
51
|
+
subgraph,
|
52
|
+
-1,
|
53
|
+
[4],
|
54
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
55
|
+
)
|
56
|
+
)
|
57
|
+
|
58
|
+
# check dequant op code is added to the model
|
59
|
+
self.assertEqual(
|
60
|
+
model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
|
61
|
+
dequant_opcode,
|
62
|
+
)
|
63
|
+
|
64
|
+
# check new tensor is correct created
|
65
|
+
self.assertIn(b"_dequant", subgraph.tensors[9].name)
|
66
|
+
self.assertEqual(
|
67
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
|
68
|
+
)
|
69
|
+
self.assertEqual(
|
70
|
+
subgraph.tensors[7].type, schema_py_generated.TensorType.INT8
|
71
|
+
)
|
72
|
+
# checking if consumer haves the correct input
|
73
|
+
self.assertEqual(subgraph.operators[5].inputs[0], 6)
|
74
|
+
self.assertEqual(subgraph.operators[5].inputs[1], 9)
|
75
|
+
|
76
|
+
# checking the inserted node has the correct input/output
|
77
|
+
self.assertEqual(subgraph.operators[4].outputs[0], 9)
|
78
|
+
self.assertEqual(subgraph.operators[4].inputs[0], 7)
|
79
|
+
# checking inserted node is the dequant node
|
80
|
+
self.assertEqual(
|
81
|
+
subgraph.operators[4].opcodeIndex, len(model.operatorCodes) - 1
|
82
|
+
)
|
83
|
+
|
84
|
+
def test_dequant_insert_activation(self):
|
85
|
+
"""Test dequant insert lib on activation tensors."""
|
86
|
+
subgraph = self._model.subgraphs[0]
|
87
|
+
model = self._model
|
88
|
+
dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
|
89
|
+
# insert dequant on the output of a conv node
|
90
|
+
dequant_insert.insert_dequant(
|
91
|
+
transformation_utils.TransformationInput(
|
92
|
+
4,
|
93
|
+
model.operatorCodes,
|
94
|
+
model.buffers,
|
95
|
+
subgraph,
|
96
|
+
1,
|
97
|
+
[3],
|
98
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
99
|
+
)
|
100
|
+
)
|
101
|
+
|
102
|
+
# check dequant op code is added to the model
|
103
|
+
self.assertEqual(
|
104
|
+
model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
|
105
|
+
dequant_opcode,
|
106
|
+
)
|
107
|
+
|
108
|
+
# check new tensor is correct created
|
109
|
+
self.assertIn(b"_dequant", subgraph.tensors[9].name)
|
110
|
+
self.assertEqual(
|
111
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
|
112
|
+
)
|
113
|
+
# check original source tensor is updated
|
114
|
+
self.assertEqual(
|
115
|
+
subgraph.tensors[4].type, schema_py_generated.TensorType.INT8
|
116
|
+
)
|
117
|
+
|
118
|
+
# checking if consumer haves the correct input
|
119
|
+
self.assertEqual(subgraph.operators[4].inputs[0], 9)
|
120
|
+
self.assertEqual(subgraph.operators[4].inputs[1], 5)
|
121
|
+
|
122
|
+
# checking the inserted node has the correct input/output
|
123
|
+
self.assertEqual(subgraph.operators[3].outputs[0], 9)
|
124
|
+
self.assertEqual(subgraph.operators[3].inputs[0], 4)
|
125
|
+
# checking inserted node is the dequant node
|
126
|
+
self.assertEqual(
|
127
|
+
subgraph.operators[3].opcodeIndex, len(model.operatorCodes) - 1
|
128
|
+
)
|
129
|
+
|
130
|
+
def test_dequant_insert_constant_multiple_consumers(self):
|
131
|
+
"""Test dequant insert lib on tensors with multiple consumers."""
|
132
|
+
subgraph = self._model.subgraphs[0]
|
133
|
+
model = self._model
|
134
|
+
dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
|
135
|
+
# insert dequant on the input of a conv node
|
136
|
+
post_trans_info = dequant_insert.insert_dequant(
|
137
|
+
transformation_utils.TransformationInput(
|
138
|
+
2,
|
139
|
+
model.operatorCodes,
|
140
|
+
model.buffers,
|
141
|
+
subgraph,
|
142
|
+
-1,
|
143
|
+
[1, 2],
|
144
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
145
|
+
)
|
146
|
+
)
|
147
|
+
self.assertEqual(post_trans_info.op_id, 1)
|
148
|
+
self.assertEqual(post_trans_info.num_ops_added, 1)
|
149
|
+
|
150
|
+
# check dequant op code is added to the model
|
151
|
+
self.assertEqual(
|
152
|
+
model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
|
153
|
+
dequant_opcode,
|
154
|
+
)
|
155
|
+
|
156
|
+
# check new tensor is correct created
|
157
|
+
self.assertIn(b"_dequant", subgraph.tensors[9].name)
|
158
|
+
self.assertEqual(
|
159
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
|
160
|
+
)
|
161
|
+
# check original source tensor has the correct type
|
162
|
+
self.assertEqual(
|
163
|
+
subgraph.tensors[2].type, schema_py_generated.TensorType.INT8
|
164
|
+
)
|
165
|
+
|
166
|
+
# checking the inserted node has the correct input/output
|
167
|
+
self.assertEqual(subgraph.operators[1].outputs[0], 9)
|
168
|
+
self.assertEqual(subgraph.operators[1].inputs[0], 2)
|
169
|
+
# checking inserted node is the dequant node
|
170
|
+
self.assertEqual(
|
171
|
+
subgraph.operators[1].opcodeIndex, len(model.operatorCodes) - 1
|
172
|
+
)
|
173
|
+
|
174
|
+
# checking if consumer haves the correct input
|
175
|
+
self.assertEqual(subgraph.operators[2].inputs[1], 9)
|
176
|
+
self.assertEqual(subgraph.operators[3].inputs[1], 9)
|
177
|
+
|
178
|
+
def test_dequant_insert_activation_multiple_consumers(self):
|
179
|
+
"""Test dequant insert lib on tensors with multiple consumers."""
|
180
|
+
subgraph = self._model.subgraphs[0]
|
181
|
+
model = self._model
|
182
|
+
dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
|
183
|
+
# insert dequant on the output of a conv node
|
184
|
+
dequant_insert.insert_dequant(
|
185
|
+
transformation_utils.TransformationInput(
|
186
|
+
1,
|
187
|
+
model.operatorCodes,
|
188
|
+
model.buffers,
|
189
|
+
subgraph,
|
190
|
+
0,
|
191
|
+
[1, 2],
|
192
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
193
|
+
)
|
194
|
+
)
|
195
|
+
|
196
|
+
# check dequant op code is added to the model
|
197
|
+
self.assertEqual(
|
198
|
+
model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
|
199
|
+
dequant_opcode,
|
200
|
+
)
|
201
|
+
|
202
|
+
# check new tensor is correct created
|
203
|
+
self.assertIn(b"_dequant", subgraph.tensors[9].name)
|
204
|
+
self.assertEqual(
|
205
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
|
206
|
+
)
|
207
|
+
# check original source tensor is updated
|
208
|
+
self.assertEqual(
|
209
|
+
subgraph.tensors[1].type, schema_py_generated.TensorType.INT8
|
210
|
+
)
|
211
|
+
|
212
|
+
# checking the inserted node has the correct input/output
|
213
|
+
self.assertEqual(subgraph.operators[1].outputs[0], 9)
|
214
|
+
self.assertEqual(subgraph.operators[1].inputs[0], 1)
|
215
|
+
# checking inserted node is the dequant node
|
216
|
+
self.assertEqual(
|
217
|
+
subgraph.operators[1].opcodeIndex, len(model.operatorCodes) - 1
|
218
|
+
)
|
219
|
+
|
220
|
+
# checking if consumer haves the correct input
|
221
|
+
self.assertEqual(subgraph.operators[2].inputs[0], 9)
|
222
|
+
self.assertEqual(subgraph.operators[3].inputs[0], 9)
|
223
|
+
|
224
|
+
def test_dequant_insert_activation_multiple_consumers_select(self):
|
225
|
+
"""Test dequant insert lib on tensors with multiple consumers but only insert for one of them."""
|
226
|
+
subgraph = self._model.subgraphs[0]
|
227
|
+
model = self._model
|
228
|
+
dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
|
229
|
+
# insert dequant on the output of a conv node
|
230
|
+
dequant_insert.insert_dequant(
|
231
|
+
transformation_utils.TransformationInput(
|
232
|
+
1,
|
233
|
+
model.operatorCodes,
|
234
|
+
model.buffers,
|
235
|
+
subgraph,
|
236
|
+
0,
|
237
|
+
[1],
|
238
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
239
|
+
)
|
240
|
+
)
|
241
|
+
|
242
|
+
# check dequant op code is added to the model
|
243
|
+
self.assertEqual(
|
244
|
+
model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
|
245
|
+
dequant_opcode,
|
246
|
+
)
|
247
|
+
|
248
|
+
# check new tensor is correct created
|
249
|
+
self.assertIn(b"_dequant", subgraph.tensors[9].name)
|
250
|
+
self.assertEqual(
|
251
|
+
subgraph.tensors[9].type, schema_py_generated.TensorType.FLOAT32
|
252
|
+
)
|
253
|
+
# check original source tensor is updated
|
254
|
+
self.assertEqual(
|
255
|
+
subgraph.tensors[1].type, schema_py_generated.TensorType.INT8
|
256
|
+
)
|
257
|
+
|
258
|
+
# checking the inserted node has the correct input/output
|
259
|
+
self.assertEqual(subgraph.operators[1].outputs[0], 9)
|
260
|
+
self.assertEqual(subgraph.operators[1].inputs[0], 1)
|
261
|
+
# checking inserted node is the dequant node
|
262
|
+
self.assertEqual(
|
263
|
+
subgraph.operators[1].opcodeIndex, len(model.operatorCodes) - 1
|
264
|
+
)
|
265
|
+
|
266
|
+
# checking if consumer haves the correct input
|
267
|
+
self.assertEqual(subgraph.operators[2].inputs[0], 9)
|
268
|
+
self.assertEqual(subgraph.operators[3].inputs[0], 1)
|
269
|
+
|
270
|
+
def test_dequant_insert_on_graph_output(self):
|
271
|
+
"""Test dequant insert lib on graph output."""
|
272
|
+
subgraph = self._model.subgraphs[0]
|
273
|
+
model = self._model
|
274
|
+
dequant_opcode = schema_py_generated.BuiltinOperator.DEQUANTIZE
|
275
|
+
# insert dequant on the graph output
|
276
|
+
dequant_insert.insert_dequant(
|
277
|
+
transformation_utils.TransformationInput(
|
278
|
+
8,
|
279
|
+
model.operatorCodes,
|
280
|
+
model.buffers,
|
281
|
+
subgraph,
|
282
|
+
4,
|
283
|
+
[-1],
|
284
|
+
qtyping.UniformQuantParams(8, None, np.array([1]), np.array([0])),
|
285
|
+
)
|
286
|
+
)
|
287
|
+
|
288
|
+
# check dequant op code is added to the model
|
289
|
+
self.assertEqual(
|
290
|
+
model.operatorCodes[len(model.operatorCodes) - 1].builtinCode,
|
291
|
+
dequant_opcode,
|
292
|
+
)
|
293
|
+
|
294
|
+
# checking inserted node is the dequant node
|
295
|
+
self.assertEqual(
|
296
|
+
subgraph.operators[5].opcodeIndex, len(model.operatorCodes) - 1
|
297
|
+
)
|
298
|
+
|
299
|
+
# check if the graph output is updated
|
300
|
+
self.assertEqual(subgraph.outputs[0], 9)
|
301
|
+
|
302
|
+
|
303
|
+
if __name__ == "__main__":
|
304
|
+
googletest.main()
|