ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -103,58 +103,6 @@ class CalibratorTest(googletest.TestCase):
103
103
  model_tensor_qsvs = self._calibrator.get_model_qsvs()
104
104
  self.assertEmpty(model_tensor_qsvs)
105
105
 
106
- def test_calibrator_initialize_qsv(self):
107
- _add_default_int8xint8_integer_recipe(self._recipe_manager)
108
- # Overwrite the single op to fc
109
- self._recipe_manager.add_quantization_config(
110
- regex=".*Stateful.*",
111
- operation_name=qtyping.TFLOperationName.FULLY_CONNECTED,
112
- algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT,
113
- op_config=qtyping.OpQuantizationConfig(
114
- weight_tensor_config=_TENSOR_QUANT_CONFIG(
115
- num_bits=4,
116
- granularity=qtyping.QuantGranularity.CHANNELWISE,
117
- ),
118
- compute_precision=_ComputePrecision.INTEGER,
119
- ),
120
- )
121
- self._calibrator._initialize_model_qsvs(self._recipe_manager)
122
- model_tensor_qsvs = self._calibrator.get_model_qsvs()
123
-
124
- self.assertLen(model_tensor_qsvs, 4)
125
- self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
126
- input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
127
- self.assertEmpty(input_qsv)
128
-
129
- self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
130
- weight_tensor_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
131
- mins_maxs_shape = (16, 1)
132
- self.assertTupleEqual(weight_tensor_qsv["min"].shape, mins_maxs_shape)
133
- self.assertAlmostEqual(weight_tensor_qsv["min"][0][0], -0.40436327)
134
- self.assertTupleEqual(weight_tensor_qsv["max"].shape, mins_maxs_shape)
135
- self.assertAlmostEqual(weight_tensor_qsv["max"][0][0], 0.46138108)
136
-
137
- self.assertIn(
138
- "sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
139
- ) # bias
140
- bias_tensor_qsv = model_tensor_qsvs[
141
- "sequential/dense/BiasAdd/ReadVariableOp"
142
- ]
143
- mins_maxs_shape = (16,)
144
- self.assertTupleEqual(bias_tensor_qsv["min"].shape, mins_maxs_shape)
145
- self.assertAlmostEqual(bias_tensor_qsv["min"][0], -0.26978338)
146
- self.assertTupleEqual(bias_tensor_qsv["max"].shape, mins_maxs_shape)
147
- # Here bias min/max will be the same as each element is a scalar
148
- # Bias will be quantized with input_scale * weight_scale.
149
- self.assertSequenceEqual(
150
- list(bias_tensor_qsv["max"].flatten()),
151
- list(bias_tensor_qsv["min"].flatten()),
152
- )
153
-
154
- self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
155
- output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
156
- self.assertEmpty(output_qsv)
157
-
158
106
  def test_calibrate_single_fc_success(self):
159
107
  _add_default_int8xint8_integer_recipe(self._recipe_manager)
160
108
  self._calibrator.calibrate(
@@ -162,7 +110,7 @@ class CalibratorTest(googletest.TestCase):
162
110
  )
163
111
  model_tensor_qsvs = self._calibrator.get_model_qsvs()
164
112
 
165
- self.assertLen(model_tensor_qsvs, 4)
113
+ self.assertLen(model_tensor_qsvs, 2)
166
114
  self.assertIn("serving_default_input_1:0", model_tensor_qsvs) # input
167
115
  input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
