ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -18,12 +18,55 @@
18
18
  import inspect as _inspect
19
19
  import os.path as _os_path
20
20
  import sys as _sys
21
- from typing import Any, Union
21
+ from typing import Optional, Union
22
22
 
23
- import numpy as np
23
+ from absl.testing import parameterized
24
24
 
25
+ from ai_edge_quantizer import model_validator
26
+ from ai_edge_quantizer import qtyping
27
+ from ai_edge_quantizer import quantizer
25
28
  from ai_edge_quantizer.utils import tfl_interpreter_utils
26
29
 
30
+ _ComputePrecision = qtyping.ComputePrecision
31
+ _OpName = qtyping.TFLOperationName
32
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
33
+ _OpQuantConfig = qtyping.OpQuantizationConfig
34
+ _AlgorithmName = quantizer.AlgorithmName
35
+ _Numeric = Union[int, float]
36
+
37
+
38
+ DEFAULT_ACTIVATION_QUANT_SETTING = _TensorQuantConfig(
39
+ num_bits=8,
40
+ symmetric=False,
41
+ granularity=qtyping.QuantGranularity.TENSORWISE,
42
+ )
43
+ DEFAULT_WEIGHT_QUANT_SETTING = _TensorQuantConfig(
44
+ num_bits=8,
45
+ symmetric=True,
46
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
47
+ )
48
+
49
+
50
+ def get_static_activation_quant_setting(
51
+ num_bits: int, symmetric: bool
52
+ ) -> _TensorQuantConfig:
53
+ return _TensorQuantConfig(
54
+ num_bits=num_bits,
55
+ symmetric=symmetric,
56
+ granularity=qtyping.QuantGranularity.TENSORWISE,
57
+ )
58
+
59
+
60
+ def get_static_op_quant_config(
61
+ activation_config: _TensorQuantConfig = DEFAULT_ACTIVATION_QUANT_SETTING,
62
+ weight_config: _TensorQuantConfig = DEFAULT_WEIGHT_QUANT_SETTING,
63
+ ) -> _OpQuantConfig:
64
+ return qtyping.OpQuantizationConfig(
65
+ activation_tensor_config=activation_config,
66
+ weight_tensor_config=weight_config,
67
+ compute_precision=_ComputePrecision.INTEGER,
68
+ )
69
+
27
70
 
28
71
  def get_path_to_datafile(path):
