ai-edge-quantizer-nightly 0.1.0.dev20250511__py3-none-any.whl → 0.1.0.dev20250513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,7 @@ from ai_edge_quantizer import qtyping
24
24
  from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
25
25
  from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
26
26
  from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
27
+ from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
27
28
  from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
28
29
  from ai_edge_quantizer.algorithms.uniform_quantize import octav
29
30
 
@@ -58,6 +59,8 @@ class AlgorithmName(str, enum.Enum):
58
59
  FLOAT_CASTING = float_casting.ALGORITHM_KEY
59
60
  DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
60
61
  OCTAV = octav.ALGORITHM_KEY
62
+ HADAMARD_ROTATION = hadamard_rotation.ALGORITHM_KEY
63
+
61
64
 
62
65
  ### MIN/MAX_UNIFORM_QUANT ###
63
66
 
@@ -104,6 +107,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
104
107
  common_quantize.materialize_dynamic_update_slice
105
108
  ),
106
109
  _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
110
+ _TFLOpName.PAD: common_quantize.materialize_pad,
107
111
  }
108
112
  for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
109
113
  register_quantized_op(
@@ -237,6 +241,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
237
241
  common_quantize.materialize_dynamic_update_slice
238
242
  ),
239
243
  _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
244
+ _TFLOpName.PAD: common_quantize.materialize_pad,
240
245
  })
241
246
 
242
247
  for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
@@ -250,3 +255,32 @@ for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
250
255
  octav.get_tensor_quant_params,
251
256
  ),
252
257
  )
258
+
259
+ # Register the Hadamard Rotation algorithm.
260
+ register_op_quant_config_validation_func(
261
+ AlgorithmName.HADAMARD_ROTATION,
262
+ common_quantize.check_op_quantization_config,
263
+ )
264
+
265
+ # Register a config check policy for the Hadamard Rotation algorithm.
266
+ register_config_check_policy_func(
267
+ AlgorithmName.HADAMARD_ROTATION,
268
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
269
+ )
270
+
271
+ # Register specialized hadamard rotation materialize functions.
272
+ _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
273
+ _TFLOpName.FULLY_CONNECTED: hadamard_rotation.materialize_fully_connected,
274
+ _TFLOpName.EMBEDDING_LOOKUP: hadamard_rotation.materialize_embedding_lookup,
275
+ })
276
+ for (
277
+ op_name,
278
+ materialize_func,
279
+ ) in _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
280
+ register_quantized_op(
281
+ AlgorithmName.HADAMARD_ROTATION,
282
+ op_name,
283
+ naive_min_max_quantize.init_qsvs,
284
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
285
+ materialize_func=materialize_func,
286
+ )
@@ -680,6 +680,23 @@ def materialize_split(
680
680
  )
681
681
 
682
682
 
683
+ def materialize_pad(
684
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
685
+ op_info: qtyping.OpInfo,
686
+ graph_info: qtyping.GraphInfo,
687
+ tensor_name_to_qsv: dict[str, Any],
688
+ ) -> list[qtyping.TensorTransformationParams]:
689
+ """Materialize tensors in tfl.pad."""
690
+ return common_utils.materialize_standard_op(
691
+ op_info,
692
+ graph_info,
693
+ tensor_name_to_qsv,
694
+ get_tensor_quant_params_fn,
695
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
696
+ inputs_to_ignore=[1], # Padding value does not need to be quantized.
697
+ )
698
+
699
+
683
700
  def _get_tensor_shape_for_blockwise(
684
701
  tensor_shape: Sequence[int], quantized_dim: int, block_size: int
685
702
  ) -> list[int]:
@@ -0,0 +1,352 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements the Hadamard Rotation quantization."""
17
+
18
+ from typing import Any, Optional
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.algorithms.uniform_quantize import octav
22
+ from ai_edge_quantizer.algorithms.utils import common_utils
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
+
25
+
26
+ ALGORITHM_KEY = "HADAMARD_ROTATION"
27
+
28
+
29
+ def _make_hadamard_matrix(size: int) -> np.ndarray:
30
+ """Generates a Hadamard matrix of the given size.
31
+
32
+ Args:
33
+ size: The size of the Hadamard matrix. Must be a power of 2. This
34
+ represents a single dimension. E.g. if size is 4, then the Hadamard matrix
35
+ is a 4x4 matrix.
36
+
37
+ Returns:
38
+ The Hadamard matrix.
39
+
40
+ Raises:
41
+ ValueError: If the size is not a power of 2.
42
+ """
43
+ if size <= 0 or (size & (size - 1)) != 0:
44
+ raise ValueError("Hadamard matrix size must be a power of 2. ")
45
+ h = h2 = np.array([[1, 1], [1, -1]])
46
+ current_size = 2
47
+ while current_size < size:
48
+ h = np.kron(h, h2)
49
+ current_size *= 2
50
+ return h / np.sqrt(size)
51
+
52
+
53
+ def _rotate_with_diagonal_hadamard(
54
+ tensor_content: np.ndarray,
55
+ axis: int,
56
+ ):
57
+ """Quantizes the given float array using the diagonal Hadamard algorithm.
58
+
59
+ Args:
60
+ tensor_content: The float array to quantize.
61
+ axis: The axis of the tensor to quantize.
62
+
63
+ Returns:
64
+ A tuple containing the quantized array and the recovered array.
65
+
66
+ Raises:
67
+ ValueError: If the axis is not 1. To support other axes, please add
68
+ support to the matrix multiplication.
69
+ """
70
+ if axis != 1:
71
+ raise ValueError(
72
+ "Hadamard rotation is only supported for 2D tensors with quantized"
73
+ " dimension 0."
74
+ )
75
+
76
+ # Use the largest power of 2 that is a factor of the dimension and then
77
+ # tile this Hadamard matrix along the diagonal. 2**30 is just a large power
78
+ # of 2 to calculate this factor.
79
+ hadamard_size = np.gcd(tensor_content.shape[axis], 2 ** 30)
80
+ diagonal_size = tensor_content.shape[axis] // hadamard_size
81
+ random_vector = np.ones(hadamard_size, dtype=np.int8)
82
+
83
+ # Use a canonical Hadamard matrix.
84
+ hadamard = _make_hadamard_matrix(hadamard_size)
85
+ hadamard_diagonal = np.kron(np.eye(diagonal_size), hadamard)
86
+ w_rotated = np.einsum("ij,aj->ai", hadamard_diagonal, tensor_content)
87
+ return w_rotated, hadamard_size, random_vector
88
+
89
+
90
+ def get_tensor_quant_params(
91
+ op_info: qtyping.OpInfo,
92
+ tensor_quant_config: qtyping.TensorQuantizationConfig,
93
+ tensor_content: Optional[np.ndarray] = None,
94
+ tensor_qsv: Optional[dict[str, Any]] = None,
95
+ ) -> qtyping.UniformQuantParams:
96
+ """Returns the quantization parameters for a tensor.
97
+
98
+ This function will rotate the tensor with a Hadamard matrix and then
99
+ quantize it with OCTAV.
100
+
101
+ Args:
102
+ op_info: Aggregated information about the op (e.g., quantization config).
103
+ tensor_quant_config: The quantization config for the tensor.
104
+ tensor_content: The content of the tensor. When None, it means the tensor is
105
+ not a weight tensor (e.g. static quantization).
106
+ tensor_qsv: A dictionary containing the min/max of the tensor.
107
+
108
+ Raises:
109
+ ValueError: If the blockwise quantization is requested.
110
+ ValueError: If the asymmetric quantization is requested.
111
+ ValueError: `tensor_qsv` must contain min/max values, or `tensor_content`
112
+ must be provided so that they can be inferred.
113
+ """
114
+ if tensor_content is None:
115
+ raise ValueError("Hadamard rotation is only supported for weight tensors.")
116
+
117
+ if tensor_qsv is not None:
118
+ raise ValueError(
119
+ "Hadamard rotation is not supported for static quantization."
120
+ )
121
+
122
+ if tensor_content.ndim != 2:
123
+ raise ValueError("Hadamard rotation is only supported for 2D tensors.")
124
+
125
+ if tensor_quant_config.granularity != qtyping.QuantGranularity.CHANNELWISE:
126
+ raise ValueError(
127
+ "Hadamard rotation is not supported for"
128
+ f" {tensor_quant_config.granularity} granularity."
129
+ )
130
+
131
+ quantized_dim = common_utils.get_weight_quantized_dim(op_info, tensor_content)
132
+ if quantized_dim != 0:
133
+ raise ValueError(
134
+ f"Unsupported quantized dimension: {quantized_dim}. Only 0 is"
135
+ " supported."
136
+ )
137
+
138
+ # Reduction axis is the non-quantized dimension. Since we only support 2D
139
+ # tensors and quantized_dim of 0, the reduction axis is 1.
140
+ reduce_axis = 1
141
+
142
+ # Rotate the tensor with a Hadamard matrix.
143
+ w_rotated, hadamard_size, random_vector = _rotate_with_diagonal_hadamard(
144
+ tensor_content, axis=reduce_axis
145
+ )
146
+
147
+ # Get the quantized values of the rotated tensor.
148
+ qparams = octav.get_tensor_quant_params(
149
+ op_info, tensor_quant_config, w_rotated, tensor_qsv
150
+ )
151
+
152
+ return qtyping.UniformQuantParams(
153
+ quantized_dimension=qparams.quantized_dimension,
154
+ num_bits=qparams.num_bits,
155
+ scale=qparams.scale,
156
+ zero_point=qparams.zero_point,
157
+ symmetric=qparams.symmetric,
158
+ quantized_data=qparams.quantized_data,
159
+ block_size=qparams.block_size,
160
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
161
+ random_binary_vector=random_vector,
162
+ hadamard_size=hadamard_size,
163
+ ),
164
+ )
165
+
166
+
167
+ def materialize_fully_connected(
168
+ op_info: qtyping.OpInfo,
169
+ graph_info: qtyping.GraphInfo,
170
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
171
+ ) -> list[qtyping.TensorTransformationParams]:
172
+ """Materialize the fully_connected op.
173
+
174
+ Args:
175
+ op_info: Aggregated information about the op (e.g., quantization config).
176
+ graph_info: Graph information needed to perform quantization for the op.
177
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
178
+
179
+ Returns:
180
+ Quantization configuration for the tensors associated with the op (e.g.,
181
+ weights, bias).
182
+ """
183
+ op_tensor_params = []
184
+
185
+ # Materialize weight.
186
+ weight_tensor_index = 1
187
+ weight_tensor = graph_info.subgraph_tensors[
188
+ op_info.op.inputs[weight_tensor_index]
189
+ ]
190
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
191
+ weight_tensor, graph_info.buffers
192
+ )
193
+ # quant_params contains the rotated and quantized weights done by
194
+ # get_tensor_quant_params().
195
+ quant_params = get_tensor_quant_params(
196
+ op_info,
197
+ op_info.op_quant_config.weight_tensor_config,
198
+ tensor_data,
199
+ None,
200
+ )
201
+ transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
202
+ op2tensor_params = qtyping.OpToTensorParams(
203
+ subgraph_op_id=op_info.subgraph_op_index,
204
+ parameters=quant_params,
205
+ transformations=transformations,
206
+ )
207
+ weight_transformation_params = qtyping.TensorTransformationParams(
208
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
209
+ consumers=[op2tensor_params],
210
+ )
211
+
212
+ # Materialize input. A hadamard rotation op should be inserted on the input
213
+ # tensor to do the inverse of the weight's transformation.
214
+ input_tensor_index = 0
215
+ input_tensor = graph_info.subgraph_tensors[
216
+ op_info.op.inputs[input_tensor_index]
217
+ ]
218
+ transformations = [
219
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
220
+ ]
221
+ op2tensor_params = qtyping.OpToTensorParams(
222
+ subgraph_op_id=op_info.subgraph_op_index,
223
+ parameters=quant_params,
224
+ transformations=transformations,
225
+ )
226
+ input_transformation_params = qtyping.TensorTransformationParams(
227
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(input_tensor),
228
+ consumers=[op2tensor_params],
229
+ )
230
+ op_tensor_params.append(input_transformation_params)
231
+ op_tensor_params.append(weight_transformation_params)
232
+
233
+ # Materialize bias. Since static quantization is not supported, we do not
234
+ # quantize the bias tensor.
235
+ bias_tensor_index = 2
236
+ bias_tensor = graph_info.subgraph_tensors[
237
+ op_info.op.inputs[bias_tensor_index]
238
+ ]
239
+ no_quant_tensor_params = qtyping.OpToTensorParams(
240
+ subgraph_op_id=op_info.subgraph_op_index,
241
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
242
+ )
243
+ bias_transformation_params = qtyping.TensorTransformationParams(
244
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
245
+ consumers=[no_quant_tensor_params],
246
+ )
247
+ op_tensor_params.append(bias_transformation_params)
248
+
249
+ # Materialize output. Since static quantization is not supported, we do not
250
+ # quantize the output tensor.
251
+ output_tensor_index = 0
252
+ output_tensor = graph_info.subgraph_tensors[
253
+ op_info.op.outputs[output_tensor_index]
254
+ ]
255
+ no_quant_tensor_params = qtyping.OpToTensorParams(
256
+ subgraph_op_id=op_info.subgraph_op_index,
257
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
258
+ )
259
+ output_transformation_params = qtyping.TensorTransformationParams(
260
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
261
+ producer=no_quant_tensor_params,
262
+ )
263
+ op_tensor_params.append(output_transformation_params)
264
+
265
+ return op_tensor_params
266
+
267
+
268
+ def materialize_embedding_lookup(
269
+ op_info: qtyping.OpInfo,
270
+ graph_info: qtyping.GraphInfo,
271
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
272
+ ) -> list[qtyping.TensorTransformationParams]:
273
+ """Materialize the embedding_lookup op.
274
+
275
+ Args:
276
+ op_info: Aggregated information about the op (e.g., quantization config).
277
+ graph_info: Graph information needed to perform quantization for the op.
278
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
279
+
280
+ Returns:
281
+ Quantization configuration for the tensors associated with the op (e.g.,
282
+ weights, bias).
283
+ """
284
+ op_tensor_params = []
285
+
286
+ # Materialize lookup.
287
+ lookup_tensor_index = 0
288
+ lookup_tensor = graph_info.subgraph_tensors[
289
+ op_info.op.inputs[lookup_tensor_index]
290
+ ]
291
+ transformations = [
292
+ qtyping.QuantTransformation.NO_QUANTIZE,
293
+ ]
294
+ op2tensor_params = qtyping.OpToTensorParams(
295
+ subgraph_op_id=op_info.subgraph_op_index,
296
+ parameters=None,
297
+ transformations=transformations,
298
+ )
299
+ lookup_transformation_params = qtyping.TensorTransformationParams(
300
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(lookup_tensor),
301
+ consumers=[op2tensor_params],
302
+ )
303
+ op_tensor_params.append(lookup_transformation_params)
304
+
305
+ # Materialize embedding. The embedding table should be rotated and then
306
+ # quantized.
307
+ embedding_tensor_index = 1
308
+ embedding_tensor = graph_info.subgraph_tensors[
309
+ op_info.op.inputs[embedding_tensor_index]
310
+ ]
311
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
312
+ embedding_tensor, graph_info.buffers
313
+ )
314
+ quant_params = get_tensor_quant_params(
315
+ op_info,
316
+ op_info.op_quant_config.weight_tensor_config,
317
+ tensor_data,
318
+ None,
319
+ )
320
+ transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
321
+ op2tensor_params = qtyping.OpToTensorParams(
322
+ subgraph_op_id=op_info.subgraph_op_index,
323
+ parameters=quant_params,
324
+ transformations=transformations,
325
+ )
326
+ weight_transformation_params = qtyping.TensorTransformationParams(
327
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(embedding_tensor),
328
+ consumers=[op2tensor_params],
329
+ )
330
+ op_tensor_params.append(weight_transformation_params)
331
+
332
+ # Materialize output. A hadamard rotation op should be inserted on the output
333
+ # tensor to do the inverse of the embedding's transformation.
334
+ output_tensor_index = 0
335
+ output_tensor = graph_info.subgraph_tensors[
336
+ op_info.op.outputs[output_tensor_index]
337
+ ]
338
+ transformations = [
339
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
340
+ ]
341
+ op2tensor_params = qtyping.OpToTensorParams(
342
+ subgraph_op_id=op_info.subgraph_op_index,
343
+ parameters=quant_params,
344
+ transformations=transformations,
345
+ )
346
+ output_transformation_params = qtyping.TensorTransformationParams(
347
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
348
+ producer=op2tensor_params,
349
+ )
350
+ op_tensor_params.append(output_transformation_params)
351
+
352
+ return op_tensor_params
@@ -0,0 +1,216 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Test Hadamard rotation materialization."""
17
+
18
+ import os
19
+
20
+ from absl.testing import parameterized
21
+ import numpy as np
22
+
23
+ from tensorflow.python.platform import googletest
24
+ from ai_edge_quantizer import qtyping
25
+ from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
26
+ from ai_edge_quantizer.utils import test_utils
27
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
+
29
+ _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../../tests/models")
30
+ _TFLOpName = qtyping.TFLOperationName
31
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
32
+
33
+
34
+ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
35
+
36
+ def setUp(self):
37
+ super().setUp()
38
+ np.random.seed(888)
39
+ self._test_model_path = os.path.join(
40
+ _TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
41
+ )
42
+ self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
43
+ self._graph_info = qtyping.GraphInfo(
44
+ subgraph_tensors=self._test_model.subgraphs[0].tensors,
45
+ buffers=self._test_model.buffers,
46
+ )
47
+ self._tensor_name_to_qsv = None
48
+ self._subgraph = self._test_model.subgraphs[0]
49
+ fc_subgraph_op_index = 3
50
+ self._fc_op = self._subgraph.operators[fc_subgraph_op_index]
51
+ self._fc_buffer_id = self._subgraph.tensors[self._fc_op.inputs[1]].buffer
52
+ self._op_info = qtyping.OpInfo(
53
+ op=self._fc_op,
54
+ op_name=_TFLOpName.FULLY_CONNECTED,
55
+ subgraph_op_index=fc_subgraph_op_index,
56
+ op_quant_config=qtyping.OpQuantizationConfig(
57
+ weight_tensor_config=_TensorQuantConfig(
58
+ num_bits=8,
59
+ symmetric=True,
60
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
61
+ ),
62
+ ),
63
+ )
64
+
65
+ def test_materialize_fully_connected_basic(self):
66
+ params = hadamard_rotation.materialize_fully_connected(
67
+ self._op_info, self._graph_info, self._tensor_name_to_qsv
68
+ )
69
+ fc_input = params[0]
70
+ weight = params[1]
71
+ bias = params[2]
72
+ output = params[3]
73
+
74
+ self.assertLen(params, 4)
75
+ self.assertIsNone(fc_input.producer)
76
+ self.assertIsNotNone(fc_input.consumers)
77
+ self.assertIsNone(weight.producer)
78
+ self.assertIsNotNone(weight.consumers)
79
+ self.assertIsNone(bias.producer)
80
+ self.assertIsNotNone(bias.consumers)
81
+ self.assertIsNotNone(output.producer)
82
+ self.assertIsNone(output.consumers)
83
+ self.assertEqual(
84
+ fc_input.consumers[0].transformations,
85
+ [qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION],
86
+ )
87
+ self.assertEqual(
88
+ weight.consumers[0].transformations,
89
+ [qtyping.QuantTransformation.QUANTIZE_TENSOR],
90
+ )
91
+ self.assertEqual(
92
+ bias.consumers[0].transformations,
93
+ [qtyping.QuantTransformation.NO_QUANTIZE],
94
+ )
95
+ if output.producer is not None:
96
+ self.assertEqual(
97
+ output.producer.transformations,
98
+ [qtyping.QuantTransformation.NO_QUANTIZE],
99
+ )
100
+
101
+ def test_get_tensor_quant_params_basic(self):
102
+ input_tensor = self._subgraph.tensors[self._fc_op.inputs[1]]
103
+ buffer = self._graph_info.buffers[self._fc_buffer_id]
104
+ np_buffer = np.frombuffer(buffer.data, dtype=np.float32).reshape(
105
+ input_tensor.shape
106
+ )
107
+ qparams = hadamard_rotation.get_tensor_quant_params(
108
+ self._op_info,
109
+ self._op_info.op_quant_config.weight_tensor_config,
110
+ np_buffer,
111
+ self._tensor_name_to_qsv,
112
+ )
113
+ self.assertEqual(qparams.num_bits, 8)
114
+ self.assertEqual(qparams.zero_point.all(), 0)
115
+ self.assertEqual(qparams.symmetric, True)
116
+ self.assertIsNotNone(qparams.quantized_data)
117
+ self.assertEqual(qparams.block_size, 0)
118
+ self.assertIsNotNone(qparams.hadamard)
119
+ if qparams.hadamard is not None:
120
+ self.assertEqual(qparams.hadamard.hadamard_size, 32)
121
+
122
+ def test_raise_missing_tensor_content(self):
123
+ with self.assertRaisesWithPredicateMatch(
124
+ ValueError, lambda err: "weight tensor" in str(err)
125
+ ):
126
+ hadamard_rotation.get_tensor_quant_params(
127
+ self._op_info,
128
+ self._op_info.op_quant_config.weight_tensor_config,
129
+ None,
130
+ self._tensor_name_to_qsv,
131
+ )
132
+
133
+ def test_raise_qsv_set(self):
134
+ with self.assertRaisesWithPredicateMatch(
135
+ ValueError, lambda err: "static quantization" in str(err)
136
+ ):
137
+ hadamard_rotation.get_tensor_quant_params(
138
+ self._op_info,
139
+ self._op_info.op_quant_config.weight_tensor_config,
140
+ self._graph_info.buffers[self._fc_buffer_id],
141
+ self._graph_info.buffers[self._fc_buffer_id],
142
+ )
143
+
144
+ def test_raise_non_2d_constant(self):
145
+ with self.assertRaisesWithPredicateMatch(
146
+ ValueError, lambda err: "2D tensors" in str(err)
147
+ ):
148
+ hadamard_rotation.get_tensor_quant_params(
149
+ self._op_info,
150
+ self._op_info.op_quant_config.weight_tensor_config,
151
+ np.array([1.0, 2.0, 3.0]),
152
+ self._tensor_name_to_qsv,
153
+ )
154
+
155
+
156
+ class HadamardRotationEmbeddingLookupTest(parameterized.TestCase):
157
+
158
+ def setUp(self):
159
+ super().setUp()
160
+ np.random.seed(888)
161
+ self._test_model_path = os.path.join(
162
+ _TEST_DATA_PREFIX_PATH, "embedding_lookup.tflite"
163
+ )
164
+ self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
165
+ self._graph_info = qtyping.GraphInfo(
166
+ subgraph_tensors=self._test_model.subgraphs[0].tensors,
167
+ buffers=self._test_model.buffers,
168
+ )
169
+ self._tensor_name_to_qsv = None
170
+
171
+ def test_materialize_embedding_lookup_basic(self):
172
+ subgraph = self._test_model.subgraphs[0]
173
+ embedding_subgraph_op_index = 0
174
+ embedding_op = subgraph.operators[embedding_subgraph_op_index]
175
+ op_info = qtyping.OpInfo(
176
+ op=embedding_op,
177
+ op_name=_TFLOpName.EMBEDDING_LOOKUP,
178
+ subgraph_op_index=embedding_subgraph_op_index,
179
+ op_quant_config=qtyping.OpQuantizationConfig(
180
+ weight_tensor_config=_TensorQuantConfig(
181
+ num_bits=8,
182
+ symmetric=True,
183
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
184
+ ),
185
+ ),
186
+ )
187
+ params = hadamard_rotation.materialize_embedding_lookup(
188
+ op_info, self._graph_info, self._tensor_name_to_qsv
189
+ )
190
+ self.assertLen(params, 3)
191
+ lookup = params[0]
192
+ value = params[1]
193
+ output = params[2]
194
+ self.assertIsNone(lookup.producer)
195
+ self.assertIsNotNone(lookup.consumers)
196
+ self.assertIsNone(value.producer)
197
+ self.assertIsNotNone(value.consumers)
198
+ self.assertIsNotNone(output.producer)
199
+ self.assertIsNone(output.consumers)
200
+ self.assertEqual(
201
+ lookup.consumers[0].transformations,
202
+ [qtyping.QuantTransformation.NO_QUANTIZE],
203
+ )
204
+ self.assertEqual(
205
+ value.consumers[0].transformations,
206
+ [qtyping.QuantTransformation.QUANTIZE_TENSOR],
207
+ )
208
+ if output.producer is not None:
209
+ self.assertEqual(
210
+ output.producer.transformations,
211
+ [qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION],
212
+ )
213
+
214
+
215
+ if __name__ == "__main__":
216
+ googletest.main()
@@ -183,7 +183,8 @@ DEFAULT_JSON_POLICY = """
183
183
  "SELECT_V2",
184
184
  "DYNAMIC_UPDATE_SLICE",
185
185
  "SELECT_V2",
186
- "STABLEHLO_COMPOSITE"
186
+ "STABLEHLO_COMPOSITE",
187
+ "PAD"
187
188
  ],
188
189
  "static_wi8_ai8": [
189
190
  "ADD",
@@ -214,7 +215,8 @@ DEFAULT_JSON_POLICY = """
214
215
  "SELECT_V2",