168
116
  self.assertSequenceAlmostEqual(
@@ -171,19 +119,6 @@ class CalibratorTest(googletest.TestCase):
171
119
  self.assertSequenceAlmostEqual(
172
120
  input_qsv["max"].flatten(), [TEST_MAX_VAL], delta=1e-5
173
121
  )
174
-
175
- self.assertIn("sequential/dense/MatMul", model_tensor_qsvs) # weight
176
- weight_qsv = model_tensor_qsvs["sequential/dense/MatMul"]
177
- self.assertSequenceAlmostEqual(weight_qsv["min"].flatten(), [-0.49114203])
178
- self.assertSequenceAlmostEqual(weight_qsv["max"].flatten(), [0.4903704])
179
-
180
- self.assertIn(
181
- "sequential/dense/BiasAdd/ReadVariableOp", model_tensor_qsvs
182
- ) # bias
183
- bias_qsv = model_tensor_qsvs["sequential/dense/BiasAdd/ReadVariableOp"]
184
- self.assertSequenceAlmostEqual(bias_qsv["min"].flatten(), [-0.38401994])
185
- self.assertSequenceAlmostEqual(bias_qsv["max"].flatten(), [0.31727126])
186
-
187
122
  self.assertIn("StatefulPartitionedCall:0", model_tensor_qsvs) # output
188
123
  output_qsv = model_tensor_qsvs["StatefulPartitionedCall:0"]
189
124
  # Relu, only check the min
@@ -249,15 +184,11 @@ class CalibratorAlreadyQuantizedModelTest(googletest.TestCase):
249
184
  )
250
185
  _ = calibrator.Calibrator(test_model_path)
251
186
 
252
- def test_check_is_float_model_raises_error_when_model_is_quantized(self):
187
+ def test_check_is_quantized_model_succeeds_when_model_is_quantized(self):
253
188
  test_model_path = os.path.join(
254
189
  TEST_DATA_PREFIX_PATH, "tests/models/mnist_quantized.tflite"
255
190
  )
256
- with self.assertRaisesRegex(
257
- ValueError,
258
- "The input model for calibration is not a float model.",
259
- ):
260
- _ = calibrator.Calibrator(test_model_path)
191
+ _ = calibrator.Calibrator(test_model_path)
261
192
 
262
193
 
263
194
  class CalibratorToyGemma2Test(googletest.TestCase):
@@ -302,7 +233,7 @@ class CalibratorToyGemma2Test(googletest.TestCase):
302
233
  self._toy_gemma2_calibration_dataset,
303
234
  model_recipe_manager=recipe_mngr,
304
235
  )
305
- self.assertLen(calib.get_model_qsvs(), 282)
236
+ self.assertLen(calib.get_model_qsvs(), 202)
306
237
 
307
238
 
308
239
  if __name__ == "__main__":
@@ -19,9 +19,9 @@ import collections
19
19
  import copy
20
20
  import json
21
21
  from typing import Any, Union
22
+ from ai_edge_litert.tools import flatbuffer_utils
22
23
  from ai_edge_quantizer import qtyping
23
24
  from ai_edge_litert import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
24
- from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
25
25
 
26
26
  _TFLOpName = qtyping.TFLOperationName
27
27
  _OpQuantizationConfig = qtyping.OpQuantizationConfig
@@ -61,9 +61,8 @@ DEFAULT_JSON_POLICY = """
61
61
  "weight_tensor_config": {
62
62
  "num_bits": 4,
63
63
  "symmetric": [true],
64
- "granularity": ["BLOCKWISE"],
65
- "dtype": "INT",
66
- "block_size": [32, 64, 96, 128, 256]
64
+ "granularity": ["BLOCKWISE_32", "BLOCKWISE_64", "BLOCKWISE_128", "BLOCKWISE_256"],
65
+ "dtype": "INT"
67
66
  },
68
67
  "explicit_dequantize": false,
69
68
  "compute_precision": "INTEGER"
@@ -178,12 +177,30 @@ DEFAULT_JSON_POLICY = """
178
177
  "INPUT",
179
178
  "OUTPUT",
180
179
  "SLICE",
181
- "EMBEDDING_LOOKUP",
182
180
  "SUM",
181
+ "SELECT",
183
182
  "SELECT_V2",
184
183
  "DYNAMIC_UPDATE_SLICE",
185
184
  "SELECT_V2",
