ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -112,6 +112,32 @@ class TransformationPerformerTest(parameterized.TestCase):
112
112
  for index, op_id in enumerate(op_id_map[0]):
113
113
  self.assertEqual(op_id, index)
114
114
 
115
+ def test_update_op_id_map_not_changing_value_single_op_model(self):
116
+ """test for _update_op_id_map."""
117
+ model = tfl_flatbuffer_utils.read_model(
118
+ os.path.join(
119
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
120
+ )
121
+ )
122
+ self._transformation_performer._create_op_id_map(model)
123
+ instruction = qtyping.TransformationInst(
124
+ transformation=qtyping.QuantTransformation.QUANTIZE_TENSOR,
125
+ tensor_id=0,
126
+ producer=0,
127
+ consumers=[-1],
128
+ parameters=qtyping.UniformQuantParams(
129
+ 8, None, np.array([1]), np.array([0])
130
+ ),
131
+ )
132
+ producer = self._transformation_performer._get_updated_producer_id(
133
+ instruction.producer, 0
134
+ )
135
+ consumers = self._transformation_performer._get_updated_consumer_ids(
136
+ instruction.consumers, 0
137
+ )
138
+ self.assertEqual(producer, 0)
139
+ self.assertEqual(consumers, [-1])
140
+
115
141
  @parameterized.named_parameters(
116
142
  dict(
117
143
  testcase_name="test_no_update",
@@ -271,7 +297,7 @@ class TransformationPerformerTest(parameterized.TestCase):
271
297
  expected_added_op_id_map,
272
298
  )
273
299
 
274
- def test__update_instructions_updates_tensor_id_after_duplicate_tensor(self):
300
+ def test_update_instructions_updates_tensor_id_after_duplicate_tensor(self):
275
301
  def get_test_instruction(transformation, consumers):
276
302
  return qtyping.TransformationInst(
277
303
  transformation=transformation,
@@ -325,6 +351,8 @@ class TransformationPerformerTest(parameterized.TestCase):
325
351
  tensor_name="sequential/conv2d/Relu;sequential/conv2d/BiasAdd;"
326
352
  + "sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1",
327
353
  subgraph_id=0,
354
+ # Conv2d: op_id=0, output_tensor_id=7.
355
+ # This should add two sequential dequants after the conv2d.
328
356
  instructions=[
329
357
  qtyping.TransformationInst(
330
358
  transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
@@ -349,6 +377,8 @@ class TransformationPerformerTest(parameterized.TestCase):
349
377
  "sequential/average_pooling2d/AvgPool": qtyping.TensorTransformationInsts(
350
378
  tensor_name="sequential/average_pooling2d/AvgPool",
351
379
  subgraph_id=0,
380
+ # Avg_pool: op_id=1, output_tensor_id=8.
381
+ # This should add two sequential dequants after the avg_pool.
352
382
  instructions=[
353
383
  qtyping.TransformationInst(
354
384
  transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
@@ -376,19 +406,111 @@ class TransformationPerformerTest(parameterized.TestCase):
376
406
  )
377
407
  self.assertLen(self._test_model.subgraphs, 1)
378
408
  self.assertLen(self._test_model.subgraphs[0].operators, 10)
409
+ # The original model has 13 tensors, each dequant adds 1 tensor.
379
410
  self.assertLen(self._test_model.subgraphs[0].tensors, 17)
411
+ # Check that the dequant opcode is added to the model.
380
412
  self.assertEqual(
381
413
  self._test_model.subgraphs[0].operators[1].opcodeIndex,
382
414
  len(self._test_model.operatorCodes) - 1,
383
415
  )
416
+ # Conv2d, dequant, dequant, avgpool, dequant, dequant, etc.
417
+ expected_builtin_op_order = [3, 6, 6, 1, 6, 6, 22, 9, 9, 25]
418
+ for i, op in enumerate(self._test_model.subgraphs[0].operators):
419
+ op_code = self._test_model.operatorCodes[op.opcodeIndex].builtinCode
420
+ self.assertEqual(op_code, expected_builtin_op_order[i])
421
+ # Check that the first dequant input is connected to the conv2d output.
422
+ self.assertEqual(self._test_model.subgraphs[0].operators[1].inputs[0], 7)
423
+ # Output is a new tensor just added.
424
+ self.assertEqual(self._test_model.subgraphs[0].operators[1].outputs[0], 13)
425
+ # Second dequant has new tensors.
384
426
  self.assertEqual(self._test_model.subgraphs[0].operators[2].inputs[0], 13)
385
427
  self.assertEqual(self._test_model.subgraphs[0].operators[2].outputs[0], 14)
386
- self.assertEqual(
387
- self._test_model.subgraphs[0].operators[2].outputs[0],
388
- self._test_model.subgraphs[0].operators[3].inputs[0],
389
- )
428
+ # Avgpool's input is second dequant's output.
429
+ self.assertEqual(self._test_model.subgraphs[0].operators[3].inputs[0], 14)
430
+ # Avgpool's output remains the same.
390
431
  self.assertEqual(self._test_model.subgraphs[0].operators[3].outputs[0], 8)
432
+ # Third dequant's output is a new tensor.
391
433
  self.assertEqual(self._test_model.subgraphs[0].operators[4].outputs[0], 15)
434
+ # Fourth dequant.
435
+ self.assertEqual(self._test_model.subgraphs[0].operators[5].inputs[0], 15)
436
+ self.assertEqual(self._test_model.subgraphs[0].operators[5].outputs[0], 16)
437
+
438
+ # Avgpool (op_id=1) and reshape (op_id=2) are bumped by 2 due to the two
439
+ # dequants added after it.
440
+ expected_op_id_map = [0, 3, 6, 7, 8, 9]
441
+ self.assertEqual(
442
+ self._transformation_performer._original_op_id_map[0],
443
+ expected_op_id_map,
444
+ )
445
+ # New dequants are added at these indices.
446
+ expected_added_op_id_map = [1, 2, 4, 5]
447
+ self.assertEqual(
448
+ self._transformation_performer._added_op_id_map[0],
449
+ expected_added_op_id_map,
450
+ )
451
+
452
+ def test_op_insertion_at_input_and_output(self):
453
+ """test for _update_op_id_map."""
454
+ model = tfl_flatbuffer_utils.read_model(
455
+ os.path.join(
456
+ TEST_DATA_PREFIX_PATH, "tests/models/single_fc_bias.tflite"
457
+ )
458
+ )
459
+ self._transformation_performer._create_op_id_map(model)
460
+ instructions = {
461
+ # Fully_connected: op_id=0, input_tensor_id=0, output_tensor_id=3.
462
+ # Add a new quantize op to the input of the fully_connected.
463
+ "serving_default_input_2:0": qtyping.TensorTransformationInsts(
464
+ tensor_name="serving_default_input_2:0",
465
+ subgraph_id=0,
466
+ instructions=[
467
+ qtyping.TransformationInst(
468
+ transformation=qtyping.QuantTransformation.ADD_QUANTIZE,
469
+ tensor_id=0,
470
+ producer=-1,
471
+ consumers=[0],
472
+ parameters=qtyping.UniformQuantParams(
473
+ 8, None, np.array([1]), np.array([0])
474
+ ),
475
+ ),
476
+ ],
477
+ ),
478
+ # Add a new dequantize op to the output of the fully_connected.
479
+ "StatefulPartitionedCall:0": qtyping.TensorTransformationInsts(
480
+ tensor_name="StatefulPartitionedCall:0",
481
+ subgraph_id=0,
482
+ instructions=[
483
+ qtyping.TransformationInst(
484
+ transformation=qtyping.QuantTransformation.ADD_DEQUANTIZE,
485
+ tensor_id=3,
486
+ producer=0,
487
+ consumers=[-1],
488
+ parameters=qtyping.UniformQuantParams(
489
+ 8, None, np.array([1]), np.array([0])
490
+ ),
491
+ ),
492
+ ],
493
+ ),
494
+ }
495
+ self._transformation_performer.transform_graph(instructions, model)
496
+
497
+ # Original fc (op_id=0) should be bumped to op_id=1.
498
+ self.assertEqual(
499
+ self._transformation_performer._original_op_id_map[0],
500
+ [1],
501
+ )
502
+ # New quantize added at op_id=0, dequantize added at op_id=1.
503
+ expected_added_op_id_map = [0, 2]
504
+ self.assertEqual(
505
+ self._transformation_performer._added_op_id_map[0],
506
+ expected_added_op_id_map,
507
+ )
508
+ # Quantize, fully_connected, dequantize.
509
+ expected_builtin_op_order = [114, 9, 6]
510
+ for i, op in enumerate(model.subgraphs[0].operators):
511
+ op_code = model.operatorCodes[op.opcodeIndex].builtinCode
512
+ self.assertEqual(op_code, expected_builtin_op_order[i])
513
+
392
514
 
393
515
  if __name__ == "__main__":
394
516
  googletest.main()
@@ -34,9 +34,10 @@ def duplicate_buffer(
34
34
  f' Tensor {tensor_name} is not constant.'
35
35
  )
36
36
 
37
- duplicated_buffer_id = transformation_utils.add_new_constant_buffer(
37
+ duplicated_buffer_id = transformation_utils.get_constant_buffer(
38
38
  data=buffer_data,
39
39
  buffers=transformation_input.buffers,
40
+ force_duplicate_buffer=True,
40
41
  )
41
42
  tensor.buffer = duplicated_buffer_id
42
43
 
@@ -41,6 +41,7 @@ def duplicate_tensor(
41
41
  tensor_shape=tensor.shape,
42
42
  subgraph=subgraph,
43
43
  buffers=transformation_input.buffers,
44
+ force_duplicate_buffer=True,
44
45
  )
45
46
  # Update the tensor name to avoid name collision in case when tensor is
46
47
  # duplicated mulitple times.
@@ -0,0 +1,299 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Hadamard rotation decomposed pattern transformation."""
17
+
18
+ from flatbuffers import flexbuffers
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.transformations import transformation_utils
22
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ def _to_flexbuffer(
26
+ hadamard_size: int,
27
+ random_binary_vector: list[np.int8],
28
+ ) -> bytes:
29
+ """Converts hadamard_size to flexbuffer."""
30
+ fbb = flexbuffers.Builder()
31
+ with fbb.Map():
32
+ fbb.Int('hadamard_size', hadamard_size)
33
+ fbb.VectorFromElements('random_binary_vector', random_binary_vector)
34
+ return fbb.Finish()
35
+
36
+
37
+ def _update_embedding_lookup_consumers(
38
+ transformation: transformation_utils.TransformationInput,
39
+ new_tensor_id: int,
40
+ ) -> bool:
41
+ """Updates the consumers of the embedding lookup op to use the new tensor.
42
+
43
+ Args:
44
+ transformation: The transformation input to update the consumers of.
45
+ new_tensor_id: The new tensor id to use as the input to the embedding lookup
46
+ consumers.
47
+ """
48
+ for consumer in transformation.consumers:
49
+ # If the consumer is a graph output and not an op, we can ignore it here
50
+ # since the graph output will be updated later.
51
+ if consumer == -1:
52
+ continue
53
+ consumer_op = transformation.subgraph.operators[consumer]
54
+ # Find the input that was attached to the insertion point, and replace it
55
+ # with the new tensor.
56
+ for i in range(len(consumer_op.inputs)):
57
+ if consumer_op.inputs[i] == transformation.tensor_id:
58
+ consumer_op.inputs[i] = new_tensor_id
59
+
60
+
61
+ def _update_fully_connected_consumers(
62
+ transformation: transformation_utils.TransformationInput,
63
+ new_tensor_id: int,
64
+ ) -> bool:
65
+ """Updates the fully connected op(s) to use the new tensor.
66
+
67
+ Since the new tensor is inserted to the fully_connected's input, we need to
68
+ scan each consumer (in case of multiple fully_connected ops), and update
69
+ the input tensor to the new tensor.
70
+
71
+ Args:
72
+ transformation: The transformation input to update the consumers of.
73
+ new_tensor_id: The new tensor id to use as the input to the fully connected
74
+ consumers.
75
+
76
+ Returns:
77
+ True if the fully connected op(s) were updated to use the new tensor.
78
+ """
79
+ updated = False
80
+ for consumer in transformation.consumers:
81
+ if (
82
+ transformation_utils.get_schema_op_id(transformation, consumer)
83
+ == schema_py_generated.BuiltinOperator.FULLY_CONNECTED
84
+ ):
85
+ transformation.subgraph.operators[consumer].inputs[0] = new_tensor_id
86
+ updated = True
87
+ return updated
88
+
89
+
90
+ def _make_hadamard_matrix(size: int):
91
+ """Generates a Hadamard matrix of the given size.
92
+
93
+ Args:
94
+ size: The size of the Hadamard matrix. Must be a power of 2. This represents
95
+ a single dimension. E.g. if size is 4, then the Hadamard matrix is a 4x4
96
+ matrix.
97
+
98
+ Returns:
99
+ The Hadamard matrix.
100
+
101
+ Raises:
102
+ ValueError: If the size is not a power of 2.
103
+ """
104
+ if size <= 0 or (size & (size - 1)) != 0:
105
+ raise ValueError('Hadamard matrix size must be a power of 2. ')
106
+ h = h2 = np.array([[1, 1], [1, -1]])
107
+ current_size = 2
108
+ while current_size < size:
109
+ h = np.kron(h, h2)
110
+ current_size *= 2
111
+ return h / np.sqrt(size)
112
+
113
+
114
+ def insert_decomposed_hadamard_rotation(
115
+ transformation_input: transformation_utils.TransformationInput,
116
+ ) -> qtyping.TransformationInfo:
117
+ """Inserts a decomposed pattern of Hadamard rotation on this tensor.
118
+
119
+ This function works for float32 tensors only. Instead of inserting a single
120
+ custom op (aeq.hadamard_rotation), this inserts the mathematical equivalent
121
+ expressed in built-in TFLite ops. The mathematical equivalent is:
122
+ x' = reshape(x, (-1, hadamard_size))
123
+ x' = x' @ H(hadamard_size)
124
+ x' = reshape(x, x.shape)
125
+ where H(n) is a Hadamard matrix of size n.
126
+
127
+ Args:
128
+ transformation_input: The transformation input to insert the ops on.
129
+
130
+ Returns:
131
+ The transformation info of the inserted ops.
132
+
133
+ Raises:
134
+ ValueError: If the transformation input is not a uniform quantization
135
+ transformation.
136
+ ValueError: If the Hadamard quantization params are not set.
137
+ ValueError: If the tensor is not a float32 tensor.
138
+ ValueError: If no supported ops were found as the tensor's producer or
139
+ consumers.
140
+ """
141
+ if not isinstance(
142
+ transformation_input.quant_params, qtyping.UniformQuantParams
143
+ ):
144
+ raise ValueError('Hadamard rotation supports uniform quantization only')
145
+
146
+ if transformation_input.quant_params.hadamard is None:
147
+ raise ValueError(
148
+ 'Hadamard rotation quantization params are not set but op insertion is'
149
+ ' requested.'
150
+ )
151
+
152
+ tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
153
+ if tensor.type != schema_py_generated.TensorType.FLOAT32:
154
+ raise ValueError(
155
+ 'The Hadamard rotation op supports float32 tensors only. Got'
156
+ f' {tensor.type} tensor.'
157
+ )
158
+
159
+ # Insert x' = tfl.reshape to reshape x to (-1, hadamard_size)
160
+ hadamard_size = transformation_input.quant_params.hadamard.hadamard_size
161
+ tensor_size = np.prod(tensor.shape)
162
+ num_hadamard_blocks = tensor_size // hadamard_size
163
+ prerotate_shape = [num_hadamard_blocks, hadamard_size]
164
+ prerotate_shape_tensor_id = transformation_utils.add_new_constant_tensor(
165
+ tensor.name + b'_prerotate_shape',
166
+ np.array(prerotate_shape, dtype=np.int32),
167
+ schema_py_generated.TensorType.INT32,
168
+ transformation_input.subgraph,
169
+ transformation_input.buffers,
170
+ )
171
+ prerotate_reshape_output_tensor_id = (
172
+ transformation_utils.add_new_activation_tensor(
173
+ tensor.name + b'_prerotate_reshaped',
174
+ prerotate_shape,
175
+ schema_py_generated.TensorType.FLOAT32,
176
+ transformation_input.subgraph,
177
+ )
178
+ )
179
+
180
+ prerotate_reshape_op_code_idx = transformation_utils.add_op_code(
181
+ schema_py_generated.BuiltinOperator.RESHAPE,
182
+ transformation_input.op_codes,
183
+ 'RESHAPE',
184
+ )
185
+ prerorate_reshape_op = schema_py_generated.OperatorT()
186
+ prerorate_reshape_op.opcodeIndex = prerotate_reshape_op_code_idx
187
+ prerorate_reshape_op.inputs = [
188
+ transformation_input.tensor_id,
189
+ prerotate_shape_tensor_id,
190
+ ]
191
+ prerorate_reshape_op.outputs = [prerotate_reshape_output_tensor_id]
192
+
193
+ # Generate hadamard_matrix(hadamard_size).
194
+ # We could quantize this to INT4 for better memory efficiency, but for large
195
+ # models the memory overhead is not significant, and floating point
196
+ # computation does seem to result in better accuracy.
197
+ hadamard_matrix = _make_hadamard_matrix(hadamard_size)
198
+ hadamard_matrix_tensor_id = transformation_utils.add_new_constant_tensor(
199
+ tensor.name + b'_hadamard_matrix',
200
+ hadamard_matrix.astype(np.float32),
201
+ schema_py_generated.TensorType.FLOAT32,
202
+ transformation_input.subgraph,
203
+ transformation_input.buffers,
204
+ )
205
+
206
+ # Insert x' = tfl.fully_connected(x', hadamard_matrix)
207
+ fc_output_tensor_id = transformation_utils.add_new_activation_tensor(
208
+ tensor.name + b'_rotated',
209
+ prerotate_shape,
210
+ schema_py_generated.TensorType.FLOAT32,
211
+ transformation_input.subgraph,
212
+ )
213
+
214
+ fc_op_code_idx = transformation_utils.add_op_code(
215
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
216
+ transformation_input.op_codes,
217
+ 'FULLY_CONNECTED',
218
+ )
219
+ fc_op = schema_py_generated.OperatorT()
220
+ fc_op.opcodeIndex = fc_op_code_idx
221
+ fc_op.inputs = [prerotate_reshape_output_tensor_id, hadamard_matrix_tensor_id]
222
+ fc_op.outputs = [fc_output_tensor_id]
223
+ fc_options = schema_py_generated.FullyConnectedOptionsT()
224
+ fc_options.fusedActivationFunction = (
225
+ schema_py_generated.ActivationFunctionType.NONE
226
+ )
227
+ fc_op.builtinOptionsType = (
228
+ schema_py_generated.BuiltinOptions.FullyConnectedOptions
229
+ )
230
+ fc_op.builtinOptions = fc_options
231
+
232
+ # Insert x' = tfl.reshape(x', x.shape)
233
+ post_reshape_op_code_idx = transformation_utils.add_op_code(
234
+ schema_py_generated.BuiltinOperator.RESHAPE,
235
+ transformation_input.op_codes,
236
+ 'RESHAPE',
237
+ )
238
+ post_reshape_op = schema_py_generated.OperatorT()
239
+ post_reshape_op.opcodeIndex = post_reshape_op_code_idx
240
+ post_reshape_shape_tensor_id = transformation_utils.add_new_constant_tensor(
241
+ tensor.name + b'_postrotate_shape',
242
+ np.array(tensor.shape, dtype=np.int32),
243
+ schema_py_generated.TensorType.INT32,
244
+ transformation_input.subgraph,
245
+ transformation_input.buffers,
246
+ )
247
+
248
+ post_reshape_output_tensor_id = (
249
+ transformation_utils.add_new_activation_tensor(
250
+ tensor.name + b'_postrotate_reshaped',
251
+ tensor.shape,
252
+ schema_py_generated.TensorType.FLOAT32,
253
+ transformation_input.subgraph,
254
+ )
255
+ )
256
+ post_reshape_op.inputs = [
257
+ fc_output_tensor_id,
258
+ post_reshape_shape_tensor_id,
259
+ ]
260
+ post_reshape_op.outputs = [post_reshape_output_tensor_id]
261
+
262
+ # Update the users of this tensor to use the new tensor.
263
+ if (
264
+ transformation_utils.get_producer_schema_op_id(transformation_input)
265
+ == schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
266
+ ):
267
+ _update_embedding_lookup_consumers(
268
+ transformation_input, post_reshape_output_tensor_id
269
+ )
270
+ elif not _update_fully_connected_consumers(
271
+ transformation_input, post_reshape_output_tensor_id
272
+ ):
273
+ raise ValueError(
274
+ 'The Hadamard rotation op supports embedding lookup and fully connected'
275
+ ' ops only, but no such ops were found.'
276
+ )
277
+
278
+ # If the tensor is a graph output, we need to replace the tensor with the
279
+ # new tensor.
280
+ for i, output in enumerate(transformation_input.subgraph.outputs):
281
+ if output == transformation_input.tensor_id:
282
+ transformation_input.subgraph.outputs[i] = post_reshape_output_tensor_id
283
+
284
+ # Find the actual insertion point. The insertion point should be after the
285
+ # producer op and before the first consumer op. The max() operation ensures
286
+ # that we're not using -1 as the insertion point.
287
+ first_consumer_id = min(transformation_input.consumers)
288
+ op_id = max(transformation_input.producer + 1, first_consumer_id)
289
+
290
+ # Insert the new ops in the correct order.
291
+ transformation_input.subgraph.operators.insert(op_id, prerorate_reshape_op)
292
+ transformation_input.subgraph.operators.insert(op_id + 1, fc_op)
293
+ transformation_input.subgraph.operators.insert(op_id + 2, post_reshape_op)
294
+
295
+ return qtyping.TransformationInfo(
296
+ op_id=op_id,
297
+ num_ops_added=3,
298
+ output_tensor_id=post_reshape_output_tensor_id,
299
+ )