ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -170,7 +170,7 @@ class QuantizeTensorTest(parameterized.TestCase):
170
170
  # Check if the scale and zero point tensors are inserted correctly.
171
171
  self.assertEqual(quant_param.details.scales, 9)
172
172
  # So far we don't have zero point in blockwise quantization.
173
- self.assertEqual(quant_param.details.zeroPoints, 0)
173
+ self.assertEqual(quant_param.details.zeroPoints, -1)
174
174
 
175
175
  def test_int4_constant_packed_correctly(self):
176
176
  subgraph = self._model.subgraphs[0]
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Utility functions for graph transformations."""
17
17
 
18
+ import copy
18
19
  import dataclasses
19
20
  from typing import Optional, Union
20
21
 
@@ -51,39 +52,78 @@ class TransformationInput:
51
52
  def add_op_code(
52
53
  op_code: schema_py_generated.OperatorCodeT,
53
54
  model_op_codes: list[schema_py_generated.OperatorCodeT],
55
+ custom_op_name: Optional[str] = None,
54
56
  ) -> int:
55
57
  """Add an op code into a model if it's not present.
56
58
 
57
59
  Args:
58
60
  op_code: The op code to be added.
59
61
  model_op_codes: The op codes of the model.
62
+ custom_op_name: The custom string of the op code. If None, the op code will
63
+ be added as a builtin op code.
60
64
 
61
65
  Returns:
62
66
  The index of the op code in the model.
63
67
  """
68
+ if (
69
+ op_code == schema_py_generated.BuiltinOperator.CUSTOM
70
+ and custom_op_name is None
71
+ ):
72
+ raise ValueError('Custom string is required for custom op code.')
73
+
64
74
  for i, model_op_code in enumerate(model_op_codes):
75
+ # If the model already has the op code, just return the index.
65
76
  if model_op_code.builtinCode == op_code:
66
- return i
77
+ if custom_op_name is not None:
78
+ if model_op_code.customCode == custom_op_name:
79
+ return i
80
+ else:
81
+ # Built-in op
82
+ return i
83
+
67
84
  model_op_codes.append(schema_py_generated.OperatorCodeT())
68
85
  model_op_codes[-1].builtinCode = op_code
86
+ if custom_op_name is not None:
87
+ model_op_codes[-1].customCode = custom_op_name
69
88
  return len(model_op_codes) - 1
70
89
 
71
90
 
72
- def add_new_constant_buffer(
91
+ def get_constant_buffer(
73
92
  data: np.ndarray,
74
93
  buffers: list[schema_py_generated.BufferT],
94
+ force_duplicate_buffer: bool = False,
75
95
  ) -> int:
76
- """Add a new constant buffer to the model.
96
+ """Get the index of the constant buffer that contains the given data.
97
+
98
+ creating new buffer if provided data is not found in buffers list.
77
99
 
78
100
  Args:
79
101
  data: The data of the new tensor.
80
102
  buffers: The buffers of the model.
103
+ force_duplicate_buffer: Whether to add a new buffer even if the same buffer
104
+ already exists.
81
105
 
82
106
  Returns:
83
107
  The index of the new buffer in the model.
84
108
  """
109
+
110
+ if isinstance(data, np.ndarray):
111
+ # in the case where the data is passed from quantization_params.
112
+ new_data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
113
+ elif isinstance(data, bytes):
114
+ # in the case where the data is coming from duplicating buffers, we need to
115
+ # make a copy of the data to avoid having two buffers pointing to the same
116
+ # data.
117
+ new_data = copy.deepcopy(data)
118
+ else:
119
+ raise ValueError('data passed in must be either np.ndarray or bytes.')
120
+ # TODO: b/417811116 - we should make this more efficient.
121
+ if not force_duplicate_buffer:
122
+ for index, buffer in enumerate(buffers):
123
+ if np.array_equal(buffer.data, new_data):
124
+ return index
85
125
  new_buffer = schema_py_generated.BufferT()
86
- new_buffer.data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
126
+ new_buffer.data = new_data
87
127
  new_buffer.offset = 0
