ai-edge-quantizer-nightly 0.1.0.dev20250512__py3-none-any.whl → 0.1.0.dev20250514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_edge_quantizer/algorithm_manager.py +34 -0
  2. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +37 -12
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +3 -5
  5. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +357 -0
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +265 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +7 -31
  8. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +27 -17
  9. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +93 -38
  10. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +133 -3
  11. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +11 -2
  12. ai_edge_quantizer/algorithms/utils/common_utils.py +21 -8
  13. ai_edge_quantizer/default_policy.py +4 -2
  14. ai_edge_quantizer/params_generator.py +1 -0
  15. ai_edge_quantizer/qtyping.py +34 -1
  16. ai_edge_quantizer/transformation_performer.py +5 -0
  17. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +209 -0
  18. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  19. ai_edge_quantizer/utils/test_utils.py +33 -0
  20. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
  21. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/METADATA +1 -1
  22. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/RECORD +25 -21
  23. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/LICENSE +0 -0
  24. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/WHEEL +0 -0
  25. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,11 @@
16
16
  """Uniform quantize in tensor level."""
17
17
 
18
18
  import dataclasses
19
- from typing import Optional
19
+ from typing import Optional, Sequence
20
+ import ml_dtypes
20
21
  import numpy as np
21
22
  from ai_edge_quantizer import qtyping
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
22
24
 
23
25
 
24
26
  @dataclasses.dataclass(frozen=True)
@@ -120,19 +122,127 @@ def fix_quantization_params_rank(
120
122
  )
121
123
 
122
124
 
125
+ def _get_tensor_shape_for_blockwise(
126
+ tensor_shape: Sequence[int], quantized_dim: int, block_size: int
127
+ ) -> list[int]:
128
+ """Get the tensor shape for blockwise quantization.
129
+
130
+ This function splits the quantize dimension of the tensor into blocks and the
131
+ dim/blocks. Hence, min/max of the tensor can be calculated for each block
132
+ using existing functions.
133
+
134
+ Args:
135
+ tensor_shape: The original shape of the tensor.
136
+ quantized_dim: The dimension to be quantized blockwise.
137
+ block_size: The size of the block.
138
+
139
+ Returns:
140
+ The new tensor shape for calculating scale and zp for blockwise
141
+ quantization.
142
+ """
143
+ new_shape = []
144
+ for index, val in enumerate(tensor_shape):
145
+ if index == quantized_dim:
146
+ new_shape.append(int(val / block_size))
147
+ new_shape.append(block_size)
148
+ else:
149
+ new_shape.append(val)
150
+ return new_shape
151
+
152
+
153
+ def reshape_data_for_blockwise(
154
+ tensor_data: np.ndarray, op_name: qtyping.TFLOperationName, block_size: int
155
+ ) -> tuple[np.ndarray, int]:
156
+ """Reshapes data for blockwise quantization.
157
+
158
+ Args:
159
+ tensor_data: The original tensor data.
160
+ op_name: The name of the TFL op.
161
+ block_size: The size of the block.
162
+
163
+ Returns:
164
+ A tuple containing the reshaped tensor data and the new reduce dimension.
165
+ """
166
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
167
+ op_name
168
+ ]
169
+ new_shape = _get_tensor_shape_for_blockwise(
170
+ tensor_data.shape, quantized_dim, block_size
171
+ )
172
+ reshaped_data = tensor_data.reshape(new_shape)
173
+ return reshaped_data, quantized_dim + 1
174
+
175
+
176
+ def _broadcast_scale_zp_for_blockwise(
177
+ tensor_content: np.ndarray,
178
+ quant_params: qtyping.UniformQuantParams,
179
+ ) -> qtyping.UniformQuantParams:
180
+ """Broadcasts scale and zp for blockwise quantization.
181
+
182
+ Args:
183
+ tensor_content: The original tensor data.
184
+ quant_params: The quantization parameters.
185
+ `quant_params.quantized_dimension` must be specified.
186
+ `quant_params.block_size` must be specified and positive.
187
+
188
+ Returns:
189
+ The updated quantization parameters with broadcasted scale and zp for
190
+ correct constant quantization.
191
+ """
192
+ if quant_params.quantized_dimension is None:
193
+ raise ValueError("Quantized dimension must be specified.")
194
+ if quant_params.block_size is None or quant_params.block_size <= 0:
195
+ raise ValueError("Block size must be specified and positive.")
196
+ quantized_dim = quant_params.quantized_dimension
197
+ expanded_tensor_shape = _get_tensor_shape_for_blockwise(
198
+ tensor_content.shape, quantized_dim, quant_params.block_size
199
+ )
200
+ expanded_scale = np.reshape(
201
+ np.broadcast_to(
202
+ np.expand_dims(quant_params.scale, quantized_dim + 1),
203
+ expanded_tensor_shape,
204
+ ),
205
+ tensor_content.shape,
206
+ )
207
+ expanded_zp = np.reshape(
208
+ np.broadcast_to(
209
+ np.expand_dims(quant_params.zero_point, quantized_dim + 1),
210
+ expanded_tensor_shape,
211
+ ),
212
+ tensor_content.shape,
213
+ )
214
+ return qtyping.UniformQuantParams(
215
+ scale=expanded_scale,
216
+ zero_point=expanded_zp,
217
+ num_bits=quant_params.num_bits,
218
+ symmetric=quant_params.symmetric,
219
+ quantized_dimension=quantized_dim,
220
+ block_size=quant_params.block_size,
221
+ )
222
+
223
+
123
224
  def uniform_quantize(
124
225
  tensor_data: np.ndarray,
125
226
  quantization_params: qtyping.UniformQuantParams,
227
+ is_blockwise: bool = False,
126
228
  ):
