ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,10 @@
18
18
  import collections
19
19
  import copy
20
20
  import json
21
- from typing import Any
21
+ from typing import Any, Union
22
+ from ai_edge_litert.tools import flatbuffer_utils
22
23
  from ai_edge_quantizer import qtyping
24
+ from ai_edge_litert import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
23
25
 
24
26
  _TFLOpName = qtyping.TFLOperationName
25
27
  _OpQuantizationConfig = qtyping.OpQuantizationConfig
@@ -55,6 +57,16 @@ DEFAULT_JSON_POLICY = """
55
57
  "explicit_dequantize": false,
56
58
  "compute_precision": "INTEGER"
57
59
  },
60
+ "dynamic_wi4_afp32_blockwise": {
61
+ "weight_tensor_config": {
62
+ "num_bits": 4,
63
+ "symmetric": [true],
64
+ "granularity": ["BLOCKWISE_32", "BLOCKWISE_64", "BLOCKWISE_128", "BLOCKWISE_256"],
65
+ "dtype": "INT"
66
+ },
67
+ "explicit_dequantize": false,
68
+ "compute_precision": "INTEGER"
69
+ },
58
70
  "static_wi8_ai16": {
59
71
  "activation_tensor_config": {
60
72
  "num_bits": 16,
@@ -165,9 +177,30 @@ DEFAULT_JSON_POLICY = """
165
177
  "INPUT",
166
178
  "OUTPUT",
167
179
  "SLICE",
168
- "EMBEDDING_LOOKUP",
169
180
  "SUM",
170
- "SELECT_V2"
181
+ "SELECT",
182
+ "SELECT_V2",
183
+ "DYNAMIC_UPDATE_SLICE",
184
+ "SELECT_V2",
185
+ "STABLEHLO_COMPOSITE",
186
+ "PAD",
187
+ "MAX_POOL_2D",
188
+ "RESIZE_BILINEAR",
189
+ "RESIZE_NEAREST_NEIGHBOR",
190
+ "GATHER_ND",
191
+ "PACK",
192
+ "UNPACK",
193
+ "DIV",
194
+ "BROADCAST_TO",
195
+ "SQRT",
196
+ "GATHER",
197
+ "MAXIMUM",
198
+ "PADV2",
199
+ "REDUCE_MIN",
200
+ "EQUAL",
201
+ "NOT_EQUAL",
202
+ "MIRROR_PAD",
203
+ "RELU"
171
204
  ],
172
205
  "static_wi8_ai8": [
173
206
  "ADD",
@@ -193,12 +226,36 @@ DEFAULT_JSON_POLICY = """
193
226
  "INPUT",
194
227
  "OUTPUT",
195
228
  "SLICE",
196
- "EMBEDDING_LOOKUP",
197
229
  "SUM",
198
- "SELECT_V2"
230
+ "SELECT",
231
+ "SELECT_V2",
232
+ "DYNAMIC_UPDATE_SLICE",
233
+ "SELECT_V2",
234
+ "STABLEHLO_COMPOSITE",
235
+ "PAD",
236
+ "SQUARED_DIFFERENCE",
237
+ "MAX_POOL_2D",
238
+ "RESIZE_BILINEAR",
239
+ "RESIZE_NEAREST_NEIGHBOR",
240
+ "GATHER_ND",
241
+ "PACK",
242
+ "UNPACK",
243
+ "DIV",
244
+ "BROADCAST_TO",
245
+ "SQRT",
246
+ "GATHER",
247
+ "HARD_SWISH",
248
+ "MAXIMUM",
249
+ "PADV2",
250
+ "REDUCE_MIN",
251
+ "EQUAL",
252
+ "NOT_EQUAL",
253
+ "MIRROR_PAD",
254
+ "SPACE_TO_DEPTH",
255
+ "RELU"
199
256
  ],
200
- "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
201
- "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
257
+ "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
258
+ "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
202
259
  "dynamic_wi8_afp32": [
203
260
  "BATCH_MATMUL",
204
261
  "CONV_2D",
@@ -208,6 +265,7 @@ DEFAULT_JSON_POLICY = """
208
265
  "FULLY_CONNECTED"
209
266
  ],
210
267
  "dynamic_wi4_afp32": ["FULLY_CONNECTED", "EMBEDDING_LOOKUP", "CONV_2D"],