215
216
  "DYNAMIC_UPDATE_SLICE",
216
217
  "SELECT_V2",
217
- "STABLEHLO_COMPOSITE"
218
+ "STABLEHLO_COMPOSITE",
219
+ "PAD"
218
220
  ],
219
221
  "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
220
222
  "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
@@ -508,6 +508,7 @@ def _compatible_tensor_params(
508
508
  float_source_transformations = [
509
509
  _QuantTrans.ADD_QUANTIZE,
510
510
  _QuantTrans.NO_QUANTIZE,
511
+ _QuantTrans.INSERT_HADAMARD_ROTATION,
511
512
  ]
512
513
  quantized_source_transformations = [
513
514
  _QuantTrans.QUANTIZE_TENSOR,
@@ -20,7 +20,7 @@ from collections.abc import MutableMapping
20
20
  import copy
21
21
  import dataclasses
22
22
  import enum
23
- from typing import Any, Optional, Union, Callable
23
+ from typing import Any, Callable, Optional, Union
24
24
 
25
25
  import numpy as np
26
26
  from typing_extensions import TypeAlias
@@ -62,6 +62,7 @@ class TFLOperationName(str, enum.Enum):
62
62
  SELECT_V2 = 'SELECT_V2'
63
63
  DYNAMIC_UPDATE_SLICE = 'DYNAMIC_UPDATE_SLICE'
64
64
  STABLEHLO_COMPOSITE = 'STABLEHLO_COMPOSITE'
65
+ PAD = 'PAD'
65
66
 
66
67
 
67
68
  class QuantizeMode(enum.Enum):
@@ -113,6 +114,8 @@ class QuantTransformation(enum.Enum):
113
114
  DUPLICATE_BUFFER = 5
114
115
  # Duplicate the tensor.
115
116
  DUPLICATE_TENSOR = 6
117
+ # Insert the aeq.hadamard_rotation op.
118
+ INSERT_HADAMARD_ROTATION = 7
116
119
 
117
120
 
118
121
  @dataclasses.dataclass(frozen=True)
@@ -128,8 +131,35 @@ class UniformQuantParams:
128
131
  quantized_data: The quantized data.
129
132
  block_size: The block size for blockwise quantization, block_size=0 meaning
130
133
  no blockwise quantization.
134
+ hadamard: The Hadamard rotation parameters, if set.
131
135
  """
132
136
 
137
+ class HadamardRotationParams:
138
+ """Parameters for the Hadamard rotation.
139
+
140
+ Attributes:
141
+ random_binary_vector: The random binary vector for the Hadamard rotation.
142
+ TODO(b/415392354): Randomization is an experimental feature that's
143
+ currently not implemented yet hence this is always 1. We will add
144
+ support or remove in the future.
145
+ hadamard_size: The size of the Hadamard matrix.
146
+ """
147
+
148
+ random_binary_vector: np.ndarray
149
+ hadamard_size: int
150
+
151
+ def __init__(self, random_binary_vector: np.ndarray, hadamard_size: int):
152
+ self.random_binary_vector = random_binary_vector
153
+ self.hadamard_size = hadamard_size
154
+
155
+ def __eq__(self, other):
156
+ if other.__class__ is not self.__class__:
157
+ return NotImplemented
158
+ return (
159
+ np.array_equal(self.random_binary_vector, other.random_binary_vector)
160
+ and self.hadamard_size == other.hadamard_size
161
+ )
162
+
133
163
  num_bits: int
134
164
  quantized_dimension: Optional[int]
135
165
  scale: np.ndarray
@@ -137,6 +167,7 @@ class UniformQuantParams:
137
167
  symmetric: bool = True
138
168
  quantized_data: Optional[np.ndarray] = None
139
169
  block_size: int = 0
170
+ hadamard: Optional[HadamardRotationParams] = None
140
171
 
141
172
  @classmethod
142
173
  def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams':
@@ -180,6 +211,7 @@ class UniformQuantParams:
180
211
  and self.symmetric == other.symmetric
181
212
  and _compare_array_or_none(self.quantized_data, other.quantized_data)
182
213
  and self.block_size == other.block_size
214
+ and self.hadamard == other.hadamard
183
215
  )
184
216
 
185
217
 
@@ -492,6 +524,7 @@ class IOOperator:
492
524
  outputs: list[int]
493
525
  op_key: TFLOperationName
494
526
 
527
+
495
528
  # The function signature for `get_tensor_quant_params_fn`.
496
529
  GetTensorQuantParamsFuncSignature = Callable[
497
530
  [
@@ -25,6 +25,7 @@ from ai_edge_quantizer.transformations import dequant_insert
25
25
  from ai_edge_quantizer.transformations import duplicate_buffer
26
26
  from ai_edge_quantizer.transformations import duplicate_tensor
27
27
  from ai_edge_quantizer.transformations import emulated_subchannel
28
+ from ai_edge_quantizer.transformations import insert_hadamard_rotation
28
29
  from ai_edge_quantizer.transformations import quant_insert
29
30
  from ai_edge_quantizer.transformations import quantize_tensor
30
31
  from ai_edge_quantizer.transformations import transformation_utils
@@ -80,6 +81,9 @@ class TransformationPerformer:
80
81
  qtyping.QuantTransformation.DUPLICATE_TENSOR: (
81
82
  duplicate_tensor.duplicate_tensor
82
83
  ),
84
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
85
+ insert_hadamard_rotation.insert_hadamard_rotation
86
+ ),
83
87
  }
84
88
  # transformations are seprated in two categories:
85
89
  # op_insertion_transformations are transformations that only insert ops
@@ -91,6 +95,7 @@ class TransformationPerformer:
91
95
  qtyping.QuantTransformation.ADD_QUANTIZE,
92
96
  qtyping.QuantTransformation.DUPLICATE_BUFFER,
93
97
  qtyping.QuantTransformation.DUPLICATE_TENSOR,
98
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
94
99
  ])
95
100
  self._op_replacement_transformations = set(
96
101
  [qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
@@ -0,0 +1,209 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Hadamard rotation pattern transformation."""
17
+
18
+ from flatbuffers import flexbuffers
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.transformations import transformation_utils
22
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ def _to_flexbuffer(
26
+ hadamard_size: int,
27
+ random_binary_vector: list[np.int8],
28
+ ) -> bytes:
29
+ """Converts hadamard_size to flexbuffer."""
30
+ fbb = flexbuffers.Builder()
31
+ with fbb.Map():
32
+ fbb.Int('hadamard_size', hadamard_size)
33
+ fbb.VectorFromElements('random_binary_vector', random_binary_vector)
34
+ return fbb.Finish()
35
+
36
+
37
+ def _is_producer_embedding_lookup(
38
+ transformation: transformation_utils.TransformationInput,
39
+ ) -> bool:
40
+ """Checks if the tensor's producer is an embedding lookup op."""
41
+ if transformation.producer == -1:
42
+ return False
43
+ else:
44
+ return (
45
+ transformation.op_codes[
46
+ transformation.subgraph.operators[
47
+ transformation.producer
48
+ ].opcodeIndex
49
+ ].builtinCode
50
+ == schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
51
+ )
52
+
53
+
54
+ def _is_fully_connected(
55
+ transformation: transformation_utils.TransformationInput, op_id: int
56
+ ) -> bool:
57
+ """Checks if the any of the tensor's consumers is a fully connected op."""
58
+ return (
59
+ transformation.op_codes[
60
+ transformation.subgraph.operators[op_id].opcodeIndex
61
+ ].builtinCode
62
+ == schema_py_generated.BuiltinOperator.FULLY_CONNECTED
63
+ )
64
+
65
+
66
+ def _update_embedding_lookup_consumers(
67
+ transformation: transformation_utils.TransformationInput,
68
+ new_tensor_id: int,
69
+ ) -> bool:
70
+ """Updates the consumers of the embedding lookup op to use the new tensor.
71
+
72
+ Args:
73
+ transformation: The transformation input to update the consumers of.
74
+ new_tensor_id: The new tensor id to use as the input to the embedding lookup
75
+ consumers.
76
+ """
77
+ for consumer in transformation.consumers:
78
+ # If the consumer is a graph output and not an op, we can ignore it here
79
+ # since the graph output will be updated later.
80
+ if consumer == -1:
81
+ continue
82
+ consumer_op = transformation.subgraph.operators[consumer]
83
+ # Find the input that was attached to the insertion point, and replace it
84
+ # with the new tensor.
85
+ for i in range(len(consumer_op.inputs)):
86
+ if consumer_op.inputs[i] == transformation.tensor_id:
87
+ consumer_op.inputs[i] = new_tensor_id
88
+
89
+
90
+ def _update_fully_connected_consumers(
91
+ transformation: transformation_utils.TransformationInput,
92
+ new_tensor_id: int,
93
+ ) -> bool:
94
+ """Updates the fully connected op(s) to use the new tensor.
95
+
96
+ Since the new tensor is inserted to the fully_connected's input, we need to
97
+ scan each consumer (in case of multiple fully_connected ops), and update
98
+ the input tensor to the new tensor.
99
+
100
+ Args:
101
+ transformation: The transformation input to update the consumers of.
102
+ new_tensor_id: The new tensor id to use as the input to the fully connected
103
+ consumers.
104
+
105
+ Returns:
106
+ True if the fully connected op(s) were updated to use the new tensor.
107
+ """
108
+ updated = False
109
+ for consumer in transformation.consumers:
110
+ if _is_fully_connected(transformation, consumer):
111
+ transformation.subgraph.operators[consumer].inputs[0] = new_tensor_id
112
+ updated = True
113
+ return updated
114
+
115
+
116
+ def insert_hadamard_rotation(
117
+ transformation_input: transformation_utils.TransformationInput,
118
+ ) -> qtyping.TransformationInfo:
119
+ """Inserts a custom aeq.hadamard_rotation op on this tensor.
120
+
121
+ This function works for float32 tensors only.
122
+
123
+ Args:
124
+ transformation_input: The transformation input to insert the custom op on.
125
+
126
+ Returns:
127
+ The transformation info of the inserted custom op.
128
+
129
+ Raises:
130
+ ValueError: If the transformation input is not a uniform quantization
131
+ transformation.
132
+ ValueError: If the Hadamard quantization params are not set.
133
+ ValueError: If the tensor is not a float32 tensor.
134
+ ValueError: If no supported ops were found as the tensor's producer or
135
+ consumers.
136
+ """
137
+ if not isinstance(
138
+ transformation_input.quant_params, qtyping.UniformQuantParams
139
+ ):
140
+ raise ValueError('Hadamard rotation supports uniform quantization only')
141
+
142
+ if transformation_input.quant_params.hadamard is None:
143
+ raise ValueError(
144
+ 'Hadamard rotation quantization params are not set but op insertion is'
145
+ ' requested.'
146
+ )
147
+
148
+ tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
149
+ if tensor.type != schema_py_generated.TensorType.FLOAT32:
150
+ raise ValueError(
151
+ 'The Hadamard rotation op supports float32 tensors only. Got'
152
+ f' {tensor.type} tensor.'
153
+ )
154
+
155
+ # Create new custom op with the current tensor as input and a new activation
156
+ # tensor as output.
157
+ custom_op_code_idx = transformation_utils.add_op_code(
158
+ schema_py_generated.BuiltinOperator.CUSTOM,
159
+ transformation_input.op_codes,
160
+ 'aeq.hadamard_rotation',
161
+ )
162
+ custom_op = schema_py_generated.OperatorT()
163
+ custom_op.opcodeIndex = custom_op_code_idx
164
+ custom_op.inputs = [transformation_input.tensor_id]
165
+ custom_op.customOptions = _to_flexbuffer(
166
+ transformation_input.quant_params.hadamard.hadamard_size,
167
+ transformation_input.quant_params.hadamard.random_binary_vector.tolist(),
168
+ )
169
+ new_tensor_id = transformation_utils.add_new_activation_tensor(
170
+ tensor.name + b'_rotated',
171
+ tensor.shapeSignature
172
+ if tensor.shapeSignature is not None
173
+ else tensor.shape,
174
+ schema_py_generated.TensorType.FLOAT32,
175
+ transformation_input.subgraph,
176
+ )
177
+ custom_op.outputs = [new_tensor_id]
178
+
179
+ # Update the users of this tensor to use the new tensor.
180
+ if _is_producer_embedding_lookup(transformation_input):
181
+ _update_embedding_lookup_consumers(transformation_input, new_tensor_id)
182
+ elif not _update_fully_connected_consumers(
183
+ transformation_input, new_tensor_id
184
+ ):
185
+ raise ValueError(
186
+ 'The Hadamard rotation op supports embedding lookup and fully connected'
187
+ ' ops only, but no such ops were found.'
188
+ )
189
+
190
+ # If the tensor is a graph output, we need to replace the tensor with the
191
+ # new tensor.
192
+ for i, output in enumerate(transformation_input.subgraph.outputs):
193
+ if output == transformation_input.tensor_id:
194
+ transformation_input.subgraph.outputs[i] = new_tensor_id
195
+
196
+ # Find the actual insertion point. The insertion point should be after the
197
+ # producer op and before the first consumer op. The max() operation ensures
198
+ # that we're not using -1 as the insertion point.
199
+ first_consumer_id = min(transformation_input.consumers)
200
+ op_id = max(transformation_input.producer + 1, first_consumer_id)
201
+
202
+ # Insert the custom op.
203
+ transformation_input.subgraph.operators.insert(op_id, custom_op)
204
+
205
+ return qtyping.TransformationInfo(
206
+ op_id=op_id,
207
+ num_ops_added=1,
208
+ output_tensor_id=new_tensor_id,
209
+ )
@@ -0,0 +1,200 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Test insertion of the Hadamard rotation custom op."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.transformations import insert_hadamard_rotation
23
+ from ai_edge_quantizer.transformations import transformation_utils
24
+ from ai_edge_quantizer.utils import test_utils
25
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
27
+
28
+ _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('..')
29
+
30
+
31
+ class InsertHadamardRotationFullyConnectedTest(googletest.TestCase):
32
+
33
+ def setUp(self):
34
+ super().setUp()
35
+ model_path = os.path.join(
36
+ _TEST_DATA_PREFIX_PATH, 'tests/models/single_fc_bias.tflite'
37
+ )
38
+ self.model = tfl_flatbuffer_utils.read_model(model_path)
39
+ self.params = qtyping.UniformQuantParams(
40
+ num_bits=8,
41
+ quantized_dimension=None,
42
+ scale=np.ones(1),
43
+ zero_point=np.zeros(1),
44
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
45
+ random_binary_vector=np.ones(1),
46
+ hadamard_size=2,
47
+ ),
48
+ )
49
+
50
+ def test_raise_unsupported_qparams(self):
51
+ with self.assertRaisesWithPredicateMatch(
52
+ ValueError, lambda err: 'uniform quantization' in str(err)
53
+ ):
54
+ insert_hadamard_rotation.insert_hadamard_rotation(
55
+ transformation_utils.TransformationInput(
56
+ tensor_id=0,
57
+ op_codes=self.model.operatorCodes,
58
+ buffers=self.model.buffers,
59
+ subgraph=self.model.subgraphs[0],
60
+ producer=-1,
61
+ consumers=[-1],
62
+ quant_params=qtyping.NonLinearQuantParams(
63
+ num_bits=16, quantized_data=None
64
+ ),
65
+ )
66
+ )
67
+
68
+ def test_raise_missing_hadamard_data(self):
69
+ with self.assertRaisesWithPredicateMatch(
70
+ ValueError, lambda err: 'quantization params are not set' in str(err)
71
+ ):
72
+ insert_hadamard_rotation.insert_hadamard_rotation(
73
+ transformation_utils.TransformationInput(
74
+ tensor_id=0,
75
+ op_codes=self.model.operatorCodes,
76
+ buffers=self.model.buffers,
77
+ subgraph=self.model.subgraphs[0],
78
+ producer=-1,
79
+ consumers=[-1],
80
+ quant_params=qtyping.UniformQuantParams(
81
+ num_bits=8,
82
+ quantized_dimension=None,
83
+ scale=np.ones(1),
84
+ zero_point=np.zeros(1),
85
+ ),
86
+ )
87
+ )
88
+
89
+ def test_raise_non_float32_tensor(self):
90
+ self.model.subgraphs[0].tensors[
91
+ 0
92
+ ].type = schema_py_generated.TensorType.INT32
93
+ with self.assertRaisesWithPredicateMatch(
94
+ ValueError, lambda err: 'float32 tensors' in str(err)
95
+ ):
96
+ insert_hadamard_rotation.insert_hadamard_rotation(
97
+ transformation_utils.TransformationInput(
98
+ tensor_id=0,
99
+ op_codes=self.model.operatorCodes,
100
+ buffers=self.model.buffers,
101
+ subgraph=self.model.subgraphs[0],
102
+ producer=-1,
103
+ consumers=[-1],
104
+ quant_params=self.params,
105
+ ),
106
+ )
107
+
108
+ def test_insert_single_custom_op(self):
109
+ # Insert aeq.hadamard_rotation before fully_connected
110
+ info = insert_hadamard_rotation.insert_hadamard_rotation(
111
+ transformation_utils.TransformationInput(
112
+ tensor_id=0,
113
+ op_codes=self.model.operatorCodes,
114
+ buffers=self.model.buffers,
115
+ subgraph=self.model.subgraphs[0],
116
+ producer=-1,
117
+ consumers=[-1],
118
+ quant_params=self.params,
119
+ )
120
+ )
121
+ subgraph = self.model.subgraphs[0]
122
+ self.assertEqual(info.op_id, 0)
123
+ self.assertEqual(info.num_ops_added, 1)
124
+ # Model had 4 tensors, added 1.
125
+ self.assertEqual(info.output_tensor_id, 4)
126
+ self.assertLen(subgraph.tensors, 5)
127
+ # Model had 1 op, added a new one.
128
+ self.assertLen(self.model.operatorCodes, 2)
129
+ self.assertEqual(
130
+ self.model.operatorCodes[1].builtinCode,
131
+ schema_py_generated.BuiltinOperator.CUSTOM,
132
+ )
133
+ # First op is now the custom op, precedes fully_connected.
134
+ self.assertEqual(
135
+ self.model.operatorCodes[subgraph.operators[0].opcodeIndex].builtinCode,
136
+ schema_py_generated.BuiltinOperator.CUSTOM,
137
+ )
138
+ # Input to the custom op is graph input
139
+ self.assertEqual(subgraph.operators[0].inputs[0], 0)
140
+ # Input to the FC is the custom op output
141
+ self.assertEqual(subgraph.operators[1].inputs[0], 4)
142
+
143
+
144
+ class InsertHadamardRotationEmbeddingLookupTest(googletest.TestCase):
145
+
146
+ def setUp(self):
147
+ super().setUp()
148
+ model_path = os.path.join(
149
+ _TEST_DATA_PREFIX_PATH, 'tests/models/embedding_lookup.tflite'
150
+ )
151
+ self.model = tfl_flatbuffer_utils.read_model(model_path)
152
+ self.params = qtyping.UniformQuantParams(
153
+ num_bits=8,
154
+ quantized_dimension=None,
155
+ scale=np.ones(1),
156
+ zero_point=np.zeros(1),
157
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
158
+ random_binary_vector=np.ones(1),
159
+ hadamard_size=2,
160
+ ),
161
+ )
162
+
163
+ def test_insert_single_custom_op(self):
164
+ # Insert aeq.hadamard_rotation after embedding_lookup
165
+ info = insert_hadamard_rotation.insert_hadamard_rotation(
166
+ transformation_utils.TransformationInput(
167
+ tensor_id=2,
168
+ op_codes=self.model.operatorCodes,
169
+ buffers=self.model.buffers,
170
+ subgraph=self.model.subgraphs[0],
171
+ producer=0,
172
+ consumers=[-1],
173
+ quant_params=self.params,
174
+ )
175
+ )
176
+ subgraph = self.model.subgraphs[0]
177
+ self.assertEqual(info.op_id, 1)
178
+ self.assertEqual(info.num_ops_added, 1)
179
+ # Model had 3 tensors, added 1.
180
+ self.assertEqual(info.output_tensor_id, 3)
181
+ self.assertLen(subgraph.tensors, 4)
182
+ # Model had 1 op, added a new one.
183
+ self.assertLen(self.model.operatorCodes, 2)
184
+ self.assertEqual(
185
+ self.model.operatorCodes[1].builtinCode,
186
+ schema_py_generated.BuiltinOperator.CUSTOM,
187
+ )
188
+ # Second op is now the custom op, after embedding_lookup.
189
+ self.assertEqual(
190
+ self.model.operatorCodes[subgraph.operators[1].opcodeIndex].builtinCode,
191
+ schema_py_generated.BuiltinOperator.CUSTOM,
192
+ )
193
+ # Input to the custom op is embedding's output
194
+ self.assertEqual(subgraph.operators[1].inputs[0], 2)
195
+ # Custom op's output is the new tensor
196
+ self.assertEqual(subgraph.operators[1].outputs[0], 3)
197
+
198
+
199
+ if __name__ == '__main__':
200
+ googletest.main()
@@ -33,6 +33,39 @@ _OpQuantConfig = qtyping.OpQuantizationConfig
33
33
  _AlgorithmName = quantizer.AlgorithmName