29
72
  """Get the path to the specified file in the data dependencies.
@@ -46,62 +89,152 @@ def get_path_to_datafile(path):
46
89
  return path
47
90
 
48
91
 
49
- def create_random_normal_dataset(
50
- input_details: dict[str, Any],
51
- num_samples: int,
52
- random_seed: Union[int, np._typing.ArrayLike],
53
- ) -> list[dict[str, Any]]:
54
- """create random dataset following random distribution.
55
-
56
- Args:
57
- input_details: list of dict created by
58
- tensorflow.lite.interpreter.get_input_details() for generating dataset
59
- num_samples: number of input samples to be generated
60
- random_seed: random seed to be used for function
61
-
62
- Returns:
63
- a list of inputs to the given interpreter, for a single interpreter we may
64
- have multiple input tensors so each set of inputs is also represented as
65
- list
66
- """
67
- rng = np.random.default_rng(random_seed)
68
- dataset = []
69
- for _ in range(num_samples):
70
- input_data = {}
71
- for arg_name, input_tensor in input_details.items():
72
- new_data = rng.normal(size=input_tensor['shape']).astype(
73
- input_tensor['dtype']
92
+ class BaseOpTestCase(parameterized.TestCase):
93
+ """Base class for op-level tests."""
94
+
95
+ def quantize_and_validate(
96
+ self,
97
+ model_path: str,
98
+ algorithm_key: _AlgorithmName,
99
+ op_name: _OpName,
100
+ op_config: _OpQuantConfig,
101
+ num_validation_samples: int = 4,
102
+ num_calibration_samples: Optional[int] = None,
103
+ error_metric: str = 'mse',
104
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
105
+ ) -> model_validator.ComparisonResult:
106
+ """Quantizes and validates the given model with the given configurations.
107
+
108
+ Args:
109
+ model_path: The path to the model to be quantized.
110
+ algorithm_key: The algorithm to be used for quantization.
111
+ op_name: The name of the operation to be quantized.
112
+ op_config: The configuration for the operation to be quantized.
113
+ num_validation_samples: The number of samples to use for validation.
114
+ num_calibration_samples: The number of samples to use for calibration. If
115
+ None then it will be set to num_validation_samples * 8.
116
+ error_metric: The error error_metric to use for validation.
117
+ min_max_range: The min and max of the input range.
118
+
119
+ Returns:
120
+ The comparison result of the validation.
121
+ """
122
+ quantizer_instance = quantizer.Quantizer(model_path)
123
+ quantizer_instance.update_quantization_recipe(
124
+ algorithm_key=algorithm_key,
125
+ regex='.*',
126
+ operation_name=op_name,
127
+ op_config=op_config,
128
+ )
129
+ if quantizer_instance.need_calibration:
130
+ if num_calibration_samples is None:
131
+ num_calibration_samples = num_validation_samples * 8
132
+ calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
133
+ quantizer_instance.float_model,
134
+ num_samples=num_calibration_samples,
135
+ min_max_range=min_max_range,
136
+ )
137
+ calibration_result = quantizer_instance.calibrate(calibration_data)
138
+ quantization_result = quantizer_instance.quantize(calibration_result)
139
+ else:
140
+ quantization_result = quantizer_instance.quantize()
141
+ test_data = tfl_interpreter_utils.create_random_normal_input_data(
142
+ quantization_result.quantized_model,
143
+ num_samples=num_validation_samples,
144
+ min_max_range=min_max_range,
145
+ )
146
+ return quantizer_instance.validate(test_data, error_metric)
147
+
148
+ def assert_model_size_reduction_above_min_pct(
149
+ self,
150
+ validation_result: model_validator.ComparisonResult,
151
+ min_pct: float,
152
+ ):
153
+ """Checks the model size reduction (percentage) against the given expectation."""
154
+ _, reduction_pct = validation_result.get_model_size_reduction()
155
+ self.assertGreater(reduction_pct, min_pct)
156
+
157
+ def assert_weights_errors_below_tolerance(
158
+ self,
159
+ validation_result: model_validator.ComparisonResult,
160
+ weight_tolerance: float,
161
+ ):
162
+ """Checks the weight tensors' numerical behavior against the given tolerance."""
163
+ self.assertNotEmpty(validation_result.available_signature_keys())
164
+ for signature_key in validation_result.available_signature_keys():
165
+ signature_result = validation_result.get_signature_comparison_result(
166
+ signature_key
167
+ )
168
+ for result in signature_result.constant_tensors.values():
169
+ self.assertLess(result, weight_tolerance)
170
+
171
+ def assert_output_errors_below_tolerance(
172
+ self,
173
+ validation_result: model_validator.ComparisonResult,
174
+ output_tolerance: float,
175
+ ):
176
+ """Checks the output tensor numerical behavior against the given tolerance."""
177
+ self.assertNotEmpty(validation_result.available_signature_keys())
178
+ for signature_key in validation_result.available_signature_keys():
179
+ signature_result = validation_result.get_signature_comparison_result(
180
+ signature_key
181
+ )
182
+ for result in signature_result.output_tensors.values():
183
+ self.assertLess(result, output_tolerance)
184
+
185
+ def assert_quantization_accuracy_and_size(
186
+ self,
187
+ algorithm_key: _AlgorithmName,
188
+ model_path: str,
189
+ op_name: _OpName,
190
+ op_config: _OpQuantConfig,
191
+ expected_model_size_reduction: float,
192
+ weight_tolerance: float = 1e-4,
193
+ output_tolerance: float = 1e-4,
194
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
195
+ ):
196
+ """Check if the quantization is successful and the result is valid."""
197
+ validation_result = self.quantize_and_validate(
198
+ model_path=model_path,
199
+ algorithm_key=algorithm_key,
200
+ op_name=op_name,
201
+ op_config=op_config,
202
+ min_max_range=min_max_range,
203
+ )
204
+ with self.subTest(name='ModelSizeReduction'):
205
+ self.assert_model_size_reduction_above_min_pct(
206
+ validation_result, expected_model_size_reduction
207
+ )
208
+ with self.subTest(name='WeightsErrors'):
209
+ self.assert_weights_errors_below_tolerance(
210
+ validation_result, weight_tolerance
211
+ )
212
+ with self.subTest(name='OutputErrors'):
213
+ self.assert_output_errors_below_tolerance(
214
+ validation_result, output_tolerance
74
215
  )