268
+ "dynamic_wi4_afp32_blockwise": ["EMBEDDING_LOOKUP", "FULLY_CONNECTED"],
211
269
  "weightonly_wi8_afp32": [
212
270
  "BATCH_MATMUL",
213
271
  "CONV_2D",
@@ -220,6 +278,11 @@ DEFAULT_JSON_POLICY = """
220
278
  }
221
279
  }
222
280
  """
281
+ QUANTIZABLE_COMPOSITES = [
282
+ "od" + "ml.npu_call",
283
+ "od" + "ml.rms_norm",
284
+ "od" + "ml.l2_norm",
285
+ ]
223
286
 
224
287
 
225
288
  def _unroll_json_config(
@@ -251,6 +314,7 @@ def _unroll_json_config(
251
314
 
252
315
  # Then unroll weight configs and turn them into quantization configs.
253
316
  quant_configs = []
317
+ weight_configs = []
254
318
  for symmetric in json_config["weight_tensor_config"]["symmetric"]:
255
319
  for granularity in json_config["weight_tensor_config"]["granularity"]:
256
320
  tensor_config = {
@@ -259,6 +323,9 @@ def _unroll_json_config(
259
323
  "granularity": granularity,
260
324
  "dtype": json_config["weight_tensor_config"]["dtype"],
261
325
  }
326
+ weight_configs.append(
327
+ qtyping.TensorQuantizationConfig.from_dict(tensor_config)
328
+ )
262
329
 
263
330
  if activation_configs:
264
331
  for activation_config in activation_configs:
@@ -273,19 +340,44 @@ def _unroll_json_config(
273
340
  )
274
341
  )
275
342
  else:
276
- quant_configs.append(
277
- qtyping.OpQuantizationConfig(
278
- weight_tensor_config=qtyping.TensorQuantizationConfig.from_dict(
279
- tensor_config
280
- ),
281
- compute_precision=json_config["compute_precision"],
282
- explicit_dequantize=json_config["explicit_dequantize"],
283
- )
284
- )
343
+ for weight_config in weight_configs:
344
+ quant_configs.append(
345
+ qtyping.OpQuantizationConfig(
346
+ weight_tensor_config=weight_config,
347
+ compute_precision=json_config["compute_precision"],
348
+ explicit_dequantize=json_config["explicit_dequantize"],
349
+ )
350
+ )
285
351
 
286
352
  return quant_configs
287
353
 
288
354
 
355
+ # TODO: b/401024954 - Have a better way to specify recipes based on op options.
356
+ def is_non_quantizable_composite_op(
357
+ op: Union[schema.Operator, schema.OperatorT],
358
+ ) -> bool:
359
+ """Checks if the operator is a non-quantizable composite op.
360
+
361
+ We may want to quantize an op only when its has certain options.
362
+ Policies/recipes
363
+ are not aware of op options, so it is checked here.
364
+
365
+ Args:
366
+ op: The operator to check.
367
+
368
+ Returns:
369
+ True if the operator is conditionally unquantized, False otherwise.
370
+ """
371
+ if opts := flatbuffer_utils.get_options_as(
372
+ op, schema.StableHLOCompositeOptionsT
373
+ ):
374
+ name = opts.name.decode("utf-8")
375
+ if name not in QUANTIZABLE_COMPOSITES:
376
+ return True
377
+
378
+ return False
379
+
380
+
289
381
  def update_default_config_policy(raw_json_policy: str):
290
382
  """Updates the default config check policy."""
291
383
  json_policy_content = json.loads(raw_json_policy)
@@ -15,15 +15,23 @@
15
15
 
16
16
  """Model Modifier class that produce the final quantized TFlite model."""
17
17
 
18
+ from collections.abc import Sequence
18
19
  import copy
20
+ import logging
19
21
 
20
22
  import numpy as np
21
23
 
24
+ from ai_edge_litert.tools import flatbuffer_utils
22
25
  from ai_edge_quantizer import qtyping
23
26
  from ai_edge_quantizer import transformation_instruction_generator
24
27
  from ai_edge_quantizer import transformation_performer
28
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
29
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
30
+ from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
25
31
  from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
26
- from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
32
+
33
+
34
+ _DEQUANT_SUFFIX = "_dequant"
27
35
 
28
36
 
29
37
  class ModelModifier:
@@ -46,6 +54,35 @@ class ModelModifier:
46
54
  transformation_performer.TransformationPerformer()
47
55
  )
48
56
 
57
+ def _get_tensor_processing_order(
58
+ self,
59
+ tensor_names: Sequence[str],
60
+ flatbuffer_model: schema_py_generated.ModelT,
61
+ ) -> list[str]:
62
+ """Get the tensor processing order obtained from `buffer_to_tensors`.
63
+
64
+ The processing order is used to ensure that last tensor in a buffer is
65
+ processed the last. This is required for the correctness of buffer
66
+ duplication, as the last tensor in a buffer won't be duplicated.
67
+
68
+ Args:
69
+ tensor_names: Names of the tensors that need to be processed.
70
+ flatbuffer_model: TFlite model.
71
+
72
+ Returns:
73
+ A list of tensor names in the processing order.
74
+ """
75
+ buffer_to_tensors = tfl_flatbuffer_utils.buffer_to_tensors(flatbuffer_model)
76
+
77
+ processing_order = []
78
+ for buffer_tensors in buffer_to_tensors.values():
79
+ for tensor in buffer_tensors:
80
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
81
+ if tensor_name in tensor_names:
82
+ processing_order.append(tensor_name)
83
+
84
+ return processing_order
85
+
49
86
  def modify_model(
50
87
  self, params: dict[str, qtyping.TensorTransformationParams]
51
88
  ) -> bytearray:
@@ -66,15 +103,102 @@ class ModelModifier:
66
103
  params, quantized_model
67
104
  )
68
105
 
106
+ tensor_processing_order = self._get_tensor_processing_order(
107
+ list(instructions.keys()), quantized_model
108
+ )
69
109
  self._transformation_performer.transform_graph(
70
- instructions, quantized_model
110
+ instructions, quantized_model, tensor_processing_order
71
111
  )
72
112
  constant_buffer_size = self._process_constant_map(quantized_model)
73
- # we leave 64MB for the model architecture.
74
- if constant_buffer_size > 2**31 - 2**26:
75
- return self._serialize_large_model(quantized_model)
76
- else:
77
- return self._serialize_small_model(quantized_model)
113
+ # we leave 256MB for the model architecture.
114
+ serialize_fun = (
115
+ self._serialize_large_model
116
+ if constant_buffer_size > 2**31 - 2**28
117
+ else self._serialize_small_model
118
+ )
119
+ serialized_quantized_model = serialize_fun(quantized_model)
120
+
121
+ # Update signature defs if dequant is inserted before output.
122
+ if self._has_dequant_before_output(instructions):
123
+ quantized_model = self._update_signature_defs_for_dequant_output(
124
+ quantized_model, serialized_quantized_model
125
+ )
126
+ serialized_quantized_model = serialize_fun(quantized_model)
127
+
128
+ return serialized_quantized_model
129
+
130
+ def _update_signature_defs_for_dequant_output(
131
+ self, model: schema_py_generated.ModelT, serialized_model: bytearray
132
+ ):
133
+ """Updates the signature definitions in the model.
134
+
135
+ This function is called when a dequantize operation is inserted before
136
+ an output tensor. It updates the tensor index in the signature
137
+ definitions to point to the newly inserted dequantize output tensor.
138
+
139
+ Args:
140
+ model: The TFlite ModelT object.
141
+ serialized_model: The serialized bytearray of the TFlite model.
142
+
143
+ Returns:
144
+ The updated TFlite ModelT object.
145
+ """
146
+ interpreter = tfl.Interpreter(model_content=bytes(serialized_model))
147
+
148
+ for signature_def in model.signatureDefs:
149
+ signature_key = signature_def.signatureKey.decode("utf-8")
150
+ logging.info("Signature = %s", signature_key)
151
+ subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
152
+ interpreter, signature_key
153
+ )
154
+ output_details = interpreter.get_signature_runner(
155
+ signature_key
156
+ ).get_output_details()
157
+ subgraph = model.subgraphs[subgraph_idx]
158
+ graph_info = qtyping.GraphInfo(subgraph.tensors, model.buffers)
159
+
160
+ for output in subgraph.outputs:
161
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(
162
+ graph_info.subgraph_tensors[output]
163
+ )
164
+ logging.info("\tOutput tensor = `%s`", tensor_name)
165
+
166
+ for signature_name, tensor_details in output_details.items():
167
+ if tensor_details["name"] + _DEQUANT_SUFFIX == tensor_name:
168
+ logging.info(
169
+ "\t\tfound tensor mapping: `%s`->`%s` for signature name: `%s`",
170
+ tensor_details["name"],
171
+ tensor_name,
172
+ signature_name,
173
+ )
174
+ for signature_item in signature_def.outputs:
175
+ if signature_item.name.decode("utf-8") == signature_name:
176
+ signature_item.tensorIndex = output
177
+ logging.info(
178
+ "\t\t\tswapped tensor index: %s->%s",
179
+ tensor_details["index"],
180
+ output,
181
+ )
182
+ break
183
+ break
184
+
185
+ return model
186
+
187
+ def _has_dequant_before_output(
188
+ self, instructions: dict[str, qtyping.TensorTransformationInsts]
189
+ ) -> bool:
190
+ """Check if the model has dequant insert to output."""
191
+ for tensor_name, tensor_trans_insts in instructions.items():
192
+ for instr in tensor_trans_insts.instructions:
193
+ if (
194
+ qtyping.QuantTransformation.ADD_DEQUANTIZE == instr.transformation
195
+ and instr.consumers == [-1]
196
+ ):
197
+ logging.info(
198
+ "Found dequant insert to output for tensor: %s", tensor_name
199
+ )
200
+ return True
201
+ return False
78
202
 