127
229
  """Uniform quantize a tensor.
128
230
 
129
231
  Args:
130
232
  tensor_data: The tensor to be quantized.
131
233
  quantization_params: The quantization parameters.
234
+ is_blockwise: Whether the tensor is blockwise quantized.
132
235
 
133
236
  Returns:
134
237
  The quantized tensor.
135
238
  """
239
+ # The reshaping for blockwise quantization is unique hence we do this here
240
+ # to avoid unexpected broadcast behavior downstream.
241
+ if is_blockwise:
242
+ quantization_params = _broadcast_scale_zp_for_blockwise(
243
+ tensor_data, quantization_params
244
+ )
245
+
136
246
  # quant params in flatbuffer is flattened, expand the rank to be the same
137
247
  # as the tensor rank to avoid ambiguous broadcasting.
138
248
  quantization_params = fix_quantization_params_rank(
@@ -242,15 +352,19 @@ def tensor_zp_scale_from_min_max(
242
352
  max_value,
243
353
  num_bits: int,
244
354
  symmetric: bool,
355
+ granularity: qtyping.QuantGranularity,
245
356
  clipping_values: Optional[np.ndarray] = None,
246
357
  ):
247
358
  """Get zero point and scale from min and max value.
248
359
 
249
360
  Args:
250
- min_value: The minimum value of the tensor (channel-wise supported).
251
- max_value: The maximum value of the tensor (channel-wise supported).
361
+ min_value: The minimum value of the tensor (channelwise and blockwise
362
+ supported).
363
+ max_value: The maximum value of the tensor (channelwise and blockwise
364
+ supported).
252
365
  num_bits: The number of bits of the tensor.
253
366
  symmetric: Whether the tensor is symmetric.
367
+ granularity: The granularity of the tensor.
254
368
  clipping_values: Absolute clipping values to apply to the tensor. This will
255
369
  clip the tensors to the range [-clipping_values, clipping_values]. This
256
370
  should be the same shape as min_value and max_value. If None, no clipping
@@ -267,6 +381,16 @@ def tensor_zp_scale_from_min_max(
267
381
  qmin, qmax = get_quantized_range(qtype)
268
382
  min_bound = 1e-4 # 1e-6 precision for int8 and 1e-8 for int16.
269
383
 
384
+ if granularity == qtyping.QuantGranularity.BLOCKWISE:
385
+ # Blockwise quantization uses float16 scale, with 7 bit mantissa,
386
+ # so the maximum representable value is 65280.
387
+ float16_max = np.broadcast_to(np.array(65280), min_value.shape)
388
+ clipping_values = (
389
+ float16_max
390
+ if clipping_values is None
391
+ else np.minimum(clipping_values, float16_max)
392
+ )
393
+
270
394
  if symmetric:
271
395
  bound = np.maximum(np.abs(min_value), np.abs(max_value))
272
396
  bound = np.maximum(bound, min_bound)
@@ -292,6 +416,12 @@ def tensor_zp_scale_from_min_max(
292
416
  zp = qmin - bound_min / scale
293
417
  zp = np.rint(zp)
294
418
 
419
+ if granularity == qtyping.QuantGranularity.BLOCKWISE:
420
+ # Round the scale values to 7 bit mantissa.
421
+ scale = (
422
+ scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
423
+ )
424
+
295
425
  # It's safe to cast zp to qtype without clipping because we can infer
296
426
  # qmin <= zp <= qmax from bound_min <= 0 <= bound_max.
297
427
  zp = assign_quantized_type(zp, qtype)
@@ -336,7 +336,11 @@ class TensorUtilsTest(parameterized.TestCase):
336
336
  max_val = np.max(self._test_data, keepdims=True)
337
337
 
338
338
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
339
- min_val, max_val, num_bits, symmetric
339
+ min_val,
340
+ max_val,
341
+ num_bits,
342
+ symmetric,
343
+ qtyping.QuantGranularity.TENSORWISE,
340
344
  )
341
345
  self.assertEqual(zp.shape, scale.shape)
342
346
  max_q = 2**num_bits / 2 - 1
@@ -364,7 +368,12 @@ class TensorUtilsTest(parameterized.TestCase):
364
368
  max_val = np.array([[5.0]])
365
369
  clipping_values = np.array([4.0])
366
370
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
367
- min_val, max_val, num_bits, symmetric, clipping_values
371
+ min_val,
372
+ max_val,
373
+ num_bits,
374
+ symmetric,
375
+ qtyping.QuantGranularity.TENSORWISE,
376
+ clipping_values,
368
377
  )
369
378
  expected_scale = clipping_values / quantized_bound
370
379
 
@@ -905,23 +905,36 @@ def get_tensor_transformation_params(
905
905
  )
906
906
 
907
907
 
908
- def get_weight_quantized_dim(op_info: qtyping.OpInfo, tensor_data: np.ndarray):
908
+ def get_weight_quantized_dim(
909
+ op_info: qtyping.OpInfo,
910
+ tensor_data: np.ndarray,
911
+ granularity: qtyping.QuantGranularity,
912
+ ):
909
913
  """Get the quantized dimension for the weight tensor.
910
914
 
911
915
  Args:
912
916
  op_info: Aggregated information about the op (e.g., quantization config).
913
917
  tensor_data: The weight tensor data.
918
+ granularity: The granularity of the weight tensor.
914
919
 
915
920
  Returns:
916
921
  The quantized dimension for the weight tensor.
917
922
  """
918
- if op_info.op_name == _TFLOpName.BATCH_MATMUL:
919
- quantized_dim = get_bmm_weight_quantized_dim(
920
- tensor_data, adj_y=op_info.op.builtinOptions.adjY
921
- )
922
- else:
923
- quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
924
- op_info.op_name, None
923
+ quantized_dim = None
924
+ if granularity == qtyping.QuantGranularity.CHANNELWISE:
925
+ if op_info.op_name == _TFLOpName.BATCH_MATMUL:
926
+ quantized_dim = get_bmm_weight_quantized_dim(
927
+ tensor_data, adj_y=op_info.op.builtinOptions.adjY
928
+ )
929
+ else:
930
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
931
+ op_info.op_name, None
932
+ )
933
+ elif granularity == qtyping.QuantGranularity.BLOCKWISE:
934
+ quantized_dim = (
935
+ tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
936
+ op_info.op_name
937
+ ]
925
938
  )
926
939
  return quantized_dim
927
940
 
@@ -183,7 +183,8 @@ DEFAULT_JSON_POLICY = """
183
183
  "SELECT_V2",
184
184
  "DYNAMIC_UPDATE_SLICE",
185
185
  "SELECT_V2",
186
- "STABLEHLO_COMPOSITE"
186
+ "STABLEHLO_COMPOSITE",
187
+ "PAD"
187
188
  ],
