ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
"""quantize a given tensor."""
|
|
17
17
|
|
|
18
18
|
from typing import Optional, cast
|
|
19
|
+
import ml_dtypes
|
|
19
20
|
import numpy as np
|
|
20
21
|
from ai_edge_quantizer import qtyping
|
|
21
22
|
from ai_edge_quantizer.transformations import transformation_utils
|
|
@@ -67,29 +68,6 @@ def nonlinear_quant_params_to_tflite_type(
|
|
|
67
68
|
raise ValueError(f"Unsupported nonlinear params: {bitwidth}")
|
|
68
69
|
|
|
69
70
|
|
|
70
|
-
def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
|
|
71
|
-
"""Pack the data to the corresponding bit width.
|
|
72
|
-
|
|
73
|
-
Currently only support 4 bits. If no packing is needed, the original data is
|
|
74
|
-
returned.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
bitwidth: Bit width from NonLinearQuantParams.
|
|
78
|
-
flattened_data: The data to be packed.
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
Packed data.
|
|
82
|
-
"""
|
|
83
|
-
if bitwidth == 4:
|
|
84
|
-
even_data = flattened_data[::2] & 0x0F
|
|
85
|
-
odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
|
|
86
|
-
if odd_data.shape[0] == even_data.shape[0] - 1:
|
|
87
|
-
odd_data = np.pad(odd_data, (0, 1), constant_values=0)
|
|
88
|
-
return np.bitwise_or(even_data, odd_data)
|
|
89
|
-
else:
|
|
90
|
-
return flattened_data
|
|
91
|
-
|
|
92
|
-
|
|
93
71
|
def _perform_channelwise_quantization(
|
|
94
72
|
transformation_input: transformation_utils.TransformationInput,
|
|
95
73
|
) -> schema_py_generated.QuantizationParametersT():
|
|
@@ -142,26 +120,25 @@ def _perform_blockwise_quantization(
|
|
|
142
120
|
)
|
|
143
121
|
tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
|
|
144
122
|
blockwise_details = schema_py_generated.BlockwiseQuantizationT()
|
|
123
|
+
# Downcast and round the scale to fp16 with 7 bit mantissa.
|
|
145
124
|
scale_tensor_id = transformation_utils.add_new_constant_tensor(
|
|
146
|
-
tensor.name + b"
|
|
147
|
-
transformation_input.quant_params.scale
|
|
125
|
+
tensor.name + b"_scales",
|
|
126
|
+
transformation_input.quant_params.scale.astype(ml_dtypes.bfloat16).astype(
|
|
127
|
+
np.float16
|
|
128
|
+
),
|
|
148
129
|
schema_py_generated.TensorType.FLOAT16,
|
|
149
130
|
transformation_input.subgraph,
|
|
150
131
|
transformation_input.buffers,
|
|
151
132
|
)
|
|
152
|
-
blockwise_details.
|
|
133
|
+
blockwise_details.scales = scale_tensor_id
|
|
134
|
+
# Blockwise quantization does not support zero point yet, so this points to
|
|
135
|
+
# a -1 buffer index.
|
|
136
|
+
# TODO: b/404909258 - Add optional zero point to blockwise quantization.
|
|
137
|
+
blockwise_details.zeroPoints = -1
|
|
153
138
|
blockwise_details.blockSize = transformation_input.quant_params.block_size
|
|
154
|
-
# blockwise quantization allows optional zero point.
|
|
155
|
-
if transformation_input.quant_params.zero_point is not None:
|
|
156
|
-
zero_point_tensor_id = transformation_utils.add_new_constant_tensor(
|
|
157
|
-
tensor.name + b"_zero_point",
|
|
158
|
-
transformation_input.quant_params.zero_point,
|
|
159
|
-
schema_py_generated.TensorType.INT32,
|
|
160
|
-
transformation_input.subgraph,
|
|
161
|
-
transformation_input.buffers,
|
|
162
|
-
)
|
|
163
|
-
blockwise_details.zeroPoint = zero_point_tensor_id
|
|
164
139
|
flatbuffer_quantization.details = blockwise_details
|
|
140
|
+
# TODO: b/443830202 - Hardcoding to 0 for now.
|
|
141
|
+
flatbuffer_quantization.quantizedDimension = 0
|
|
165
142
|
return flatbuffer_quantization
|
|
166
143
|
|
|
167
144
|
|
|
@@ -185,14 +162,17 @@ def quantize_tensor(
|
|
|
185
162
|
# is not provided.
|
|
186
163
|
if tensor.buffer:
|
|
187
164
|
if transformation_input.quant_params.quantized_data is not None:
|
|
188
|
-
transformation_input.buffers[tensor.buffer].data =
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
165
|
+
transformation_input.buffers[tensor.buffer].data = (
|
|
166
|
+
transformation_utils.pack_data(
|
|
167
|
+
transformation_input.quant_params.num_bits,
|
|
168
|
+
np.frombuffer(
|
|
169
|
+
cast(
|
|
170
|
+
np.ndarray,
|
|
171
|
+
transformation_input.quant_params.quantized_data,
|
|
172
|
+
).tobytes(),
|
|
173
|
+
dtype=np.uint8,
|
|
174
|
+
).flatten(),
|
|
175
|
+
)
|
|
196
176
|
)
|
|
197
177
|
|
|
198
178
|
if isinstance(transformation_input.quant_params, qtyping.UniformQuantParams):
|
|
@@ -168,8 +168,9 @@ class QuantizeTensorTest(parameterized.TestCase):
|
|
|
168
168
|
)
|
|
169
169
|
self.assertEqual(quant_param.details.blockSize, 32)
|
|
170
170
|
# Check if the scale and zero point tensors are inserted correctly.
|
|
171
|
-
self.assertEqual(quant_param.details.
|
|
172
|
-
|
|
171
|
+
self.assertEqual(quant_param.details.scales, 9)
|
|
172
|
+
# So far we don't have zero point in blockwise quantization.
|
|
173
|
+
self.assertEqual(quant_param.details.zeroPoints, -1)
|
|
173
174
|
|
|
174
175
|
def test_int4_constant_packed_correctly(self):
|
|
175
176
|
subgraph = self._model.subgraphs[0]
|
|
@@ -15,8 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
"""Utility functions for graph transformations."""
|
|
17
17
|
|
|
18
|
+
import copy
|
|
18
19
|
import dataclasses
|
|
19
|
-
from typing import Union
|
|
20
|
+
from typing import Optional, Union
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
|
|
@@ -51,30 +52,94 @@ class TransformationInput:
|
|
|
51
52
|
def add_op_code(
|
|
52
53
|
op_code: schema_py_generated.OperatorCodeT,
|
|
53
54
|
model_op_codes: list[schema_py_generated.OperatorCodeT],
|
|
55
|
+
custom_op_name: Optional[str] = None,
|
|
54
56
|
) -> int:
|
|
55
57
|
"""Add an op code into a model if it's not present.
|
|
56
58
|
|
|
57
59
|
Args:
|
|
58
60
|
op_code: The op code to be added.
|
|
59
61
|
model_op_codes: The op codes of the model.
|
|
62
|
+
custom_op_name: The custom string of the op code. If None, the op code will
|
|
63
|
+
be added as a builtin op code.
|
|
60
64
|
|
|
61
65
|
Returns:
|
|
62
66
|
The index of the op code in the model.
|
|
63
67
|
"""
|
|
68
|
+
if (
|
|
69
|
+
op_code == schema_py_generated.BuiltinOperator.CUSTOM
|
|
70
|
+
and custom_op_name is None
|
|
71
|
+
):
|
|
72
|
+
raise ValueError('Custom string is required for custom op code.')
|
|
73
|
+
|
|
64
74
|
for i, model_op_code in enumerate(model_op_codes):
|
|
75
|
+
# If the model already has the op code, just return the index.
|
|
65
76
|
if model_op_code.builtinCode == op_code:
|
|
66
|
-
|
|
77
|
+
if custom_op_name is not None:
|
|
78
|
+
if model_op_code.customCode == custom_op_name:
|
|
79
|
+
return i
|
|
80
|
+
else:
|
|
81
|
+
# Built-in op
|
|
82
|
+
return i
|
|
83
|
+
|
|
67
84
|
model_op_codes.append(schema_py_generated.OperatorCodeT())
|
|
68
85
|
model_op_codes[-1].builtinCode = op_code
|
|
86
|
+
if custom_op_name is not None:
|
|
87
|
+
model_op_codes[-1].customCode = custom_op_name
|
|
69
88
|
return len(model_op_codes) - 1
|
|
70
89
|
|
|
71
90
|
|
|
91
|
+
def get_constant_buffer(
|
|
92
|
+
data: np.ndarray,
|
|
93
|
+
buffers: list[schema_py_generated.BufferT],
|
|
94
|
+
force_duplicate_buffer: bool = False,
|
|
95
|
+
) -> int:
|
|
96
|
+
"""Get the index of the constant buffer that contains the given data.
|
|
97
|
+
|
|
98
|
+
creating new buffer if provided data is not found in buffers list.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
data: The data of the new tensor.
|
|
102
|
+
buffers: The buffers of the model.
|
|
103
|
+
force_duplicate_buffer: Whether to add a new buffer even if the same buffer
|
|
104
|
+
already exists.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
The index of the new buffer in the model.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
if isinstance(data, np.ndarray):
|
|
111
|
+
# in the case where the data is passed from quantization_params.
|
|
112
|
+
new_data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
|
|
113
|
+
elif isinstance(data, bytes):
|
|
114
|
+
# in the case where the data is coming from duplicating buffers, we need to
|
|
115
|
+
# make a copy of the data to avoid having two buffers pointing to the same
|
|
116
|
+
# data.
|
|
117
|
+
new_data = copy.deepcopy(data)
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError('data passed in must be either np.ndarray or bytes.')
|
|
120
|
+
# TODO: b/417811116 - we should make this more efficient.
|
|
121
|
+
if not force_duplicate_buffer:
|
|
122
|
+
for index, buffer in enumerate(buffers):
|
|
123
|
+
if np.array_equal(buffer.data, new_data):
|
|
124
|
+
return index
|
|
125
|
+
new_buffer = schema_py_generated.BufferT()
|
|
126
|
+
new_buffer.data = new_data
|
|
127
|
+
new_buffer.offset = 0
|
|
128
|
+
new_buffer.size = 0
|
|
129
|
+
new_buffer_id = len(buffers)
|
|
130
|
+
buffers.append(new_buffer)
|
|
131
|
+
|
|
132
|
+
return new_buffer_id
|
|
133
|
+
|
|
134
|
+
|
|
72
135
|
def add_new_constant_tensor(
|
|
73
136
|
tensor_name: str,
|
|
74
137
|
data: np.ndarray,
|
|
75
138
|
tensor_type: schema_py_generated.TensorType,
|
|
76
139
|
subgraph: schema_py_generated.SubGraphT,
|
|
77
140
|
buffers: list[schema_py_generated.BufferT],
|
|
141
|
+
tensor_shape: Optional[list[int]] = None,
|
|
142
|
+
force_duplicate_buffer: bool = False,
|
|
78
143
|
) -> int:
|
|
79
144
|
"""Add a new constant tensor to the model.
|
|
80
145
|
|
|
@@ -84,20 +149,21 @@ def add_new_constant_tensor(
|
|
|
84
149
|
tensor_type: The type of the new tensor.
|
|
85
150
|
subgraph: The subgraph where the new tensor is added.
|
|
86
151
|
buffers: The buffers of the model.
|
|
152
|
+
tensor_shape: The shape of the new tensor. If not provided, the shape of the
|
|
153
|
+
data will be used.
|
|
154
|
+
force_duplicate_buffer: Whether to add a new buffer even if the same buffer
|
|
155
|
+
already exists.
|
|
87
156
|
|
|
88
157
|
Returns:
|
|
89
158
|
The index of the new tensor in the subgraph.
|
|
90
159
|
"""
|
|
91
|
-
|
|
92
|
-
tensor_buffer.data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
|
|
93
|
-
tensor_buffer.offset = 0
|
|
94
|
-
tensor_buffer.size = 0
|
|
95
|
-
tensor_buffer_id = len(buffers)
|
|
96
|
-
buffers.append(tensor_buffer)
|
|
160
|
+
new_buffer_id = get_constant_buffer(data, buffers, force_duplicate_buffer)
|
|
97
161
|
|
|
98
162
|
new_tensor = schema_py_generated.TensorT()
|
|
99
|
-
|
|
100
|
-
|
|
163
|
+
if tensor_shape is None:
|
|
164
|
+
tensor_shape = data.shape
|
|
165
|
+
new_tensor.shape = tensor_shape
|
|
166
|
+
new_tensor.buffer = new_buffer_id
|
|
101
167
|
new_tensor.type = tensor_type
|
|
102
168
|
new_tensor.name = tensor_name
|
|
103
169
|
new_tensor_id = len(subgraph.tensors)
|
|
@@ -123,10 +189,90 @@ def add_new_activation_tensor(
|
|
|
123
189
|
The index of the new tensor in the subgraph.
|
|
124
190
|
"""
|
|
125
191
|
new_tensor = schema_py_generated.TensorT()
|
|
126
|
-
|
|
192
|
+
# If there's a dynamic shape, we need to read from the shapeSignature field
|
|
193
|
+
# instead of shape. Shape should contain just 1 for the dynamic dimension but
|
|
194
|
+
# shapeSignature should contain the true shape.
|
|
195
|
+
if -1 in shape:
|
|
196
|
+
new_tensor.shapeSignature = shape
|
|
197
|
+
new_tensor.shape = [1 if i == -1 else i for i in shape]
|
|
198
|
+
else:
|
|
199
|
+
new_tensor.shape = shape
|
|
127
200
|
new_tensor.type = tensor_type
|
|
128
201
|
new_tensor.name = tensor_name
|
|
129
202
|
new_tensor.buffer = 0
|
|
130
203
|
new_tensor_id = len(subgraph.tensors)
|
|
131
204
|
subgraph.tensors.append(new_tensor)
|
|
132
205
|
return new_tensor_id
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def raise_deprecated_error(_: TransformationInput):
|
|
209
|
+
raise NotImplementedError(
|
|
210
|
+
'This transformation is deprecated. Please contact AI Edge Quantizer team'
|
|
211
|
+
' if you see this error.'
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
|
|
216
|
+
"""Pack the data to the corresponding bit width.
|
|
217
|
+
|
|
218
|
+
Currently only support 4 bits. If no packing is needed, the original data is
|
|
219
|
+
returned.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
bitwidth: Bit width from NonLinearQuantParams.
|
|
223
|
+
flattened_data: The data to be packed.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Packed data.
|
|
227
|
+
"""
|
|
228
|
+
if bitwidth == 4:
|
|
229
|
+
even_data = flattened_data[::2] & 0x0F
|
|
230
|
+
odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
|
|
231
|
+
if odd_data.shape[0] == even_data.shape[0] - 1:
|
|
232
|
+
odd_data = np.pad(odd_data, (0, 1), constant_values=0)
|
|
233
|
+
return np.bitwise_or(even_data, odd_data)
|
|
234
|
+
else:
|
|
235
|
+
return flattened_data
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def get_producer_schema_op_id(
|
|
239
|
+
transformation: TransformationInput,
|
|
240
|
+
) -> int:
|
|
241
|
+
"""Checks if the tensor's producer matches the given op.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
transformation: The transformation input to check the producer of.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
The schema op id of the producer op. E.g.
|
|
248
|
+
schema_py_generated.BuiltinOperator.FULLY_CONNECTED.
|
|
249
|
+
"""
|
|
250
|
+
if transformation.producer == -1:
|
|
251
|
+
return False
|
|
252
|
+
else:
|
|
253
|
+
return (
|
|
254
|
+
transformation.op_codes[
|
|
255
|
+
transformation.subgraph.operators[
|
|
256
|
+
transformation.producer
|
|
257
|
+
].opcodeIndex
|
|
258
|
+
].builtinCode
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def get_schema_op_id(
|
|
263
|
+
transformation: TransformationInput, op_id: int
|
|
264
|
+
) -> bool:
|
|
265
|
+
"""Returns the schema op id of the given op.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
transformation: The transformation input to check the consumers of.
|
|
269
|
+
op_id: The op id in the list of operators to check for.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
The schema op id of the given op.
|
|
273
|
+
"""
|
|
274
|
+
return (
|
|
275
|
+
transformation.op_codes[
|
|
276
|
+
transformation.subgraph.operators[op_id].opcodeIndex
|
|
277
|
+
].builtinCode
|
|
278
|
+
)
|
|
@@ -41,19 +41,94 @@ class TransformationUtilsTest(parameterized.TestCase):
|
|
|
41
41
|
testcase_name="add_new_op_code",
|
|
42
42
|
op_code=schema_py_generated.BuiltinOperator.LOGISTIC,
|
|
43
43
|
expected=1,
|
|
44
|
+
custom_op_name=None,
|
|
44
45
|
),
|
|
45
46
|
dict(
|
|
46
47
|
testcase_name="add_existing_op_code",
|
|
47
48
|
op_code=schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
|
|
48
49
|
expected=0,
|
|
50
|
+
custom_op_name=None,
|
|
51
|
+
),
|
|
52
|
+
dict(
|
|
53
|
+
testcase_name="add_new_custom_op_code",
|
|
54
|
+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
|
|
55
|
+
expected=1,
|
|
56
|
+
custom_op_name="random_new_custom_op",
|
|
49
57
|
),
|
|
50
58
|
)
|
|
51
|
-
def test_add_op_code(self, op_code, expected):
|
|
59
|
+
def test_add_op_code(self, op_code, expected, custom_op_name):
|
|
52
60
|
"""Tests if the op code is added to the model."""
|
|
53
61
|
got = transformation_utils.add_op_code(
|
|
54
|
-
op_code=op_code,
|
|
62
|
+
op_code=op_code,
|
|
63
|
+
model_op_codes=self.model.operatorCodes,
|
|
64
|
+
custom_op_name=custom_op_name,
|
|
55
65
|
)
|
|
56
66
|
self.assertEqual(expected, got)
|
|
67
|
+
if custom_op_name is not None:
|
|
68
|
+
self.assertEqual(self.model.operatorCodes[got].customCode, custom_op_name)
|
|
69
|
+
|
|
70
|
+
def test_add_custom_op_code_without_op_string_raises_error(self):
|
|
71
|
+
with self.assertRaisesRegex(ValueError, "Custom string is required"):
|
|
72
|
+
transformation_utils.add_op_code(
|
|
73
|
+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
|
|
74
|
+
model_op_codes=self.model.operatorCodes,
|
|
75
|
+
custom_op_name=None,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def test_add_two_custom_op_codes(self):
|
|
79
|
+
custom_op_name = "random_new_custom_op"
|
|
80
|
+
added_index = transformation_utils.add_op_code(
|
|
81
|
+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
|
|
82
|
+
model_op_codes=self.model.operatorCodes,
|
|
83
|
+
custom_op_name=custom_op_name,
|
|
84
|
+
)
|
|
85
|
+
self.assertEqual(1, added_index)
|
|
86
|
+
self.assertEqual(
|
|
87
|
+
self.model.operatorCodes[added_index].customCode, custom_op_name
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
custom_op_name_2 = "random_new_custom_op_2"
|
|
91
|
+
added_index = transformation_utils.add_op_code(
|
|
92
|
+
op_code=schema_py_generated.BuiltinOperator.CUSTOM,
|
|
93
|
+
model_op_codes=self.model.operatorCodes,
|
|
94
|
+
custom_op_name=custom_op_name_2,
|
|
95
|
+
)
|
|
96
|
+
self.assertEqual(2, added_index)
|
|
97
|
+
self.assertEqual(
|
|
98
|
+
self.model.operatorCodes[added_index].customCode, custom_op_name_2
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
@parameterized.named_parameters(
|
|
102
|
+
dict(
|
|
103
|
+
testcase_name="float32",
|
|
104
|
+
data=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
|
|
105
|
+
),
|
|
106
|
+
dict(
|
|
107
|
+
testcase_name="int8",
|
|
108
|
+
data=np.array([[1, 2], [3, 4]], dtype=np.int8),
|
|
109
|
+
),
|
|
110
|
+
)
|
|
111
|
+
def test_add_new_constant_buffer(self, data):
|
|
112
|
+
"""Tests if the constant buffer is added to the model."""
|
|
113
|
+
prev_num_buffers = len(self.model.buffers) - 1
|
|
114
|
+
new_buffer_idx = transformation_utils.get_constant_buffer(
|
|
115
|
+
data=data,
|
|
116
|
+
buffers=self.model.buffers,
|
|
117
|
+
)
|
|
118
|
+
self.assertEqual(new_buffer_idx, prev_num_buffers + 1)
|
|
119
|
+
|
|
120
|
+
expected_buffer_data = (
|
|
121
|
+
np.frombuffer(
|
|
122
|
+
data.tobytes(),
|
|
123
|
+
dtype=np.uint8,
|
|
124
|
+
)
|
|
125
|
+
.flatten()
|
|
126
|
+
.tolist()
|
|
127
|
+
)
|
|
128
|
+
self.assertEqual(
|
|
129
|
+
self.model.buffers[new_buffer_idx].data.tolist(),
|
|
130
|
+
expected_buffer_data,
|
|
131
|
+
)
|
|
57
132
|
|
|
58
133
|
@parameterized.named_parameters(
|
|
59
134
|
dict(
|
|
@@ -157,6 +232,25 @@ class TransformationUtilsTest(parameterized.TestCase):
|
|
|
157
232
|
self.model.subgraphs[0].tensors[-1].shape,
|
|
158
233
|
)
|
|
159
234
|
|
|
235
|
+
def test_add_new_activation_tensor_with_dynamic_shape(self):
|
|
236
|
+
"""Tests adding an activation tensor with dynamic shape."""
|
|
237
|
+
subgraph = self.model.subgraphs[0]
|
|
238
|
+
new_id = transformation_utils.add_new_activation_tensor(
|
|
239
|
+
tensor_name="test_tensor",
|
|
240
|
+
shape=[1, -1, -1, 1],
|
|
241
|
+
tensor_type=schema_py_generated.TensorType.FLOAT32,
|
|
242
|
+
subgraph=subgraph,
|
|
243
|
+
)
|
|
244
|
+
# Originally had 4 tensors, new tensor is added at index 4.
|
|
245
|
+
self.assertEqual(new_id, 4)
|
|
246
|
+
self.assertLen(subgraph.tensors, 5)
|
|
247
|
+
self.assertEqual(subgraph.tensors[-1].name, "test_tensor")
|
|
248
|
+
self.assertEqual(
|
|
249
|
+
subgraph.tensors[-1].type, schema_py_generated.TensorType.FLOAT32
|
|
250
|
+
)
|
|
251
|
+
self.assertEqual(subgraph.tensors[-1].shape, [1, 1, 1, 1])
|
|
252
|
+
self.assertEqual(subgraph.tensors[-1].shapeSignature, [1, -1, -1, 1])
|
|
253
|
+
|
|
160
254
|
|
|
161
255
|
if __name__ == "__main__":
|
|
162
256
|
googletest.main()
|