186
- "STABLEHLO_COMPOSITE"
185
+ "STABLEHLO_COMPOSITE",
186
+ "PAD",
187
+ "MAX_POOL_2D",
188
+ "RESIZE_BILINEAR",
189
+ "RESIZE_NEAREST_NEIGHBOR",
190
+ "GATHER_ND",
191
+ "PACK",
192
+ "UNPACK",
193
+ "DIV",
194
+ "BROADCAST_TO",
195
+ "SQRT",
196
+ "GATHER",
197
+ "MAXIMUM",
198
+ "PADV2",
199
+ "REDUCE_MIN",
200
+ "EQUAL",
201
+ "NOT_EQUAL",
202
+ "MIRROR_PAD",
203
+ "RELU"
187
204
  ],
188
205
  "static_wi8_ai8": [
189
206
  "ADD",
@@ -209,15 +226,36 @@ DEFAULT_JSON_POLICY = """
209
226
  "INPUT",
210
227
  "OUTPUT",
211
228
  "SLICE",
212
- "EMBEDDING_LOOKUP",
213
229
  "SUM",
230
+ "SELECT",
214
231
  "SELECT_V2",
215
232
  "DYNAMIC_UPDATE_SLICE",
216
233
  "SELECT_V2",
217
- "STABLEHLO_COMPOSITE"
234
+ "STABLEHLO_COMPOSITE",
235
+ "PAD",
236
+ "SQUARED_DIFFERENCE",
237
+ "MAX_POOL_2D",
238
+ "RESIZE_BILINEAR",
239
+ "RESIZE_NEAREST_NEIGHBOR",
240
+ "GATHER_ND",
241
+ "PACK",
242
+ "UNPACK",
243
+ "DIV",
244
+ "BROADCAST_TO",
245
+ "SQRT",
246
+ "GATHER",
247
+ "HARD_SWISH",
248
+ "MAXIMUM",
249
+ "PADV2",
250
+ "REDUCE_MIN",
251
+ "EQUAL",
252
+ "NOT_EQUAL",
253
+ "MIRROR_PAD",
254
+ "SPACE_TO_DEPTH",
255
+ "RELU"
218
256
  ],
219
- "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
220
- "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
257
+ "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
258
+ "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
221
259
  "dynamic_wi8_afp32": [
222
260
  "BATCH_MATMUL",
223
261
  "CONV_2D",
@@ -240,6 +278,11 @@ DEFAULT_JSON_POLICY = """
240
278
  }
241
279
  }