88
128
  new_buffer.size = 0
89
129
  new_buffer_id = len(buffers)
@@ -99,6 +139,7 @@ def add_new_constant_tensor(
99
139
  subgraph: schema_py_generated.SubGraphT,
100
140
  buffers: list[schema_py_generated.BufferT],
101
141
  tensor_shape: Optional[list[int]] = None,
142
+ force_duplicate_buffer: bool = False,
102
143
  ) -> int:
103
144
  """Add a new constant tensor to the model.
104
145
 
@@ -110,11 +151,13 @@ def add_new_constant_tensor(
110
151
  buffers: The buffers of the model.
111
152
  tensor_shape: The shape of the new tensor. If not provided, the shape of the
112
153
  data will be used.
154
+ force_duplicate_buffer: Whether to add a new buffer even if the same buffer
155
+ already exists.
113
156
 
114
157
  Returns:
115
158
  The index of the new tensor in the subgraph.
116
159
  """
117
- new_buffer_id = add_new_constant_buffer(data, buffers)
160
+ new_buffer_id = get_constant_buffer(data, buffers, force_duplicate_buffer)
118
161
 
119
162
  new_tensor = schema_py_generated.TensorT()
120
163
  if tensor_shape is None:
@@ -146,10 +189,90 @@ def add_new_activation_tensor(
146
189
  The index of the new tensor in the subgraph.
147
190
  """
148
191
  new_tensor = schema_py_generated.TensorT()
149
- new_tensor.shape = shape
192
+ # If there's a dynamic shape, we need to read from the shapeSignature field
193
+ # instead of shape. Shape should contain just 1 for the dynamic dimension but
194
+ # shapeSignature should contain the true shape.
195
+ if -1 in shape:
196
+ new_tensor.shapeSignature = shape
197
+ new_tensor.shape = [1 if i == -1 else i for i in shape]
198
+ else:
199
+ new_tensor.shape = shape
150
200
  new_tensor.type = tensor_type
151
201
  new_tensor.name = tensor_name
152
202
  new_tensor.buffer = 0
153
203
  new_tensor_id = len(subgraph.tensors)
154
204
  subgraph.tensors.append(new_tensor)
155
205
  return new_tensor_id
206
+
207
+
208
+ def raise_deprecated_error(_: TransformationInput):
209
+ raise NotImplementedError(
210
+ 'This transformation is deprecated. Please contact AI Edge Quantizer team'
211
+ ' if you see this error.'
212
+ )
213
+
214
+
215
+ def pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
216
+ """Pack the data to the corresponding bit width.
217
+
218
+ Currently only support 4 bits. If no packing is needed, the original data is
219
+ returned.
220
+
221
+ Args:
222
+ bitwidth: Bit width from NonLinearQuantParams.
223
+ flattened_data: The data to be packed.
224
+
225
+ Returns:
226
+ Packed data.
227
+ """
228
+ if bitwidth == 4:
229
+ even_data = flattened_data[::2] & 0x0F
230
+ odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
231
+ if odd_data.shape[0] == even_data.shape[0] - 1:
232
+ odd_data = np.pad(odd_data, (0, 1), constant_values=0)
233
+ return np.bitwise_or(even_data, odd_data)
234
+ else:
235
+ return flattened_data
236
+
237
+
238
+ def get_producer_schema_op_id(
239
+ transformation: TransformationInput,
240
+ ) -> int:
241
+ """Checks if the tensor's producer matches the given op.
242
+
243
+ Args:
244
+ transformation: The transformation input to check the producer of.
245
+
246
+ Returns:
247
+ The schema op id of the producer op. E.g.
248
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED.
249
+ """
250
+ if transformation.producer == -1:
251
+ return False
252
+ else:
253
+ return (
254
+ transformation.op_codes[
255
+ transformation.subgraph.operators[
256
+ transformation.producer
257
+ ].opcodeIndex
258
+ ].builtinCode
259
+ )
260
+
261
+
262
+ def get_schema_op_id(
263
+ transformation: TransformationInput, op_id: int
264
+ ) -> bool:
265
+ """Returns the schema op id of the given op.
266
+
267
+ Args:
268
+ transformation: The transformation input to check the consumers of.
269
+ op_id: The op id in the list of operators to check for.
270
+
271
+ Returns:
272
+ The schema op id of the given op.
273
+ """
274
+ return (
275
+ transformation.op_codes[
276
+ transformation.subgraph.operators[op_id].opcodeIndex
277
+ ].builtinCode
278
+ )
@@ -41,19 +41,62 @@ class TransformationUtilsTest(parameterized.TestCase):
41
41
  testcase_name="add_new_op_code",
42
42
  op_code=schema_py_generated.BuiltinOperator.LOGISTIC,
43
43
  expected=1,
44
+ custom_op_name=None,
44
45
  ),
45
46
  dict(
46
47
  testcase_name="add_existing_op_code",
47
48
  op_code=schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
48
49
  expected=0,
50
+ custom_op_name=None,
51
+ ),
52
+ dict(
53
+ testcase_name="add_new_custom_op_code",
54
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
55
+ expected=1,
56
+ custom_op_name="random_new_custom_op",
49
57
  ),
50
58
  )
51
- def test_add_op_code(self, op_code, expected):
59
+ def test_add_op_code(self, op_code, expected, custom_op_name):
52
60
  """Tests if the op code is added to the model."""
53
61
  got = transformation_utils.add_op_code(
54
- op_code=op_code, model_op_codes=self.model.operatorCodes
62
+ op_code=op_code,
63
+ model_op_codes=self.model.operatorCodes,
64
+ custom_op_name=custom_op_name,
55
65
  )
56
66
  self.assertEqual(expected, got)
67
+ if custom_op_name is not None:
68
+ self.assertEqual(self.model.operatorCodes[got].customCode, custom_op_name)
69
+
70
+ def test_add_custom_op_code_without_op_string_raises_error(self):
71
+ with self.assertRaisesRegex(ValueError, "Custom string is required"):
72
+ transformation_utils.add_op_code(
73
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
74
+ model_op_codes=self.model.operatorCodes,
75
+ custom_op_name=None,
76
+ )
77
+
78
+ def test_add_two_custom_op_codes(self):
79
+ custom_op_name = "random_new_custom_op"
80
+ added_index = transformation_utils.add_op_code(
81
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
82
+ model_op_codes=self.model.operatorCodes,
83
+ custom_op_name=custom_op_name,
84
+ )
85
+ self.assertEqual(1, added_index)
86
+ self.assertEqual(
87
+ self.model.operatorCodes[added_index].customCode, custom_op_name
88
+ )
89
+
90
+ custom_op_name_2 = "random_new_custom_op_2"
91
+ added_index = transformation_utils.add_op_code(
92
+ op_code=schema_py_generated.BuiltinOperator.CUSTOM,
93
+ model_op_codes=self.model.operatorCodes,
94
+ custom_op_name=custom_op_name_2,
95
+ )
96
+ self.assertEqual(2, added_index)
97
+ self.assertEqual(
98
+ self.model.operatorCodes[added_index].customCode, custom_op_name_2
99
+ )
57
100
 
58
101
  @parameterized.named_parameters(
59
102
  dict(
@@ -68,7 +111,7 @@ class TransformationUtilsTest(parameterized.TestCase):
68
111
  def test_add_new_constant_buffer(self, data):
69
112
  """Tests if the constant buffer is added to the model."""
70
113
  prev_num_buffers = len(self.model.buffers) - 1
71
- new_buffer_idx = transformation_utils.add_new_constant_buffer(
114
+ new_buffer_idx = transformation_utils.get_constant_buffer(
72
115
  data=data,
73
116
  buffers=self.model.buffers,
74
117
  )
@@ -189,6 +232,25 @@ class TransformationUtilsTest(parameterized.TestCase):
189
232
  self.model.subgraphs[0].tensors[-1].shape,
190
233
  )
191
234
 
235
+ def test_add_new_activation_tensor_with_dynamic_shape(self):
236
+ """Tests adding an activation tensor with dynamic shape."""
237
+ subgraph = self.model.subgraphs[0]
238
+ new_id = transformation_utils.add_new_activation_tensor(
239
+ tensor_name="test_tensor",
240
+ shape=[1, -1, -1, 1],
241
+ tensor_type=schema_py_generated.TensorType.FLOAT32,
242
+ subgraph=subgraph,
243
+ )
244
+ # Originally had 4 tensors, new tensor is added at index 4.
245
+ self.assertEqual(new_id, 4)
246
+ self.assertLen(subgraph.tensors, 5)
247
+ self.assertEqual(subgraph.tensors[-1].name, "test_tensor")
248
+ self.assertEqual(
249
+ subgraph.tensors[-1].type, schema_py_generated.TensorType.FLOAT32
250
+ )
251
+ self.assertEqual(subgraph.tensors[-1].shape, [1, 1, 1, 1])
252
+ self.assertEqual(subgraph.tensors[-1].shapeSignature, [1, -1, -1, 1])
253
+
192
254
 
193
255
  if __name__ == "__main__":
194
256
  googletest.main()
@@ -15,9 +15,24 @@
15
15
 
16
16
  """Utilities for model calibration."""
17
17
 
18
- from typing import Union
18
+ import copy
19
+ from typing import Any, Union
20
+
19
21
  import numpy as np
22
+
23
+ from ai_edge_litert.tools import flatbuffer_utils
20
24
  from ai_edge_quantizer import qtyping
25
+ from ai_edge_quantizer.algorithms.utils import common_utils
26
+ from ai_edge_quantizer.utils import constrained_ops_utils
27
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
29
+
30
+
31
+ _SignatureInput = dict[str, Any]
32
+ _OpQuantConstraint = common_utils.OpQuantConstraint
33
+ _SignatureData = dict[
34
+ str, list[str]
35
+ ] # signature_key -> list of signature_names.
21
36
 
22
37
 
23
38
  def _update_moving_average(
@@ -84,3 +99,250 @@ def min_max_update(qsv: qtyping.QSV, new_qsv: qtyping.QSV) -> qtyping.QSV:
84
99
  updated_qsv["min"] = np.minimum(qsv["min"], new_qsv["min"])
85
100
  updated_qsv["max"] = np.maximum(qsv["max"], new_qsv["max"])
86
101
  return updated_qsv
102
+
103
+
104
+ def _find_overall_min_max(
105
+ qsv: qtyping.QSV, tensor_names: list[str]
106
+ ) -> tuple[np.ndarray, np.ndarray]:
107
+ """Finds the overall minimum and maximum values for the given tensors.
108
+
109
+ Args:
110
+ qsv: The quantization statistical value of the tensor (min/max).
111
+ tensor_names: The list of tensor names to find the minimum and maximum
112
+ values.
113
+
114
+ Returns:
115
+ The minimum and maximum values for the given tensors.
116
+ """
117
+ min_value = np.inf
118
+ max_value = -np.inf
119
+ for tensor_name in tensor_names:
120
+ min_value = min(min_value, qsv[tensor_name]["min"])
121
+ max_value = max(max_value, qsv[tensor_name]["max"])
122
+ return min_value, max_value
123
+
124
+
125
+ class CalibrationQsvAlignmentUtils:
126
+ """Calibration utils for alignment of QSVs.
127
+
128
+ This class is used to align QSVs for a given model. It builds a list of ops
129
+ that need to be constrained to the same scale as the input. Based on this
130
+ list, it finds the corresponding tensor names for a given signature data.
131
+ """
132
+
133
+ def __init__(self, model_path: str):
134
+ self._same_as_input_scale_ops = (
135
+ constrained_ops_utils.get_constrained_op_list(
136
+ _OpQuantConstraint.SAME_AS_INPUT_SCALE
137
+ )
138
+ )
139
+
140
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(model_path)
141
+ self._flatbuffer_object = tfl_flatbuffer_utils.read_model(model_path)
142
+
143
+ signature_keys = list(tfl_interpreter.get_signature_list().keys())
144
+
145
+ # Build a dict of signature runners.
146
+ self._signature_runners = {}
147
+ for signature_key in signature_keys:
148
+ signature_runner = tfl_interpreter.get_signature_runner(signature_key)
149
+ self._signature_runners[signature_key] = signature_runner
150
+
151
+ def _search_tensor_by_signature_name(
152
+ self, signature_key: str, signature_input_output_name: str, verbose=False
153
+ ) -> list[str]:
154
+ """Searches for a tensor name for a given signature by signature input or output name.
155
+
156
+ Args:
157
+ signature_key: Name of the signature.
158
+ signature_input_output_name: Name of the input or output in the signature.
159
+ verbose: Flag to enable verbose output.
160
+
161
+ Returns:
162
+ The list with one or two tensor names. The first one is the input tensor
163
+ name, and the second one is the output tensor name.
164
+ """
165
+
166
+ if verbose:
167
+ print("Searching tensor by signature name.")
168
+
169
+ tensor_names = []
170
+
171
+ # Search among inputs.
172
+ input_details = self._signature_runners[signature_key].get_input_details()
173
+ if signature_input_output_name in input_details.keys():
174
+ tensor_names.append(input_details[signature_input_output_name]["name"])
175
+
176
+ # Search among outputs.
177
+ output_details = self._signature_runners[signature_key].get_output_details()
178
+ if signature_input_output_name not in output_details:
179
+ if not tensor_names:
180
+ raise ValueError(
181
+ f"Signature {signature_key} does not have input or output"
182
+ f" `{signature_input_output_name}`"
183
+ )
184
+ return tensor_names
185
+
186
+ output_tensor_name = output_details[signature_input_output_name]["name"]
187
+ if verbose:
188
+ print(
189
+ ">> Starting recursive search for the output tensor name:"
190
+ f" {output_tensor_name}"
191
+ )
192
+
193
+ idx = self._signature_runners[signature_key]._subgraph_index # pylint: disable=protected-access
194
+ subgraph = self._flatbuffer_object.subgraphs[idx]
195
+ graph_info = qtyping.GraphInfo(
196
+ subgraph.tensors, self._flatbuffer_object.buffers
197
+ )
198
+
199
+ # Recursively search the graph for the output tensor name since it may be
200
+ # `SAME_AS_INPUT` constrainted.
201
+ operators = copy.deepcopy(subgraph.operators)
202
+ tensor_name = self._search_reverse_order_recursively(
203
+ graph_info, operators, output_tensor_name, indent=" ", verbose=verbose
204
+ )
205
+ tensor_names.append(tensor_name)
206
+
207
+ if verbose:
208
+ print(f"\n\nFound tensor name: {tensor_name}")
209
+
210
+ return tensor_names
211
+
212
+ def _search_reverse_order_recursively(
213
+ self,
214
+ graph_info: qtyping.GraphInfo,
215
+ operators: list[Any],
216
+ output_tensor_name: str,
217
+ indent: str,
218
+ verbose: bool = False,
219
+ ):
220
+ """Searches for a tensor name in reverse order recursively.
221
+
222
+ Stop criteria is when the tensor belongs to an operator that is not
223
+ `SAME_AS_INPUT` constrainted.
224
+
225
+ Args:
226
+ graph_info: Graph information.
227
+ operators: List of operators.
228
+ output_tensor_name: Name of the output tensor to search for.
229
+ indent: Indentation string for debug output.
230
+ verbose: Flag to enable verbose output.
231
+
232
+ Returns:
233
+ The name of the tensor found, or None if not found.
234
+ """
235
+ op_codes = self._flatbuffer_object.operatorCodes
236
+
237
+ while operators:
238
+ op = operators.pop()
239
+ op_code = op_codes[op.opcodeIndex].builtinCode
240
+ op_name = flatbuffer_utils.opcode_to_name(
241
+ self._flatbuffer_object, op.opcodeIndex
242
+ )
243
+ if op_code not in tfl_flatbuffer_utils.TFL_OP_CODE_TO_NAME:
244
+ continue
245
+ for output_idx in op.outputs:
246
+ if output_tensor_name == tfl_flatbuffer_utils.get_tensor_name(
247
+ graph_info.subgraph_tensors[output_idx]
248
+ ):
249
+ dbg_str = (
250
+ f"{indent}>> Found `{op_name}`, output tensor"
251
+ f" '{output_tensor_name}'"
252
+ )
253
+
254
+ if op_name not in self._same_as_input_scale_ops:
255
+ if verbose:
256
+ print(f"{dbg_str}, returning...")
257
+ return output_tensor_name
258
+
259
+ if verbose:
260
+ print(f"{dbg_str}, with SAME_AS_INPUT, search recursively among:")
261
+ for input_idx in op.inputs:
262
+ input_tensor_name = graph_info.subgraph_tensors[
263
+ input_idx
264
+ ].name.decode("utf-8")
265
+
266
+ if verbose:
267
+ print(f"{indent} Input: {input_tensor_name}")
268
+
269
+ return self._search_reverse_order_recursively(
270
+ graph_info,
271
+ operators,
272
+ input_tensor_name,
273
+ indent=f"{indent} ",
274
+ verbose=verbose,
275
+ )
276
+ return output_tensor_name
277
+
278
+ def align_quant_stats(
279
+ self, qsv: dict[str, Any], signature_data: _SignatureData
280
+ ) -> tuple[np.ndarray, np.ndarray]:
281
+ """Aligns quantization statistics for a given signature data.
282
+
283
+ This function takes quantization statistics and signature data as input,
284
+ identifies the tensors associated with the signature data, and aligns
285
+ the quantization statistics of these tensors by setting their minimum
286
+ and maximum values to the same value. This ensures that the tensors
287
+ have the same quantization parameters.
288
+
289
+ Args:
290
+ qsv: Quantization statistics.
291
+ signature_data: Signature data.
292
+
293
+ Returns:
294
+ Tuple of min and max values.
295
+ """
296
+ # Go over all signature info and find the corresponding tensor names.
297
+ tensor_names = []
298
+ for signature_key, signature_names in signature_data.items():
299
+ for signature_name in signature_names:
300
+ tensor_name = self._search_tensor_by_signature_name(
301
+ signature_key, signature_name
302
+ )
303
+ tensor_names.extend(tensor_name)
304
+
305
+ # Find min and max values accross all tensors.
306
+ min_value, max_value = _find_overall_min_max(qsv, tensor_names)
307
+
308
+ # Overwrite the min and max values in the QSV.
309
+ for tensor_name in tensor_names:
310
+ qsv[tensor_name]["min"] = min_value
311
+ qsv[tensor_name]["max"] = max_value
312
+
313
+ return min_value, max_value
314
+
315
+ def update_quant_stats(
316
+ self,
317
+ qsv: dict[str, Any],
318
+ signature_data: _SignatureData,
319
+ min_value: np.ndarray,
320
+ max_value: np.ndarray,
321
+ ):
322
+ """Updates quantization statistics for a given signature data.
323
+
324
+ This function updates the quantization statistics with the provided min, max
325
+ values for the tensors specified in the signature data.
326
+
327
+ Args:
328
+ qsv: Quantization statistics.
329
+ signature_data: Signature data.
330
+ min_value: Minimum value to update.
331
+ max_value: Maximum value to update.
332
+
333
+ Returns:
334
+ Updated quantization statistics.
335
+ """
336
+ # Go over all signature info and find the corresponding tensor names.
337
+ tensor_names = []
338
+ for signature_key, signature_names in signature_data.items():
339
+ for signature_name in signature_names:
340
+ tensor_name = self._search_tensor_by_signature_name(
341
+ signature_key, signature_name
342
+ )
343
+ tensor_names.extend(tensor_name)
344
+
345
+ # Overwrite the min and max values in the QSV.
346
+ for tensor_name in tensor_names:
347
+ qsv[tensor_name]["min"] = min_value
348
+ qsv[tensor_name]["max"] = max_value