79
203
  def _process_constant_map(
80
204
  self, quantized_model: schema_py_generated.ModelT
@@ -108,7 +232,7 @@ class ModelModifier:
108
232
  remainder = len(bytearr) % 16
109
233
  if remainder != 0:
110
234
  padding_size = 16 - remainder
111
- bytearr.extend(b'\0' * padding_size)
235
+ bytearr.extend(b"\0" * padding_size)
112
236
 
113
237
  # TODO: b/333797307 - support > 2GB output model
114
238
  def _serialize_large_model(
@@ -19,13 +19,13 @@ import os
19
19
  import tracemalloc
20
20
  from tensorflow.python.platform import googletest
21
21
  from absl.testing import parameterized
22
+ from ai_edge_litert.tools import flatbuffer_utils
22
23
  from ai_edge_quantizer import model_modifier
23
24
  from ai_edge_quantizer import params_generator
24
25
  from ai_edge_quantizer import qtyping
25
26
  from ai_edge_quantizer import recipe_manager
26
27
  from ai_edge_quantizer.utils import test_utils
27
28
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
- from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
29
29
 
30
30
  TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
31
31
 
@@ -125,6 +125,86 @@ class ModelModifierTest(parameterized.TestCase):
125
125
  loosen_mem_use_factor = 4.5
126
126
  self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor)
127
127
 
128
+ def test_has_dequant_before_output_true(self):
129
+ instructions = {
130
+ 'tensor1': qtyping.TensorTransformationInsts(
131
+ 'tensor1',
132
+ 0,
133
+ instructions=[
134
+ qtyping.TransformationInst(
135
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
136
+ tensor_id=0,
137
+ producer=0,
138
+ consumers=[-1],
139
+ )
140
+ ],
141
+ )
142
+ }
143
+ self.assertTrue(
144
+ self._model_modifier._has_dequant_before_output(instructions)
145
+ )
146
+
147
+ def test_has_dequant_before_output_false(self):
148
+ instructions = {
149
+ 'tensor1': qtyping.TensorTransformationInsts(
150
+ 'tensor1',
151
+ 0,
152
+ instructions=[
153
+ qtyping.TransformationInst(
154
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
155
+ tensor_id=0,
156
+ producer=0,
157
+ consumers=[1],
158
+ )
159
+ ],
160
+ )
161
+ }
162
+ self.assertFalse(
163
+ self._model_modifier._has_dequant_before_output(instructions)
164
+ )
165
+
166
+ def test_pad_bytearray(self):
167
+ arr = bytearray(b'\x01\x02\x03')
168
+ self._model_modifier._pad_bytearray(arr)
169
+ self.assertLen(arr, 16)
170
+ self.assertEqual(arr, b'\x01\x02\x03' + b'\0' * 13)
171
+
172
+ arr = bytearray(b'\x01' * 16)
173
+ self._model_modifier._pad_bytearray(arr)
174
+ self.assertLen(arr, 16)
175
+
176
+ arr = bytearray(b'\x01' * 17)
177
+ self._model_modifier._pad_bytearray(arr)
178
+ self.assertLen(arr, 32)
179
+
180
+
181
+ class ModelModifierTestWithSignature(parameterized.TestCase):
182
+
183
+ def setUp(self):
184
+ super().setUp()
185
+ self._model_path = os.path.join(
186
+ TEST_DATA_PREFIX_PATH,
187
+ 'tests/models/single_fc.tflite',
188
+ )
189
+ self._model_content: bytes = tfl_flatbuffer_utils.get_model_content(
190
+ self._model_path
191
+ )
192
+ self._model_modifier = model_modifier.ModelModifier(self._model_content)
193
+
194
+ def test_update_signature_defs_for_dequant_output_succeeds(self):
195
+ # This is a simplified test that only checks if the function runs without
196
+ # crashing and returns a model. A more thorough test with a model
197
+ # with a known signature was added in `quantizer_test`.
198
+ model_bytearray = flatbuffer_utils.read_model_from_bytearray(
199
+ self._model_content
200
+ )
201
+ updated_model = (
202
+ self._model_modifier._update_signature_defs_for_dequant_output(
203
+ model_bytearray, bytearray(self._model_content)
204
+ )
205
+ )
206
+ self.assertIsNotNone(updated_model)
207
+
128
208
 
129
209
  if __name__ == '__main__':
130
210
  googletest.main()
@@ -25,7 +25,7 @@ from typing import Any, Optional, Union
25
25
  import numpy as np
26
26
 
27
27
  from ai_edge_quantizer.utils import tfl_interpreter_utils as utils
28
- from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
28
+ import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
29
29
 
30
30
 
31
31
  _DEFAULT_SIGNATURE_KEY = utils.DEFAULT_SIGNATURE_KEY
@@ -118,7 +118,8 @@ class ComparisonResult:
118
118
  for name in utils.get_input_tensor_names(
119
119
  self._reference_model, signature_key
120
120
  ):
121
- input_tensor_results[name] = result.pop(name)
121
+ if name in result:
122
+ input_tensor_results[name] = result.pop(name)
122
123
 
123
124
  output_tensor_results = {}
124
125
  for name in utils.get_output_tensor_names(
@@ -136,7 +137,8 @@ class ComparisonResult:
136
137
  self._reference_model,
137
138
  subgraph_index,
138
139
  ):
139
- constant_tensor_results[name] = result.pop(name)
140
+ if name in result:
141
+ constant_tensor_results[name] = result.pop(name)
140
142
 
141
143
  self._comparison_results[signature_key] = SingleSignatureComparisonResult(
142
144
  error_metric=error_metric,
@@ -160,6 +162,12 @@ class ComparisonResult:
160
162
  result.update(signature_comparison_result.intermediate_tensors)
161
163
  return result
162
164
 
165
+ def get_model_size_reduction(self) -> tuple[int, float]:
166
+ """Get the model size reduction in bytes and percentage."""
167
+ reduced_model_size = len(self._reference_model) - len(self._target_model)
168
+ reduction_perc = reduced_model_size / len(self._reference_model) * 100
169
+ return reduced_model_size, reduction_perc
170
+
163
171
  def save(self, save_folder: str, model_name: str) -> None:
164
172
  """Saves the model comparison result.
165
173
 
@@ -170,8 +178,7 @@ class ComparisonResult:
170
178
  Raises:
171
179
  RuntimeError: If no quantized model is available.
172
180
  """
173
- reduced_model_size = len(self._reference_model) - len(self._target_model)
174
- reduction_ratio = reduced_model_size / len(self._reference_model) * 100
181
+ reduced_model_size, reduction_ratio = self.get_model_size_reduction()
175
182
  result = {
176
183
  'reduced_size_bytes': reduced_model_size,
177
184
  'reduced_size_percentage': reduction_ratio,
@@ -187,7 +194,7 @@ class ComparisonResult:
187
194
  result_save_path = os.path.join(
188
195
  save_folder, model_name + '_comparison_result.json'
189
196
  )
190
- with gfile.GFile(result_save_path, 'w') as output_file_handle:
197
+ with open(result_save_path, 'w') as output_file_handle:
191
198
  output_file_handle.write(json.dumps(result))
192
199
 
193
200
  # TODO: b/365578554 - Remove after ME is updated to use the new json format.
@@ -199,7 +206,7 @@ class ComparisonResult:
199
206
  json_save_path = os.path.join(
200
207
  save_folder, model_name + '_comparison_result_me_input.json'
201
208
  )
202
- with gfile.GFile(json_save_path, 'w') as output_file_handle:
209
+ with open(json_save_path, 'w') as output_file_handle:
203
210
  output_file_handle.write(json_object)
204
211
 
205
212
 
@@ -209,6 +216,7 @@ def _setup_validation_interpreter(
209
216
  signature_key: Optional[str],
210
217
  use_xnnpack: bool,
211
218
  num_threads: int,
219
+ preserve_all_tensors: bool = True,
212
220
  ) -> tuple[Any, int, dict[str, Any]]:
213
221
  """Setup the interpreter for validation given a signature key.
214
222
 
@@ -219,13 +227,17 @@ def _setup_validation_interpreter(
219
227
  model only has one signature, this can be set to None.
220
228
  use_xnnpack: Whether to use xnnpack for the interpreter.
221
229
  num_threads: The number of threads to use for the interpreter.
230
+ preserve_all_tensors: Whether to preserve all tensors.
222
231
 
223
232
  Returns:
224
233
  A tuple of interpreter, subgraph_index and tensor_name_to_details.
225
234
  """
226
235
 
227
236
  interpreter = utils.create_tfl_interpreter(
228
- tflite_model=model, use_xnnpack=use_xnnpack, num_threads=num_threads
237
+ tflite_model=model,
238
+ use_xnnpack=use_xnnpack,
239
+ num_threads=num_threads,
240
+ preserve_all_tensors=preserve_all_tensors,
229
241
  )
230
242
  utils.invoke_interpreter_signature(
231
243
  interpreter, signature_input, signature_key
@@ -250,6 +262,7 @@ def compare_model(
250
262
  compare_fn: Callable[[Any, Any], float],
251
263
  use_xnnpack: bool = True,
252
264
  num_threads: int = 16,
265
+ validate_output_tensors_only: bool = False,
253
266
  ) -> ComparisonResult:
254
267
  """Compares model tensors over a model signature using the compare_fn.
255
268
 
@@ -270,10 +283,13 @@ def compare_model(
270
283
  single float value.
271
284
  use_xnnpack: Whether to use xnnpack for the interpreter.
272
285
  num_threads: The number of threads to use for the interpreter.
286
+ validate_output_tensors_only: If True, only compare output tensors.
287
+ Otherwise, compare all tensors.
273
288
 
274
289
  Returns:
275
290
  A ComparisonResult object.
276
291
  """
292
+ preserve_all_tensors = not validate_output_tensors_only
277
293
  model_comparion_result = ComparisonResult(reference_model, target_model)
278
294
  for signature_key, signature_inputs in test_data.items():
279
295
  comparison_results = {}
@@ -286,6 +302,7 @@ def compare_model(
286
302
  signature_key,
287
303
  use_xnnpack,
288
304
  num_threads,
305
+ preserve_all_tensors=preserve_all_tensors,
289
306
  )
290
307
  )
291
308
  targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
@@ -295,12 +312,23 @@ def compare_model(
295
312
  signature_key,
296
313
  use_xnnpack,
297
314
  num_threads,
315
+ preserve_all_tensors=preserve_all_tensors,
298
316
  )
299
317
  )
300
- # Compare the cached tensor values.
301
- for tensor_name, detail in ref_tensor_name_to_details.items():
318
+ # Compare the cached tensor value
319
+ tensor_names_to_compare = (
320
+ utils.get_output_tensor_names(reference_model, signature_key)
321
+ if validate_output_tensors_only
322
+ else list(ref_tensor_name_to_details.keys())
323
+ )
324
+
325
+ for tensor_name in tensor_names_to_compare:
326
+ detail = ref_tensor_name_to_details[tensor_name]
302
327
  if detail['dtype'] == np.object_:
303
328
  continue
329
+ # Ignore tensors where any dimension of the shape is 0.
330
+ if not np.all(detail['shape']):
331
+ continue
304
332
  if tensor_name in targ_tensor_name_to_details:
305
333
  if tensor_name not in comparison_results:
306
334
  comparison_results[tensor_name] = []
@@ -21,6 +21,7 @@ from tensorflow.python.platform import googletest
21
21
  from ai_edge_quantizer import model_validator
22
22
  from ai_edge_quantizer.utils import test_utils
23
23
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
24
25
  from ai_edge_quantizer.utils import validation_utils
25
26
 
26
27
  TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
@@ -194,7 +195,7 @@ class ModelValidatorCompareTest(googletest.TestCase):
194
195
  self.target_model_path
195
196
  )
196
197
  self.signature_key = 'serving_default' # single signature.
197
- self.test_data = test_utils.create_random_normal_input_data(
198
+ self.test_data = tfl_interpreter_utils.create_random_normal_input_data(
198
199
  self.reference_model_path
199
200
  )
200
201
  self.test_dir = self.create_tempdir()