ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,244 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Test insertion of the Decomposed Hadamard rotation ops."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.transformations import insert_decomposed_hadamard_rotation
23
+ from ai_edge_quantizer.transformations import transformation_utils
24
+ from ai_edge_quantizer.utils import test_utils
25
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
27
+
28
+ _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('..')
29
+
30
+
31
+ class InsertDecomposedHadamardRotationFullyConnectedTest(googletest.TestCase):
32
+
33
+ def setUp(self):
34
+ super().setUp()
35
+ model_path = os.path.join(
36
+ _TEST_DATA_PREFIX_PATH, 'tests/models/single_fc_bias.tflite'
37
+ )
38
+ self.model = tfl_flatbuffer_utils.read_model(model_path)
39
+ self.params = qtyping.UniformQuantParams(
40
+ num_bits=8,
41
+ quantized_dimension=None,
42
+ scale=np.ones(1),
43
+ zero_point=np.zeros(1),
44
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
45
+ random_binary_vector=np.ones(1),
46
+ hadamard_size=2,
47
+ ),
48
+ )
49
+
50
+ def test_raise_unsupported_qparams(self):
51
+ with self.assertRaisesWithPredicateMatch(
52
+ ValueError, lambda err: 'uniform quantization' in str(err)
53
+ ):
54
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
55
+ transformation_utils.TransformationInput(
56
+ tensor_id=0,
57
+ op_codes=self.model.operatorCodes,
58
+ buffers=self.model.buffers,
59
+ subgraph=self.model.subgraphs[0],
60
+ producer=-1,
61
+ consumers=[-1],
62
+ quant_params=qtyping.NonLinearQuantParams(
63
+ num_bits=16, quantized_data=None
64
+ ),
65
+ )
66
+ )
67
+
68
+ def test_raise_missing_hadamard_data(self):
69
+ with self.assertRaisesWithPredicateMatch(
70
+ ValueError, lambda err: 'quantization params are not set' in str(err)
71
+ ):
72
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
73
+ transformation_utils.TransformationInput(
74
+ tensor_id=0,
75
+ op_codes=self.model.operatorCodes,
76
+ buffers=self.model.buffers,
77
+ subgraph=self.model.subgraphs[0],
78
+ producer=-1,
79
+ consumers=[-1],
80
+ quant_params=qtyping.UniformQuantParams(
81
+ num_bits=8,
82
+ quantized_dimension=None,
83
+ scale=np.ones(1),
84
+ zero_point=np.zeros(1),
85
+ ),
86
+ )
87
+ )
88
+
89
+ def test_raise_non_float32_tensor(self):
90
+ self.model.subgraphs[0].tensors[
91
+ 0
92
+ ].type = schema_py_generated.TensorType.INT32
93
+ with self.assertRaisesWithPredicateMatch(
94
+ ValueError, lambda err: 'float32 tensors' in str(err)
95
+ ):
96
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
97
+ transformation_utils.TransformationInput(
98
+ tensor_id=0,
99
+ op_codes=self.model.operatorCodes,
100
+ buffers=self.model.buffers,
101
+ subgraph=self.model.subgraphs[0],
102
+ producer=-1,
103
+ consumers=[-1],
104
+ quant_params=self.params,
105
+ ),
106
+ )
107
+
108
+ def test_insert_decomposed_ops(self):
109
+ # Insert Decomposed Hadamard ops before fully_connected
110
+ info = (
111
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
112
+ transformation_utils.TransformationInput(
113
+ tensor_id=0,
114
+ op_codes=self.model.operatorCodes,
115
+ buffers=self.model.buffers,
116
+ subgraph=self.model.subgraphs[0],
117
+ producer=-1,
118
+ consumers=[0], # Consumer is the FC op
119
+ quant_params=self.params,
120
+ )
121
+ )
122
+ )
123
+ subgraph = self.model.subgraphs[0]
124
+ self.assertEqual(info.op_id, 0)
125
+ self.assertEqual(info.num_ops_added, 3)
126
+ # Model had 4 tensors, added 6 tensors (3 activations 3 constants).
127
+ self.assertEqual(info.output_tensor_id, 9)
128
+ self.assertLen(subgraph.tensors, 10)
129
+ # Model had 1 op code, added RESHAPE and FC.
130
+ self.assertLen(self.model.operatorCodes, 3)
131
+ self.assertEqual(
132
+ self.model.operatorCodes[1].builtinCode,
133
+ schema_py_generated.BuiltinOperator.RESHAPE,
134
+ )
135
+ self.assertEqual(
136
+ self.model.operatorCodes[2].builtinCode,
137
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
138
+ )
139
+
140
+ # Op 0: RESHAPE
141
+ reshape_op = subgraph.operators[0]
142
+ self.assertEqual(
143
+ self.model.operatorCodes[reshape_op.opcodeIndex].builtinCode,
144
+ schema_py_generated.BuiltinOperator.RESHAPE,
145
+ )
146
+ self.assertEqual(reshape_op.inputs[0], 0) # Graph input
147
+ self.assertEqual(reshape_op.outputs[0], 5) # Reshape output
148
+
149
+ # Op 1: FULLY_CONNECTED
150
+ fc_op = subgraph.operators[1]
151
+ self.assertEqual(
152
+ self.model.operatorCodes[fc_op.opcodeIndex].builtinCode,
153
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
154
+ )
155
+ self.assertEqual(fc_op.inputs[0], 5) # Reshape output
156
+ self.assertEqual(fc_op.inputs[1], 6) # Hadamard matrix tensor
157
+ self.assertEqual(fc_op.outputs[0], 7) # FC output
158
+
159
+ # Op 2: RESHAPE (post)
160
+ post_reshape_op = subgraph.operators[2]
161
+ self.assertEqual(
162
+ self.model.operatorCodes[post_reshape_op.opcodeIndex].builtinCode,
163
+ schema_py_generated.BuiltinOperator.RESHAPE,
164
+ )
165
+ self.assertEqual(post_reshape_op.inputs[0], 7) # FC output
166
+ self.assertEqual(post_reshape_op.outputs[0], 9) # Post Reshape output
167
+
168
+ # Op 3: Original FULLY_CONNECTED
169
+ orig_fc_op = subgraph.operators[3]
170
+ self.assertEqual(
171
+ self.model.operatorCodes[orig_fc_op.opcodeIndex].builtinCode,
172
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
173
+ )
174
+ # Input to the original FC is the post reshape output
175
+ self.assertEqual(orig_fc_op.inputs[0], 9)
176
+
177
+
178
+ class InsertDecomposedHadamardRotationEmbeddingLookupTest(googletest.TestCase):
179
+
180
+ def setUp(self):
181
+ super().setUp()
182
+ model_path = os.path.join(
183
+ _TEST_DATA_PREFIX_PATH, 'tests/models/embedding_lookup.tflite'
184
+ )
185
+ self.model = tfl_flatbuffer_utils.read_model(model_path)
186
+ self.params = qtyping.UniformQuantParams(
187
+ num_bits=8,
188
+ quantized_dimension=None,
189
+ scale=np.ones(1),
190
+ zero_point=np.zeros(1),
191
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
192
+ random_binary_vector=np.ones(1),
193
+ hadamard_size=2,
194
+ ),
195
+ )
196
+
197
+ def test_insert_decomposed_ops(self):
198
+ # Insert Decomposed Hadamard ops after embedding_lookup
199
+ info = (
200
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
201
+ transformation_utils.TransformationInput(
202
+ tensor_id=2, # Output of embedding_lookup
203
+ op_codes=self.model.operatorCodes,
204
+ buffers=self.model.buffers,
205
+ subgraph=self.model.subgraphs[0],
206
+ producer=0,
207
+ consumers=[-1], # Output is a graph output
208
+ quant_params=self.params,
209
+ )
210
+ )
211
+ )
212
+ subgraph = self.model.subgraphs[0]
213
+ self.assertEqual(info.op_id, 1)
214
+ self.assertEqual(info.num_ops_added, 3)
215
+ # Model had 3 tensors, added 6 (3 activations 3 constants).
216
+ self.assertEqual(info.output_tensor_id, 8)
217
+ self.assertLen(subgraph.tensors, 9)
218
+ # Model had 1 op code, added RESHAPE and FC.
219
+ self.assertLen(self.model.operatorCodes, 3)
220
+
221
+ # Op 0: EMBEDDING_LOOKUP (Original)
222
+ # Op 1: RESHAPE
223
+ reshape_op = subgraph.operators[1]
224
+ self.assertEqual(reshape_op.inputs[0], 2) # Embedding lookup output
225
+ self.assertEqual(reshape_op.outputs[0], 4)
226
+
227
+ # Op 2: FULLY_CONNECTED
228
+ fc_op = subgraph.operators[2]
229
+ self.assertEqual(fc_op.inputs[0], 4)
230
+ self.assertEqual(fc_op.inputs[1], 5) # Hadamard matrix
231
+ self.assertEqual(fc_op.outputs[0], 6)
232
+
233
+ # Op 3: RESHAPE (post)
234
+ post_reshape_op = subgraph.operators[3]
235
+ self.assertEqual(post_reshape_op.inputs[0], 6)
236
+ self.assertEqual(post_reshape_op.outputs[0], 8)
237
+
238
+ # Check graph output
239
+ self.assertIn(8, subgraph.outputs)
240
+ self.assertNotIn(2, subgraph.outputs)
241
+
242
+
243
+ if __name__ == '__main__':
244
+ googletest.main()
@@ -0,0 +1,186 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Hadamard rotation pattern transformation."""
17
+
18
+ from flatbuffers import flexbuffers
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.transformations import transformation_utils
22
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ def _to_flexbuffer(
26
+ hadamard_size: int,
27
+ random_binary_vector: list[np.int8],
28
+ ) -> bytes:
29
+ """Converts hadamard_size to flexbuffer."""
30
+ fbb = flexbuffers.Builder()
31
+ with fbb.Map():
32
+ fbb.Int('hadamard_size', hadamard_size)
33
+ fbb.VectorFromElements('random_binary_vector', random_binary_vector)
34
+ return fbb.Finish()
35
+
36
+
37
+ def _update_embedding_lookup_consumers(
38
+ transformation: transformation_utils.TransformationInput,
39
+ new_tensor_id: int,
40
+ ) -> bool:
41
+ """Updates the consumers of the embedding lookup op to use the new tensor.
42
+
43
+ Args:
44
+ transformation: The transformation input to update the consumers of.
45
+ new_tensor_id: The new tensor id to use as the input to the embedding lookup
46
+ consumers.
47
+ """
48
+ for consumer in transformation.consumers:
49
+ # If the consumer is a graph output and not an op, we can ignore it here
50
+ # since the graph output will be updated later.
51
+ if consumer == -1:
52
+ continue
53
+ consumer_op = transformation.subgraph.operators[consumer]
54
+ # Find the input that was attached to the insertion point, and replace it
55
+ # with the new tensor.
56
+ for i in range(len(consumer_op.inputs)):
57
+ if consumer_op.inputs[i] == transformation.tensor_id:
58
+ consumer_op.inputs[i] = new_tensor_id
59
+
60
+
61
+ def _update_fully_connected_consumers(
62
+ transformation: transformation_utils.TransformationInput,
63
+ new_tensor_id: int,
64
+ ) -> bool:
65
+ """Updates the fully connected op(s) to use the new tensor.
66
+
67
+ Since the new tensor is inserted to the fully_connected's input, we need to
68
+ scan each consumer (in case of multiple fully_connected ops), and update
69
+ the input tensor to the new tensor.
70
+
71
+ Args:
72
+ transformation: The transformation input to update the consumers of.
73
+ new_tensor_id: The new tensor id to use as the input to the fully connected
74
+ consumers.
75
+
76
+ Returns:
77
+ True if the fully connected op(s) were updated to use the new tensor.
78
+ """
79
+ updated = False
80
+ for consumer in transformation.consumers:
81
+ if (
82
+ transformation_utils.get_schema_op_id(transformation, consumer)
83
+ == schema_py_generated.BuiltinOperator.FULLY_CONNECTED
84
+ ):
85
+ transformation.subgraph.operators[consumer].inputs[0] = new_tensor_id
86
+ updated = True
87
+ return updated
88
+
89
+
90
+ def insert_hadamard_rotation(
91
+ transformation_input: transformation_utils.TransformationInput,
92
+ ) -> qtyping.TransformationInfo:
93
+ """Inserts a custom aeq.hadamard_rotation op on this tensor.
94
+
95
+ This function works for float32 tensors only.
96
+
97
+ Args:
98
+ transformation_input: The transformation input to insert the custom op on.
99
+
100
+ Returns:
101
+ The transformation info of the inserted custom op.
102
+
103
+ Raises:
104
+ ValueError: If the transformation input is not a uniform quantization
105
+ transformation.
106
+ ValueError: If the Hadamard quantization params are not set.
107
+ ValueError: If the tensor is not a float32 tensor.
108
+ ValueError: If no supported ops were found as the tensor's producer or
109
+ consumers.
110
+ """
111
+ if not isinstance(
112
+ transformation_input.quant_params, qtyping.UniformQuantParams
113
+ ):
114
+ raise ValueError('Hadamard rotation supports uniform quantization only')
115
+
116
+ if transformation_input.quant_params.hadamard is None:
117
+ raise ValueError(
118
+ 'Hadamard rotation quantization params are not set but op insertion is'
119
+ ' requested.'
120
+ )
121
+
122
+ tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
123
+ if tensor.type != schema_py_generated.TensorType.FLOAT32:
124
+ raise ValueError(
125
+ 'The Hadamard rotation op supports float32 tensors only. Got'
126
+ f' {tensor.type} tensor.'
127
+ )
128
+
129
+ # Create new custom op with the current tensor as input and a new activation
130
+ # tensor as output.
131
+ custom_op_code_idx = transformation_utils.add_op_code(
132
+ schema_py_generated.BuiltinOperator.CUSTOM,
133
+ transformation_input.op_codes,
134
+ 'aeq.hadamard_rotation',
135
+ )
136
+ custom_op = schema_py_generated.OperatorT()
137
+ custom_op.opcodeIndex = custom_op_code_idx
138
+ custom_op.inputs = [transformation_input.tensor_id]
139
+ custom_op.customOptions = _to_flexbuffer(
140
+ transformation_input.quant_params.hadamard.hadamard_size,
141
+ transformation_input.quant_params.hadamard.random_binary_vector.tolist(),
142
+ )
143
+ new_tensor_id = transformation_utils.add_new_activation_tensor(
144
+ tensor.name + b'_rotated',
145
+ tensor.shapeSignature
146
+ if tensor.shapeSignature is not None
147
+ else tensor.shape,
148
+ schema_py_generated.TensorType.FLOAT32,
149
+ transformation_input.subgraph,
150
+ )
151
+ custom_op.outputs = [new_tensor_id]
152
+
153
+ # Update the users of this tensor to use the new tensor.
154
+ if (
155
+ transformation_utils.get_producer_schema_op_id(transformation_input)
156
+ == schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
157
+ ):
158
+ _update_embedding_lookup_consumers(transformation_input, new_tensor_id)
159
+ elif not _update_fully_connected_consumers(
160
+ transformation_input, new_tensor_id
161
+ ):
162
+ raise ValueError(
163
+ 'The Hadamard rotation op supports embedding lookup and fully connected'
164
+ ' ops only, but no such ops were found.'
165
+ )
166
+
167
+ # If the tensor is a graph output, we need to replace the tensor with the
168
+ # new tensor.
169
+ for i, output in enumerate(transformation_input.subgraph.outputs):
170
+ if output == transformation_input.tensor_id:
171
+ transformation_input.subgraph.outputs[i] = new_tensor_id
172
+
173
+ # Find the actual insertion point. The insertion point should be after the
174
+ # producer op and before the first consumer op. The max() operation ensures
175
+ # that we're not using -1 as the insertion point.
176
+ first_consumer_id = min(transformation_input.consumers)
177
+ op_id = max(transformation_input.producer + 1, first_consumer_id)
178
+
179
+ # Insert the custom op.
180
+ transformation_input.subgraph.operators.insert(op_id, custom_op)
181
+
182
+ return qtyping.TransformationInfo(
183
+ op_id=op_id,
184
+ num_ops_added=1,
185
+ output_tensor_id=new_tensor_id,
186
+ )
@@ -0,0 +1,200 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Test insertion of the Hadamard rotation custom op."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.transformations import insert_hadamard_rotation
23
+ from ai_edge_quantizer.transformations import transformation_utils
24
+ from ai_edge_quantizer.utils import test_utils
25
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
27
+
28
+ _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('..')
29
+
30
+
31
+ class InsertHadamardRotationFullyConnectedTest(googletest.TestCase):
32
+
33
+ def setUp(self):
34
+ super().setUp()
35
+ model_path = os.path.join(
36
+ _TEST_DATA_PREFIX_PATH, 'tests/models/single_fc_bias.tflite'
37
+ )
38
+ self.model = tfl_flatbuffer_utils.read_model(model_path)
39
+ self.params = qtyping.UniformQuantParams(
40
+ num_bits=8,
41
+ quantized_dimension=None,
42
+ scale=np.ones(1),
43
+ zero_point=np.zeros(1),
44
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
45
+ random_binary_vector=np.ones(1),
46
+ hadamard_size=2,
47
+ ),
48
+ )
49
+
50
+ def test_raise_unsupported_qparams(self):
51
+ with self.assertRaisesWithPredicateMatch(
52
+ ValueError, lambda err: 'uniform quantization' in str(err)
53
+ ):
54
+ insert_hadamard_rotation.insert_hadamard_rotation(
55
+ transformation_utils.TransformationInput(
56
+ tensor_id=0,
57
+ op_codes=self.model.operatorCodes,
58
+ buffers=self.model.buffers,
59
+ subgraph=self.model.subgraphs[0],
60
+ producer=-1,
61
+ consumers=[-1],
62
+ quant_params=qtyping.NonLinearQuantParams(
63
+ num_bits=16, quantized_data=None
64
+ ),
65
+ )
66
+ )
67
+
68
+ def test_raise_missing_hadamard_data(self):
69
+ with self.assertRaisesWithPredicateMatch(
70
+ ValueError, lambda err: 'quantization params are not set' in str(err)
71
+ ):
72
+ insert_hadamard_rotation.insert_hadamard_rotation(
73
+ transformation_utils.TransformationInput(
74
+ tensor_id=0,
75
+ op_codes=self.model.operatorCodes,
76
+ buffers=self.model.buffers,
77
+ subgraph=self.model.subgraphs[0],
78
+ producer=-1,
79
+ consumers=[-1],
80
+ quant_params=qtyping.UniformQuantParams(
81
+ num_bits=8,
82
+ quantized_dimension=None,
83
+ scale=np.ones(1),
84
+ zero_point=np.zeros(1),
85
+ ),
86
+ )
87
+ )
88
+
89
+ def test_raise_non_float32_tensor(self):
90
+ self.model.subgraphs[0].tensors[
91
+ 0
92
+ ].type = schema_py_generated.TensorType.INT32
93
+ with self.assertRaisesWithPredicateMatch(
94
+ ValueError, lambda err: 'float32 tensors' in str(err)
95
+ ):
96
+ insert_hadamard_rotation.insert_hadamard_rotation(
97
+ transformation_utils.TransformationInput(
98
+ tensor_id=0,
99
+ op_codes=self.model.operatorCodes,
100
+ buffers=self.model.buffers,
101
+ subgraph=self.model.subgraphs[0],
102
+ producer=-1,
103
+ consumers=[-1],
104
+ quant_params=self.params,
105
+ ),
106
+ )
107
+
108
+ def test_insert_single_custom_op(self):
109
+ # Insert aeq.hadamard_rotation before fully_connected
110
+ info = insert_hadamard_rotation.insert_hadamard_rotation(
111
+ transformation_utils.TransformationInput(
112
+ tensor_id=0,
113
+ op_codes=self.model.operatorCodes,
114
+ buffers=self.model.buffers,
115
+ subgraph=self.model.subgraphs[0],
116
+ producer=-1,
117
+ consumers=[-1],
118
+ quant_params=self.params,
119
+ )
120
+ )
121
+ subgraph = self.model.subgraphs[0]
122
+ self.assertEqual(info.op_id, 0)
123
+ self.assertEqual(info.num_ops_added, 1)
124
+ # Model had 4 tensors, added 1.
125
+ self.assertEqual(info.output_tensor_id, 4)
126
+ self.assertLen(subgraph.tensors, 5)
127
+ # Model had 1 op, added a new one.
128
+ self.assertLen(self.model.operatorCodes, 2)
129
+ self.assertEqual(
130
+ self.model.operatorCodes[1].builtinCode,
131
+ schema_py_generated.BuiltinOperator.CUSTOM,
132
+ )
133
+ # First op is now the custom op, precedes fully_connected.
134
+ self.assertEqual(
135
+ self.model.operatorCodes[subgraph.operators[0].opcodeIndex].builtinCode,
136
+ schema_py_generated.BuiltinOperator.CUSTOM,
137
+ )
138
+ # Input to the custom op is graph input
139
+ self.assertEqual(subgraph.operators[0].inputs[0], 0)
140
+ # Input to the FC is the custom op output
141
+ self.assertEqual(subgraph.operators[1].inputs[0], 4)
142
+
143
+
144
+ class InsertHadamardRotationEmbeddingLookupTest(googletest.TestCase):
145
+
146
+ def setUp(self):
147
+ super().setUp()
148
+ model_path = os.path.join(
149
+ _TEST_DATA_PREFIX_PATH, 'tests/models/embedding_lookup.tflite'
150
+ )
151
+ self.model = tfl_flatbuffer_utils.read_model(model_path)
152
+ self.params = qtyping.UniformQuantParams(
153
+ num_bits=8,
154
+ quantized_dimension=None,
155
+ scale=np.ones(1),
156
+ zero_point=np.zeros(1),
157
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
158
+ random_binary_vector=np.ones(1),
159
+ hadamard_size=2,
160
+ ),
161
+ )
162
+
163
+ def test_insert_single_custom_op(self):
164
+ # Insert aeq.hadamard_rotation after embedding_lookup
165
+ info = insert_hadamard_rotation.insert_hadamard_rotation(
166
+ transformation_utils.TransformationInput(
167
+ tensor_id=2,
168
+ op_codes=self.model.operatorCodes,
169
+ buffers=self.model.buffers,
170
+ subgraph=self.model.subgraphs[0],
171
+ producer=0,
172
+ consumers=[-1],
173
+ quant_params=self.params,
174
+ )
175
+ )
176
+ subgraph = self.model.subgraphs[0]
177
+ self.assertEqual(info.op_id, 1)
178
+ self.assertEqual(info.num_ops_added, 1)
179
+ # Model had 3 tensors, added 1.
180
+ self.assertEqual(info.output_tensor_id, 3)
181
+ self.assertLen(subgraph.tensors, 4)
182
+ # Model had 1 op, added a new one.
183
+ self.assertLen(self.model.operatorCodes, 2)
184
+ self.assertEqual(
185
+ self.model.operatorCodes[1].builtinCode,
186
+ schema_py_generated.BuiltinOperator.CUSTOM,
187
+ )
188
+ # Second op is now the custom op, after embedding_lookup.
189
+ self.assertEqual(
190
+ self.model.operatorCodes[subgraph.operators[1].opcodeIndex].builtinCode,
191
+ schema_py_generated.BuiltinOperator.CUSTOM,
192
+ )
193
+ # Input to the custom op is embedding's output
194
+ self.assertEqual(subgraph.operators[1].inputs[0], 2)
195
+ # Custom op's output is the new tensor
196
+ self.assertEqual(subgraph.operators[1].outputs[0], 3)
197
+
198
+
199
+ if __name__ == '__main__':
200
+ googletest.main()
@@ -68,29 +68,6 @@ def nonlinear_quant_params_to_tflite_type(
68
68
  raise ValueError(f"Unsupported nonlinear params: {bitwidth}")
69
69
 
70
70
 
71
- def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
72
- """Pack the data to the corresponding bit width.
73
-
74
- Currently only support 4 bits. If no packing is needed, the original data is
75
- returned.
76
-
77
- Args:
78
- bitwidth: Bit width from NonLinearQuantParams.
79
- flattened_data: The data to be packed.
80
-
81
- Returns:
82
- Packed data.
83
- """
84
- if bitwidth == 4:
85
- even_data = flattened_data[::2] & 0x0F
86
- odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
87
- if odd_data.shape[0] == even_data.shape[0] - 1:
88
- odd_data = np.pad(odd_data, (0, 1), constant_values=0)
89
- return np.bitwise_or(even_data, odd_data)
90
- else:
91
- return flattened_data
92
-
93
-
94
71
  def _perform_channelwise_quantization(
95
72
  transformation_input: transformation_utils.TransformationInput,
96
73
  ) -> schema_py_generated.QuantizationParametersT():
@@ -154,9 +131,14 @@ def _perform_blockwise_quantization(
154
131
  transformation_input.buffers,
155
132
  )
156
133
  blockwise_details.scales = scale_tensor_id
157
- blockwise_details.blockSize = transformation_input.quant_params.block_size
134
+ # Blockwise quantization does not support zero point yet, so this points to
135
+ # a -1 buffer index.
158
136
  # TODO: b/404909258 - Add optional zero point to blockwise quantization.
137
+ blockwise_details.zeroPoints = -1
138
+ blockwise_details.blockSize = transformation_input.quant_params.block_size
159
139
  flatbuffer_quantization.details = blockwise_details
140
+ # TODO: b/443830202 - Hardcoding to 0 for now.
141
+ flatbuffer_quantization.quantizedDimension = 0
160
142
  return flatbuffer_quantization
161
143
 
162
144
 
@@ -180,14 +162,17 @@ def quantize_tensor(
180
162
  # is not provided.
181
163
  if tensor.buffer:
182
164
  if transformation_input.quant_params.quantized_data is not None:
183
- transformation_input.buffers[tensor.buffer].data = _pack_data(
184
- transformation_input.quant_params.num_bits,
185
- np.frombuffer(
186
- cast(
187
- np.ndarray, transformation_input.quant_params.quantized_data
188
- ).tobytes(),
189
- dtype=np.uint8,
190
- ).flatten(),
165
+ transformation_input.buffers[tensor.buffer].data = (
166
+ transformation_utils.pack_data(
167
+ transformation_input.quant_params.num_bits,
168
+ np.frombuffer(
169
+ cast(
170
+ np.ndarray,
171
+ transformation_input.quant_params.quantized_data,
172
+ ).tobytes(),
173
+ dtype=np.uint8,
174
+ ).flatten(),
175
+ )
191
176
  )
192
177
 
193
178
  if isinstance(transformation_input.quant_params, qtyping.UniformQuantParams):