34
34
 
35
35
 
36
+ DEFAULT_ACTIVATION_QUANT_SETTING = _TensorQuantConfig(
37
+ num_bits=8,
38
+ symmetric=False,
39
+ granularity=qtyping.QuantGranularity.TENSORWISE,
40
+ )
41
+ DEFAULT_WEIGHT_QUANT_SETTING = _TensorQuantConfig(
42
+ num_bits=8,
43
+ symmetric=True,
44
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
45
+ )
46
+
47
+
48
+ def get_static_activation_quant_setting(
49
+ num_bits: int, symmetric: bool
50
+ ) -> _TensorQuantConfig:
51
+ return _TensorQuantConfig(
52
+ num_bits=num_bits,
53
+ symmetric=symmetric,
54
+ granularity=qtyping.QuantGranularity.TENSORWISE,
55
+ )
56
+
57
+
58
+ def get_static_op_quant_config(
59
+ activation_config: _TensorQuantConfig = DEFAULT_ACTIVATION_QUANT_SETTING,
60
+ weight_config: _TensorQuantConfig = DEFAULT_WEIGHT_QUANT_SETTING,
61
+ ) -> _OpQuantConfig:
62
+ return qtyping.OpQuantizationConfig(
63
+ activation_tensor_config=activation_config,
64
+ weight_tensor_config=weight_config,
65
+ compute_precision=_ComputePrecision.INTEGER,
66
+ )
67
+
68
+
36
69
  def get_path_to_datafile(path):
