ai-edge-quantizer-nightly 0.1.0.dev20250512__py3-none-any.whl → 0.1.0.dev20250514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_edge_quantizer/algorithm_manager.py +34 -0
  2. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +37 -12
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +3 -5
  5. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +357 -0
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +265 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +7 -31
  8. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +27 -17
  9. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +93 -38
  10. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +133 -3
  11. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +11 -2
  12. ai_edge_quantizer/algorithms/utils/common_utils.py +21 -8
  13. ai_edge_quantizer/default_policy.py +4 -2
  14. ai_edge_quantizer/params_generator.py +1 -0
  15. ai_edge_quantizer/qtyping.py +34 -1
  16. ai_edge_quantizer/transformation_performer.py +5 -0
  17. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +209 -0
  18. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  19. ai_edge_quantizer/utils/test_utils.py +33 -0
  20. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
  21. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/METADATA +1 -1
  22. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/RECORD +25 -21
  23. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/LICENSE +0 -0
  24. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/WHEEL +0 -0
  25. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ from ai_edge_quantizer import qtyping
24
24
  from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
25
25
  from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
26
26
  from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
27
+ from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
27
28
  from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
28
29
  from ai_edge_quantizer.algorithms.uniform_quantize import octav
29
30
 
@@ -58,6 +59,8 @@ class AlgorithmName(str, enum.Enum):
58
59
  FLOAT_CASTING = float_casting.ALGORITHM_KEY
59
60
  DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
60
61
  OCTAV = octav.ALGORITHM_KEY
62
+ HADAMARD_ROTATION = hadamard_rotation.ALGORITHM_KEY
63
+
61
64
 
62
65
  ### MIN/MAX_UNIFORM_QUANT ###
63
66
 
@@ -104,6 +107,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
104
107
  common_quantize.materialize_dynamic_update_slice
105
108
  ),
106
109
  _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
110
+ _TFLOpName.PAD: common_quantize.materialize_pad,
107
111
  }
108
112
  for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
109
113
  register_quantized_op(
@@ -237,6 +241,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
237
241
  common_quantize.materialize_dynamic_update_slice
238
242
  ),
239
243
  _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
244
+ _TFLOpName.PAD: common_quantize.materialize_pad,
240
245
  })
241
246
 
242
247
  for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
@@ -250,3 +255,32 @@ for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
250
255
  octav.get_tensor_quant_params,
251
256
  ),
252
257
  )
258
+
259
+ # Register the Hadamard Rotation algorithm.
260
+ register_op_quant_config_validation_func(
261
+ AlgorithmName.HADAMARD_ROTATION,
262
+ common_quantize.check_op_quantization_config,
263
+ )
264
+
265
+ # Register a config check policy for the Hadamard Rotation algorithm.
266
+ register_config_check_policy_func(
267
+ AlgorithmName.HADAMARD_ROTATION,
268
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
269
+ )
270
+
271
+ # Register specialized hadamard rotation materialize functions.
272
+ _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
273
+ _TFLOpName.FULLY_CONNECTED: hadamard_rotation.materialize_fully_connected,
274
+ _TFLOpName.EMBEDDING_LOOKUP: hadamard_rotation.materialize_embedding_lookup,
275
+ })
276
+ for (
277
+ op_name,
278
+ materialize_func,
279
+ ) in _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
280
+ register_quantized_op(
281
+ AlgorithmName.HADAMARD_ROTATION,
282
+ op_name,
283
+ naive_min_max_quantize.init_qsvs,
284
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
285
+ materialize_func=materialize_func,
286
+ )
@@ -680,6 +680,23 @@ def materialize_split(
680
680
  )
681
681
 
682
682
 
