ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  """quantize a given tensor."""
17
17
 
18
18
  from typing import Optional, cast
19
+ import ml_dtypes
19
20
  import numpy as np
20
21
  from ai_edge_quantizer import qtyping
21
22
  from ai_edge_quantizer.transformations import transformation_utils
@@ -67,29 +68,6 @@ def nonlinear_quant_params_to_tflite_type(
67
68
  raise ValueError(f"Unsupported nonlinear params: {bitwidth}")
68
69
 
69
70
 
70
- def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
71
- """Pack the data to the corresponding bit width.
72
-
73
- Currently only support 4 bits. If no packing is needed, the original data is
74
- returned.
75
-
76
- Args:
77
- bitwidth: Bit width from NonLinearQuantParams.
78
- flattened_data: The data to be packed.
79
-
80
- Returns:
81
- Packed data.
82
- """
83
- if bitwidth == 4:
84
- even_data = flattened_data[::2] & 0x0F
85
- odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
86
- if odd_data.shape[0] == even_data.shape[0] - 1:
87
- odd_data = np.pad(odd_data, (0, 1), constant_values=0)
88
- return np.bitwise_or(even_data, odd_data)
89
- else:
90
- return flattened_data
91
-
92
-
93
71
  def _perform_channelwise_quantization(
94
72
  transformation_input: transformation_utils.TransformationInput,
95
73
  ) -> schema_py_generated.QuantizationParametersT():
@@ -142,26 +120,25 @@ def _perform_blockwise_quantization(
142
120
  )
143
121
  tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
144
122
  blockwise_details = schema_py_generated.BlockwiseQuantizationT()
123
+ # Downcast and round the scale to fp16 with 7 bit mantissa.
145
124
  scale_tensor_id = transformation_utils.add_new_constant_tensor(
146
- tensor.name + b"_scale",
147
- transformation_input.quant_params.scale,
125
+ tensor.name + b"_scales",
126
+ transformation_input.quant_params.scale.astype(ml_dtypes.bfloat16).astype(
127
+ np.float16
128
+ ),
148
129
  schema_py_generated.TensorType.FLOAT16,
149
130
  transformation_input.subgraph,
150
131
  transformation_input.buffers,
151
132
  )
152
- blockwise_details.scale = scale_tensor_id
133
+ blockwise_details.scales = scale_tensor_id
134
+ # Blockwise quantization does not support zero point yet, so this points to
135
+ # a -1 buffer index.
136
+ # TODO: b/404909258 - Add optional zero point to blockwise quantization.
137
+ blockwise_details.zeroPoints = -1
153
138
  blockwise_details.blockSize = transformation_input.quant_params.block_size
154
- # blockwise quantization allows optional zero point.
155
- if transformation_input.quant_params.zero_point is not None:
156
- zero_point_tensor_id = transformation_utils.add_new_constant_tensor(
157
- tensor.name + b"_zero_point",
158
- transformation_input.quant_params.zero_point,
159
- schema_py_generated.TensorType.INT32,
160
- transformation_input.subgraph,
161
- transformation_input.buffers,
162
- )
163
- blockwise_details.zeroPoint = zero_point_tensor_id
164
139
  flatbuffer_quantization.details = blockwise_details
140
+ # TODO: b/443830202 - Hardcoding to 0 for now.
141
+ flatbuffer_quantization.quantizedDimension = 0
165
142
  return flatbuffer_quantization
166
143
 
167
144
 
@@ -185,14 +162,17 @@ def quantize_tensor(
185
162
  # is not provided.
186
163
  if tensor.buffer:
187
164
  if transformation_input.quant_params.quantized_data is not None:
188
- transformation_input.buffers[tensor.buffer].data = _pack_data(
189
- transformation_input.quant_params.num_bits,
190
- np.frombuffer(
191
- cast(
192
- np.ndarray, transformation_input.quant_params.quantized_data
193
- ).tobytes(),
194
- dtype=np.uint8,
195
- ).flatten(),
165
+ transformation_input.buffers[tensor.buffer].data = (
166
+ transformation_utils.pack_data(
167
+ transformation_input.quant_params.num_bits,
168
+ np.frombuffer(
169
+ cast(
170
+ np.ndarray,
171
+ transformation_input.quant_params.quantized_data,
172
+ ).tobytes(),
173
+ dtype=np.uint8,
174
+ ).flatten(),
175
+ )
196
176
  )
197
177
 
198
178
  if isinstance(transformation_input.quant_params, qtyping.UniformQuantParams):
@@ -168,8 +168,9 @@ class QuantizeTensorTest(parameterized.TestCase):
168
168
  )
169
169
  self.assertEqual(quant_param.details.blockSize, 32)
170
170
  # Check if the scale and zero point tensors are inserted correctly.
171
- self.assertEqual(quant_param.details.scale, 9)
172
- self.assertEqual(quant_param.details.zeroPoint, 10)
171
+ self.assertEqual(quant_param.details.scales, 9)
172
+ # So far we don't have zero point in blockwise quantization.
173
+ self.assertEqual(quant_param.details.zeroPoints, -1)
173
174
 
174
175
  def test_int4_constant_packed_correctly(self):
175
176
  subgraph = self._model.subgraphs[0]
@@ -15,8 +15,9 @@
15
15
 
16
16
  """Utility functions for graph transformations."""
17
17
 
18
+ import copy
18
19
  import dataclasses
19
- from typing import Union
20
+ from typing import Optional, Union
20
21
 
21
22
  import numpy as np
22
23
 
@@ -51,30 +52,94 @@ class TransformationInput:
51
52
  def add_op_code(
52
53
  op_code: schema_py_generated.OperatorCodeT,
53
54
  model_op_codes: list[schema_py_generated.OperatorCodeT],
55
+ custom_op_name: Optional[str] = None,
54
56
  ) -> int:
55
57
  """Add an op code into a model if it's not present.
56
58
 
57
59
  Args:
58
60
  op_code: The op code to be added.
59
61
  model_op_codes: The op codes of the model.
62
+ custom_op_name: The custom string of the op code. If None, the op code will
63
+ be added as a builtin op code.
60
64
 
61
65
  Returns:
62
66
  The index of the op code in the model.
63
67
  """
68
+ if (
69
+ op_code == schema_py_generated.BuiltinOperator.CUSTOM
70
+ and custom_op_name is None
71
+ ):
72
+ raise ValueError('Custom string is required for custom op code.')
73
+
64
74
  for i, model_op_code in enumerate(model_op_codes):
75
+ # If the model already has the op code, just return the index.
65
76
  if model_op_code.builtinCode == op_code:
66
- return i
77
+ if custom_op_name is not None:
78
+ if model_op_code.customCode == custom_op_name:
79
+ return i
80
+ else:
81
+ # Built-in op
82
+ return i
83
+
67
84
  model_op_codes.append(schema_py_generated.OperatorCodeT())
68
85
  model_op_codes[-1].builtinCode = op_code
86
+ if custom_op_name is not None:
87
+ model_op_codes[-1].customCode = custom_op_name
69
88
  return len(model_op_codes) - 1
70
89
 
71
90
 
91
+ def get_constant_buffer(
92
+ data: np.ndarray,
93
+ buffers: list[schema_py_generated.BufferT],
94
+ force_duplicate_buffer: bool = False,
95
+ ) -> int:
96
+ """Get the index of the constant buffer that contains the given data.
97
+
98
+ creating new buffer if provided data is not found in buffers list.
99
+
100
+ Args:
101
+ data: The data of the new tensor.
102
+ buffers: The buffers of the model.
103
+ force_duplicate_buffer: Whether to add a new buffer even if the same buffer
104
+ already exists.
105
+
106
+ Returns:
107
+ The index of the new buffer in the model.
108
+ """
109
+
110
+ if isinstance(data, np.ndarray):
111
+ # in the case where the data is passed from quantization_params.
112
+ new_data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
113
+ elif isinstance(data, bytes):
114
+ # in the case where the data is coming from duplicating buffers, we need to
115
+ # make a copy of the data to avoid having two buffers pointing to the same
116
+ # data.
117
+ new_data = copy.deepcopy(data)
118
+ else:
119
+ raise ValueError('data passed in must be either np.ndarray or bytes.')
120
+ # TODO: b/417811116 - we should make this more efficient.
121
+ if not force_duplicate_buffer:
122
+ for index, buffer in enumerate(buffers):
123
+ if np.array_equal(buffer.data, new_data):
124
+ return index
125
+ new_buffer = schema_py_generated.BufferT()
126
+ new_buffer.data = new_data
127
+ new_buffer.offset = 0
128
+ new_buffer.size = 0
129
+ new_buffer_id = len(buffers)
130
+ buffers.append(new_buffer)
131
+
132
+ return new_buffer_id
133
+
134
+
72
135
  def add_new_constant_tensor(
73
136
  tensor_name: str,
74
137
  data: np.ndarray,
75
138
  tensor_type: schema_py_generated.TensorType,
76
139
  subgraph: schema_py_generated.SubGraphT,
77
140
  buffers: list[schema_py_generated.BufferT],
141
+ tensor_shape: Optional[list[int]] = None,
142
+ force_duplicate_buffer: bool = False,
78
143
  ) -> int:
79
144
  """Add a new constant tensor to the model.
80
145
 
@@ -84,20 +149,21 @@ def add_new_constant_tensor(
84
149
  tensor_type: The type of the new tensor.
85
150
  subgraph: The subgraph where the new tensor is added.
86
151
  buffers: The buffers of the model.
152
+ tensor_shape: The shape of the new tensor. If not provided, the shape of the
153
+ data will be used.
154
+ force_duplicate_buffer: Whether to add a new buffer even if the same buffer
155
+ already exists.
87
156
 
88
157
  Returns:
89
158
  The index of the new tensor in the subgraph.
90
159
  """
91
- tensor_buffer = schema_py_generated.BufferT()
92
- tensor_buffer.data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
93
- tensor_buffer.offset = 0
94
- tensor_buffer.size = 0
95
- tensor_buffer_id = len(buffers)
96
- buffers.append(tensor_buffer)
160
+ new_buffer_id = get_constant_buffer(data, buffers, force_duplicate_buffer)
97
161
 
98
162
  new_tensor = schema_py_generated.TensorT()
99
- new_tensor.shape = data.shape
100
- new_tensor.buffer = tensor_buffer_id
163
+ if tensor_shape is None:
164
+ tensor_shape = data.shape
165
+ new_tensor.shape = tensor_shape
166
+ new_tensor.buffer = new_buffer_id
101
167
  new_tensor.type = tensor_type
102
168
  new_tensor.name = tensor_name
103
169
  new_tensor_id = len(subgraph.tensors)
@@ -123,10 +189,90 @@ def add_new_activation_tensor(
123
189
  The index of the new tensor in the subgraph.
124
190
  """
125
191
  new_tensor = schema_py_generated.TensorT()
126
- new_tensor.shape = shape
192
+ # If there's a dynamic shape, we need to read from the shapeSignature field
193
+ # instead of shape. Shape should contain just 1 for the dynamic dimension but
194
+ # shapeSignature should contain the true shape.
195
+ if -1 in shape:
196
+ new_tensor.shapeSignature = shape
197
+ new_tensor.shape = [1 if i == -1 else i for i in shape]
198
+ else:
199
+ new_tensor.shape = shape
127
200
  new_tensor.type = tensor_type
128
201
  new_tensor.name = tensor_name
129
202
  new_tensor.buffer = 0
130
203
  new_tensor_id = len(subgraph.tensors)
131
204
  subgraph.tensors.append(new_tensor)
132
205
  return new_tensor_id
206
+
207
+
208
+ def raise_deprecated_error(_: TransformationInput):
209
+ raise NotImplementedError(
210
+ 'This transformation is deprecated. Please contact AI Edge Quantizer team'
211
+ ' if you see this error.'
212
+ )
213
+
214
+
215
+ def pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
216
+ """Pack the data to the corresponding bit width.
217
+
218
+ Currently only support 4 bits. If no packing is needed, the original data is
219
+ returned.
220
+
221
+ Args:
222
+ bitwidth: Bit width from NonLinearQuantParams.
223
+ flattened_data: The data to be packed.
224
+
225
+ Returns:
226
+ Packed data.
227
+ """
228
+ if bitwidth == 4:
229
+ even_data = flattened_data[::2] & 0x0F
230
+ odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
231
+ if odd_data.shape[0] == even_data.shape[0] - 1:
232
+ odd_data = np.pad(odd_data, (0, 1), constant_values=0)
233
+ return np.bitwise_or(even_data, odd_data)
234
+ else:
235
+ return flattened_data
236
+
237
+
238
+ def get_producer_schema_op_id(
239
+ transformation: TransformationInput,
240
+ ) -> int:
241
+ """Checks if the tensor's producer matches the given op.
242
+
243
+ Args:
244
+ transformation: The transformation input to check the producer of.
245
+
246
+ Returns:
247
+ The schema op id of the producer op. E.g.
248
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED.
249
+ """
250
+ if transformation.producer == -1:
251
+ return False
252
+ else:
253
+ return (
254
+ transformation.op_codes[
255
+ transformation.subgraph.operators[
256
+ transformation.producer
257
+ ].opcodeIndex
258
+ ].builtinCode
259
+ )
260
+
261
+
262
+ def get_schema_op_id(
263
+ transformation: TransformationInput, op_id: int
264
+ ) -> bool:
265
+ """Returns the schema op id of the given op.
266
+
267
+ Args:
268
+ transformation: The transformation input to check the consumers of.
269
+ op_id: The op id in the list of operators to check for.
270
+
271
+ Returns:
272
+ The schema op id of the given op.
273
+ """
274
+ return (
275
+ transformation.op_codes[
276
+ transformation.subgraph.operators[op_id].opcodeIndex
277
+ ].builtinCode
278
+ )
@@ -41,19 +41,94 @@ class TransformationUtilsTest(parameterized.TestCase):
41
41
  testcase_name="add_new_op_code",
42
42
  op_code=schema_py_generated.BuiltinOperator.LOGISTIC,
43
43
  expected=1,
44
+ custom_op_name=None,
44
45
  ),
45
46
  dict(
46
47
  testcase_name="add_existing_op_code",
47
48
  op_code=schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
48
49
  expected=0,
50
+ custom_op_name=None,
51
+ ),
52
+ dict(
53
+ testcase_name="add_new_custom_op_code",
54
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
55
+ expected=1,
56
+ custom_op_name="random_new_custom_op",
49
57
  ),
50
58
  )
51
- def test_add_op_code(self, op_code, expected):
59
+ def test_add_op_code(self, op_code, expected, custom_op_name):
52
60
  """Tests if the op code is added to the model."""
53
61
  got = transformation_utils.add_op_code(
54
- op_code=op_code, model_op_codes=self.model.operatorCodes
62
+ op_code=op_code,
63
+ model_op_codes=self.model.operatorCodes,
64
+ custom_op_name=custom_op_name,
55
65
  )
56
66
  self.assertEqual(expected, got)
67
+ if custom_op_name is not None:
68
+ self.assertEqual(self.model.operatorCodes[got].customCode, custom_op_name)
69
+
70
+ def test_add_custom_op_code_without_op_string_raises_error(self):
71
+ with self.assertRaisesRegex(ValueError, "Custom string is required"):
72
+ transformation_utils.add_op_code(
73
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
74
+ model_op_codes=self.model.operatorCodes,
75
+ custom_op_name=None,
76
+ )
77
+
78
+ def test_add_two_custom_op_codes(self):
79
+ custom_op_name = "random_new_custom_op"
80
+ added_index = transformation_utils.add_op_code(
81
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
82
+ model_op_codes=self.model.operatorCodes,
83
+ custom_op_name=custom_op_name,
84
+ )
85
+ self.assertEqual(1, added_index)
86
+ self.assertEqual(
87
+ self.model.operatorCodes[added_index].customCode, custom_op_name
88
+ )
89
+
90
+ custom_op_name_2 = "random_new_custom_op_2"
91
+ added_index = transformation_utils.add_op_code(
92
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
93
+ model_op_codes=self.model.operatorCodes,
94
+ custom_op_name=custom_op_name_2,
95
+ )
96
+ self.assertEqual(2, added_index)
97
+ self.assertEqual(
98
+ self.model.operatorCodes[added_index].customCode, custom_op_name_2
99
+ )
100
+
101
+ @parameterized.named_parameters(
102
+ dict(
103
+ testcase_name="float32",
104
+ data=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
105
+ ),
106
+ dict(
107
+ testcase_name="int8",
108
+ data=np.array([[1, 2], [3, 4]], dtype=np.int8),
109
+ ),
110
+ )
111
+ def test_add_new_constant_buffer(self, data):
112
+ """Tests if the constant buffer is added to the model."""
113
+ prev_num_buffers = len(self.model.buffers) - 1
114
+ new_buffer_idx = transformation_utils.get_constant_buffer(
115
+ data=data,
116
+ buffers=self.model.buffers,
117
+ )
118
+ self.assertEqual(new_buffer_idx, prev_num_buffers + 1)
119
+
120
+ expected_buffer_data = (
121
+ np.frombuffer(
122
+ data.tobytes(),
123
+ dtype=np.uint8,
124
+ )
125
+ .flatten()
126
+ .tolist()
127
+ )
128
+ self.assertEqual(
129
+ self.model.buffers[new_buffer_idx].data.tolist(),
130
+ expected_buffer_data,
131
+ )
57
132
 
58
133
  @parameterized.named_parameters(
59
134
  dict(
@@ -157,6 +232,25 @@ class TransformationUtilsTest(parameterized.TestCase):
157
232
  self.model.subgraphs[0].tensors[-1].shape,
158
233
  )
159
234
 
235
+ def test_add_new_activation_tensor_with_dynamic_shape(self):
236
+ """Tests adding an activation tensor with dynamic shape."""
237
+ subgraph = self.model.subgraphs[0]
238
+ new_id = transformation_utils.add_new_activation_tensor(
239
+ tensor_name="test_tensor",
240
+ shape=[1, -1, -1, 1],
241
+ tensor_type=schema_py_generated.TensorType.FLOAT32,
242
+ subgraph=subgraph,
243
+ )
244
+ # Originally had 4 tensors, new tensor is added at index 4.
245
+ self.assertEqual(new_id, 4)
246
+ self.assertLen(subgraph.tensors, 5)
247
+ self.assertEqual(subgraph.tensors[-1].name, "test_tensor")
248
+ self.assertEqual(
249
+ subgraph.tensors[-1].type, schema_py_generated.TensorType.FLOAT32
250
+ )
251
+ self.assertEqual(subgraph.tensors[-1].shape, [1, 1, 1, 1])
252
+ self.assertEqual(subgraph.tensors[-1].shapeSignature, [1, -1, -1, 1])
253
+
160
254
 
161
255
  if __name__ == "__main__":
162
256
  googletest.main()