188
189
  "static_wi8_ai8": [
189
190
  "ADD",
@@ -214,7 +215,8 @@ DEFAULT_JSON_POLICY = """
214
215
  "SELECT_V2",
215
216
  "DYNAMIC_UPDATE_SLICE",
216
217
  "SELECT_V2",
217
- "STABLEHLO_COMPOSITE"
218
+ "STABLEHLO_COMPOSITE",
219
+ "PAD"
218
220
  ],
219
221
  "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
220
222
  "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
@@ -508,6 +508,7 @@ def _compatible_tensor_params(
508
508
  float_source_transformations = [
509
509
  _QuantTrans.ADD_QUANTIZE,
510
510
  _QuantTrans.NO_QUANTIZE,
511
+ _QuantTrans.INSERT_HADAMARD_ROTATION,
511
512
  ]
512
513
  quantized_source_transformations = [
513
514
  _QuantTrans.QUANTIZE_TENSOR,
@@ -20,7 +20,7 @@ from collections.abc import MutableMapping
20
20
  import copy
21
21
  import dataclasses
22
22
  import enum
23
- from typing import Any, Optional, Union, Callable
23
+ from typing import Any, Callable, Optional, Union
24
24
 
25
25
  import numpy as np
26
26
  from typing_extensions import TypeAlias
@@ -62,6 +62,7 @@ class TFLOperationName(str, enum.Enum):
62
62
  SELECT_V2 = 'SELECT_V2'
63
63
  DYNAMIC_UPDATE_SLICE = 'DYNAMIC_UPDATE_SLICE'
64
64
  STABLEHLO_COMPOSITE = 'STABLEHLO_COMPOSITE'
65
+ PAD = 'PAD'
65
66
 
66
67
 
67
68
  class QuantizeMode(enum.Enum):
@@ -113,6 +114,8 @@ class QuantTransformation(enum.Enum):
113
114
  DUPLICATE_BUFFER = 5
114
115
  # Duplicate the tensor.
115
116
  DUPLICATE_TENSOR = 6
117
+ # Insert the aeq.hadamard_rotation op.
118
+ INSERT_HADAMARD_ROTATION = 7
116
119
 
117
120
 
118
121
  @dataclasses.dataclass(frozen=True)
@@ -128,8 +131,35 @@ class UniformQuantParams:
128
131
  quantized_data: The quantized data.
129
132
  block_size: The block size for blockwise quantization, block_size=0 meaning
130
133
  no blockwise quantization.
134
+ hadamard: The Hadamard rotation parameters, if set.
131
135
  """
132
136
 
137
+ class HadamardRotationParams:
138
+ """Parameters for the Hadamard rotation.
139
+
140
+ Attributes:
141
+ random_binary_vector: The random binary vector for the Hadamard rotation.
142
+ TODO(b/415392354): Randomization is an experimental feature that's
143
+ currently not implemented yet hence this is always 1. We will add
144
+ support or remove in the future.
145
+ hadamard_size: The size of the Hadamard matrix.
146
+ """
147
+
148
+ random_binary_vector: np.ndarray
149
+ hadamard_size: int
150
+
151
+ def __init__(self, random_binary_vector: np.ndarray, hadamard_size: int):
152
+ self.random_binary_vector = random_binary_vector
153
+ self.hadamard_size = hadamard_size
154
+
155
+ def __eq__(self, other):
156
+ if other.__class__ is not self.__class__:
157
+ return NotImplemented
158
+ return (
159
+ np.array_equal(self.random_binary_vector, other.random_binary_vector)
160
+ and self.hadamard_size == other.hadamard_size
161
+ )
162
+
133
163
  num_bits: int
134
164
  quantized_dimension: Optional[int]
135
165
  scale: np.ndarray
@@ -137,6 +167,7 @@ class UniformQuantParams:
137
167
  symmetric: bool = True
138
168
  quantized_data: Optional[np.ndarray] = None
139
169
  block_size: int = 0
170
+ hadamard: Optional[HadamardRotationParams] = None
140
171
 
141
172
  @classmethod
142
173
  def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams':
@@ -180,6 +211,7 @@ class UniformQuantParams:
180
211
  and self.symmetric == other.symmetric
181
212
  and _compare_array_or_none(self.quantized_data, other.quantized_data)
182
213
  and self.block_size == other.block_size
214
+ and self.hadamard == other.hadamard
183
215
  )
