ai-edge-quantizer-nightly 0.4.0.dev20250930__py3-none-any.whl → 0.4.0.dev20251001__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -61,7 +61,8 @@ class AlgorithmName(str, enum.Enum):
61
61
  FLOAT_CASTING = float_casting.ALGORITHM_KEY
62
62
  DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
63
63
  OCTAV = octav.ALGORITHM_KEY
64
- HADAMARD_ROTATION = hadamard_rotation.ALGORITHM_KEY
64
+ HADAMARD_ROTATION = hadamard_rotation.CUSTOM_OP_ALGORITHM_KEY
65
+ DECOMPOSED_HADAMARD_ROTATION = hadamard_rotation.DECOMPOSED_ALGORITHM_KEY
65
66
  MSE = mse.ALGORITHM_KEY
66
67
 
67
68
 
@@ -311,8 +312,12 @@ register_config_check_policy_func(
311
312
 
312
313
  # Register specialized hadamard rotation materialize functions.
313
314
  _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
314
- _TFLOpName.FULLY_CONNECTED: hadamard_rotation.materialize_fully_connected,
315
- _TFLOpName.EMBEDDING_LOOKUP: hadamard_rotation.materialize_embedding_lookup,
315
+ _TFLOpName.FULLY_CONNECTED: (
316
+ hadamard_rotation.materialize_fully_connected_custom_op
317
+ ),
318
+ _TFLOpName.EMBEDDING_LOOKUP: (
319
+ hadamard_rotation.materialize_embedding_lookup_custom_op
320
+ ),
316
321
  })
317
322
  for (
318
323
  op_name,
@@ -326,6 +331,36 @@ for (
326
331
  materialize_func=materialize_func,
327
332
  )
328
333
 
334
+ register_op_quant_config_validation_func(
335
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
336
+ common_quantize.check_op_quantization_config,
337
+ )
338
+
339
+ register_config_check_policy_func(
340
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
341
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
342
+ )
343
+
344
+ _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
345
+ _TFLOpName.FULLY_CONNECTED: (
346
+ hadamard_rotation.materialize_fully_connected_decomposed
347
+ ),
348
+ _TFLOpName.EMBEDDING_LOOKUP: (
349
+ hadamard_rotation.materialize_embedding_lookup_decomposed
350
+ ),
351
+ })
352
+ for (
353
+ op_name,
354
+ materialize_func,
355
+ ) in _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
356
+ register_quantized_op(
357
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
358
+ op_name,
359
+ naive_min_max_quantize.init_qsvs,
360
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
361
+ materialize_func=materialize_func,
362
+ )
363
+
329
364
 
330
365
  # Register the MSE algorithm.
331
366
  register_op_quant_config_validation_func(
@@ -23,16 +23,17 @@ from ai_edge_quantizer.algorithms.utils import common_utils
23
23
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
24
 
25
25
 
26
- ALGORITHM_KEY = "HADAMARD_ROTATION"
26
+ CUSTOM_OP_ALGORITHM_KEY = "HADAMARD_ROTATION"
27
+ DECOMPOSED_ALGORITHM_KEY = "DECOMPOSED_HADAMARD_ROTATION"
27
28
 
28
29
 
29
30
  def _make_hadamard_matrix(size: int) -> np.ndarray:
30
31
  """Generates a Hadamard matrix of the given size.
31
32
 
32
33
  Args:
33
- size: The size of the Hadamard matrix. Must be a power of 2. This
34
- represents a single dimension. E.g. if size is 4, then the Hadamard matrix
35
- is a 4x4 matrix.
34
+ size: The size of the Hadamard matrix. Must be a power of 2. This represents
35
+ a single dimension. E.g. if size is 4, then the Hadamard matrix is a 4x4
36
+ matrix.
36
37
 
37
38
  Returns:
38
39
  The Hadamard matrix.
@@ -157,9 +158,10 @@ def get_tensor_quant_params(
157
158
  )
158
159
 
159
160
 
160
- def materialize_fully_connected(
161
+ def _materialize_fully_connected(
161
162
  op_info: qtyping.OpInfo,
162
163
  graph_info: qtyping.GraphInfo,
164
+ is_decomposed: bool = False,
163
165
  tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
164
166
  ) -> list[qtyping.TensorTransformationParams]:
165
167
  """Materialize the fully_connected op.
@@ -167,12 +169,20 @@ def materialize_fully_connected(
167
169
  Args:
168
170
  op_info: Aggregated information about the op (e.g., quantization config).
169
171
  graph_info: Graph information needed to perform quantization for the op.
172
+ is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
173
+ op.
170
174
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
171
175
 
172
176
  Returns:
173
177
  Quantization configuration for the tensors associated with the op (e.g.,
174
178
  weights, bias).
175
179
  """
180
+ if op_info.op_quant_config.weight_tensor_config is None:
181
+ raise ValueError(
182
+ "Weight tensor quantization config is not provided for Hadamard"
183
+ " Rotation quantization."
184
+ )
185
+
176
186
  op_tensor_params = []
177
187
 
178
188
  # Materialize weight.
@@ -209,7 +219,9 @@ def materialize_fully_connected(
209
219
  op_info.op.inputs[input_tensor_index]
210
220
  ]
211
221
  transformations = [
212
- qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
222
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
223
+ if is_decomposed
224
+ else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
213
225
  ]
214
226
  op2tensor_params = qtyping.OpToTensorParams(
215
227
  subgraph_op_id=op_info.subgraph_op_index,
@@ -258,16 +270,45 @@ def materialize_fully_connected(
258
270
  return op_tensor_params
259
271
 
260
272
 
261
- def materialize_embedding_lookup(
273
+ def materialize_fully_connected_custom_op(
274
+ op_info: qtyping.OpInfo,
275
+ graph_info: qtyping.GraphInfo,
276
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
277
+ ) -> list[qtyping.TensorTransformationParams]:
278
+ return _materialize_fully_connected(
279
+ op_info,
280
+ graph_info,
281
+ is_decomposed=False,
282
+ tensor_name_to_qsv=tensor_name_to_qsv,
283
+ )
284
+
285
+
286
+ def materialize_fully_connected_decomposed(
262
287
  op_info: qtyping.OpInfo,
263
288
  graph_info: qtyping.GraphInfo,
264
289
  tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
290
+ ) -> list[qtyping.TensorTransformationParams]:
291
+ return _materialize_fully_connected(
292
+ op_info,
293
+ graph_info,
294
+ is_decomposed=True,
295
+ tensor_name_to_qsv=tensor_name_to_qsv,
296
+ )
297
+
298
+
299
+ def _materialize_embedding_lookup(
300
+ op_info: qtyping.OpInfo,
301
+ graph_info: qtyping.GraphInfo,
302
+ is_decomposed: bool = False,
303
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
265
304
  ) -> list[qtyping.TensorTransformationParams]:
266
305
  """Materialize the embedding_lookup op.
267
306
 
268
307
  Args:
269
308
  op_info: Aggregated information about the op (e.g., quantization config).
270
309
  graph_info: Graph information needed to perform quantization for the op.
310
+ is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
311
+ op.
271
312
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
272
313
 
273
314
  Returns:
@@ -329,7 +370,9 @@ def materialize_embedding_lookup(
329
370
  op_info.op.outputs[output_tensor_index]
330
371
  ]
331
372
  transformations = [
332
- qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
373
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
374
+ if is_decomposed
375
+ else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
333
376
  ]
334
377
  op2tensor_params = qtyping.OpToTensorParams(
335
378
  subgraph_op_id=op_info.subgraph_op_index,
@@ -343,3 +386,29 @@ def materialize_embedding_lookup(
343
386
  op_tensor_params.append(output_transformation_params)
344
387
 
345
388
  return op_tensor_params
389
+
390
+
391
+ def materialize_embedding_lookup_custom_op(
392
+ op_info: qtyping.OpInfo,
393
+ graph_info: qtyping.GraphInfo,
394
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
395
+ ) -> list[qtyping.TensorTransformationParams]:
396
+ return _materialize_embedding_lookup(
397
+ op_info,
398
+ graph_info,
399
+ is_decomposed=False,
400
+ tensor_name_to_qsv=tensor_name_to_qsv,
401
+ )
402
+
403
+
404
+ def materialize_embedding_lookup_decomposed(
405
+ op_info: qtyping.OpInfo,
406
+ graph_info: qtyping.GraphInfo,
407
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
408
+ ) -> list[qtyping.TensorTransformationParams]:
409
+ return _materialize_embedding_lookup(
410
+ op_info,
411
+ graph_info,
412
+ is_decomposed=True,
413
+ tensor_name_to_qsv=tensor_name_to_qsv,
414
+ )
@@ -63,7 +63,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
63
63
  )
64
64
 
65
65
  def test_materialize_fully_connected_basic(self):
66
- params = hadamard_rotation.materialize_fully_connected(
66
+ params = hadamard_rotation.materialize_fully_connected_custom_op(
67
67
  self._op_info, self._graph_info, self._tensor_name_to_qsv
68
68
  )
69
69
  fc_input = params[0]
@@ -111,7 +111,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
111
111
  ),
112
112
  ),
113
113
  )
114
- params = hadamard_rotation.materialize_fully_connected(
114
+ params = hadamard_rotation.materialize_fully_connected_custom_op(
115
115
  self._op_info, self._graph_info, self._tensor_name_to_qsv
116
116
  )
117
117
  self.assertLen(params, 4)
@@ -152,7 +152,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
152
152
  ),
153
153
  ),
