ai-edge-quantizer-nightly 0.4.0.dev20250930__py3-none-any.whl → 0.4.0.dev20251002__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. ai_edge_quantizer/algorithm_manager.py +40 -3
  2. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +28 -0
  3. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +77 -8
  4. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +69 -4
  5. ai_edge_quantizer/default_policy.py +4 -2
  6. ai_edge_quantizer/params_generator.py +1 -0
  7. ai_edge_quantizer/qtyping.py +5 -0
  8. ai_edge_quantizer/transformation_performer.py +5 -0
  9. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +291 -0
  10. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  11. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +8 -31
  12. ai_edge_quantizer/transformations/quantize_tensor.py +11 -31
  13. ai_edge_quantizer/transformations/transformation_utils.py +66 -0
  14. ai_edge_quantizer/utils/constrained_ops_utils_test.py +1 -1
  15. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
  16. ai_edge_quantizer/utils/validation_utils.py +29 -0
  17. ai_edge_quantizer/utils/validation_utils_test.py +24 -0
  18. {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/METADATA +1 -1
  19. {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/RECORD +22 -20
  20. {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/LICENSE +0 -0
  21. {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/WHEEL +0 -0
  22. {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,8 @@ class AlgorithmName(str, enum.Enum):
61
61
  FLOAT_CASTING = float_casting.ALGORITHM_KEY
62
62
  DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
63
63
  OCTAV = octav.ALGORITHM_KEY
64
- HADAMARD_ROTATION = hadamard_rotation.ALGORITHM_KEY
64
+ HADAMARD_ROTATION = hadamard_rotation.CUSTOM_OP_ALGORITHM_KEY
65
+ DECOMPOSED_HADAMARD_ROTATION = hadamard_rotation.DECOMPOSED_ALGORITHM_KEY
65
66
  MSE = mse.ALGORITHM_KEY
66
67
 
67
68
 
@@ -130,6 +131,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
130
131
  _TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
131
132
  _TFLOpName.EQUAL: common_quantize.materialize_equal,
132
133
  _TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
134
+ _TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
133
135
  }
134
136
  for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
135
137
  register_quantized_op(
@@ -283,6 +285,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
283
285
  _TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
284
286
  _TFLOpName.EQUAL: common_quantize.materialize_equal,
285
287
  _TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
288
+ _TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
286
289
  })
287
290
 
288
291
  for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
@@ -311,8 +314,12 @@ register_config_check_policy_func(
311
314
 
312
315
  # Register specialized hadamard rotation materialize functions.
313
316
  _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
314
- _TFLOpName.FULLY_CONNECTED: hadamard_rotation.materialize_fully_connected,
315
- _TFLOpName.EMBEDDING_LOOKUP: hadamard_rotation.materialize_embedding_lookup,
317
+ _TFLOpName.FULLY_CONNECTED: (
318
+ hadamard_rotation.materialize_fully_connected_custom_op
319
+ ),
320
+ _TFLOpName.EMBEDDING_LOOKUP: (
321
+ hadamard_rotation.materialize_embedding_lookup_custom_op
322
+ ),
316
323
  })
317
324
  for (
318
325
  op_name,
@@ -326,6 +333,36 @@ for (
326
333
  materialize_func=materialize_func,
327
334
  )
328
335
 
336
+ register_op_quant_config_validation_func(
337
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
338
+ common_quantize.check_op_quantization_config,
339
+ )
340
+
341
+ register_config_check_policy_func(
342
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
343
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
344
+ )
345
+
346
+ _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
347
+ _TFLOpName.FULLY_CONNECTED: (
348
+ hadamard_rotation.materialize_fully_connected_decomposed
349
+ ),
350
+ _TFLOpName.EMBEDDING_LOOKUP: (
351
+ hadamard_rotation.materialize_embedding_lookup_decomposed
352
+ ),
353
+ })
354
+ for (
355
+ op_name,
356
+ materialize_func,
357
+ ) in _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
358
+ register_quantized_op(
359
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
360
+ op_name,
361
+ naive_min_max_quantize.init_qsvs,
362
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
363
+ materialize_func=materialize_func,
364
+ )
365
+
329
366
 
330
367
  # Register the MSE algorithm.
331
368
  register_op_quant_config_validation_func(
@@ -748,6 +748,34 @@ def materialize_padv2(
748
748
  )
749
749
 
750
750
 
751
+ def materialize_mirror_pad(
752
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
753
+ op_info: qtyping.OpInfo,
754
+ graph_info: qtyping.GraphInfo,
755
+ tensor_name_to_qsv: dict[str, Any],
756
+ ) -> list[qtyping.TensorTransformationParams]:
757
+ """Materialize tensors in tfl.mirror_pad.
758
+
759
+ Args:
760
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
761
+ tensor.
762
+ op_info: Aggregated information about the op (e.g., quantization config).
763
+ graph_info: Graph information needed to perform quantization for the op.
764
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
765
+
766
+ Returns:
767
+ A list of `qtyping.TensorTransformationParams` for the tensors in the op.
768
+ """
769
+ return common_utils.materialize_standard_op(
770
+ op_info,
771
+ graph_info,
772
+ tensor_name_to_qsv,
773
+ get_tensor_quant_params_fn,
774
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
775
+ inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
776
+ )
777
+
778
+
751
779
  def materialize_squared_difference(
752
780
  get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
753
781
  op_info: qtyping.OpInfo,
@@ -23,16 +23,17 @@ from ai_edge_quantizer.algorithms.utils import common_utils
23
23
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
24
 
25
25
 
26
- ALGORITHM_KEY = "HADAMARD_ROTATION"
26
+ CUSTOM_OP_ALGORITHM_KEY = "HADAMARD_ROTATION"
27
+ DECOMPOSED_ALGORITHM_KEY = "DECOMPOSED_HADAMARD_ROTATION"
27
28
 
28
29
 
29
30
  def _make_hadamard_matrix(size: int) -> np.ndarray:
30
31
  """Generates a Hadamard matrix of the given size.
31
32
 
32
33
  Args:
33
- size: The size of the Hadamard matrix. Must be a power of 2. This
34
- represents a single dimension. E.g. if size is 4, then the Hadamard matrix
35
- is a 4x4 matrix.
34
+ size: The size of the Hadamard matrix. Must be a power of 2. This represents
35
+ a single dimension. E.g. if size is 4, then the Hadamard matrix is a 4x4
36
+ matrix.
36
37
 
37
38
  Returns:
38
39
  The Hadamard matrix.
@@ -157,9 +158,10 @@ def get_tensor_quant_params(
157
158
  )
158
159
 
159
160
 
160
- def materialize_fully_connected(
161
+ def _materialize_fully_connected(
161
162
  op_info: qtyping.OpInfo,
162
163
  graph_info: qtyping.GraphInfo,
164
+ is_decomposed: bool = False,
163
165
  tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
164
166
  ) -> list[qtyping.TensorTransformationParams]:
165
167
  """Materialize the fully_connected op.
@@ -167,12 +169,20 @@ def materialize_fully_connected(
167
169
  Args:
168
170
  op_info: Aggregated information about the op (e.g., quantization config).
169
171
  graph_info: Graph information needed to perform quantization for the op.
172
+ is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
173
+ op.
170
174
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
171
175
 
172
176
  Returns:
173
177
  Quantization configuration for the tensors associated with the op (e.g.,
174
178
  weights, bias).
175
179
  """
180
+ if op_info.op_quant_config.weight_tensor_config is None:
181
+ raise ValueError(
182
+ "Weight tensor quantization config is not provided for Hadamard"
183
+ " Rotation quantization."
184
+ )
185
+
176
186
  op_tensor_params = []
177
187
 
178
188
  # Materialize weight.
@@ -209,7 +219,9 @@ def materialize_fully_connected(
209
219
  op_info.op.inputs[input_tensor_index]
210
220
  ]
211
221
  transformations = [
212
- qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
222
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
223
+ if is_decomposed
224
+ else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
213
225
  ]
214
226
  op2tensor_params = qtyping.OpToTensorParams(
215
227
  subgraph_op_id=op_info.subgraph_op_index,
@@ -258,16 +270,45 @@ def materialize_fully_connected(
258
270
  return op_tensor_params
259
271
 
260
272
 
261
- def materialize_embedding_lookup(
273
+ def materialize_fully_connected_custom_op(
274
+ op_info: qtyping.OpInfo,
275
+ graph_info: qtyping.GraphInfo,
276
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
277
+ ) -> list[qtyping.TensorTransformationParams]:
278
+ return _materialize_fully_connected(
279
+ op_info,
280
+ graph_info,
281
+ is_decomposed=False,
282
+ tensor_name_to_qsv=tensor_name_to_qsv,
283
+ )
284
+
285
+
286
+ def materialize_fully_connected_decomposed(
262
287
  op_info: qtyping.OpInfo,
263
288
  graph_info: qtyping.GraphInfo,
264
289
  tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
290
+ ) -> list[qtyping.TensorTransformationParams]:
291
+ return _materialize_fully_connected(
292
+ op_info,
293
+ graph_info,
294
+ is_decomposed=True,
295
+ tensor_name_to_qsv=tensor_name_to_qsv,
296
+ )
297
+
298
+
299
+ def _materialize_embedding_lookup(
300
+ op_info: qtyping.OpInfo,
301
+ graph_info: qtyping.GraphInfo,
302
+ is_decomposed: bool = False,
303
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
265
304
  ) -> list[qtyping.TensorTransformationParams]:
266
305
  """Materialize the embedding_lookup op.
267
306
 
268
307
  Args:
269
308
  op_info: Aggregated information about the op (e.g., quantization config).
270
309
  graph_info: Graph information needed to perform quantization for the op.
310
+ is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
311
+ op.
271
312
  tensor_name_to_qsv: A map of tensor name to quantization parameters.
272
313
 
273
314
  Returns:
@@ -329,7 +370,9 @@ def materialize_embedding_lookup(
329
370
  op_info.op.outputs[output_tensor_index]
330
371
  ]
331
372
  transformations = [
332
- qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
373
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
374
+ if is_decomposed
375
+ else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
333
376
  ]
334
377
  op2tensor_params = qtyping.OpToTensorParams(
335
378
  subgraph_op_id=op_info.subgraph_op_index,
@@ -343,3 +386,29 @@ def materialize_embedding_lookup(
343
386
  op_tensor_params.append(output_transformation_params)
344
387
 
345
388
  return op_tensor_params
389
+
390
+
391
+ def materialize_embedding_lookup_custom_op(
392
+ op_info: qtyping.OpInfo,
393
+ graph_info: qtyping.GraphInfo,
394
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
395
+ ) -> list[qtyping.TensorTransformationParams]:
396
+ return _materialize_embedding_lookup(
397
+ op_info,
398
+ graph_info,
399
+ is_decomposed=False,
400
+ tensor_name_to_qsv=tensor_name_to_qsv,
401
+ )
402
+
403
+
404
+ def materialize_embedding_lookup_decomposed(
405
+ op_info: qtyping.OpInfo,
406
+ graph_info: qtyping.GraphInfo,
407
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
408
+ ) -> list[qtyping.TensorTransformationParams]:
409
+ return _materialize_embedding_lookup(
410
+ op_info,
411
+ graph_info,
412
+ is_decomposed=True,
413
+ tensor_name_to_qsv=tensor_name_to_qsv,
414
+ )
@@ -63,7 +63,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
63
63
  )
64
64
 
65
65
  def test_materialize_fully_connected_basic(self):
66
- params = hadamard_rotation.materialize_fully_connected(
66
+ params = hadamard_rotation.materialize_fully_connected_custom_op(
67
67
  self._op_info, self._graph_info, self._tensor_name_to_qsv
68
68
  )
69
69
  fc_input = params[0]
@@ -111,7 +111,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
111
111
  ),
112
112
  ),
113
113
  )
114
- params = hadamard_rotation.materialize_fully_connected(
114
+ params = hadamard_rotation.materialize_fully_connected_custom_op(
115
115
  self._op_info, self._graph_info, self._tensor_name_to_qsv
116
116
  )
117
117
  self.assertLen(params, 4)
@@ -152,7 +152,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
152
152
  ),
153
153
  ),
154
154
  )
155
- params = hadamard_rotation.materialize_fully_connected(
155
+ params = hadamard_rotation.materialize_fully_connected_custom_op(
156
156
  self._op_info, self._graph_info, self._tensor_name_to_qsv
157
157
  )
158
158
  self.assertLen(params, 4)
@@ -179,6 +179,34 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
179
179
  ):
180
180
  self.assertEqual(weight.consumers[0].parameters.quantized_dimension, 1)
181
181
 
182
+ def test_materialize_fully_connected_decomposed(self):
183
+ params = hadamard_rotation.materialize_fully_connected_decomposed(
184
+ self._op_info, self._graph_info, self._tensor_name_to_qsv
185
+ )
186
+ fc_input = params[0]
187
+ weight = params[1]
188
+ bias = params[2]
189
+ output = params[3]
190
+
191
+ self.assertLen(params, 4)
192
+ self.assertEqual(
193
+ fc_input.consumers[0].transformations,
194
+ [qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION],
195
+ )
196
+ self.assertEqual(
197
+ weight.consumers[0].transformations,
198
+ [qtyping.QuantTransformation.QUANTIZE_TENSOR],
199
+ )
200
+ self.assertEqual(
201
+ bias.consumers[0].transformations,
202
+ [qtyping.QuantTransformation.NO_QUANTIZE],
203
+ )
204
+ if output.producer is not None:
205
+ self.assertEqual(
206
+ output.producer.transformations,
207
+ [qtyping.QuantTransformation.NO_QUANTIZE],
208
+ )
209
+
182
210
  def test_get_tensor_quant_params_basic(self):
183
211
  input_tensor = self._subgraph.tensors[self._fc_op.inputs[1]]
184
212
  buffer = self._graph_info.buffers[self._fc_buffer_id]
@@ -344,7 +372,7 @@ class HadamardRotationEmbeddingLookupTest(parameterized.TestCase):
344
372
  ),