683
+ def materialize_pad(
684
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
685
+ op_info: qtyping.OpInfo,
686
+ graph_info: qtyping.GraphInfo,
687
+ tensor_name_to_qsv: dict[str, Any],
688
+ ) -> list[qtyping.TensorTransformationParams]:
689
+ """Materialize tensors in tfl.pad."""
690
+ return common_utils.materialize_standard_op(
691
+ op_info,
692
+ graph_info,
693
+ tensor_name_to_qsv,
694
+ get_tensor_quant_params_fn,
695
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
696
+ inputs_to_ignore=[1], # Padding value does not need to be quantized.
697
+ )
698
+
699
+
683
700
  def _get_tensor_shape_for_blockwise(
684
701
  tensor_shape: Sequence[int], quantized_dim: int, block_size: int
685
702
  ) -> list[int]:
@@ -709,18 +726,29 @@ def _get_tensor_shape_for_blockwise(
709
726
 
710
727
 
711
728
  def _reshape_data_for_blockwise(
712
- tensor_data: np.ndarray, quantized_dim: int, block_size: int
729
+ tensor_data: np.ndarray,
730
+ quantized_dim: int,
731
+ block_size: int,
713
732
  ) -> tuple[np.ndarray, int]:
714
733
  """Reshapes data for blockwise quantization.
715
734
 
716
735
  Args:
717
736
  tensor_data: The original tensor data.
718
737
  quantized_dim: The dimension to be quantized blockwise.
719
- block_size: The size of the block.
738
+ block_size: The size of the block. `block_size must be a multiple of 32. `
739
+ `The tensor quantized dimension shape must be divisible by block_size.
720
740
 
721
741
  Returns:
722
742
  A tuple containing the reshaped tensor data and the new reduce dimension.
723
743
  """
744
+
745
+ # TODO: b/417508018 - create AEQ specific error class instead of
746
+ # using generic ValueError.
747
+ if tensor_data.shape[quantized_dim] % block_size != 0:
748
+ raise ValueError(
749
+ "Tensor quantization dimension must be divisible by block size for"
750
+ " blockwise quantization."
751
+ )
724
752
  new_shape = _get_tensor_shape_for_blockwise(
725
753
  tensor_data.shape, quantized_dim, block_size
726
754
  )
@@ -801,22 +829,19 @@ def init_tensor_min_max(
801
829
  weight_tensor_config.granularity == qtyping.QuantGranularity.CHANNELWISE
802
830
  ):
803
831
  quantized_dim = common_utils.get_weight_quantized_dim(
804
- op_info, tensor_data
832
+ op_info, tensor_data, weight_tensor_config.granularity
805
833
  )
806
834
  if (
807
835
  weight_tensor_config is not None
808
836
  and weight_tensor_config.granularity
809
837
  == qtyping.QuantGranularity.BLOCKWISE
810
838
  ):
811
- quantized_dim = (
812
- tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
813
- op_info.op_name
814
- ]
815
- )
816
- reshaped_data, reduce_dims = _reshape_data_for_blockwise(
817
- tensor_data,
818
- quantized_dim,
819
- weight_tensor_config.block_size,
839
+ reshaped_data, reduce_dims = (
840
+ uniform_quantize_tensor.reshape_data_for_blockwise(
841
+ tensor_data,
842
+ op_info.op_name,
843
+ weight_tensor_config.block_size,
844
+ )
820
845
  )
821
846
  return {
822
847
  "min": np.min(reshaped_data, axis=reduce_dims, keepdims=False),
@@ -31,8 +31,7 @@ _TensorQuantConfig = qtyping.TensorQuantizationConfig
31
31
 
32
32
 
33
33
  class CommonQuantizeTest(parameterized.TestCase):
34
- """Tests for general quantize functions.
35
- """
34
+ """Tests for general quantize functions."""
36
35
 
37
36
  def setUp(self):
38
37
  super().setUp()
@@ -69,6 +68,34 @@ class CommonQuantizeTest(parameterized.TestCase):
69
68
  default_policy.DEFAULT_CONFIG_CHECK_POLICY,
70
69
  )
71
70
 
71
+ def test_reshape_data_for_blockwise_raises_error_when_quantized_dim_not_divisible_by_block_size(
72
+ self,
73
+ ):
74
+ tensor_data = np.ones((24, 128), dtype=np.float32)
75
+ block_size = 256
76
+ quantized_dim = 1
77
+ with self.assertRaisesWithPredicateMatch(
78
+ ValueError,
79
+ lambda err: (
80
+ "Tensor quantization dimension must be divisible by block"
81
+ " size for blockwise quantization."
82
+ )
83
+ in str(err),
84
+ ):
85
+ common_quantize._reshape_data_for_blockwise(
86
+ tensor_data, quantized_dim, block_size
87
+ )
88
+
89
+ def test_reshape_data_for_blockwise_returns_correct_values(self):
90
+ tensor_data = np.ones((24, 128), dtype=np.float32)
91
+ block_size = 32
92
+ quantized_dim = 1
93
+ new_tensor_data, reduce_dim = common_quantize._reshape_data_for_blockwise(
94
+ tensor_data, quantized_dim, block_size
95
+ )
96
+ self.assertEqual(new_tensor_data.shape, (24, 4, 32))
97
+ self.assertEqual(reduce_dim, 2)
98
+
72
99
 
73
100
  if __name__ == "__main__":
74
101
  googletest.main()
@@ -168,11 +168,9 @@ def get_tensor_quant_params(
168
168
  "Only symmetric weights are supported for dequantized weight recovery."
169
169
  )
170
170
 
171
- quantized_dim = None
172
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
173
- quantized_dim = common_utils.get_weight_quantized_dim(
174
- op_info, tensor_content
175
- )
171
+ quantized_dim = common_utils.get_weight_quantized_dim(
172
+ op_info, tensor_content, tensor_quant_config.granularity
173
+ )
176
174
 
177
175
  zp, scale = get_zp_scale_from_dequantized_symmetric_weights(
178
176
  dequant_vals=tensor_content,
@@ -0,0 +1,357 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements the Hadamard Rotation quantization."""
17
+
18
+ from typing import Any, Optional
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.algorithms.uniform_quantize import octav
22
+ from ai_edge_quantizer.algorithms.utils import common_utils
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
+
25
+
26
+ ALGORITHM_KEY = "HADAMARD_ROTATION"
27
+
28
+
29
+ def _make_hadamard_matrix(size: int) -> np.ndarray:
30
+ """Generates a Hadamard matrix of the given size.
31
+
32
+ Args:
33
+ size: The size of the Hadamard matrix. Must be a power of 2. This
34
+ represents a single dimension. E.g. if size is 4, then the Hadamard matrix
35
+ is a 4x4 matrix.
36
+
37
+ Returns:
38
+ The Hadamard matrix.
39
+
40
+ Raises:
41
+ ValueError: If the size is not a power of 2.
42
+ """
43
+ if size <= 0 or (size & (size - 1)) != 0:
44
+ raise ValueError("Hadamard matrix size must be a power of 2. ")
45
+ h = h2 = np.array([[1, 1], [1, -1]])
46
+ current_size = 2
47
+ while current_size < size:
48
+ h = np.kron(h, h2)
49
+ current_size *= 2
50
+ return h / np.sqrt(size)
51
+
52
+
53
+ def _rotate_with_diagonal_hadamard(
54
+ tensor_content: np.ndarray,
55
+ axis: int,
56
+ ):
57
+ """Quantizes the given float array using the diagonal Hadamard algorithm.
58
+
59
+ Args:
60
+ tensor_content: The float array to quantize.
61
+ axis: The axis of the tensor to quantize.
62
+
63
+ Returns:
64
+ A tuple containing the quantized array and the recovered array.
65
+
66
+ Raises:
67
+ ValueError: If the axis is not 1. To support other axes, please add
68
+ support to the matrix multiplication.
69
+ """
70
+ if axis != 1:
71
+ raise ValueError(
72
+ "Hadamard rotation is only supported for 2D tensors with quantized"
73
+ " dimension 0."
74
+ )
75
+
76
+ # Use the largest power of 2 that is a factor of the dimension and then
77
+ # tile this Hadamard matrix along the diagonal. 2**30 is just a large power
78
+ # of 2 to calculate this factor.
79
+ hadamard_size = np.gcd(tensor_content.shape[axis], 2 ** 30)
80
+ diagonal_size = tensor_content.shape[axis] // hadamard_size
81
+ output_size = tensor_content.shape[1 - axis]
82
+ random_vector = np.ones(hadamard_size, dtype=np.int8)
83
+
84
+ # Use a canonical Hadamard matrix.
85
+ hadamard = _make_hadamard_matrix(hadamard_size)
86
+ reshaped_tensor = tensor_content.reshape(
87
+ diagonal_size, output_size, hadamard_size
88
+ )
89
+ w_rotated = np.einsum("jk,ilk->ilj", hadamard, reshaped_tensor)
90
+ return w_rotated.reshape(tensor_content.shape), hadamard_size, random_vector
91
+
92
+
93
+ def get_tensor_quant_params(
94
+ op_info: qtyping.OpInfo,
95
+ tensor_quant_config: qtyping.TensorQuantizationConfig,
96
+ tensor_content: Optional[np.ndarray] = None,
97
+ tensor_qsv: Optional[dict[str, Any]] = None,
98
+ ) -> qtyping.UniformQuantParams:
99
+ """Returns the quantization parameters for a tensor.
100
+
101
+ This function will rotate the tensor with a Hadamard matrix and then
102
+ quantize it with OCTAV.
103
+
104
+ Args:
105
+ op_info: Aggregated information about the op (e.g., quantization config).
106
+ tensor_quant_config: The quantization config for the tensor.
107
+ tensor_content: The content of the tensor. When None, it means the tensor is
108
+ not a weight tensor (e.g. static quantization).
109
+ tensor_qsv: A dictionary containing the min/max of the tensor.
110
+
111
+ Raises:
112
+ ValueError: If the blockwise quantization is requested.
113
+ ValueError: If the asymmetric quantization is requested.
114
+ ValueError: `tensor_qsv` must contain min/max values, or `tensor_content`
115
+ must be provided so that they can be inferred.
116
+ """
117
+ if tensor_content is None:
118
+ raise ValueError("Hadamard rotation is only supported for weight tensors.")
119
+
120
+ if tensor_qsv is not None:
121
+ raise ValueError(
122
+ "Hadamard rotation is not supported for static quantization."
123
+ )
124
+
125
+ if tensor_content.ndim != 2:
126
+ raise ValueError("Hadamard rotation is only supported for 2D tensors.")
127
+
128
+ if tensor_quant_config.granularity != qtyping.QuantGranularity.CHANNELWISE:
129
+ raise ValueError(
130
+ "Hadamard rotation is not supported for"
131
+ f" {tensor_quant_config.granularity} granularity."
132
+ )
133
+
134
+ quantized_dim = common_utils.get_weight_quantized_dim(
135
+ op_info, tensor_content, tensor_quant_config.granularity
136
+ )
137
+ if quantized_dim != 0:
138
+ raise ValueError(
139
+ f"Unsupported quantized dimension: {quantized_dim}. Only 0 is"
140
+ " supported."
141
+ )
142
+
143
+ # Reduction axis is the non-quantized dimension. Since we only support 2D
144
+ # tensors and quantized_dim of 0, the reduction axis is 1.
145
+ reduce_axis = 1
146
+
147
+ # Rotate the tensor with a Hadamard matrix.
148
+ w_rotated, hadamard_size, random_vector = _rotate_with_diagonal_hadamard(
149
+ tensor_content, axis=reduce_axis
150
+ )
151
+
152
+ # Get the quantized values of the rotated tensor.
153
+ qparams = octav.get_tensor_quant_params(
154
+ op_info, tensor_quant_config, w_rotated, tensor_qsv
155
+ )
156
+
157
+ return qtyping.UniformQuantParams(
158
+ quantized_dimension=qparams.quantized_dimension,
159
+ num_bits=qparams.num_bits,
160
+ scale=qparams.scale,
161
+ zero_point=qparams.zero_point,
162
+ symmetric=qparams.symmetric,
163
+ quantized_data=qparams.quantized_data,
164
+ block_size=qparams.block_size,
165
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
166
+ random_binary_vector=random_vector,
167
+ hadamard_size=hadamard_size,
168
+ ),
169
+ )
170
+
171
+
172
+ def materialize_fully_connected(
173
+ op_info: qtyping.OpInfo,
174
+ graph_info: qtyping.GraphInfo,
175
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
176
+ ) -> list[qtyping.TensorTransformationParams]:
177
+ """Materialize the fully_connected op.
178
+
179
+ Args:
180
+ op_info: Aggregated information about the op (e.g., quantization config).
181
+ graph_info: Graph information needed to perform quantization for the op.
182
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
183
+
184
+ Returns:
185
+ Quantization configuration for the tensors associated with the op (e.g.,
186
+ weights, bias).
187
+ """
188
+ op_tensor_params = []
189
+
190
+ # Materialize weight.
191
+ weight_tensor_index = 1
192
+ weight_tensor = graph_info.subgraph_tensors[
193
+ op_info.op.inputs[weight_tensor_index]
194
+ ]
195
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
196
+ weight_tensor, graph_info.buffers
197
+ )
198
+ # quant_params contains the rotated and quantized weights done by
199
+ # get_tensor_quant_params().
200
+ quant_params = get_tensor_quant_params(
201
+ op_info,
202
+ op_info.op_quant_config.weight_tensor_config,
203
+ tensor_data,
204
+ None,
205
+ )
206
+ transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
207
+ op2tensor_params = qtyping.OpToTensorParams(
208
+ subgraph_op_id=op_info.subgraph_op_index,
209
+ parameters=quant_params,
210
+ transformations=transformations,
211
+ )
212
+ weight_transformation_params = qtyping.TensorTransformationParams(
213
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
214
+ consumers=[op2tensor_params],
215
+ )
216
+
217
+ # Materialize input. A hadamard rotation op should be inserted on the input
218
+ # tensor to do the inverse of the weight's transformation.
219
+ input_tensor_index = 0
220
+ input_tensor = graph_info.subgraph_tensors[
221
+ op_info.op.inputs[input_tensor_index]
222
+ ]
223
+ transformations = [
224
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
225
+ ]
226
+ op2tensor_params = qtyping.OpToTensorParams(
227
+ subgraph_op_id=op_info.subgraph_op_index,
228
+ parameters=quant_params,
229
+ transformations=transformations,
230
+ )
231
+ input_transformation_params = qtyping.TensorTransformationParams(
232
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(input_tensor),
233
+ consumers=[op2tensor_params],
234
+ )
235
+ op_tensor_params.append(input_transformation_params)
236
+ op_tensor_params.append(weight_transformation_params)
237
+
238
+ # Materialize bias. Since static quantization is not supported, we do not
239
+ # quantize the bias tensor.
240
+ bias_tensor_index = 2
241
+ bias_tensor = graph_info.subgraph_tensors[
242
+ op_info.op.inputs[bias_tensor_index]
243
+ ]
244
+ no_quant_tensor_params = qtyping.OpToTensorParams(
245
+ subgraph_op_id=op_info.subgraph_op_index,
246
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
247
+ )
248
+ bias_transformation_params = qtyping.TensorTransformationParams(
249
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
250
+ consumers=[no_quant_tensor_params],
251
+ )
252
+ op_tensor_params.append(bias_transformation_params)
253
+
254
+ # Materialize output. Since static quantization is not supported, we do not
255
+ # quantize the output tensor.
256
+ output_tensor_index = 0
257
+ output_tensor = graph_info.subgraph_tensors[
258
+ op_info.op.outputs[output_tensor_index]
259
+ ]
260
+ no_quant_tensor_params = qtyping.OpToTensorParams(
261
+ subgraph_op_id=op_info.subgraph_op_index,
262
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
263
+ )
264
+ output_transformation_params = qtyping.TensorTransformationParams(
265
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
266
+ producer=no_quant_tensor_params,
267
+ )
268
+ op_tensor_params.append(output_transformation_params)
269
+
270
+ return op_tensor_params
271
+
272
+
273
+ def materialize_embedding_lookup(
274
+ op_info: qtyping.OpInfo,
275
+ graph_info: qtyping.GraphInfo,
276
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
277
+ ) -> list[qtyping.TensorTransformationParams]:
278
+ """Materialize the embedding_lookup op.
279
+
280
+ Args:
281
+ op_info: Aggregated information about the op (e.g., quantization config).
282
+ graph_info: Graph information needed to perform quantization for the op.
283
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
284
+
285
+ Returns:
286
+ Quantization configuration for the tensors associated with the op (e.g.,
287
+ weights, bias).
288
+ """
289
+ op_tensor_params = []
290
+
291
+ # Materialize lookup.
292
+ lookup_tensor_index = 0
293
+ lookup_tensor = graph_info.subgraph_tensors[
294
+ op_info.op.inputs[lookup_tensor_index]
295
+ ]
296
+ transformations = [
297
+ qtyping.QuantTransformation.NO_QUANTIZE,
298
+ ]
299
+ op2tensor_params = qtyping.OpToTensorParams(
300
+ subgraph_op_id=op_info.subgraph_op_index,
301
+ parameters=None,
302
+ transformations=transformations,
303
+ )
304
+ lookup_transformation_params = qtyping.TensorTransformationParams(
305
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(lookup_tensor),
306
+ consumers=[op2tensor_params],
307
+ )
308
+ op_tensor_params.append(lookup_transformation_params)
309
+
310
+ # Materialize embedding. The embedding table should be rotated and then
311
+ # quantized.
312
+ embedding_tensor_index = 1
313
+ embedding_tensor = graph_info.subgraph_tensors[
314
+ op_info.op.inputs[embedding_tensor_index]
315
+ ]
316
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
317
+ embedding_tensor, graph_info.buffers
318
+ )
319
+ quant_params = get_tensor_quant_params(
320
+ op_info,
321
+ op_info.op_quant_config.weight_tensor_config,
322
+ tensor_data,
323
+ None,
324
+ )
325
+ transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
326
+ op2tensor_params = qtyping.OpToTensorParams(
327
+ subgraph_op_id=op_info.subgraph_op_index,
328
+ parameters=quant_params,
329
+ transformations=transformations,
330
+ )
331
+ weight_transformation_params = qtyping.TensorTransformationParams(
332
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(embedding_tensor),
333
+ consumers=[op2tensor_params],
334
+ )
335
+ op_tensor_params.append(weight_transformation_params)
336
+
337
+ # Materialize output. A hadamard rotation op should be inserted on the output
338
+ # tensor to do the inverse of the embedding's transformation.
339
+ output_tensor_index = 0
340
+ output_tensor = graph_info.subgraph_tensors[
341
+ op_info.op.outputs[output_tensor_index]
342
+ ]
343
+ transformations = [
344
+ qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
345
+ ]
346
+ op2tensor_params = qtyping.OpToTensorParams(
347
+ subgraph_op_id=op_info.subgraph_op_index,
348
+ parameters=quant_params,
349
+ transformations=transformations,
350
+ )
351
+ output_transformation_params = qtyping.TensorTransformationParams(
352
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
353
+ producer=op2tensor_params,
354
+ )
355
+ op_tensor_params.append(output_transformation_params)
356
+
357
+ return op_tensor_params