184
216
 
185
217
 
@@ -492,6 +524,7 @@ class IOOperator:
492
524
  outputs: list[int]
493
525
  op_key: TFLOperationName
494
526
 
527
+
495
528
  # The function signature for `get_tensor_quant_params_fn`.
496
529
  GetTensorQuantParamsFuncSignature = Callable[
497
530
  [
@@ -25,6 +25,7 @@ from ai_edge_quantizer.transformations import dequant_insert
25
25
  from ai_edge_quantizer.transformations import duplicate_buffer
26
26
  from ai_edge_quantizer.transformations import duplicate_tensor
27
27
  from ai_edge_quantizer.transformations import emulated_subchannel
28
+ from ai_edge_quantizer.transformations import insert_hadamard_rotation
28
29
  from ai_edge_quantizer.transformations import quant_insert
29
30
  from ai_edge_quantizer.transformations import quantize_tensor
30
31
  from ai_edge_quantizer.transformations import transformation_utils
@@ -80,6 +81,9 @@ class TransformationPerformer:
80
81
  qtyping.QuantTransformation.DUPLICATE_TENSOR: (
81
82
  duplicate_tensor.duplicate_tensor
82
83
  ),
84
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
85
+ insert_hadamard_rotation.insert_hadamard_rotation
86
+ ),
83
87
  }