37
70
  """Get the path to the specified file in the data dependencies.
38
71
 
@@ -56,6 +56,7 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
56
56
  _TFLOpName.DYNAMIC_UPDATE_SLICE: (
57
57
  schema.BuiltinOperator.DYNAMIC_UPDATE_SLICE
58
58
  ),
59
+ _TFLOpName.PAD: schema.BuiltinOperator.PAD,
59
60
  })
60
61
 
61
62
  TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.1.0.dev20250511
3
+ Version: 0.1.0.dev20250513
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -1,18 +1,18 @@
1
1
  ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas,793
2
- ai_edge_quantizer/algorithm_manager.py,sha256=0uootLsVD6h9ph9TrnXZMI-ExkX8UvXSV0lbWxBLybU,10492
2
+ ai_edge_quantizer/algorithm_manager.py,sha256=p-wX2ksIV1hbWEQz-uUnbNMVgDJrsIiIOU2ZYX2ZrTM,11726
3
3
  ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
4
4
  ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
5
5
  ai_edge_quantizer/calibrator.py,sha256=n7AD9j7UScR-CieoI6DQRMeiG_fhLBfSLRiM4460xaM,11895
6
6
  ai_edge_quantizer/calibrator_test.py,sha256=C_oWOaRugPKYX74jF-eRFH-k6nGOdA8I9_uPiocaOuE,11900
7
7
  ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
8
- ai_edge_quantizer/default_policy.py,sha256=81z4cruBK7mGFt8xFRZK5LKya65axuZwo2zpbcYSicc,11099
8
+ ai_edge_quantizer/default_policy.py,sha256=zNTeiI_eP5-dLL3P_VWIQB3RzXBrb06peJKngLnSSFY,11125
9
9
  ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
10
10
  ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
11
11
  ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
12
12
  ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
13
- ai_edge_quantizer/params_generator.py,sha256=NEZeHVVIeynmhRzPjl9o-acvWfauFgCS4i45pWFw3V8,20052
13
+ ai_edge_quantizer/params_generator.py,sha256=j1BV2cGFLlQmUY6aoW5uglYqf77b9ytN8oZ1gh6o0mM,20096
14
14
  ai_edge_quantizer/params_generator_test.py,sha256=RDYoRZDJfEZRtjlTAU2kZ_4t3JHOqEHxfJX9V4ETAhg,40597
15
- ai_edge_quantizer/qtyping.py,sha256=FqelZu7j0fGBRSCv_VVsuf3VmbfVlYJGgsjvdMXGgaw,15284
15
+ ai_edge_quantizer/qtyping.py,sha256=LKn9w53wmw3gPO0E4DKOhj8gkx9efjXMoipGnsJyGiU,16453
16
16
  ai_edge_quantizer/quantizer.py,sha256=g3DMqFMrMpt9jQttCE0WcdNbMtk0JZnmN5MmCHrNdyM,13202
17
17
  ai_edge_quantizer/quantizer_test.py,sha256=K_HBA56JkFI3HL8VLWCqGEfC0ISh5ldMKoNyBdGRAJg,20368
18
18
  ai_edge_quantizer/recipe.py,sha256=FR0uJceumZrnle2VRSOQZ1uXup4S1cTYKRH-N53mWRo,2919
@@ -21,17 +21,19 @@ ai_edge_quantizer/recipe_manager_test.py,sha256=LulVxsYp6TBGFI2PLCUCd4VsFq8ELpC7
21
21
  ai_edge_quantizer/recipe_test.py,sha256=Fg_sfxovI2fRjk5qdu18ghOvXdUvhDR1TxbE0GHDczc,3381
22
22
  ai_edge_quantizer/transformation_instruction_generator.py,sha256=R7A90Qj6iQQROrznXmXLJd-5yXq0PRHbLOdNY51dEu4,27913
23
23
  ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=E0QSDCav6N6izlJ-a1ZJOsb2VEUxuxBmTbt0-EgDdxY,49890
24
- ai_edge_quantizer/transformation_performer.py,sha256=zAzrQOb2n2IpB3qopmKV59e5E99HmTOL60QTCn9-7kA,12821
24
+ ai_edge_quantizer/transformation_performer.py,sha256=nkkqbs81ITB5u2FoWeG9z5d8EtLtCiltOxcQ34okN8E,13091
25
25
  ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
26
26
  ai_edge_quantizer/algorithms/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
27
27
  ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
28
28
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
29
29
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
30
30
  ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
31
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=JyMuqVns3meKOA1DomXYopxqhZE65yamlVZjBF6yOmY,27731
31
+ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=-ugXQ4cZoVMrgOVs4m73ozI-49CRyT0YuKrLS5begW8,28297
32
32
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=qMmKbWqxrCoVKbLKHn9WuCrGKPfHkEyU0Nmhokh8Qeo,2597
33
33
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=Fk3s9Qy2A_hjUepFOUmTwIZ_wKYVPbdDX4eoP-eoAQU,8726
34
34
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
35
+ ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py,sha256=f9HhFCAavbrdYkQQH37ivbKRuRXC1g1TO2FmILMApN8,12389
36
+ ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py,sha256=kN9aCPt1yTleiDBiH4g2RZ1vMBm7WAf5pmVFjmYCH-0,7617
35
37
  ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=Divlsn3NjNGtH0vlvE91wxL-VHb4q1nUE0JTDGiEtYc,8572
36
38
  ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=zoF_EHjYqsKkuev8wfuutIITEmp_maa70IpJI_Df3ck,7431
37
39
  ai_edge_quantizer/algorithms/uniform_quantize/octav.py,sha256=e5wYtki-vl739gSVAZHAKcs2hA87GvFUjVoSUPlnkyM,6433
@@ -50,6 +52,8 @@ ai_edge_quantizer/transformations/duplicate_tensor.py,sha256=HF1uuKFm5kFF6X0XUpd
50
52
  ai_edge_quantizer/transformations/duplicate_tensor_test.py,sha256=s-RqSxNBMfVJyCunXz2eb7-KA6UiBmbOmL7phLslENQ,5056
51
53
  ai_edge_quantizer/transformations/emulated_subchannel.py,sha256=HVaRxoC8PCAvy3xeMv3OIymukUy_yW1zK0xN8Ann6I4,13602
52
54
  ai_edge_quantizer/transformations/emulated_subchannel_test.py,sha256=gZP6u9NdPXl7s19qB_Un8evou9ZZV6I9Gy0E1rdobHM,7722
55
+ ai_edge_quantizer/transformations/insert_hadamard_rotation.py,sha256=rBbKgcVKHie38NT2UQ7KQ1xCb2tRu_rVl0yFloOAW_A,7562
56
+ ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py,sha256=iV1p3nZfHUATV2YRoBOYurnu3pLy8n3aFppLWGQOPdA,7268
53
57
  ai_edge_quantizer/transformations/quant_insert.py,sha256=jn6HsJaV-sqBiFPY-Aqbd64t8zgcYVkEkZI375x_FWY,3958
54
58
  ai_edge_quantizer/transformations/quant_insert_test.py,sha256=X9ptPDvJCFkR5tejKnD1SlHFGPazQTW-wNNMV9MEAuw,10107
55
59
  ai_edge_quantizer/transformations/quantize_tensor.py,sha256=kjaNrw9mnrn0t8u0vey9S_uPz3iVUicwy4rluxVqV3E,7617
@@ -59,15 +63,15 @@ ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=E90O4PYSjz
59
63
  ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
60
64
  ai_edge_quantizer/utils/calibration_utils.py,sha256=1Fj9MIO6aLZIRgyd4axvZN4S_O64nB_-Miu1WP664js,2536
61
65
  ai_edge_quantizer/utils/calibration_utils_test.py,sha256=Z-AcdTieesWFKyKBb08ZXm4Mgu6cvJ4bg2-MJ7hLD10,2856
62
- ai_edge_quantizer/utils/test_utils.py,sha256=HwZCIpO9fJRAhuN6t6voXKOYQtcioFtt_tpkAlDsAYk,6205
63
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=NKtw60BJAjIE6Yww8B1vJpxXwp4MSERmpKajXJWm5rI,10568
66
+ ai_edge_quantizer/utils/test_utils.py,sha256=fXwQ353P7tSy7W4Hs6YskIbCLLaBYGA724hMMbcqCUk,7129
67
+ ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=zNlR_SJAkDi-EX63O3pNpFLVqSktysScZKgKk1XT3c8,10616
64
68
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
65
69
  ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=WoewyiZpaua80oP0tpgyrw5Ws1v7f4vl88vdzS0UjDE,13490
66
70
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
67
71
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
68
72
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
69
- ai_edge_quantizer_nightly-0.1.0.dev20250511.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
70
- ai_edge_quantizer_nightly-0.1.0.dev20250511.dist-info/METADATA,sha256=ENLYvQfun3PLDrzvKxWBsjTrbyWjM7KRBETP2EKs8Kk,1528
71
- ai_edge_quantizer_nightly-0.1.0.dev20250511.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
72
- ai_edge_quantizer_nightly-0.1.0.dev20250511.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
73
- ai_edge_quantizer_nightly-0.1.0.dev20250511.dist-info/RECORD,,
73
+ ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
+ ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/METADATA,sha256=zL_JxmjzCHEwIUmLkDGzI6B7IACt6YnVQSpaxaNUujY,1528
75
+ ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
+ ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
+ ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/RECORD,,