154
154
  )
155
- params = hadamard_rotation.materialize_fully_connected(
155
+ params = hadamard_rotation.materialize_fully_connected_custom_op(
156
156
  self._op_info, self._graph_info, self._tensor_name_to_qsv
157
157
  )
158
158
  self.assertLen(params, 4)
@@ -179,6 +179,34 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
179
179
  ):
180
180
  self.assertEqual(weight.consumers[0].parameters.quantized_dimension, 1)
181
181
 
182
+ def test_materialize_fully_connected_decomposed(self):
183
+ params = hadamard_rotation.materialize_fully_connected_decomposed(
184
+ self._op_info, self._graph_info, self._tensor_name_to_qsv
185
+ )
186
+ fc_input = params[0]
187
+ weight = params[1]
188
+ bias = params[2]
189
+ output = params[3]
190
+
191
+ self.assertLen(params, 4)
192
+ self.assertEqual(
193
+ fc_input.consumers[0].transformations,
194
+ [qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION],
195
+ )
196
+ self.assertEqual(
197
+ weight.consumers[0].transformations,
198
+ [qtyping.QuantTransformation.QUANTIZE_TENSOR],
199
+ )
200
+ self.assertEqual(
201
+ bias.consumers[0].transformations,
202
+ [qtyping.QuantTransformation.NO_QUANTIZE],
203
+ )
204
+ if output.producer is not None:
205
+ self.assertEqual(
206
+ output.producer.transformations,
207
+ [qtyping.QuantTransformation.NO_QUANTIZE],
208
+ )
209
+
182
210
  def test_get_tensor_quant_params_basic(self):
183
211
  input_tensor = self._subgraph.tensors[self._fc_op.inputs[1]]
184
212
  buffer = self._graph_info.buffers[self._fc_buffer_id]
@@ -344,7 +372,7 @@ class HadamardRotationEmbeddingLookupTest(parameterized.TestCase):
344
372
  ),
345
373
  ),
346
374
  )
347
- params = hadamard_rotation.materialize_embedding_lookup(
375
+ params = hadamard_rotation.materialize_embedding_lookup_custom_op(
348
376
  op_info, self._graph_info, self._tensor_name_to_qsv
349
377
  )
350
378
  self.assertLen(params, 3)
@@ -371,6 +399,43 @@ class HadamardRotationEmbeddingLookupTest(parameterized.TestCase):
371
399
  [qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION],
372
400
  )
373
401
 
402
+ def test_materialize_embedding_lookup_decomposed(self):
403
+ subgraph = self._test_model.subgraphs[0]
404
+ embedding_subgraph_op_index = 0
405
+ embedding_op = subgraph.operators[embedding_subgraph_op_index]
406
+ op_info = qtyping.OpInfo(
407
+ op=embedding_op,
408
+ op_name=_TFLOpName.EMBEDDING_LOOKUP,
409
+ subgraph_op_index=embedding_subgraph_op_index,
410
+ op_quant_config=qtyping.OpQuantizationConfig(
411
+ weight_tensor_config=_TensorQuantConfig(
412
+ num_bits=8,
413
+ symmetric=True,
414
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
415
+ ),
416
+ ),
417
+ )
418
+ params = hadamard_rotation.materialize_embedding_lookup_decomposed(
419
+ op_info, self._graph_info, self._tensor_name_to_qsv
420
+ )
421
+ self.assertLen(params, 3)
422
+ lookup = params[0]
423
+ value = params[1]
424
+ output = params[2]
425
+ self.assertEqual(
426
+ lookup.consumers[0].transformations,
427
+ [qtyping.QuantTransformation.NO_QUANTIZE],
428
+ )
429
+ self.assertEqual(
430
+ value.consumers[0].transformations,
431
+ [qtyping.QuantTransformation.QUANTIZE_TENSOR],
432
+ )
433
+ if output.producer is not None:
434
+ self.assertEqual(
435
+ output.producer.transformations,
436
+ [qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION],
437
+ )
438
+
374
439
 
375
440
  if __name__ == "__main__":
376
441
  googletest.main()
@@ -510,6 +510,7 @@ def _compatible_tensor_params(
510
510
  _QuantTrans.ADD_QUANTIZE,
511
511
  _QuantTrans.NO_QUANTIZE,
512
512
  _QuantTrans.INSERT_HADAMARD_ROTATION,
513
+ _QuantTrans.INSERT_DECOMPOSED_HADAMARD_ROTATION,
513
514
  ]