84
88
  # transformations are seprated in two categories:
85
89
  # op_insertion_transformations are transformations that only insert ops
@@ -91,6 +95,7 @@ class TransformationPerformer:
91
95
  qtyping.QuantTransformation.ADD_QUANTIZE,
92
96
  qtyping.QuantTransformation.DUPLICATE_BUFFER,
93
97
  qtyping.QuantTransformation.DUPLICATE_TENSOR,
98
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
94
99
  ])
95
100
  self._op_replacement_transformations = set(
96
101
  [qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
@@ -0,0 +1,209 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Hadamard rotation pattern transformation."""
17
+
18
+ from flatbuffers import flexbuffers
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.transformations import transformation_utils
22
+ from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
23
+
24
+
25
+ def _to_flexbuffer(
26
+ hadamard_size: int,
27
+ random_binary_vector: list[np.int8],
28
+ ) -> bytes:
29
+ """Converts hadamard_size to flexbuffer."""
30
+ fbb = flexbuffers.Builder()
31
+ with fbb.Map():
32
+ fbb.Int('hadamard_size', hadamard_size)
33
+ fbb.VectorFromElements('random_binary_vector', random_binary_vector)
34
+ return fbb.Finish()
35
+
36
+
37
+ def _is_producer_embedding_lookup(
38
+ transformation: transformation_utils.TransformationInput,
39
+ ) -> bool:
40
+ """Checks if the tensor's producer is an embedding lookup op."""
41
+ if transformation.producer == -1:
42
+ return False
43
+ else:
44
+ return (
45
+ transformation.op_codes[
46
+ transformation.subgraph.operators[
47
+ transformation.producer
48
+ ].opcodeIndex
49
+ ].builtinCode
50
+ == schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
51
+ )
52
+
53
+
54
+ def _is_fully_connected(
55
+ transformation: transformation_utils.TransformationInput, op_id: int
56
+ ) -> bool:
57
+ """Checks if the any of the tensor's consumers is a fully connected op."""
58
+ return (
59
+ transformation.op_codes[
60
+ transformation.subgraph.operators[op_id].opcodeIndex
61
+ ].builtinCode
62
+ == schema_py_generated.BuiltinOperator.FULLY_CONNECTED
63
+ )
64
+
65
+
66
+ def _update_embedding_lookup_consumers(
67
+ transformation: transformation_utils.TransformationInput,
68
+ new_tensor_id: int,
69
+ ) -> bool:
70
+ """Updates the consumers of the embedding lookup op to use the new tensor.
71
+
72
+ Args:
73
+ transformation: The transformation input to update the consumers of.
74
+ new_tensor_id: The new tensor id to use as the input to the embedding lookup
75
+ consumers.
76
+ """
77
+ for consumer in transformation.consumers:
78
+ # If the consumer is a graph output and not an op, we can ignore it here
79
+ # since the graph output will be updated later.
80
+ if consumer == -1:
81
+ continue
82
+ consumer_op = transformation.subgraph.operators[consumer]
83
+ # Find the input that was attached to the insertion point, and replace it
84
+ # with the new tensor.
85
+ for i in range(len(consumer_op.inputs)):
86
+ if consumer_op.inputs[i] == transformation.tensor_id:
87
+ consumer_op.inputs[i] = new_tensor_id
88
+
89
+
90
+ def _update_fully_connected_consumers(
91
+ transformation: transformation_utils.TransformationInput,
92
+ new_tensor_id: int,
93
+ ) -> bool:
94
+ """Updates the fully connected op(s) to use the new tensor.
95
+
96
+ Since the new tensor is inserted to the fully_connected's input, we need to
97
+ scan each consumer (in case of multiple fully_connected ops), and update
98
+ the input tensor to the new tensor.
99
+
100
+ Args:
101
+ transformation: The transformation input to update the consumers of.
102
+ new_tensor_id: The new tensor id to use as the input to the fully connected
103
+ consumers.
104
+
105
+ Returns:
106
+ True if the fully connected op(s) were updated to use the new tensor.
107
+ """
108
+ updated = False
109
+ for consumer in transformation.consumers:
110
+ if _is_fully_connected(transformation, consumer):
111
+ transformation.subgraph.operators[consumer].inputs[0] = new_tensor_id
112
+ updated = True
113
+ return updated
114
+
115
+
116
+ def insert_hadamard_rotation(
117
+ transformation_input: transformation_utils.TransformationInput,
118
+ ) -> qtyping.TransformationInfo:
119
+ """Inserts a custom aeq.hadamard_rotation op on this tensor.
120
+
121
+ This function works for float32 tensors only.
122
+
123
+ Args:
124
+ transformation_input: The transformation input to insert the custom op on.
125
+
126
+ Returns:
127
+ The transformation info of the inserted custom op.
128
+
129
+ Raises:
130
+ ValueError: If the transformation input is not a uniform quantization
131
+ transformation.
132
+ ValueError: If the Hadamard quantization params are not set.
133
+ ValueError: If the tensor is not a float32 tensor.
134
+ ValueError: If no supported ops were found as the tensor's producer or
135
+ consumers.
136
+ """
137
+ if not isinstance(
138
+ transformation_input.quant_params, qtyping.UniformQuantParams
139
+ ):
140
+ raise ValueError('Hadamard rotation supports uniform quantization only')
141
+
142
+ if transformation_input.quant_params.hadamard is None:
143
+ raise ValueError(
144
+ 'Hadamard rotation quantization params are not set but op insertion is'
145
+ ' requested.'
146
+ )
147
+
148
+ tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
149
+ if tensor.type != schema_py_generated.TensorType.FLOAT32:
150
+ raise ValueError(
151
+ 'The Hadamard rotation op supports float32 tensors only. Got'
152
+ f' {tensor.type} tensor.'
153
+ )
154
+
155
+ # Create new custom op with the current tensor as input and a new activation
156
+ # tensor as output.
157
+ custom_op_code_idx = transformation_utils.add_op_code(
158
+ schema_py_generated.BuiltinOperator.CUSTOM,
159
+ transformation_input.op_codes,
160
+ 'aeq.hadamard_rotation',
161
+ )
162
+ custom_op = schema_py_generated.OperatorT()
163
+ custom_op.opcodeIndex = custom_op_code_idx
164
+ custom_op.inputs = [transformation_input.tensor_id]
165
+ custom_op.customOptions = _to_flexbuffer(
166
+ transformation_input.quant_params.hadamard.hadamard_size,
167
+ transformation_input.quant_params.hadamard.random_binary_vector.tolist(),
168
+ )
169
+ new_tensor_id = transformation_utils.add_new_activation_tensor(
170
+ tensor.name + b'_rotated',
171
+ tensor.shapeSignature
172
+ if tensor.shapeSignature is not None
173
+ else tensor.shape,
174
+ schema_py_generated.TensorType.FLOAT32,
175
+ transformation_input.subgraph,
176
+ )
177
+ custom_op.outputs = [new_tensor_id]
178
+
179
+ # Update the users of this tensor to use the new tensor.
180
+ if _is_producer_embedding_lookup(transformation_input):
181
+ _update_embedding_lookup_consumers(transformation_input, new_tensor_id)
182
+ elif not _update_fully_connected_consumers(
183
+ transformation_input, new_tensor_id
184
+ ):
185
+ raise ValueError(
186
+ 'The Hadamard rotation op supports embedding lookup and fully connected'
187
+ ' ops only, but no such ops were found.'
188
+ )
189
+
190
+ # If the tensor is a graph output, we need to replace the tensor with the
191
+ # new tensor.
192
+ for i, output in enumerate(transformation_input.subgraph.outputs):
193
+ if output == transformation_input.tensor_id:
194
+ transformation_input.subgraph.outputs[i] = new_tensor_id
195
+
196
+ # Find the actual insertion point. The insertion point should be after the
197
+ # producer op and before the first consumer op. The max() operation ensures
198
+ # that we're not using -1 as the insertion point.
199
+ first_consumer_id = min(transformation_input.consumers)
200
+ op_id = max(transformation_input.producer + 1, first_consumer_id)
201
+
202
+ # Insert the custom op.
203
+ transformation_input.subgraph.operators.insert(op_id, custom_op)
204
+
205
+ return qtyping.TransformationInfo(
206
+ op_id=op_id,
207
+ num_ops_added=1,
208
+ output_tensor_id=new_tensor_id,
209
+ )