75
- input_data[arg_name] = new_data
76
- dataset.append(input_data)
77
- return dataset
78
-
79
-
80
- def create_random_normal_input_data(
81
- tflite_model: Union[str, bytes],
82
- num_samples: int = 4,
83
- random_seed: int = 666,
84
- ) -> dict[str, list[dict[str, Any]]]:
85
- """create random dataset following random distribution for signature runner.
86
-
87
- Args:
88
- tflite_model: TFLite model path or bytearray
89
- num_samples: number of input samples to be generated
90
- random_seed: random seed to be used for function
91
216
 
92
- Returns:
93
- a list of inputs to the given interpreter, for a single interpreter we may
94
- have multiple signatures so each set of inputs is also represented as
95
- list
96
- """
97
- tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(tflite_model)
98
- signature_defs = tfl_interpreter.get_signature_list()
99
- signature_keys = list(signature_defs.keys())
100
- test_data = {}
101
- for signature_key in signature_keys:
102
- signature_runner = tfl_interpreter.get_signature_runner(signature_key)
103
- input_details = signature_runner.get_input_details()
104
- test_data[signature_key] = create_random_normal_dataset(
105
- input_details, num_samples, random_seed
217
+ def assert_quantization_accuracy(
218
+ self,
219
+ algorithm_key: _AlgorithmName,
220
+ model_path: str,
221
+ op_name: _OpName,
222
+ op_config: _OpQuantConfig,
223
+ num_validation_samples: int = 4,
224
+ num_calibration_samples: Optional[int] = None,
225
+ output_tolerance: float = 1e-4,
226
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
227
+ ):
228
+ """Checks if the output errors after quantization are within the tolerance."""
229
+ validation_result = self.quantize_and_validate(
230
+ model_path=model_path,
231
+ algorithm_key=algorithm_key,
232
+ num_validation_samples=num_validation_samples,
233
+ num_calibration_samples=num_calibration_samples,
234
+ op_name=op_name,
235
+ op_config=op_config,
236
+ min_max_range=min_max_range,
237
+ )
238
+ self.assert_output_errors_below_tolerance(
239
+ validation_result, output_tolerance
106
240
  )
107
- return test_data
@@ -20,48 +20,66 @@ from typing import Any, Optional, Union
20
20
  import immutabledict
21
21
  import numpy as np
22
22
 
23
+ from ai_edge_litert.tools import flatbuffer_utils
23
24
  from ai_edge_quantizer import qtyping
24
- from ai_edge_litert import schema_py_generated # pylint:disable=g-direct-tensorflow-import
25
- from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
26
- from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
25
+ from ai_edge_litert import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
26
+ import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
27
27
 
28
28
  _TFLOpName = qtyping.TFLOperationName
29
29
 
30
30
  TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
31
- _TFLOpName.FULLY_CONNECTED: (
32
- schema_py_generated.BuiltinOperator.FULLY_CONNECTED
31
+ _TFLOpName.FULLY_CONNECTED: schema.BuiltinOperator.FULLY_CONNECTED,
32
+ _TFLOpName.BATCH_MATMUL: schema.BuiltinOperator.BATCH_MATMUL,
33
+ _TFLOpName.CONV_2D: schema.BuiltinOperator.CONV_2D,
34
+ _TFLOpName.DEPTHWISE_CONV_2D: schema.BuiltinOperator.DEPTHWISE_CONV_2D,
35
+ _TFLOpName.CONV_2D_TRANSPOSE: schema.BuiltinOperator.TRANSPOSE_CONV,
36
+ _TFLOpName.EMBEDDING_LOOKUP: schema.BuiltinOperator.EMBEDDING_LOOKUP,
37
+ _TFLOpName.SOFTMAX: schema.BuiltinOperator.SOFTMAX,
38
+ _TFLOpName.AVERAGE_POOL_2D: schema.BuiltinOperator.AVERAGE_POOL_2D,
39
+ _TFLOpName.RESHAPE: schema.BuiltinOperator.RESHAPE,
40
+ _TFLOpName.TANH: schema.BuiltinOperator.TANH,
41
+ _TFLOpName.TRANSPOSE: schema.BuiltinOperator.TRANSPOSE,
42
+ _TFLOpName.GELU: schema.BuiltinOperator.GELU,
43
+ _TFLOpName.ADD: schema.BuiltinOperator.ADD,
44
+ _TFLOpName.SUB: schema.BuiltinOperator.SUB,
45
+ _TFLOpName.MUL: schema.BuiltinOperator.MUL,
46
+ _TFLOpName.MEAN: schema.BuiltinOperator.MEAN,
47
+ _TFLOpName.RSQRT: schema.BuiltinOperator.RSQRT,
48
+ _TFLOpName.CONCATENATION: schema.BuiltinOperator.CONCATENATION,
49
+ _TFLOpName.STRIDED_SLICE: schema.BuiltinOperator.STRIDED_SLICE,
50
+ _TFLOpName.SPLIT: schema.BuiltinOperator.SPLIT,
51
+ _TFLOpName.LOGISTIC: schema.BuiltinOperator.LOGISTIC,
52
+ _TFLOpName.SLICE: schema.BuiltinOperator.SLICE,
53
+ _TFLOpName.SUM: schema.BuiltinOperator.SUM,
54
+ _TFLOpName.SELECT: schema.BuiltinOperator.SELECT,
55
+ _TFLOpName.SELECT_V2: schema.BuiltinOperator.SELECT_V2,
56
+ _TFLOpName.STABLEHLO_COMPOSITE: schema.BuiltinOperator.STABLEHLO_COMPOSITE,
57
+ _TFLOpName.DYNAMIC_UPDATE_SLICE: (
58
+ schema.BuiltinOperator.DYNAMIC_UPDATE_SLICE
33
59
  ),
34
- _TFLOpName.BATCH_MATMUL: schema_py_generated.BuiltinOperator.BATCH_MATMUL,
35
- _TFLOpName.CONV_2D: schema_py_generated.BuiltinOperator.CONV_2D,
36
- _TFLOpName.DEPTHWISE_CONV_2D: (
37
- schema_py_generated.BuiltinOperator.DEPTHWISE_CONV_2D
60
+ _TFLOpName.PAD: schema.BuiltinOperator.PAD,
61
+ _TFLOpName.SQUARED_DIFFERENCE: schema.BuiltinOperator.SQUARED_DIFFERENCE,
62
+ _TFLOpName.MAX_POOL_2D: schema.BuiltinOperator.MAX_POOL_2D,
63
+ _TFLOpName.RESIZE_BILINEAR: schema.BuiltinOperator.RESIZE_BILINEAR,
64
+ _TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
65
+ schema.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR
38
66
  ),
39
- _TFLOpName.CONV_2D_TRANSPOSE: (
40
- schema_py_generated.BuiltinOperator.TRANSPOSE_CONV
41
- ),
42
- _TFLOpName.EMBEDDING_LOOKUP: (
43
- schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
44
- ),
45
- _TFLOpName.SOFTMAX: schema_py_generated.BuiltinOperator.SOFTMAX,
46
- _TFLOpName.AVERAGE_POOL_2D: (
47
- schema_py_generated.BuiltinOperator.AVERAGE_POOL_2D
48
- ),
49
- _TFLOpName.RESHAPE: schema_py_generated.BuiltinOperator.RESHAPE,
50
- _TFLOpName.TANH: schema_py_generated.BuiltinOperator.TANH,
51
- _TFLOpName.TRANSPOSE: schema_py_generated.BuiltinOperator.TRANSPOSE,
52
- _TFLOpName.GELU: schema_py_generated.BuiltinOperator.GELU,
53
- _TFLOpName.ADD: schema_py_generated.BuiltinOperator.ADD,
54
- _TFLOpName.SUB: schema_py_generated.BuiltinOperator.SUB,
55
- _TFLOpName.MUL: schema_py_generated.BuiltinOperator.MUL,
56
- _TFLOpName.MEAN: schema_py_generated.BuiltinOperator.MEAN,
57
- _TFLOpName.RSQRT: schema_py_generated.BuiltinOperator.RSQRT,
58
- _TFLOpName.CONCATENATION: schema_py_generated.BuiltinOperator.CONCATENATION,
59
- _TFLOpName.STRIDED_SLICE: schema_py_generated.BuiltinOperator.STRIDED_SLICE,
60
- _TFLOpName.SPLIT: schema_py_generated.BuiltinOperator.SPLIT,
61
- _TFLOpName.LOGISTIC: schema_py_generated.BuiltinOperator.LOGISTIC,
62
- _TFLOpName.SLICE: schema_py_generated.BuiltinOperator.SLICE,
63
- _TFLOpName.SUM: schema_py_generated.BuiltinOperator.SUM,
64
- _TFLOpName.SELECT_V2: schema_py_generated.BuiltinOperator.SELECT_V2,
67
+ _TFLOpName.GATHER_ND: schema.BuiltinOperator.GATHER_ND,
68
+ _TFLOpName.PACK: schema.BuiltinOperator.PACK,
69
+ _TFLOpName.UNPACK: schema.BuiltinOperator.UNPACK,
70
+ _TFLOpName.DIV: schema.BuiltinOperator.DIV,
71
+ _TFLOpName.BROADCAST_TO: schema.BuiltinOperator.BROADCAST_TO,
72
+ _TFLOpName.SQRT: schema.BuiltinOperator.SQRT,
73
+ _TFLOpName.GATHER: schema.BuiltinOperator.GATHER,
74
+ _TFLOpName.HARD_SWISH: schema.BuiltinOperator.HARD_SWISH,
75
+ _TFLOpName.MAXIMUM: schema.BuiltinOperator.MAXIMUM,
76
+ _TFLOpName.PADV2: schema.BuiltinOperator.PADV2,
77
+ _TFLOpName.REDUCE_MIN: schema.BuiltinOperator.REDUCE_MIN,
78
+ _TFLOpName.EQUAL: schema.BuiltinOperator.EQUAL,
79
+ _TFLOpName.NOT_EQUAL: schema.BuiltinOperator.NOT_EQUAL,
80
+ _TFLOpName.MIRROR_PAD: schema.BuiltinOperator.MIRROR_PAD,
81
+ _TFLOpName.SPACE_TO_DEPTH: schema.BuiltinOperator.SPACE_TO_DEPTH,
82
+ _TFLOpName.RELU: schema.BuiltinOperator.RELU,
65
83
  })