242
280
  """
281
+ QUANTIZABLE_COMPOSITES = [
282
+ "od" + "ml.npu_call",
283
+ "od" + "ml.rms_norm",
284
+ "od" + "ml.l2_norm",
285
+ ]
243
286
 
244
287
 
245
288
  def _unroll_json_config(
@@ -280,16 +323,9 @@ def _unroll_json_config(
280
323
  "granularity": granularity,
281
324
  "dtype": json_config["weight_tensor_config"]["dtype"],
282
325
  }
283
- if "block_size" in json_config["weight_tensor_config"]:
284
- for block_size in json_config["weight_tensor_config"]["block_size"]:
285
- tensor_config["block_size"] = block_size
286
- weight_configs.append(
287
- qtyping.TensorQuantizationConfig.from_dict(tensor_config)
288
- )
289
- else:
290
- weight_configs.append(
291
- qtyping.TensorQuantizationConfig.from_dict(tensor_config)
292
- )
326
+ weight_configs.append(
327
+ qtyping.TensorQuantizationConfig.from_dict(tensor_config)
328
+ )
293
329
 
294
330
  if activation_configs:
295
331
  for activation_config in activation_configs:
@@ -317,10 +353,10 @@ def _unroll_json_config(
317
353
 
318
354
 
319
355
  # TODO: b/401024954 - Have a better way to specify recipes based on op options.
320
- def is_conditionally_unquantized(
356
+ def is_non_quantizable_composite_op(
321
357
  op: Union[schema.Operator, schema.OperatorT],
322
358
  ) -> bool:
323
- """Checks if the operator is conditionally unquantized.
359
+ """Checks if the operator is a non-quantizable composite op.
324
360
 
325
361
  We may want to quantize an op only when its has certain options.
326
362
  Policies/recipes
@@ -335,10 +371,9 @@ def is_conditionally_unquantized(
335
371
  if opts := flatbuffer_utils.get_options_as(
336
372
  op, schema.StableHLOCompositeOptionsT
337
373
  ):
338
- name: bytes = opts.name
339
- # Non npu_call composites may have a kernel and as such will not be
340
- # quantized.
341
- return ("od" + "ml.npu_call") not in name.decode("utf-8")
374
+ name = opts.name.decode("utf-8")
375
+ if name not in QUANTIZABLE_COMPOSITES:
376
+ return True
342
377
 
343
378
  return False
344
379
 
@@ -17,15 +17,21 @@
17
17
 
18
18
  from collections.abc import Sequence
19
19
  import copy
20
+ import logging
20
21
 
21
22
  import numpy as np
22
23
 
24
+ from ai_edge_litert.tools import flatbuffer_utils
23
25
  from ai_edge_quantizer import qtyping
24
26
  from ai_edge_quantizer import transformation_instruction_generator
25
27
  from ai_edge_quantizer import transformation_performer
26
28
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
29
+ from ai_edge_quantizer.utils import tfl_interpreter_utils
30
+ from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
27
31
  from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
28
- from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
32
+
33
+
34
+ _DEQUANT_SUFFIX = "_dequant"
29
35
 
30
36
 
31
37
  class ModelModifier:
@@ -104,11 +110,95 @@ class ModelModifier:
104
110
  instructions, quantized_model, tensor_processing_order
105
111
  )
106
112
  constant_buffer_size = self._process_constant_map(quantized_model)
107
- # we leave 64MB for the model architecture.
108
- if constant_buffer_size > 2**31 - 2**26:
109
- return self._serialize_large_model(quantized_model)
110
- else:
111
- return self._serialize_small_model(quantized_model)
113
+ # we leave 256MB for the model architecture.
114
+ serialize_fun = (
115
+ self._serialize_large_model
116
+ if constant_buffer_size > 2**31 - 2**28
117
+ else self._serialize_small_model
118
+ )
119
+ serialized_quantized_model = serialize_fun(quantized_model)
120
+
121
+ # Update signature defs if dequant is inserted before output.
122
+ if self._has_dequant_before_output(instructions):
123
+ quantized_model = self._update_signature_defs_for_dequant_output(
124
+ quantized_model, serialized_quantized_model
125
+ )
126
+ serialized_quantized_model = serialize_fun(quantized_model)
127
+
128
+ return serialized_quantized_model
129
+
130
+ def _update_signature_defs_for_dequant_output(
131
+ self, model: schema_py_generated.ModelT, serialized_model: bytearray
132
+ ):
133
+ """Updates the signature definitions in the model.
134
+
135
+ This function is called when a dequantize operation is inserted before
136
+ an output tensor. It updates the tensor index in the signature
137
+ definitions to point to the newly inserted dequantize output tensor.
138
+
139
+ Args:
140
+ model: The TFlite ModelT object.
141
+ serialized_model: The serialized bytearray of the TFlite model.
142
+
143
+ Returns:
144
+ The updated TFlite ModelT object.
145
+ """
146
+ interpreter = tfl.Interpreter(model_content=bytes(serialized_model))
147
+
148
+ for signature_def in model.signatureDefs:
149
+ signature_key = signature_def.signatureKey.decode("utf-8")
150
+ logging.info("Signature = %s", signature_key)
151
+ subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
152
+ interpreter, signature_key
153
+ )
154
+ output_details = interpreter.get_signature_runner(
155
+ signature_key
156
+ ).get_output_details()
157
+ subgraph = model.subgraphs[subgraph_idx]
158
+ graph_info = qtyping.GraphInfo(subgraph.tensors, model.buffers)
159
+
160
+ for output in subgraph.outputs:
161
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(
162
+ graph_info.subgraph_tensors[output]
163
+ )
164
+ logging.info("\tOutput tensor = `%s`", tensor_name)
165
+
166
+ for signature_name, tensor_details in output_details.items():
167
+ if tensor_details["name"] + _DEQUANT_SUFFIX == tensor_name:
168
+ logging.info(
169
+ "\t\tfound tensor mapping: `%s`->`%s` for signature name: `%s`",
170
+ tensor_details["name"],
171
+ tensor_name,
172
+ signature_name,
173
+ )
174
+ for signature_item in signature_def.outputs:
175
+ if signature_item.name.decode("utf-8") == signature_name:
176
+ signature_item.tensorIndex = output
177
+ logging.info(
178
+ "\t\t\tswapped tensor index: %s->%s",
179
+ tensor_details["index"],
180
+ output,
181
+ )
182
+ break
183
+ break
184
+
185
+ return model
186
+
187
+ def _has_dequant_before_output(
188
+ self, instructions: dict[str, qtyping.TensorTransformationInsts]
189
+ ) -> bool:
190
+ """Check if the model has dequant insert to output."""
191
+ for tensor_name, tensor_trans_insts in instructions.items():
192
+ for instr in tensor_trans_insts.instructions:
193
+ if (
194
+ qtyping.QuantTransformation.ADD_DEQUANTIZE == instr.transformation
195
+ and instr.consumers == [-1]
196
+ ):
197
+ logging.info(
198
+ "Found dequant insert to output for tensor: %s", tensor_name
199
+ )
200
+ return True
201
+ return False
112
202
 
113
203
  def _process_constant_map(
114
204
  self, quantized_model: schema_py_generated.ModelT
@@ -142,7 +232,7 @@ class ModelModifier:
142
232
  remainder = len(bytearr) % 16
143
233
  if remainder != 0:
144
234
  padding_size = 16 - remainder
145
- bytearr.extend(b'\0' * padding_size)
235
+ bytearr.extend(b"\0" * padding_size)
146
236
 
147
237
  # TODO: b/333797307 - support > 2GB output model
148
238
  def _serialize_large_model(
@@ -19,13 +19,13 @@ import os
19
19
  import tracemalloc
20
20
  from tensorflow.python.platform import googletest
21
21
  from absl.testing import parameterized
22
+ from ai_edge_litert.tools import flatbuffer_utils
22
23
  from ai_edge_quantizer import model_modifier
23
24
  from ai_edge_quantizer import params_generator
24
25
  from ai_edge_quantizer import qtyping
25
26
  from ai_edge_quantizer import recipe_manager
26
27
  from ai_edge_quantizer.utils import test_utils
27
28
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
- from tensorflow.lite.tools import flatbuffer_utils # pylint: disable=g-direct-tensorflow-import
29
29
 
30
30
  TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('.')
31
31
 
@@ -125,6 +125,86 @@ class ModelModifierTest(parameterized.TestCase):
125
125
  loosen_mem_use_factor = 4.5
126
126
  self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor)
127
127
 
128
+ def test_has_dequant_before_output_true(self):
129
+ instructions = {
130
+ 'tensor1': qtyping.TensorTransformationInsts(
131
+ 'tensor1',
132
+ 0,
133
+ instructions=[
134
+ qtyping.TransformationInst(
135
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
136
+ tensor_id=0,
137
+ producer=0,
138
+ consumers=[-1],
139
+ )
140
+ ],
141
+ )
142
+ }
143
+ self.assertTrue(
144
+ self._model_modifier._has_dequant_before_output(instructions)
145
+ )
146
+
147
+ def test_has_dequant_before_output_false(self):
148
+ instructions = {
149
+ 'tensor1': qtyping.TensorTransformationInsts(
150
+ 'tensor1',
151
+ 0,
152
+ instructions=[
153
+ qtyping.TransformationInst(
154
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
155
+ tensor_id=0,
156
+ producer=0,
157
+ consumers=[1],
158
+ )
159
+ ],
160
+ )
161
+ }
162
+ self.assertFalse(
163
+ self._model_modifier._has_dequant_before_output(instructions)
164
+ )
165
+
166
+ def test_pad_bytearray(self):
167
+ arr = bytearray(b'\x01\x02\x03')
168
+ self._model_modifier._pad_bytearray(arr)
169
+ self.assertLen(arr, 16)
170
+ self.assertEqual(arr, b'\x01\x02\x03' + b'\0' * 13)
171
+
172
+ arr = bytearray(b'\x01' * 16)
173
+ self._model_modifier._pad_bytearray(arr)
174
+ self.assertLen(arr, 16)
175
+
176
+ arr = bytearray(b'\x01' * 17)
177
+ self._model_modifier._pad_bytearray(arr)
178
+ self.assertLen(arr, 32)
179
+
180
+
181
+ class ModelModifierTestWithSignature(parameterized.TestCase):
182
+
183
+ def setUp(self):
184
+ super().setUp()
185
+ self._model_path = os.path.join(
186
+ TEST_DATA_PREFIX_PATH,
187
+ 'tests/models/single_fc.tflite',
188
+ )
189
+ self._model_content: bytes = tfl_flatbuffer_utils.get_model_content(
190
+ self._model_path
191
+ )
192
+ self._model_modifier = model_modifier.ModelModifier(self._model_content)
193
+
194
+ def test_update_signature_defs_for_dequant_output_succeeds(self):
195
+ # This is a simplified test that only checks if the function runs without
196
+ # crashing and returns a model. A more thorough test with a model
197
+ # with a known signature was added in `quantizer_test`.
198
+ model_bytearray = flatbuffer_utils.read_model_from_bytearray(
199
+ self._model_content
200
+ )
201
+ updated_model = (
202
+ self._model_modifier._update_signature_defs_for_dequant_output(
203
+ model_bytearray, bytearray(self._model_content)
204
+ )
205
+ )
206
+ self.assertIsNotNone(updated_model)
207
+
128
208
 
129
209
  if __name__ == '__main__':
130
210
  googletest.main()
@@ -25,7 +25,7 @@ from typing import Any, Optional, Union
25
25
  import numpy as np
26
26
 
27
27
  from ai_edge_quantizer.utils import tfl_interpreter_utils as utils
28
- from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
28
+ import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
29
29
 
30
30
 
31
31
  _DEFAULT_SIGNATURE_KEY = utils.DEFAULT_SIGNATURE_KEY
@@ -118,7 +118,8 @@ class ComparisonResult:
118
118
  for name in utils.get_input_tensor_names(
119
119
  self._reference_model, signature_key
120
120
  ):
121
- input_tensor_results[name] = result.pop(name)
121
+ if name in result:
122
+ input_tensor_results[name] = result.pop(name)
122
123
 
123
124
  output_tensor_results = {}
124
125
  for name in utils.get_output_tensor_names(
@@ -136,7 +137,8 @@ class ComparisonResult:
136
137
  self._reference_model,
137
138
  subgraph_index,
138
139
  ):
139
- constant_tensor_results[name] = result.pop(name)
140
+ if name in result:
141
+ constant_tensor_results[name] = result.pop(name)
140
142
 
141
143
  self._comparison_results[signature_key] = SingleSignatureComparisonResult(
142
144
  error_metric=error_metric,
@@ -192,7 +194,7 @@ class ComparisonResult:
192
194
  result_save_path = os.path.join(
193
195
  save_folder, model_name + '_comparison_result.json'
194
196
  )
195
- with gfile.GFile(result_save_path, 'w') as output_file_handle:
197
+ with open(result_save_path, 'w') as output_file_handle:
196
198
  output_file_handle.write(json.dumps(result))
197
199
 
198
200
  # TODO: b/365578554 - Remove after ME is updated to use the new json format.
@@ -204,7 +206,7 @@ class ComparisonResult:
204
206
  json_save_path = os.path.join(
205
207
  save_folder, model_name + '_comparison_result_me_input.json'
206
208
  )
207
- with gfile.GFile(json_save_path, 'w') as output_file_handle:
209
+ with open(json_save_path, 'w') as output_file_handle:
208
210
  output_file_handle.write(json_object)
209
211
 
210
212
 
@@ -214,6 +216,7 @@ def _setup_validation_interpreter(
214
216
  signature_key: Optional[str],
215
217
  use_xnnpack: bool,
216
218
  num_threads: int,
219
+ preserve_all_tensors: bool = True,
217
220
  ) -> tuple[Any, int, dict[str, Any]]:
218
221
  """Setup the interpreter for validation given a signature key.