345
373
  ),
346
374
  )
347
- params = hadamard_rotation.materialize_embedding_lookup(
375
+ params = hadamard_rotation.materialize_embedding_lookup_custom_op(
348
376
  op_info, self._graph_info, self._tensor_name_to_qsv
349
377
  )
350
378
  self.assertLen(params, 3)
@@ -371,6 +399,43 @@ class HadamardRotationEmbeddingLookupTest(parameterized.TestCase):
371
399
  [qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION],
372
400
  )
373
401
 
402
+ def test_materialize_embedding_lookup_decomposed(self):
403
+ subgraph = self._test_model.subgraphs[0]
404
+ embedding_subgraph_op_index = 0
405
+ embedding_op = subgraph.operators[embedding_subgraph_op_index]
406
+ op_info = qtyping.OpInfo(
407
+ op=embedding_op,
408
+ op_name=_TFLOpName.EMBEDDING_LOOKUP,
409
+ subgraph_op_index=embedding_subgraph_op_index,
410
+ op_quant_config=qtyping.OpQuantizationConfig(
411
+ weight_tensor_config=_TensorQuantConfig(
412
+ num_bits=8,
413
+ symmetric=True,
414
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
415
+ ),
416
+ ),
417
+ )
418
+ params = hadamard_rotation.materialize_embedding_lookup_decomposed(
419
+ op_info, self._graph_info, self._tensor_name_to_qsv
420
+ )
421
+ self.assertLen(params, 3)
422
+ lookup = params[0]
423
+ value = params[1]
424
+ output = params[2]
425
+ self.assertEqual(
426
+ lookup.consumers[0].transformations,
427
+ [qtyping.QuantTransformation.NO_QUANTIZE],
428
+ )
429
+ self.assertEqual(
430
+ value.consumers[0].transformations,
431
+ [qtyping.QuantTransformation.QUANTIZE_TENSOR],
432
+ )
433
+ if output.producer is not None:
434
+ self.assertEqual(
435
+ output.producer.transformations,
436
+ [qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION],
437
+ )
438
+
374
439
 