66
84
 
67
85
  TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
@@ -78,6 +96,11 @@ TFL_OP_TO_WEIGHT_QUANTIZED_DIM = immutabledict.immutabledict({
78
96
  _TFLOpName.CONV_2D_TRANSPOSE: 0,
79
97
  })
80
98
 
99
+ TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM = immutabledict.immutabledict({
100
+ _TFLOpName.FULLY_CONNECTED: 1,
101
+ _TFLOpName.EMBEDDING_LOOKUP: 1,
102
+ })
103
+
81
104
  NUM_TFL_DATATYPES = 18
82
105
  TENSOR_CODE_TO_TYPE = {}
83
106
  for dtype_code in range(NUM_TFL_DATATYPES):
@@ -87,7 +110,7 @@ TENSOR_TYPE_TO_CODE = immutabledict.immutabledict(
87
110
  (reversed(item) for item in TENSOR_CODE_TO_TYPE.items())
88
111
  )
89
112
 
90
- # Expose functions in tensorflow.lite.tools.flatbuffer_utils
113
+ # Expose functions in litert.python.tools.flatbuffer_utils
91
114
  write_model = flatbuffer_utils.write_model
92
115
 
93
116
 
@@ -122,7 +145,7 @@ def get_model_content(tflite_path: str) -> bytes:
122
145
  Returns:
123
146
  The model bytes.
124
147
  """
125
- with gfile.Open(tflite_path, "rb") as tflite_file:
148
+ with open(tflite_path, "rb") as tflite_file:
126
149
  return tflite_file.read()
127
150
 
128
151
 
@@ -135,7 +158,7 @@ def get_model_buffer(tflite_path: str) -> bytearray:
135
158
  Returns:
136
159
  model_buffer: the model buffer.
137
160
  """
138
- with gfile.Open(tflite_path, "rb") as tflite_file:
161
+ with open(tflite_path, "rb") as tflite_file:
139
162
  return bytearray(tflite_file.read())
140
163
 
141
164
 
@@ -188,25 +211,18 @@ def parse_fc_bmm_conv_tensors(
188
211
  return input_tensor, weight_tensor, bias_tensor, output_tensor
189
212
 
190
213
 
191
- # flatbuffer_model has Any type since tensorflow.lite.tools.flatbuffer_utils
214
+ # flatbuffer_model has Any type since litert.python.tools.flatbuffer_utils
192
215
  # is not type annotated.
193
216
  def buffer_to_tensors(flatbuffer_model: Any) -> dict[int, list[Any]]:
194
- """Get the buffer to tensor map for a tflite model.
195
-
196
- Args:
197
- flatbuffer_model: the flatbuffer_model.
198
-
199
- Returns:
200
- buffer_to_tensor_map: key as buffer index, value as list of tensors share
201
- the buffer
202
- """
217
+ """Returns a map from buffer id to tensors that use it."""
203
218
  buffer_to_tensor_map = {}
204
219
  for subgraph in flatbuffer_model.subgraphs:
205
220
  for op in subgraph.operators:
206
221
  for tensor in parse_op_tensors(op, subgraph.tensors):
207
222
  if tensor.buffer not in buffer_to_tensor_map:
208
223
  buffer_to_tensor_map[tensor.buffer] = []
209
- buffer_to_tensor_map[tensor.buffer].append(tensor)
224
+ if tensor not in buffer_to_tensor_map[tensor.buffer]:
225
+ buffer_to_tensor_map[tensor.buffer].append(tensor)
210
226
  return buffer_to_tensor_map
211
227
 
212
228
 
@@ -239,7 +255,8 @@ def get_tensor_data(tensor: Any, buffers: list[Any]) -> Optional[np.ndarray]:
239
255
  data = np.frombuffer(
240
256
  buffer_data, dtype=TENSOR_CODE_TO_TYPE[tensor.type].lower()
241
257
  )
242
- data = np.reshape(data, tensor.shape)
258
+ if tensor.shape is not None:
259
+ data = np.reshape(data, tensor.shape)
243
260
  return data
244
261
 
245
262
 
@@ -315,3 +332,32 @@ def get_subgraph_input_output_operators(
315
332
  op_key=qtyping.TFLOperationName.OUTPUT,
316
333
  )
317
334
  return [input_op, output_op]
335
+
336
+
337
+ def get_op_side_effect_subgraphs(
338
+ op: Union[schema.Operator, schema.OperatorT],
339
+ ) -> list[int]:
340
+ """Get indices of any subgraphs invoked as a side effect of the operator.
341
+
342
+ Args:
343
+ op: The operator object.
344
+
345
+ Returns:
346
+ A list of subgraph indices invoked by the operator. Empty if the operator
347
+ does not invoke any subgraphs.
348
+ """
349
+ if opts := flatbuffer_utils.get_options_as(
350
+ op, schema.StableHLOCompositeOptionsT
351
+ ):
352
+ return [opts.decompositionSubgraphIndex]
353
+ # Can add other nested ops here (control flow ops, etc).
354
+ return []
355
+
356
+
357
+ def get_op_name_by_index(
358
+ flatbuffer_model: Any, subgraph_id: int, op_index: int
359
+ ) -> str:
360
+ """Get the op name from the flatbuffer model."""
361
+ op = flatbuffer_model.subgraphs[subgraph_id].operators[op_index]
362
+ builtin_code = flatbuffer_model.operatorCodes[op.opcodeIndex].builtinCode
363
+ return TFL_OP_CODE_TO_NAME[builtin_code]
@@ -105,6 +105,26 @@ class FlatbufferUtilsTest(googletest.TestCase):
105
105
  conv2d_filter_tensor = tensors[0]
106
106
  self.assertEqual(tuple(conv2d_filter_tensor.shape), (8, 3, 3, 1))
107
107
 
108
+ def test_buffer_to_tensors_has_unique_values(self):
109
+ test_model_path = os.path.join(
110
+ TEST_DATA_PREFIX_PATH,
111
+ "constant_tensor_and_buffer_only_sharing_weight_fcs.tflite",
112
+ )
113
+ test_model = tfl_flatbuffer_utils.read_model(test_model_path)
114
+ buffer_to_tensor_map = tfl_flatbuffer_utils.buffer_to_tensors(test_model)
115
+ self.assertLen(buffer_to_tensor_map, 7)
116
+ # The following buffer is shared by two tensors, each shared by two FC ops.
117
+ # This is where before we had multiple enrties for the same tensor.
118
+ self.assertLen(buffer_to_tensor_map[2], 2)
119
+ got_tensor_names = [
120
+ tfl_flatbuffer_utils.get_tensor_name(tensor)
121
+ for tensor in buffer_to_tensor_map[2]
122
+ ]
123
+ self.assertEqual(
124
+ got_tensor_names,
125
+ ["arith.constant", "arith.constant1"],
126
+ )
127
+
108
128
  def test_get_tensor_name(self):
109
129
  subgraph0 = self._test_model.subgraphs[0]
110
130
  subgraph_tensors = subgraph0.tensors