ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -15,10 +15,17 @@
15
15
 
16
16
  """Python manager for transformations to be applied to TFlite models."""
17
17
 
18
+ from collections.abc import Sequence
19
+ from typing import Optional
20
+
18
21
  import numpy as np
22
+
19
23
  from ai_edge_quantizer import qtyping
20
24
  from ai_edge_quantizer.transformations import dequant_insert
21
- from ai_edge_quantizer.transformations import emulated_subchannel
25
+ from ai_edge_quantizer.transformations import duplicate_buffer
26
+ from ai_edge_quantizer.transformations import duplicate_tensor
27
+ from ai_edge_quantizer.transformations import insert_decomposed_hadamard_rotation
28
+ from ai_edge_quantizer.transformations import insert_hadamard_rotation
22
29
  from ai_edge_quantizer.transformations import quant_insert
23
30
  from ai_edge_quantizer.transformations import quantize_tensor
24
31
  from ai_edge_quantizer.transformations import transformation_utils
@@ -65,9 +72,21 @@ class TransformationPerformer:
65
72
  quantize_tensor.quantize_tensor
66
73
  ),
67
74
  qtyping.QuantTransformation.EMULATED_SUBCHANNEL: (
68
- emulated_subchannel.emulated_subchannel
75
+ transformation_utils.raise_deprecated_error
69
76
  ),
70
77
  qtyping.QuantTransformation.ADD_QUANTIZE: quant_insert.insert_quant,
78
+ qtyping.QuantTransformation.DUPLICATE_BUFFER: (
79
+ duplicate_buffer.duplicate_buffer
80
+ ),
81
+ qtyping.QuantTransformation.DUPLICATE_TENSOR: (
82
+ duplicate_tensor.duplicate_tensor
83
+ ),
84
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
85
+ insert_hadamard_rotation.insert_hadamard_rotation
86
+ ),
87
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION: (
88
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation
89
+ ),
71
90
  }
72
91
  # transformations are seprated in two categories:
73
92
  # op_insertion_transformations are transformations that only insert ops
@@ -77,6 +96,10 @@ class TransformationPerformer:
77
96
  qtyping.QuantTransformation.ADD_DEQUANTIZE,
78
97
  qtyping.QuantTransformation.QUANTIZE_TENSOR,
79
98
  qtyping.QuantTransformation.ADD_QUANTIZE,
99
+ qtyping.QuantTransformation.DUPLICATE_BUFFER,
100
+ qtyping.QuantTransformation.DUPLICATE_TENSOR,
101
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
102
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION,
80
103
  ])