514
515
  quantized_source_transformations = [
515
516
  _QuantTrans.QUANTIZE_TENSOR,
@@ -133,6 +133,9 @@ class QuantTransformation(enum.Enum):
133
133
  DUPLICATE_TENSOR = 6
134
134
  # Insert the aeq.hadamard_rotation op.
135
135
  INSERT_HADAMARD_ROTATION = 7
136
+ # Insert decomposed Hadamard rotation ops. This expresses the Hadamard
137
+ # rotation as matrix multiplication with Hadamard matrices.
138
+ INSERT_DECOMPOSED_HADAMARD_ROTATION = 8
136
139
 
137
140
 
138
141
  @dataclasses.dataclass(frozen=True)
@@ -305,6 +308,7 @@ class TensorQuantizationConfig:
305
308
  quantization.
306
309
  dtype: The data type of the tensor.
307
310
  block_size: The block size for blockwise quantization, ignored otherwise.
311
+ algorithm_key: The algorithm key to use for quantization.
308
312
  """
309
313
 
310
314
  num_bits: int
@@ -24,6 +24,7 @@ from ai_edge_quantizer import qtyping
24
24
  from ai_edge_quantizer.transformations import dequant_insert
25
25
  from ai_edge_quantizer.transformations import duplicate_buffer
26
26
  from ai_edge_quantizer.transformations import duplicate_tensor
27
+ from ai_edge_quantizer.transformations import insert_decomposed_hadamard_rotation
27
28
  from ai_edge_quantizer.transformations import insert_hadamard_rotation
28
29
  from ai_edge_quantizer.transformations import quant_insert
29
30
  from ai_edge_quantizer.transformations import quantize_tensor
@@ -83,6 +84,9 @@ class TransformationPerformer:
83
84
  qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
84
85
  insert_hadamard_rotation.insert_hadamard_rotation
85
86
  ),
87
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION: (
88
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation
89
+ ),
86
90
  }
87
91
  # transformations are seprated in two categories:
88
92
  # op_insertion_transformations are transformations that only insert ops
@@ -95,6 +99,7 @@ class TransformationPerformer:
95
99
  qtyping.QuantTransformation.DUPLICATE_BUFFER,
96
100
  qtyping.QuantTransformation.DUPLICATE_TENSOR,
97
101
  qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
102
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION,
98
103
  ])
99
104
  self._op_replacement_transformations = set(
100
105
  [qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
@@ -0,0 +1,291 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Hadamard rotation decomposed pattern transformation."""
17
+
18
+ from flatbuffers import flexbuffers
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.transformations import transformation_utils
22
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ def _to_flexbuffer(
26
+ hadamard_size: int,
27
+ random_binary_vector: list[np.int8],
28
+ ) -> bytes:
29
+ """Converts hadamard_size to flexbuffer."""
30
+ fbb = flexbuffers.Builder()
31
+ with fbb.Map():
32
+ fbb.Int('hadamard_size', hadamard_size)
33
+ fbb.VectorFromElements('random_binary_vector', random_binary_vector)
34
+ return fbb.Finish()
35
+
36
+
37
+ def _update_embedding_lookup_consumers(
38
+ transformation: transformation_utils.TransformationInput,
39
+ new_tensor_id: int,
40
+ ) -> bool:
41
+ """Updates the consumers of the embedding lookup op to use the new tensor.
42
+
43
+ Args:
44
+ transformation: The transformation input to update the consumers of.
45
+ new_tensor_id: The new tensor id to use as the input to the embedding lookup
46
+ consumers.
47
+ """
48
+ for consumer in transformation.consumers:
49
+ # If the consumer is a graph output and not an op, we can ignore it here
50
+ # since the graph output will be updated later.
51
+ if consumer == -1:
52
+ continue
53
+ consumer_op = transformation.subgraph.operators[consumer]
54
+ # Find the input that was attached to the insertion point, and replace it
55
+ # with the new tensor.
56
+ for i in range(len(consumer_op.inputs)):
57
+ if consumer_op.inputs[i] == transformation.tensor_id:
58
+ consumer_op.inputs[i] = new_tensor_id
59
+
60
+
61
+ def _update_fully_connected_consumers(
62
+ transformation: transformation_utils.TransformationInput,
63
+ new_tensor_id: int,
64
+ ) -> bool:
65
+ """Updates the fully connected op(s) to use the new tensor.
66
+
67
+ Since the new tensor is inserted to the fully_connected's input, we need to
68
+ scan each consumer (in case of multiple fully_connected ops), and update
69
+ the input tensor to the new tensor.
70
+
71
+ Args:
72
+ transformation: The transformation input to update the consumers of.
73
+ new_tensor_id: The new tensor id to use as the input to the fully connected
74
+ consumers.
75
+
76
+ Returns:
77
+ True if the fully connected op(s) were updated to use the new tensor.
78
+ """
79
+ updated = False
80
+ for consumer in transformation.consumers:
81
+ if (
82
+ transformation_utils.get_schema_op_id(transformation, consumer)
83
+ == schema_py_generated.BuiltinOperator.FULLY_CONNECTED
84
+ ):
85
+ transformation.subgraph.operators[consumer].inputs[0] = new_tensor_id
86
+ updated = True
87
+ return updated
88
+
89
+
90
+ def _make_hadamard_matrix(size: int):
91
+ """Generates a Hadamard matrix of the given size.
92
+
93
+ Args:
94
+ size: The size of the Hadamard matrix. Must be a power of 2. This represents
95
+ a single dimension. E.g. if size is 4, then the Hadamard matrix is a 4x4
96
+ matrix.
97
+
98
+ Returns:
99
+ The Hadamard matrix.
100
+
101
+ Raises:
102
+ ValueError: If the size is not a power of 2.
103
+ """
104
+ if size <= 0 or (size & (size - 1)) != 0:
105
+ raise ValueError('Hadamard matrix size must be a power of 2. ')
106
+ h = h2 = np.array([[1, 1], [1, -1]])
107
+ current_size = 2
108
+ while current_size < size:
109
+ h = np.kron(h, h2)
110
+ current_size *= 2
111
+ return h / np.sqrt(size)
112
+
113
+
114
+ def insert_decomposed_hadamard_rotation(
115
+ transformation_input: transformation_utils.TransformationInput,
116
+ ) -> qtyping.TransformationInfo:
117
+ """Inserts a decomposed pattern of Hadamard rotation on this tensor.
118
+
119
+ This function works for float32 tensors only. Instead of inserting a single
120
+ custom op (aeq.hadamard_rotation), this inserts the mathematical equivalent
121
+ expressed in built-in TFLite ops. The mathematical equivalent is:
122
+ x' = reshape(x, (-1, hadamard_size))
123
+ x' = x' @ H(hadamard_size)
124
+ x' = reshape(x, x.shape)
125
+ where H(n) is a Hadamard matrix of size n.
126
+
127
+ Args:
128
+ transformation_input: The transformation input to insert the ops on.
129
+
130
+ Returns:
131
+ The transformation info of the inserted ops.
132
+
133
+ Raises:
134
+ ValueError: If the transformation input is not a uniform quantization
135
+ transformation.
136
+ ValueError: If the Hadamard quantization params are not set.
137
+ ValueError: If the tensor is not a float32 tensor.
138
+ ValueError: If no supported ops were found as the tensor's producer or
139
+ consumers.
140
+ """
141
+ if not isinstance(
142
+ transformation_input.quant_params, qtyping.UniformQuantParams
143
+ ):
144
+ raise ValueError('Hadamard rotation supports uniform quantization only')
145
+
146
+ if transformation_input.quant_params.hadamard is None:
147
+ raise ValueError(
148
+ 'Hadamard rotation quantization params are not set but op insertion is'
149
+ ' requested.'
150
+ )
151
+
152
+ tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
153
+ if tensor.type != schema_py_generated.TensorType.FLOAT32:
154
+ raise ValueError(
155
+ 'The Hadamard rotation op supports float32 tensors only. Got'
156
+ f' {tensor.type} tensor.'
157
+ )
158
+
159
+ # Insert x' = tfl.reshape to reshape x to (-1, hadamard_size)
160
+ hadamard_size = transformation_input.quant_params.hadamard.hadamard_size
161
+ tensor_size = np.prod(tensor.shape)
162
+ num_hadamard_blocks = tensor_size // hadamard_size
163
+ prerotate_shape = [num_hadamard_blocks, hadamard_size]
164
+ prerotate_shape_tensor_id = transformation_utils.add_new_constant_tensor(
165
+ tensor.name + b'_prerotate_shape',
166
+ np.array(prerotate_shape, dtype=np.int32),
167
+ schema_py_generated.TensorType.INT32,
168
+ transformation_input.subgraph,
169
+ transformation_input.buffers,
170
+ )
171
+ prerotate_reshape_output_tensor_id = (
172
+ transformation_utils.add_new_activation_tensor(
173
+ tensor.name + b'_prerotate_reshaped',
174
+ prerotate_shape,
175
+ schema_py_generated.TensorType.FLOAT32,
176
+ transformation_input.subgraph,
177
+ )
178
+ )
179
+
180
+ prerotate_reshape_op_code_idx = transformation_utils.add_op_code(
181
+ schema_py_generated.BuiltinOperator.RESHAPE,
182
+ transformation_input.op_codes,
183
+ 'RESHAPE',
184
+ )
185
+ prerorate_reshape_op = schema_py_generated.OperatorT()
186
+ prerorate_reshape_op.opcodeIndex = prerotate_reshape_op_code_idx
187
+ prerorate_reshape_op.inputs = [
188
+ transformation_input.tensor_id,
189
+ prerotate_shape_tensor_id,
190
+ ]
191
+ prerorate_reshape_op.outputs = [prerotate_reshape_output_tensor_id]
192
+
193
+ # Generate hadamard_matrix(hadamard_size).
194
+ # We could quantize this to INT4 for better memory efficiency, but for large
195
+ # models the memory overhead is not significant, and floating point
196
+ # computation does seem to result in better accuracy.
197
+ hadamard_matrix = _make_hadamard_matrix(hadamard_size)
198
+ hadamard_matrix_tensor_id = transformation_utils.add_new_constant_tensor(
199
+ tensor.name + b'_hadamard_matrix',
200
+ hadamard_matrix.astype(np.float32),
201
+ schema_py_generated.TensorType.FLOAT32,
202
+ transformation_input.subgraph,
203
+ transformation_input.buffers,
204
+ )
205
+
206
+ # Insert x' = tfl.fully_connected(x', hadamard_matrix)
207
+ fc_output_tensor_id = transformation_utils.add_new_activation_tensor(
208
+ tensor.name + b'_rotated',
209
+ prerotate_shape,
210
+ schema_py_generated.TensorType.FLOAT32,
211
+ transformation_input.subgraph,
212
+ )
213
+
214
+ fc_op_code_idx = transformation_utils.add_op_code(
215
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
216
+ transformation_input.op_codes,
217
+ 'FULLY_CONNECTED',
218
+ )
219
+ fc_op = schema_py_generated.OperatorT()
220
+ fc_op.opcodeIndex = fc_op_code_idx
221
+ fc_op.inputs = [prerotate_reshape_output_tensor_id, hadamard_matrix_tensor_id]
222
+ fc_op.outputs = [fc_output_tensor_id]
223
+
224
+ # Insert x' = tfl.reshape(x', x.shape)
225
+ post_reshape_op_code_idx = transformation_utils.add_op_code(
226
+ schema_py_generated.BuiltinOperator.RESHAPE,
227
+ transformation_input.op_codes,
228
+ 'RESHAPE',
229
+ )
230
+ post_reshape_op = schema_py_generated.OperatorT()
231
+ post_reshape_op.opcodeIndex = post_reshape_op_code_idx
232
+ post_reshape_shape_tensor_id = transformation_utils.add_new_constant_tensor(
233
+ tensor.name + b'_postrotate_shape',
234
+ np.array(tensor.shape, dtype=np.int32),
235
+ schema_py_generated.TensorType.INT32,
236
+ transformation_input.subgraph,
237
+ transformation_input.buffers,
238
+ )
239
+
240
+ post_reshape_output_tensor_id = (
241
+ transformation_utils.add_new_activation_tensor(
242
+ tensor.name + b'_postrotate_reshaped',
243
+ tensor.shape,
244
+ schema_py_generated.TensorType.FLOAT32,
245
+ transformation_input.subgraph,
246
+ )
247
+ )
248
+ post_reshape_op.inputs = [
249
+ fc_output_tensor_id,
250
+ post_reshape_shape_tensor_id,
251
+ ]
252
+ post_reshape_op.outputs = [post_reshape_output_tensor_id]
253
+
254
+ # Update the users of this tensor to use the new tensor.
255
+ if (
256
+ transformation_utils.get_producer_schema_op_id(transformation_input)
257
+ == schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
258
+ ):
259
+ _update_embedding_lookup_consumers(
260
+ transformation_input, post_reshape_output_tensor_id
261
+ )
262
+ elif not _update_fully_connected_consumers(
263
+ transformation_input, post_reshape_output_tensor_id
264
+ ):
265
+ raise ValueError(
266
+ 'The Hadamard rotation op supports embedding lookup and fully connected'
267
+ ' ops only, but no such ops were found.'
268
+ )
269
+
270
+ # If the tensor is a graph output, we need to replace the tensor with the
271
+ # new tensor.
272
+ for i, output in enumerate(transformation_input.subgraph.outputs):
273
+ if output == transformation_input.tensor_id:
274
+ transformation_input.subgraph.outputs[i] = post_reshape_output_tensor_id
275
+
276
+ # Find the actual insertion point. The insertion point should be after the
277
+ # producer op and before the first consumer op. The max() operation ensures
278
+ # that we're not using -1 as the insertion point.
279
+ first_consumer_id = min(transformation_input.consumers)
280
+ op_id = max(transformation_input.producer + 1, first_consumer_id)
281
+
282
+ # Insert the new ops in the correct order.
283
+ transformation_input.subgraph.operators.insert(op_id, prerorate_reshape_op)
284
+ transformation_input.subgraph.operators.insert(op_id + 1, fc_op)
285
+ transformation_input.subgraph.operators.insert(op_id + 2, post_reshape_op)
286
+
287
+ return qtyping.TransformationInfo(
288
+ op_id=op_id,
289
+ num_ops_added=3,
290
+ output_tensor_id=post_reshape_output_tensor_id,
291
+ )
@@ -0,0 +1,244 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Test insertion of the Decomposed Hadamard rotation ops."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from tensorflow.python.platform import googletest
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.transformations import insert_decomposed_hadamard_rotation
23
+ from ai_edge_quantizer.transformations import transformation_utils
24
+ from ai_edge_quantizer.utils import test_utils
25
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
27
+
28
+ _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('..')
29
+
30
+
31
+ class InsertDecomposedHadamardRotationFullyConnectedTest(googletest.TestCase):
32
+
33
+ def setUp(self):
34
+ super().setUp()
35
+ model_path = os.path.join(
36
+ _TEST_DATA_PREFIX_PATH, 'tests/models/single_fc_bias.tflite'
37
+ )
38
+ self.model = tfl_flatbuffer_utils.read_model(model_path)
39
+ self.params = qtyping.UniformQuantParams(
40
+ num_bits=8,
41
+ quantized_dimension=None,
42
+ scale=np.ones(1),
43
+ zero_point=np.zeros(1),
44
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
45
+ random_binary_vector=np.ones(1),
46
+ hadamard_size=2,
47
+ ),
48
+ )
49
+
50
+ def test_raise_unsupported_qparams(self):
51
+ with self.assertRaisesWithPredicateMatch(
52
+ ValueError, lambda err: 'uniform quantization' in str(err)
53
+ ):
54
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
55
+ transformation_utils.TransformationInput(
56
+ tensor_id=0,
57
+ op_codes=self.model.operatorCodes,
58
+ buffers=self.model.buffers,
59
+ subgraph=self.model.subgraphs[0],
60
+ producer=-1,
61
+ consumers=[-1],
62
+ quant_params=qtyping.NonLinearQuantParams(
63
+ num_bits=16, quantized_data=None
64
+ ),
65
+ )
66
+ )
67
+
68
+ def test_raise_missing_hadamard_data(self):
69
+ with self.assertRaisesWithPredicateMatch(
70
+ ValueError, lambda err: 'quantization params are not set' in str(err)
71
+ ):
72
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
73
+ transformation_utils.TransformationInput(
74
+ tensor_id=0,
75
+ op_codes=self.model.operatorCodes,
76
+ buffers=self.model.buffers,
77
+ subgraph=self.model.subgraphs[0],
78
+ producer=-1,
79
+ consumers=[-1],
80
+ quant_params=qtyping.UniformQuantParams(
81
+ num_bits=8,
82
+ quantized_dimension=None,
83
+ scale=np.ones(1),
84
+ zero_point=np.zeros(1),
85
+ ),
86
+ )
87
+ )
88
+
89
+ def test_raise_non_float32_tensor(self):
90
+ self.model.subgraphs[0].tensors[
91
+ 0
92
+ ].type = schema_py_generated.TensorType.INT32
93
+ with self.assertRaisesWithPredicateMatch(
94
+ ValueError, lambda err: 'float32 tensors' in str(err)
95
+ ):
96
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
97
+ transformation_utils.TransformationInput(
98
+ tensor_id=0,
99
+ op_codes=self.model.operatorCodes,
100
+ buffers=self.model.buffers,
101
+ subgraph=self.model.subgraphs[0],
102
+ producer=-1,
103
+ consumers=[-1],
104
+ quant_params=self.params,
105
+ ),
106
+ )
107
+
108
+ def test_insert_decomposed_ops(self):
109
+ # Insert Decomposed Hadamard ops before fully_connected
110
+ info = (
111
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
112
+ transformation_utils.TransformationInput(
113
+ tensor_id=0,
114
+ op_codes=self.model.operatorCodes,
115
+ buffers=self.model.buffers,
116
+ subgraph=self.model.subgraphs[0],
117
+ producer=-1,
118
+ consumers=[0], # Consumer is the FC op
119
+ quant_params=self.params,
120
+ )
121
+ )
122
+ )
123
+ subgraph = self.model.subgraphs[0]
124
+ self.assertEqual(info.op_id, 0)
125
+ self.assertEqual(info.num_ops_added, 3)
126
+ # Model had 4 tensors, added 6 tensors (3 activations 3 constants).
127
+ self.assertEqual(info.output_tensor_id, 9)
128
+ self.assertLen(subgraph.tensors, 10)
129
+ # Model had 1 op code, added RESHAPE and FC.
130
+ self.assertLen(self.model.operatorCodes, 3)
131
+ self.assertEqual(
132
+ self.model.operatorCodes[1].builtinCode,
133
+ schema_py_generated.BuiltinOperator.RESHAPE,
134
+ )
135
+ self.assertEqual(
136
+ self.model.operatorCodes[2].builtinCode,
137
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
138
+ )
139
+
140
+ # Op 0: RESHAPE
141
+ reshape_op = subgraph.operators[0]
142
+ self.assertEqual(
143
+ self.model.operatorCodes[reshape_op.opcodeIndex].builtinCode,
144
+ schema_py_generated.BuiltinOperator.RESHAPE,
145
+ )
146
+ self.assertEqual(reshape_op.inputs[0], 0) # Graph input
147
+ self.assertEqual(reshape_op.outputs[0], 5) # Reshape output
148
+
149
+ # Op 1: FULLY_CONNECTED
150
+ fc_op = subgraph.operators[1]
151
+ self.assertEqual(
152
+ self.model.operatorCodes[fc_op.opcodeIndex].builtinCode,
153
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
154
+ )
155
+ self.assertEqual(fc_op.inputs[0], 5) # Reshape output
156
+ self.assertEqual(fc_op.inputs[1], 6) # Hadamard matrix tensor
157
+ self.assertEqual(fc_op.outputs[0], 7) # FC output
158
+
159
+ # Op 2: RESHAPE (post)
160
+ post_reshape_op = subgraph.operators[2]
161
+ self.assertEqual(
162
+ self.model.operatorCodes[post_reshape_op.opcodeIndex].builtinCode,
163
+ schema_py_generated.BuiltinOperator.RESHAPE,
164
+ )
165
+ self.assertEqual(post_reshape_op.inputs[0], 7) # FC output
166
+ self.assertEqual(post_reshape_op.outputs[0], 9) # Post Reshape output
167
+
168
+ # Op 3: Original FULLY_CONNECTED
169
+ orig_fc_op = subgraph.operators[3]
170
+ self.assertEqual(
171
+ self.model.operatorCodes[orig_fc_op.opcodeIndex].builtinCode,
172
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED,
173
+ )
174
+ # Input to the original FC is the post reshape output
175
+ self.assertEqual(orig_fc_op.inputs[0], 9)
176
+
177
+
178
+ class InsertDecomposedHadamardRotationEmbeddingLookupTest(googletest.TestCase):
179
+
180
+ def setUp(self):
181
+ super().setUp()
182
+ model_path = os.path.join(
183
+ _TEST_DATA_PREFIX_PATH, 'tests/models/embedding_lookup.tflite'
184
+ )
185
+ self.model = tfl_flatbuffer_utils.read_model(model_path)
186
+ self.params = qtyping.UniformQuantParams(
187
+ num_bits=8,
188
+ quantized_dimension=None,
189
+ scale=np.ones(1),
190
+ zero_point=np.zeros(1),
191
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
192
+ random_binary_vector=np.ones(1),
193
+ hadamard_size=2,
194
+ ),
195
+ )
196
+
197
+ def test_insert_decomposed_ops(self):
198
+ # Insert Decomposed Hadamard ops after embedding_lookup
199
+ info = (
200
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation(
201
+ transformation_utils.TransformationInput(
202
+ tensor_id=2, # Output of embedding_lookup
203
+ op_codes=self.model.operatorCodes,
204
+ buffers=self.model.buffers,
205
+ subgraph=self.model.subgraphs[0],
206
+ producer=0,
207
+ consumers=[-1], # Output is a graph output
208
+ quant_params=self.params,
209
+ )
210
+ )
211
+ )
212
+ subgraph = self.model.subgraphs[0]
213
+ self.assertEqual(info.op_id, 1)
214
+ self.assertEqual(info.num_ops_added, 3)
215
+ # Model had 3 tensors, added 6 (3 activations 3 constants).
216
+ self.assertEqual(info.output_tensor_id, 8)
217
+ self.assertLen(subgraph.tensors, 9)
218
+ # Model had 1 op code, added RESHAPE and FC.
219
+ self.assertLen(self.model.operatorCodes, 3)
220
+
221
+ # Op 0: EMBEDDING_LOOKUP (Original)
222
+ # Op 1: RESHAPE
223
+ reshape_op = subgraph.operators[1]
224
+ self.assertEqual(reshape_op.inputs[0], 2) # Embedding lookup output
225
+ self.assertEqual(reshape_op.outputs[0], 4)
226
+
227
+ # Op 2: FULLY_CONNECTED
228
+ fc_op = subgraph.operators[2]
229
+ self.assertEqual(fc_op.inputs[0], 4)
230
+ self.assertEqual(fc_op.inputs[1], 5) # Hadamard matrix
231
+ self.assertEqual(fc_op.outputs[0], 6)
232
+
233
+ # Op 3: RESHAPE (post)
234
+ post_reshape_op = subgraph.operators[3]
235
+ self.assertEqual(post_reshape_op.inputs[0], 6)
236
+ self.assertEqual(post_reshape_op.outputs[0], 8)
237
+
238
+ # Check graph output
239
+ self.assertIn(8, subgraph.outputs)
240
+ self.assertNotIn(2, subgraph.outputs)
241
+
242
+
243
+ if __name__ == '__main__':
244
+ googletest.main()
@@ -34,35 +34,6 @@ def _to_flexbuffer(
34
34
  return fbb.Finish()
35
35
 
36
36
 
37
- def _is_producer_embedding_lookup(
38
- transformation: transformation_utils.TransformationInput,
39
- ) -> bool:
40
- """Checks if the tensor's producer is an embedding lookup op."""
41
- if transformation.producer == -1:
42
- return False
43
- else:
44
- return (
45
- transformation.op_codes[
46
- transformation.subgraph.operators[
47
- transformation.producer
48
- ].opcodeIndex
49
- ].builtinCode
50
- == schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
51
- )
52
-
53
-
54
- def _is_fully_connected(
55
- transformation: transformation_utils.TransformationInput, op_id: int
56
- ) -> bool:
57
- """Checks if the any of the tensor's consumers is a fully connected op."""
58
- return (
59
- transformation.op_codes[
60
- transformation.subgraph.operators[op_id].opcodeIndex
61
- ].builtinCode
62
- == schema_py_generated.BuiltinOperator.FULLY_CONNECTED
63
- )
64
-
65
-
66
37
  def _update_embedding_lookup_consumers(
67
38
  transformation: transformation_utils.TransformationInput,
68
39
  new_tensor_id: int,
@@ -107,7 +78,10 @@ def _update_fully_connected_consumers(
107
78
  """
108
79
  updated = False
109
80
  for consumer in transformation.consumers:
110
- if _is_fully_connected(transformation, consumer):
81
+ if (
82
+ transformation_utils.get_schema_op_id(transformation, consumer)
83
+ == schema_py_generated.BuiltinOperator.FULLY_CONNECTED
84
+ ):
111
85
  transformation.subgraph.operators[consumer].inputs[0] = new_tensor_id
112
86
  updated = True
113
87
  return updated
@@ -177,7 +151,10 @@ def insert_hadamard_rotation(
177
151
  custom_op.outputs = [new_tensor_id]
178
152
 
179
153
  # Update the users of this tensor to use the new tensor.
180
- if _is_producer_embedding_lookup(transformation_input):
154
+ if (
155
+ transformation_utils.get_producer_schema_op_id(transformation_input)
156
+ == schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
157
+ ):
181
158
  _update_embedding_lookup_consumers(transformation_input, new_tensor_id)
182
159
  elif not _update_fully_connected_consumers(
183
160
  transformation_input, new_tensor_id
@@ -68,29 +68,6 @@ def nonlinear_quant_params_to_tflite_type(
68
68
  raise ValueError(f"Unsupported nonlinear params: {bitwidth}")
69
69
 
70
70
 
71
- def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
72
- """Pack the data to the corresponding bit width.
73
-
74
- Currently only support 4 bits. If no packing is needed, the original data is
75
- returned.
76
-
77
- Args:
78
- bitwidth: Bit width from NonLinearQuantParams.
79
- flattened_data: The data to be packed.
80
-
81
- Returns:
82
- Packed data.
83
- """
84
- if bitwidth == 4:
85
- even_data = flattened_data[::2] & 0x0F
86
- odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
87
- if odd_data.shape[0] == even_data.shape[0] - 1:
88
- odd_data = np.pad(odd_data, (0, 1), constant_values=0)
89
- return np.bitwise_or(even_data, odd_data)
90
- else:
91
- return flattened_data
92
-
93
-
94
71
  def _perform_channelwise_quantization(
95
72
  transformation_input: transformation_utils.TransformationInput,
96
73
  ) -> schema_py_generated.QuantizationParametersT():
@@ -180,14 +157,17 @@ def quantize_tensor(
180
157
  # is not provided.
181
158
  if tensor.buffer:
182
159
  if transformation_input.quant_params.quantized_data is not None:
183
- transformation_input.buffers[tensor.buffer].data = _pack_data(
184
- transformation_input.quant_params.num_bits,
185
- np.frombuffer(
186
- cast(
187
- np.ndarray, transformation_input.quant_params.quantized_data
188
- ).tobytes(),
189
- dtype=np.uint8,
190
- ).flatten(),
160
+ transformation_input.buffers[tensor.buffer].data = (
161
+ transformation_utils.pack_data(
162
+ transformation_input.quant_params.num_bits,
163
+ np.frombuffer(
164
+ cast(
165
+ np.ndarray,
166
+ transformation_input.quant_params.quantized_data,
167
+ ).tobytes(),
168
+ dtype=np.uint8,
169
+ ).flatten(),
170
+ )
191
171
  )
192
172
 
193
173
  if isinstance(transformation_input.quant_params, qtyping.UniformQuantParams):
@@ -210,3 +210,69 @@ def raise_deprecated_error(_: TransformationInput):
210
210
  'This transformation is deprecated. Please contact AI Edge Quantizer team'
211
211
  ' if you see this error.'
212
212
  )
213
+
214
+
215
+ def pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
216
+ """Pack the data to the corresponding bit width.
217
+
218
+ Currently only support 4 bits. If no packing is needed, the original data is
219
+ returned.
220
+
221
+ Args:
222
+ bitwidth: Bit width from NonLinearQuantParams.
223
+ flattened_data: The data to be packed.
224
+
225
+ Returns:
226
+ Packed data.
227
+ """
228
+ if bitwidth == 4:
229
+ even_data = flattened_data[::2] & 0x0F
230
+ odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
231
+ if odd_data.shape[0] == even_data.shape[0] - 1:
232
+ odd_data = np.pad(odd_data, (0, 1), constant_values=0)
233
+ return np.bitwise_or(even_data, odd_data)
234
+ else:
235
+ return flattened_data
236
+
237
+
238
+ def get_producer_schema_op_id(
239
+ transformation: TransformationInput,
240
+ ) -> int:
241
+ """Checks if the tensor's producer matches the given op.
242
+
243
+ Args:
244
+ transformation: The transformation input to check the producer of.
245
+
246
+ Returns:
247
+ The schema op id of the producer op. E.g.
248
+ schema_py_generated.BuiltinOperator.FULLY_CONNECTED.
249
+ """
250
+ if transformation.producer == -1:
251
+ return False
252
+ else:
253
+ return (
254
+ transformation.op_codes[
255
+ transformation.subgraph.operators[
256
+ transformation.producer
257
+ ].opcodeIndex
258
+ ].builtinCode
259
+ )
260
+
261
+
262
+ def get_schema_op_id(
263
+ transformation: TransformationInput, op_id: int
264
+ ) -> bool:
265
+ """Returns the schema op id of the given op.
266
+
267
+ Args:
268
+ transformation: The transformation input to check the consumers of.
269
+ op_id: The op id in the list of operators to check for.
270
+
271
+ Returns:
272
+ The schema op id of the given op.
273
+ """
274
+ return (
275
+ transformation.op_codes[
276
+ transformation.subgraph.operators[op_id].opcodeIndex
277
+ ].builtinCode
278
+ )
@@ -38,6 +38,8 @@ def get_validation_func(
38
38
  return mean_squared_difference
39
39
  elif func_name == "median_diff_ratio":
40
40
  return median_diff_ratio
41
+ elif func_name == "cosine_similarity":
42
+ return cosine_similarity
41
43
  else:
42
44
  raise ValueError(f"Validation function {func_name} not supported")
43
45
 
@@ -99,6 +101,33 @@ def median_diff_ratio(
99
101
  return median_ratio
100
102
 
101
103
 
104
+ def cosine_similarity(
105
+ data1: np._typing.ArrayLike, data2: np._typing.ArrayLike
106
+ ) -> float:
107
+ """Calculates the cosine similarity between data1 & data2.
108
+
109
+ ref: https://en.wikipedia.org/wiki/Cosine_similarity
110
+
111
+ Args:
112
+ data1: input data to be used for comparison
113
+ data2: input data to be used for comparison, data1 & 2 must be of the same
114
+ shape
115
+
116
+ Returns:
117
+ a float value representing the cosine similarity between data1 & 2
118
+
119
+ Raises:
120
+ Value error if the two inputs don't have the same number of elements
121
+ """
122
+ data1, data2 = _preprocess_same_size_arrays(data1, data2)
123
+ # special handling for tensor of size 0
124
+ if data1.size == 0:
125
+ return float(0)
126
+ return float(
127
+ np.dot(data1, data2) / (np.linalg.norm(data1) * np.linalg.norm(data2))
128
+ )
129
+
130
+
102
131
  def _preprocess_same_size_arrays(
103
132
  data1: np._typing.ArrayLike, data2: np._typing.ArrayLike
104
133
  ) -> Tuple[np.ndarray, np.ndarray]:
@@ -82,6 +82,30 @@ class ValidationUtilTest(googletest.TestCase):
82
82
  result = validation_utils.median_diff_ratio(data1, data2)
83
83
  self.assertEqual(result, 0)
84
84
 
85
+ def test_cosine_similarity(self):
86
+ data1 = [1, 2, 3]
87
+ data2 = [1, 2, 3]
88
+ result = validation_utils.cosine_similarity(data1, data2)
89
+ self.assertAlmostEqual(result, 1.0, 6)
90
+
91
+ def test_cosine_similarity_perpendicular(self):
92
+ data1 = [1, 0, 0]
93
+ data2 = [0, 1, 0]
94
+ result = validation_utils.cosine_similarity(data1, data2)
95
+ self.assertAlmostEqual(result, 0.0, 6)
96
+
97
+ def test_cosine_similarity_multidim(self):
98
+ data1 = [[1, 2], [4, 5]]
99
+ data2 = [[1, 3], [2, 2]]
100
+ result = validation_utils.cosine_similarity(data1, data2)
101
+ self.assertAlmostEqual(result, 0.86881, 6)
102
+
103
+ def test_cosine_similarity_0d(self):
104
+ data1 = []
105
+ data2 = []
106
+ result = validation_utils.cosine_similarity(data1, data2)
107
+ self.assertEqual(result, 0)
108
+
85
109
 
86
110
  if __name__ == "__main__":
87
111
  googletest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.4.0.dev20250930
3
+ Version: 0.4.0.dev20251001
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -1,5 +1,5 @@
1
1
  ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas,793
2
- ai_edge_quantizer/algorithm_manager.py,sha256=XkLMG_wQqf_X6swp6YBIhJpIbIdRcOt2LJ_6oTZ3GzU,14956
2
+ ai_edge_quantizer/algorithm_manager.py,sha256=6-kM-pV6N7vCY7c043fWXuGTNHWdXqvI3ezDxjBdrx0,16022
3
3
  ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
4
4
  ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
5
5
  ai_edge_quantizer/calibrator.py,sha256=Sms7_AIHPH9G5xFaz5Ef3a5gPhxuIWQI8d2LUM8C96I,12071
@@ -10,9 +10,9 @@ ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4b
10
10
  ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
11
11
  ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
12
12
  ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
13
- ai_edge_quantizer/params_generator.py,sha256=hcgMHJlERZERUyIAEi6AHJcLJ8gsKIBAEojzFFz-tqk,20098
13
+ ai_edge_quantizer/params_generator.py,sha256=0w-sDGk84sVNkXoduon1wDqq30sGOHVgBVbdg44QVF4,20153
14
14
  ai_edge_quantizer/params_generator_test.py,sha256=RDYoRZDJfEZRtjlTAU2kZ_4t3JHOqEHxfJX9V4ETAhg,40597
15
- ai_edge_quantizer/qtyping.py,sha256=tfrPip-uzJuF_PASgUExx5Oy9gghWUbQaApR0XaBpNw,16882
15
+ ai_edge_quantizer/qtyping.py,sha256=gBNBPuh488IujrpCTIoNrPKRXJXdsuzbzl7oi7MPZpc,17121
16
16
  ai_edge_quantizer/quantizer.py,sha256=ckAEOnnBxuCKZuvlzdChevCKPuE-IeDPHCNtFTWr250,17857
17
17
  ai_edge_quantizer/quantizer_test.py,sha256=m6f4ayyaF3yQb9i4V0aFAbmGw0OKZ2Zam1RoTPh-u24,22917
18
18
  ai_edge_quantizer/recipe.py,sha256=MEkfQ2Sg3KAE9LAORHWcbjYNPg06EUbwc1d-VspQA2U,6461
@@ -21,7 +21,7 @@ ai_edge_quantizer/recipe_manager_test.py,sha256=qjgGUF-wggXnSXqZ5khmqrDMIQI5CShk
21
21
  ai_edge_quantizer/recipe_test.py,sha256=QisyaTol8JRZFcGOGyee7QRCvqj5VbF4guKWdIoMUOE,6213
22
22
  ai_edge_quantizer/transformation_instruction_generator.py,sha256=O0U2aZcB8aXQgOV8r9g1rGNzDUiuI5Ta53XnxZbVffE,31576
23
23
  ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=KW5-WoTTo9IqLEVnWxVC8ut8eWLi_91xfKgGqVQ9QDk,54635
24
- ai_edge_quantizer/transformation_performer.py,sha256=o4J6OUbI0dLoobVYjkOFw5Po3yH0gZJXrfuTIYais4o,13029
24
+ ai_edge_quantizer/transformation_performer.py,sha256=mFsig0E5Isy7cnG1wMO2jzBn3Wql8fElM_PSpaL8okw,13354
25
25
  ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
26
26
  ai_edge_quantizer/algorithms/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
27
27
  ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
@@ -32,8 +32,8 @@ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=E17cSR-M
32
32
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
33
33
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=BDdn_uBZakfHyzdMJPKadsOqxqyC-s6W2ZzFH99L4fE,8652
34
34
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
35
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py,sha256=otKRiZn_C0QH0891pxLsIPIBT1mLDwbKYYP7bI-MXAA,12279
36
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py,sha256=_SpP12aDLujv_7tWf_mCt89WknNXTSGE-JpZWO1bYSE,13238
35
+ ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py,sha256=qxt9CPDcidVWIxp5nSWPN2hKKj1XZcsOOLBd2SYIvW0,14572
36
+ ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py,sha256=mgv6aGIqQouxfA8_GacuGdOftvL75XBF1_h5tlCCYJQ,15468
37
37
  ai_edge_quantizer/algorithms/uniform_quantize/mse.py,sha256=qiIyzogATGVxjYwxzH0cZvgwPSPBJv_3y8NSumHZXTk,4561
38
38
  ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py,sha256=-_P4jQJ7gVo0FNSapP3sIGcnhwfjQHW1AKLfoiAlS_s,7142
39
39
  ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=1sB2j1vlvvWDKyjcGvA_JLCpN2KbCmMslGCBUc4--V4,8461
@@ -52,13 +52,15 @@ ai_edge_quantizer/transformations/duplicate_buffer.py,sha256=TvTHbm24IiICNkWOlvR
52
52
  ai_edge_quantizer/transformations/duplicate_buffer_test.py,sha256=YYWl3Q5WF60s8T8pLzzA8TCSxz-i7dqc03dJt1LtMw4,3880
53
53
  ai_edge_quantizer/transformations/duplicate_tensor.py,sha256=WKhf2LIAL0MnZe88b6942A37lvHXe1cFjUDqE5VNmvU,2490
54
54
  ai_edge_quantizer/transformations/duplicate_tensor_test.py,sha256=s-RqSxNBMfVJyCunXz2eb7-KA6UiBmbOmL7phLslENQ,5056
55
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py,sha256=rBbKgcVKHie38NT2UQ7KQ1xCb2tRu_rVl0yFloOAW_A,7562
55
+ ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py,sha256=D47xTbMQM-R2X3SwSG1RjOAKxvGp76y61aaZA1VyN8E,10791
56
+ ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py,sha256=Z9Nr5e5aEeEMahhhizFyOkAMEXkEg1EKYZ_bGb5Vbvw,8993
57
+ ai_edge_quantizer/transformations/insert_hadamard_rotation.py,sha256=5D5WwrJCE6hQoANbMwa6YGBbjcG5HcL_rkkoXIAIW9w,6883
56
58
  ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py,sha256=iV1p3nZfHUATV2YRoBOYurnu3pLy8n3aFppLWGQOPdA,7268
57
59
  ai_edge_quantizer/transformations/quant_insert.py,sha256=jn6HsJaV-sqBiFPY-Aqbd64t8zgcYVkEkZI375x_FWY,3958
58
60
  ai_edge_quantizer/transformations/quant_insert_test.py,sha256=X9ptPDvJCFkR5tejKnD1SlHFGPazQTW-wNNMV9MEAuw,10107
59
- ai_edge_quantizer/transformations/quantize_tensor.py,sha256=kjaNrw9mnrn0t8u0vey9S_uPz3iVUicwy4rluxVqV3E,7617
61
+ ai_edge_quantizer/transformations/quantize_tensor.py,sha256=unqInO0we6_cgwPjtHB3tLWIHPajfNuJSLGW-IFnI9E,7029
60
62
  ai_edge_quantizer/transformations/quantize_tensor_test.py,sha256=mHLO3_MRt36A8-ZN8ADn5tBBJlqjTWa7ZUN8Mmu5Rcw,9116
61
- ai_edge_quantizer/transformations/transformation_utils.py,sha256=efJdAkA24wlg6Vj5NFO7_7MDuvQLSNn-l11Vs_JPktI,7123
63
+ ai_edge_quantizer/transformations/transformation_utils.py,sha256=IKrtXJNH0msiTcI7KXkCYn2EkzmbZKWMMX_r5PMEx2U,8857
62
64
  ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rvxRQIfi4ny9IoODFCTcbpjnIwoCL40zDKk,8698
63
65
  ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
64
66
  ai_edge_quantizer/utils/calibration_utils.py,sha256=iMf_bSCf-O86MzDt5D9hLKqbTydqLwirluaC6BJ9yHo,11553
@@ -70,10 +72,10 @@ ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=anbxbIKS7t8iIkJZJH7AkAR18
70
72
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
71
73
  ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EoVjI_hplX_Rml3hfRsGmQOihexmizeJqt4SQcET9aA,14925
72
74
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
73
- ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
74
- ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
75
- ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
76
- ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info/METADATA,sha256=2ScDdoSyEtkSTdXwzLZehwVjwd9rENGfotDREu05Ec4,1508
77
- ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
78
- ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
79
- ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info/RECORD,,
75
+ ai_edge_quantizer/utils/validation_utils.py,sha256=yJH9Cvepr_XWn-3Hsh91j7HuC5iLQHAyskyQ48bGNoc,4797
76
+ ai_edge_quantizer/utils/validation_utils_test.py,sha256=1sblJWHLTYTbn1Qi9rwnrREOSXRy5KwHAWSwgI1e_aU,3697
77
+ ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
78
+ ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info/METADATA,sha256=Lz_ZDdMl2fFALbGbeUJwv3sP8ZrgOsYLLgdlD5p5G-I,1508
79
+ ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
80
+ ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
81
+ ai_edge_quantizer_nightly-0.4.0.dev20251001.dist-info/RECORD,,