219
222
 
@@ -224,13 +227,17 @@ def _setup_validation_interpreter(
224
227
  model only has one signature, this can be set to None.
225
228
  use_xnnpack: Whether to use xnnpack for the interpreter.
226
229
  num_threads: The number of threads to use for the interpreter.
230
+ preserve_all_tensors: Whether to preserve all tensors.
227
231
 
228
232
  Returns:
229
233
  A tuple of interpreter, subgraph_index and tensor_name_to_details.
230
234
  """
231
235
 
232
236
  interpreter = utils.create_tfl_interpreter(
233
- tflite_model=model, use_xnnpack=use_xnnpack, num_threads=num_threads
237
+ tflite_model=model,
238
+ use_xnnpack=use_xnnpack,
239
+ num_threads=num_threads,
240
+ preserve_all_tensors=preserve_all_tensors,
234
241
  )
235
242
  utils.invoke_interpreter_signature(
236
243
  interpreter, signature_input, signature_key
@@ -255,6 +262,7 @@ def compare_model(
255
262
  compare_fn: Callable[[Any, Any], float],
256
263
  use_xnnpack: bool = True,
257
264
  num_threads: int = 16,
265
+ validate_output_tensors_only: bool = False,
258
266
  ) -> ComparisonResult:
259
267
  """Compares model tensors over a model signature using the compare_fn.
260
268
 
@@ -275,10 +283,13 @@ def compare_model(
275
283
  single float value.
276
284
  use_xnnpack: Whether to use xnnpack for the interpreter.
277
285
  num_threads: The number of threads to use for the interpreter.
286
+ validate_output_tensors_only: If True, only compare output tensors.
287
+ Otherwise, compare all tensors.
278
288
 
279
289
  Returns:
280
290
  A ComparisonResult object.
281
291
  """
292
+ preserve_all_tensors = not validate_output_tensors_only
282
293
  model_comparion_result = ComparisonResult(reference_model, target_model)
283
294
  for signature_key, signature_inputs in test_data.items():
284
295
  comparison_results = {}
@@ -291,6 +302,7 @@ def compare_model(
291
302
  signature_key,
292
303
  use_xnnpack,
293
304
  num_threads,
305
+ preserve_all_tensors=preserve_all_tensors,
294
306
  )
295
307
  )
296
308
  targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
@@ -300,12 +312,23 @@ def compare_model(
300
312
  signature_key,
301
313
  use_xnnpack,
302
314
  num_threads,
315
+ preserve_all_tensors=preserve_all_tensors,
303
316
  )
304
317
  )
305
- # Compare the cached tensor values.
306
- for tensor_name, detail in ref_tensor_name_to_details.items():
318
+ # Compare the cached tensor value
319
+ tensor_names_to_compare = (
320
+ utils.get_output_tensor_names(reference_model, signature_key)
321
+ if validate_output_tensors_only
322
+ else list(ref_tensor_name_to_details.keys())
323
+ )
324
+
325
+ for tensor_name in tensor_names_to_compare:
326
+ detail = ref_tensor_name_to_details[tensor_name]
307
327
  if detail['dtype'] == np.object_:
308
328
  continue
329
+ # Ignore tensors where any dimension of the shape is 0.
330
+ if not np.all(detail['shape']):
331
+ continue
309
332
  if tensor_name in targ_tensor_name_to_details:
310
333
  if tensor_name not in comparison_results:
311
334
  comparison_results[tensor_name] = []
@@ -35,12 +35,12 @@ class ParamsGenerator:
35
35
  def __init__(self, float_tflite: Union[str, bytes]):
36
36
  self.flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
37
37
 
38
- if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
39
- raise ValueError(
40
- 'The input model for quantization parameters generation is not a'
41
- ' float model. Please check the model (e.g., if it is already'
42
- ' quantized).'
43
- )
38
+ # if not tfl_flatbuffer_utils.is_float_model(self.flatbuffer_model):
39
+ # raise ValueError(
40
+ # 'The input model for quantization parameters generation is not a'
41
+ # ' float model. Please check the model (e.g., if it is already'
42
+ # ' quantized).'
43
+ # )
44
44
  self._check_tensor_names_are_unique()
45
45
  self.buffer_to_tensors: dict[int, list[Any]] = (
46
46
  tfl_flatbuffer_utils.buffer_to_tensors(self.flatbuffer_model)
@@ -78,8 +78,6 @@ class ParamsGenerator:
78
78
  skip_subgraphs = set()
79
79
  op_codes = self.flatbuffer_model.operatorCodes
80
80
  for sg_ind, subgraph in enumerate(self.flatbuffer_model.subgraphs):
81
- if sg_ind in skip_subgraphs:
82
- continue
83
81
 
84
82
  graph_info = qtyping.GraphInfo(
85
83
  subgraph.tensors, self.flatbuffer_model.buffers
@@ -109,7 +107,10 @@ class ParamsGenerator:
109
107
  algorithm_name, op_quant_config = (
110
108
  model_recipe_manager.get_quantization_configs(op_key, op_scope)
111
109
  )
112
- if policy.is_conditionally_unquantized(op):
110
+
111
+ if sg_ind in skip_subgraphs or policy.is_non_quantizable_composite_op(
112
+ op
113
+ ):
113
114
  algorithm_name = algorithm_manager.AlgorithmName.NO_QUANTIZE
114
115
 
115
116
  if algorithm_name == algorithm_manager.AlgorithmName.NO_QUANTIZE:
@@ -408,7 +409,11 @@ class ParamsGenerator:
408
409
  buffers_to_duplicate = []
409
410
  tensor_names_to_duplicate = []
410
411
  for buffer_idx, tensors in self.buffer_to_tensors.items():
411
- if not tensors:
412
+ # TODO: b/458797890 - Investigate if skipping buffer_idx == 0 is a
413
+ # correct fix, or if it just covers up a deeper issue. This is only
414
+ # required when statically quantizing models that have already been
415
+ # quantized dynamically.
416
+ if not tensors or buffer_idx == 0:
412
417
  continue
413
418
  # Check if any of the tensors needs to be duplicated.
414
419
  for tensor in tensors:
@@ -508,6 +513,8 @@ def _compatible_tensor_params(
508
513
  float_source_transformations = [
509
514
  _QuantTrans.ADD_QUANTIZE,
510
515
  _QuantTrans.NO_QUANTIZE,
516
+ _QuantTrans.INSERT_HADAMARD_ROTATION,
517
+ _QuantTrans.INSERT_DECOMPOSED_HADAMARD_ROTATION,
511
518
  ]
512
519
  quantized_source_transformations = [
513
520
  _QuantTrans.QUANTIZE_TENSOR,