81
104
  self._op_replacement_transformations = set(
82
105
  [qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
@@ -123,44 +146,81 @@ class TransformationPerformer:
123
146
  transformations: list[qtyping.TransformationInst],
124
147
  subgraph_id: int,
125
148
  trans_info: qtyping.TransformationInfo,
126
- ):
127
- """Update the instructions after the graph is modified.
149
+ ) -> None:
150
+ """Update the instructions in-place after the graph is modified.
128
151
 
129
- After an op is inserted, the topology is changed and this may impact the
130
- following transformation to be applied. So we need to update instructions
131
- that have yet to be applied.
152
+ After an op is inserted or a tensor is duplicated, the topology is changed
153
+ and this may impact the following transformation to be applied. So we need
154
+ to update instructions that have yet to be applied.
132
155
 
133
156
  Args:
134
- prev_transformation_index: the index of the last applied transformation
135
- transformations: the list of transformations we're applying
136
- subgraph_id: the subgraph where the provided instrucitons belongs to
137
- trans_info: transformation info returned by a transformation
138
-
139
- Returns:
140
- None, modifies the transformation in place
157
+ prev_transformation_index: The index of the last applied transformation.
158
+ transformations: The list of transformations we're applying.
159
+ subgraph_id: The subgraph where the provided instructions belong to.
160
+ trans_info: Transformation info returned by a transformation.
141
161
  """
142
- # if no ops were added, then no need for update
143
- if trans_info.num_ops_added == 0:
144
- return
145
162
  prev_transformation = transformations[prev_transformation_index]
146
- self._added_op_id_map[subgraph_id].append(
147
- trans_info.op_id + trans_info.num_ops_added - 1
163
+ is_prev_not_duplicate_tensor = (
164
+ prev_transformation.transformation
165
+ != qtyping.QuantTransformation.DUPLICATE_TENSOR
148
166
  )
167
+ was_op_added = trans_info.num_ops_added > 0
168
+ if not was_op_added and is_prev_not_duplicate_tensor:
169
+ return
170
+
171
+ if was_op_added:
172
+ self._added_op_id_map[subgraph_id].append(
173
+ trans_info.op_id + trans_info.num_ops_added - 1
174
+ )
175
+
149
176
  for transformations_index in range(
150
177
  prev_transformation_index + 1, len(transformations)
151
178
  ):
152
179
  transformation = transformations[transformations_index]
153
180
  for consumer_index in transformation.consumers:
154
- # if the consumer need to use newly added ops, then the new added op
181
+ # If the consumer needs to use newly added ops, then the new added op
155
182
  # index needs to be outside of the range of the orignal op ids.
156
183
  if consumer_index in prev_transformation.consumers:
157
- transformation.producer = (
158
- len(self._original_op_id_map[subgraph_id])
159
- + len(self._added_op_id_map[subgraph_id])
160
- - 1
161
- )
184
+ if was_op_added:
185
+ transformation.producer = (
186
+ len(self._original_op_id_map[subgraph_id])
187
+ + len(self._added_op_id_map[subgraph_id])
188
+ - 1
189
+ )
162
190
  transformation.tensor_id = trans_info.output_tensor_id
163
191
 
192
+ def _get_updated_producer_id(
193
+ self, original_producer_id: int, subgraph_id: int
194
+ ) -> int:
195
+ """Update the producer of a transformation instruction."""
196
+ if original_producer_id is None or original_producer_id < 0:
197
+ producer = -1
198
+ elif original_producer_id < len(self._original_op_id_map[subgraph_id]):
199
+ producer = self._original_op_id_map[subgraph_id][original_producer_id]
200
+ else:
201
+ # If the producer id is not in the original op map, it's an added op,
202
+ # go the added op map to find the producer.
203
+ producer = self._added_op_id_map[subgraph_id][
204
+ original_producer_id - len(self._original_op_id_map[subgraph_id])
205
+ ]
206
+ return producer
207
+
208
+ def _get_updated_consumer_ids(
209
+ self,
210
+ original_consumer_ids: list[int],
211
+ subgraph_id: int,
212
+ ) -> list[int]:
213
+ """Update the consumers of a transformation instruction."""
214
+ consumers = []
215
+ for original_op_id in original_consumer_ids:
216
+ new_consumer_id = (
217
+ -1
218
+ if original_op_id == -1
219
+ else self._original_op_id_map[subgraph_id][original_op_id]
220
+ )
221
+ consumers.append(new_consumer_id)
222
+ return consumers
223
+
164
224
  def _apply_single_transformation(
165
225
  self,
166
226
  transformation_inst: qtyping.TensorTransformationInsts,
@@ -179,28 +239,12 @@ class TransformationPerformer:
179
239
  None, update the transformation_inst & tflite_model in place
180
240
  """
181
241
  instruction = transformation_inst.instructions[transformation_index]
182
- if not instruction.producer or instruction.producer < 0:
183
- producer = -1
184
- elif instruction.producer < len(
185
- self._original_op_id_map[transformation_inst.subgraph_id]
186
- ):
187
- producer = self._original_op_id_map[transformation_inst.subgraph_id][
188
- instruction.producer
189
- ]
190
- else:
191
- # if the producer id is not in the original op map, it's an added op,
192
- # go the corresponding new maps
193
- producer = self._added_op_id_map[transformation_inst.subgraph_id][
194
- instruction.producer
195
- - len(self._original_op_id_map[transformation_inst.subgraph_id])
196
- ]
197
- consumers = []
198
- for original_op_id in instruction.consumers:
199
- consumers.append(
200
- self._original_op_id_map[transformation_inst.subgraph_id][
201
- original_op_id
202
- ]
203
- )
242
+ producer = self._get_updated_producer_id(
243
+ instruction.producer, transformation_inst.subgraph_id
244
+ )
245
+ consumers = self._get_updated_consumer_ids(
246
+ instruction.consumers, transformation_inst.subgraph_id
247
+ )
204
248
  trans_info = self._transformation_registration[instruction.transformation](
205
249
  transformation_utils.TransformationInput(
206
250
  instruction.tensor_id,
@@ -220,7 +264,12 @@ class TransformationPerformer:
220
264
  )
221
265
  self._update_op_id_map(
222
266
  transformation_inst.subgraph_id,
223
- min(instruction.consumers),
267
+ # The added op must be right before the most immediate consumer, unless
268
+ # the consumer is the graph output (id=-1), then use the producer's
269
+ # index instead.
270
+ min(instruction.consumers)
271
+ if min(instruction.consumers) >= 0
272
+ else instruction.producer + 1,
224
273
  trans_info.num_ops_added,
225
274
  )
226
275
 
@@ -260,19 +309,24 @@ class TransformationPerformer:
260
309
  self,
261
310
  transformation_instructions: dict[str, qtyping.TensorTransformationInsts],
262
311
  tflite_model: schema_py_generated.ModelT,
263
- ):
264
- """Apply all transformations to the given tflite_model.
312
+ tensor_processing_order: Optional[Sequence[str]] = None,
313
+ ) -> None:
314
+ """Apply all transformations to the given tflite_model in place.
265
315
 
266
316
  Args:
267
- transformation_instructions: a dict of transformation instructions grouped
268
- by tensors, produced by transformation_instruction_generator
269
- tflite_model: the tflite model to apply quantization on
270
-
271
- Returns:
272
- None, modifies the input tflite_model in place
317
+ transformation_instructions: Mapping from tensor name to its
318
+ transformation instructions, produced by
319
+ transformation_instruction_generator.
320
+ tflite_model: The tflite model to apply quantization to.
321
+ tensor_processing_order: The order of tensors to process. If not provided,
322
+ the order will be inferred from `transformation_instructions`.
273
323
  """
274
324
  self._original_op_id_map = []
275
325
  self._added_op_id_map = []
276
326
  self._create_op_id_map(tflite_model)
277
- for transformation_inst in transformation_instructions.values():
278
- self._apply_transformations(transformation_inst, tflite_model)
327
+ if tensor_processing_order is None:
328
+ tensor_processing_order = transformation_instructions.keys()
329
+ for tensor_name in tensor_processing_order:
330
+ self._apply_transformations(
331
+ transformation_instructions[tensor_name], tflite_model
332
+ )
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Tests for transformation_performer."""
17
17
 
18
+ import copy
18
19
  import os
19
20
 
20
21
  import numpy as np
@@ -26,6 +27,9 @@ from ai_edge_quantizer import transformation_performer
26
27
  from ai_edge_quantizer.utils import test_utils
27
28
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
29
 
30
+ _QTransf = qtyping.QuantTransformation
31
+
32
+
29
33
  TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(".")
30
34
 
31
35
 
@@ -108,6 +112,32 @@ class TransformationPerformerTest(parameterized.TestCase):
108
112
  for index, op_id in enumerate(op_id_map[0]):
109
113
  self.assertEqual(op_id, index)
110
114
 
115
+ def test_update_op_id_map_not_changing_value_single_op_model(self):
116
+ """test for _update_op_id_map."""
117
+ model = tfl_flatbuffer_utils.read_model(
118
+ os.path.join(
119
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
120
+ )
121
+ )
122
+ self._transformation_performer._create_op_id_map(model)
123
+ instruction = qtyping.TransformationInst(
124
+ transformation=qtyping.QuantTransformation.QUANTIZE_TENSOR,
125
+ tensor_id=0,
126
+ producer=0,
127
+ consumers=[-1],
128
+ parameters=qtyping.UniformQuantParams(
129
+ 8, None, np.array([1]), np.array([0])
130
+ ),
131
+ )
132
+ producer = self._transformation_performer._get_updated_producer_id(
133
+ instruction.producer, 0
134
+ )
135
+ consumers = self._transformation_performer._get_updated_consumer_ids(
136
+ instruction.consumers, 0
137
+ )
138
+ self.assertEqual(producer, 0)
139
+ self.assertEqual(consumers, [-1])
140
+
111
141
  @parameterized.named_parameters(
112
142
  dict(
113
143
  testcase_name="test_no_update",
@@ -267,6 +297,52 @@ class TransformationPerformerTest(parameterized.TestCase):
267
297
  expected_added_op_id_map,
268
298
  )
269
299
 
300
+ def test_update_instructions_updates_tensor_id_after_duplicate_tensor(self):
301
+ def get_test_instruction(transformation, consumers):
302
+ return qtyping.TransformationInst(
303
+ transformation=transformation,
304
+ consumers=consumers,
305
+ # Dummy values below.
306
+ tensor_id=0,
307
+ producer=0,
308
+ parameters=qtyping.UniformQuantParams(
309
+ 8, None, np.array([1]), np.array([0])
310
+ ),
311
+ )
312
+
313
+ instructions = [
314
+ get_test_instruction(_QTransf.DUPLICATE_TENSOR, consumers=[1]),
315
+ get_test_instruction(_QTransf.ADD_QUANTIZE, consumers=[1]),
316
+ get_test_instruction(_QTransf.ADD_DEQUANTIZE, consumers=[1]),
317
+ get_test_instruction(_QTransf.QUANTIZE_TENSOR, consumers=[2]),
318
+ ]
319
+ # Simulate a situation as if the first instruction (duplicate tensor) was
320
+ # applied.
321
+ subgraph_id = 0
322
+ duplicated_tensor_id = 13
323
+ prev_trans_idx = 0
324
+ trans_info = qtyping.TransformationInfo(
325
+ # Copy of what duplicate_tensor.py returns.
326
+ op_id=0,
327
+ num_ops_added=0,
328
+ output_tensor_id=duplicated_tensor_id,
329
+ )
330
+ self._transformation_performer._create_op_id_map(self._test_model)
331
+ self._transformation_performer._update_instructions(
332
+ prev_trans_idx, instructions, subgraph_id, trans_info
333
+ )
334
+ # Expecting the ops with the same consumers as in the DUPLICATE_TENSOR
335
+ # instruction to use the new tensor id.
336
+ expected_instructions = copy.deepcopy(instructions)
337
+ expected_instructions[1].tensor_id = duplicated_tensor_id
338
+ expected_instructions[2].tensor_id = duplicated_tensor_id
339
+ self.assertSequenceEqual(instructions, expected_instructions)
340
+ # Expecting no change to the op id map.
341
+ self.assertListEqual(
342
+ self._transformation_performer._added_op_id_map,
343
+ [[]],
344
+ )
345
+
270
346
  def test_transform_graph(self):
271
347
  """test for transform_graph."""
272
348
  instructions = {
@@ -275,6 +351,8 @@ class TransformationPerformerTest(parameterized.TestCase):
275
351
  tensor_name="sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
276
352
  + "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1",
277
353
  subgraph_id=0,
354
+ # Conv2d: op_id=0, output_tensor_id=7.
355
+ # This should add two sequential dequants after the conv2d.
278
356
  instructions=[
279
357
  qtyping.TransformationInst(
280
358
  transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
@@ -299,6 +377,8 @@ class TransformationPerformerTest(parameterized.TestCase):
299
377
  "sequential/average_pooling2d/AvgPool": qtyping.TensorTransformationInsts(
300
378
  tensor_name="sequential/average_pooling2d/AvgPool",
301
379
  subgraph_id=0,
380
+ # Avg_pool: op_id=1, output_tensor_id=8.
381
+ # This should add two sequential dequants after the avg_pool.
302
382
  instructions=[
303
383
  qtyping.TransformationInst(
304
384
  transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
@@ -326,19 +406,111 @@ class TransformationPerformerTest(parameterized.TestCase):
326
406
  )
327
407
  self.assertLen(self._test_model.subgraphs, 1)
328
408
  self.assertLen(self._test_model.subgraphs[0].operators, 10)
409
+ # The original model has 13 tensors, each dequant adds 1 tensor.
329
410
  self.assertLen(self._test_model.subgraphs[0].tensors, 17)
411
+ # Check that the dequant opcode is added to the model.
330
412
  self.assertEqual(
331
413
  self._test_model.subgraphs[0].operators[1].opcodeIndex,
332
414
  len(self._test_model.operatorCodes) - 1,
333
415
  )
416
+ # Conv2d, dequant, dequant, avgpool, dequant, dequant, etc.
417
+ expected_builtin_op_order = [3, 6, 6, 1, 6, 6, 22, 9, 9, 25]
418
+ for i, op in enumerate(self._test_model.subgraphs[0].operators):
419
+ op_code = self._test_model.operatorCodes[op.opcodeIndex].builtinCode
420
+ self.assertEqual(op_code, expected_builtin_op_order[i])
421
+ # Check that the first dequant input is connected to the conv2d output.
422
+ self.assertEqual(self._test_model.subgraphs[0].operators[1].inputs[0], 7)
423
+ # Output is a new tensor just added.
424
+ self.assertEqual(self._test_model.subgraphs[0].operators[1].outputs[0], 13)
425
+ # Second dequant has new tensors.
334
426
  self.assertEqual(self._test_model.subgraphs[0].operators[2].inputs[0], 13)
335
427
  self.assertEqual(self._test_model.subgraphs[0].operators[2].outputs[0], 14)
336
- self.assertEqual(
337
- self._test_model.subgraphs[0].operators[2].outputs[0],
338
- self._test_model.subgraphs[0].operators[3].inputs[0],
339
- )
428
+ # Avgpool's input is second dequant's output.
429
+ self.assertEqual(self._test_model.subgraphs[0].operators[3].inputs[0], 14)
430
+ # Avgpool's output remains the same.
340
431
  self.assertEqual(self._test_model.subgraphs[0].operators[3].outputs[0], 8)
432
+ # Third dequant's output is a new tensor.
341
433
  self.assertEqual(self._test_model.subgraphs[0].operators[4].outputs[0], 15)
434
+ # Fourth dequant.
435
+ self.assertEqual(self._test_model.subgraphs[0].operators[5].inputs[0], 15)
436
+ self.assertEqual(self._test_model.subgraphs[0].operators[5].outputs[0], 16)
437
+
438
+ # Avgpool (op_id=1) and reshape (op_id=2) are bumped by 2 due to the two
439
+ # dequants added after it.
440
+ expected_op_id_map = [0, 3, 6, 7, 8, 9]
441
+ self.assertEqual(
442
+ self._transformation_performer._original_op_id_map[0],
443
+ expected_op_id_map,
444
+ )
445
+ # New dequants are added at these indices.
446
+ expected_added_op_id_map = [1, 2, 4, 5]
447
+ self.assertEqual(
448
+ self._transformation_performer._added_op_id_map[0],
449
+ expected_added_op_id_map,
450
+ )
451
+
452
+ def test_op_insertion_at_input_and_output(self):
453
+ """test for _update_op_id_map."""
454
+ model = tfl_flatbuffer_utils.read_model(
455
+ os.path.join(
456
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
457
+ )
458
+ )
459
+ self._transformation_performer._create_op_id_map(model)
460
+ instructions = {
461
+ # Fully_connected: op_id=0, input_tensor_id=0, output_tensor_id=3.
462
+ # Add a new quantize op to the input of the fully_connected.
463
+ "serving_default_input_2:0": qtyping.TensorTransformationInsts(
464
+ tensor_name="serving_default_input_2:0",
465
+ subgraph_id=0,
466
+ instructions=[
467
+ qtyping.TransformationInst(
468
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
469
+ tensor_id=0,
470
+ producer=-1,
471
+ consumers=[0],
472
+ parameters=qtyping.UniformQuantParams(
473
+ 8, None, np.array([1]), np.array([0])
474
+ ),
475
+ ),
476
+ ],
477
+ ),
478
+ # Add a new dequantize op to the output of the fully_connected.
479
+ "StatefulPartitionedCall:0": qtyping.TensorTransformationInsts(
480
+ tensor_name="StatefulPartitionedCall:0",
481
+ subgraph_id=0,
482
+ instructions=[
483
+ qtyping.TransformationInst(
484
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
485
+ tensor_id=3,
486
+ producer=0,
487
+ consumers=[-1],
488
+ parameters=qtyping.UniformQuantParams(
489
+ 8, None, np.array([1]), np.array([0])
490
+ ),
491
+ ),
492
+ ],
493
+ ),
494
+ }
495
+ self._transformation_performer.transform_graph(instructions, model)
496
+
497
+ # Original fc (op_id=0) should be bumped to op_id=1.
498
+ self.assertEqual(
499
+ self._transformation_performer._original_op_id_map[0],
500
+ [1],
501
+ )
502
+ # New quantize added at op_id=0, dequantize added at op_id=1.
503
+ expected_added_op_id_map = [0, 2]
504
+ self.assertEqual(
505
+ self._transformation_performer._added_op_id_map[0],
506
+ expected_added_op_id_map,
507
+ )
508
+ # Quantize, fully_connected, dequantize.
509
+ expected_builtin_op_order = [114, 9, 6]
510
+ for i, op in enumerate(model.subgraphs[0].operators):
511
+ op_code = model.operatorCodes[op.opcodeIndex].builtinCode
512
+ self.assertEqual(op_code, expected_builtin_op_order[i])
513
+
342
514
 
343
515
  if __name__ == "__main__":
344
516
  googletest.main()
@@ -0,0 +1,46 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Duplicate buffer transformation."""
17
+
18
+ from ai_edge_quantizer import qtyping
19
+ from ai_edge_quantizer.transformations import transformation_utils
20
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
21
+
22
+
23
+ def duplicate_buffer(
24
+ transformation_input: transformation_utils.TransformationInput,
25
+ ) -> qtyping.TransformationInfo:
26
+ """Duplicates the buffer of the tensor."""
27
+ tensor_id = transformation_input.tensor_id
28
+ tensor = transformation_input.subgraph.tensors[tensor_id]
29
+ buffer_data = transformation_input.buffers[tensor.buffer].data
30
+ if buffer_data is None:
31
+ tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
32
+ raise ValueError(
33
+ 'Duplicate Buffer transformation supports only constant tensors.'
34
+ f' Tensor {tensor_name} is not constant.'
35
+ )
36
+
37
+ duplicated_buffer_id = transformation_utils.get_constant_buffer(
38
+ data=buffer_data,
39
+ buffers=transformation_input.buffers,
40
+ force_duplicate_buffer=True,
41
+ )
42
+ tensor.buffer = duplicated_buffer_id
43
+
44
+ return qtyping.TransformationInfo(
45
+ op_id=0, num_ops_added=0, output_tensor_id=tensor_id
46
+ )
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ import numpy as np
18
+ from tensorflow.python.platform import googletest
19
+ from ai_edge_quantizer import qtyping
20
+ from ai_edge_quantizer.transformations import duplicate_buffer
21
+ from ai_edge_quantizer.transformations import transformation_utils
22
+ from ai_edge_quantizer.utils import test_utils
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
+
25
+ TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('..')
26
+
27
+
28
+ class DuplicateBufferTest(googletest.TestCase):
29
+
30
+ def setUp(self):
31
+ super().setUp()
32
+ model_path = os.path.join(
33
+ TEST_DATA_PREFIX_PATH, 'tests/models/weight_sharing_fcs.tflite'
34
+ )
35
+ self.model = tfl_flatbuffer_utils.read_model(model_path)
36
+
37
+ def _get_transformation_input(
38
+ self, subgraph_idx: int, tensor_idx: int
39
+ ) -> transformation_utils.TransformationInput:
40
+ return transformation_utils.TransformationInput(
41
+ tensor_id=tensor_idx,
42
+ buffers=self.model.buffers,
43
+ # Dummy params below.
44
+ op_codes=self.model.operatorCodes,
45
+ subgraph=self.model.subgraphs[subgraph_idx],
46
+ producer=-1,
47
+ consumers=[],
48
+ quant_params=qtyping.UniformQuantParams(
49
+ num_bits=8,
50
+ quantized_dimension=None,
51
+ scale=np.ones(1),
52
+ zero_point=np.zeros(1),
53
+ ),
54
+ )
55
+
56
+ def test_constant_buffer_is_correctly_duplicated(self):
57
+ # Duplicate the FC weight tensor in the second subgraph.
58
+ subgraph_idx = 1
59
+ subgraph = self.model.subgraphs[subgraph_idx]
60
+ weight_tensor_idx = 1
61
+ prev_buffer_id = subgraph.tensors[weight_tensor_idx].buffer
62
+ prev_num_buffers = len(self.model.buffers)
63
+ transformation_input = self._get_transformation_input(
64
+ subgraph_idx, weight_tensor_idx
65
+ )
66
+ transformation_info = duplicate_buffer.duplicate_buffer(
67
+ transformation_input
68
+ )
69
+ self.assertEqual(transformation_info.op_id, 0)
70
+ self.assertEqual(transformation_info.num_ops_added, 0)
71
+ self.assertEqual(transformation_info.output_tensor_id, 1)
72
+ # Check that a new buffer was added.
73
+ self.assertLen(self.model.buffers, prev_num_buffers + 1)
74
+ # Check that the new buffer is used by the weight tensor.
75
+ new_buffer_id = len(self.model.buffers) - 1
76
+ self.assertEqual(subgraph.tensors[weight_tensor_idx].buffer, new_buffer_id)
77
+ # Check that the new buffer has the same data as the original one.
78
+ self.assertTrue(
79
+ np.all(
80
+ np.frombuffer(
81
+ self.model.buffers[new_buffer_id].data,
82
+ dtype=np.float32,
83
+ )
84
+ == np.frombuffer(
85
+ self.model.buffers[prev_buffer_id].data,
86
+ dtype=np.float32,
87
+ )
88
+ )
89
+ )
90
+
91
+ def test_duplicate_buffer_raises_error_when_tensor_is_not_constant(self):
92
+ # Duplicate the FC input tensor in the second subgraph.
93
+ subgraph_idx = 1
94
+ weight_tensor_idx = 0
95
+ transformation_input = self._get_transformation_input(
96
+ subgraph_idx, weight_tensor_idx
97
+ )
98
+ with self.assertRaisesRegex(
99
+ ValueError,
100
+ 'Duplicate Buffer transformation supports only constant tensors.',
101
+ ):
102
+ duplicate_buffer.duplicate_buffer(transformation_input)
103
+
104
+
105
+ if __name__ == '__main__':
106
+ googletest.main()