375
440
  if __name__ == "__main__":
376
441
  googletest.main()
@@ -199,7 +199,8 @@ DEFAULT_JSON_POLICY = """
199
199
  "PADV2",
200
200
  "REDUCE_MIN",
201
201
  "EQUAL",
202
- "NOT_EQUAL"
202
+ "NOT_EQUAL",
203
+ "MIRROR_PAD"
203
204
  ],
204
205
  "static_wi8_ai8": [
205
206
  "ADD",
@@ -248,7 +249,8 @@ DEFAULT_JSON_POLICY = """
248
249
  "PADV2",
249
250
  "REDUCE_MIN",
250
251
  "EQUAL",
251
- "NOT_EQUAL"
252
+ "NOT_EQUAL",
253
+ "MIRROR_PAD"
252
254
  ],
253
255
  "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
254
256
  "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
@@ -510,6 +510,7 @@ def _compatible_tensor_params(
510
510
  _QuantTrans.ADD_QUANTIZE,
511
511
  _QuantTrans.NO_QUANTIZE,
512
512
  _QuantTrans.INSERT_HADAMARD_ROTATION,
513
+ _QuantTrans.INSERT_DECOMPOSED_HADAMARD_ROTATION,
513
514
  ]
514
515
  quantized_source_transformations = [
515
516
  _QuantTrans.QUANTIZE_TENSOR,
@@ -80,6 +80,7 @@ class TFLOperationName(str, enum.Enum):
80
80
  REDUCE_MIN = 'REDUCE_MIN'
81
81
  EQUAL = 'EQUAL'
82
82
  NOT_EQUAL = 'NOT_EQUAL'
83
+ MIRROR_PAD = 'MIRROR_PAD'
83
84
 
84
85
 
85
86
  class QuantizeMode(enum.Enum):
@@ -133,6 +134,9 @@ class QuantTransformation(enum.Enum):
133
134
  DUPLICATE_TENSOR = 6
134
135
  # Insert the aeq.hadamard_rotation op.
135
136
  INSERT_HADAMARD_ROTATION = 7
137
+ # Insert decomposed Hadamard rotation ops. This expresses the Hadamard
138
+ # rotation as matrix multiplication with Hadamard matrices.
139
+ INSERT_DECOMPOSED_HADAMARD_ROTATION = 8
136
140
 
137
141
 
138
142
  @dataclasses.dataclass(frozen=True)
@@ -305,6 +309,7 @@ class TensorQuantizationConfig:
305
309
  quantization.
306
310
  dtype: The data type of the tensor.
307
311
  block_size: The block size for blockwise quantization, ignored otherwise.
312
+ algorithm_key: The algorithm key to use for quantization.
308
313
  """
309
314
 
310
315
  num_bits: int
@@ -24,6 +24,7 @@ from ai_edge_quantizer import qtyping
24
24
  from ai_edge_quantizer.transformations import dequant_insert
25
25
  from ai_edge_quantizer.transformations import duplicate_buffer
26
26
  from ai_edge_quantizer.transformations import duplicate_tensor
27
+ from ai_edge_quantizer.transformations import insert_decomposed_hadamard_rotation
27
28
  from ai_edge_quantizer.transformations import insert_hadamard_rotation
28
29
  from ai_edge_quantizer.transformations import quant_insert
29
30
  from ai_edge_quantizer.transformations import quantize_tensor
@@ -83,6 +84,9 @@ class TransformationPerformer:
83
84
  qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
84
85
  insert_hadamard_rotation.insert_hadamard_rotation
85
86
  ),
87
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION: (
88
+ insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation
89
+ ),
86
90
  }
87
91
  # transformations are seprated in two categories:
88
92
  # op_insertion_transformations are transformations that only insert ops
@@ -95,6 +99,7 @@ class TransformationPerformer:
95
99
  qtyping.QuantTransformation.DUPLICATE_BUFFER,
96
100
  qtyping.QuantTransformation.DUPLICATE_TENSOR,
97
101
  qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
102
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION,
98
103
  ])
99
104
  self._op_replacement_transformations = set(
100
105
  [qtyping.QuantTransformation.EMULATED_SUBCHANNEL]