ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -158,7 +158,7 @@ def get_tensor_quant_params(
158
158
  op_info, tensor_quant_config, tensor_content, tensor_qsv
159
159
  )
160
160
 
161
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
161
+ if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
162
162
  raise ValueError(
163
163
  "Blockwise quantization is not supported for dequantized weight"
164
164
  " recovery."
@@ -168,11 +168,9 @@ def get_tensor_quant_params(
168
168
  "Only symmetric weights are supported for dequantized weight recovery."
169
169
  )
170
170
 
171
- quantized_dim = None
172
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
173
- quantized_dim = common_utils.get_weight_quantized_dim(
174
- op_info, tensor_content
175
- )
171
+ quantized_dim = common_utils.get_weight_quantized_dim(
172
+ op_info, tensor_content, tensor_quant_config.granularity
173
+ )
176
174
 
177
175
  zp, scale = get_zp_scale_from_dequantized_symmetric_weights(
178
176
  dequant_vals=tensor_content,
@@ -0,0 +1,414 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements the Hadamard Rotation quantization."""
17
+
18
+ from typing import Any, Optional
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.algorithms.uniform_quantize import octav
22
+ from ai_edge_quantizer.algorithms.utils import common_utils
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
+
25
+
26
+ CUSTOM_OP_ALGORITHM_KEY = "HADAMARD_ROTATION"
27
+ DECOMPOSED_ALGORITHM_KEY = "DECOMPOSED_HADAMARD_ROTATION"
28
+
29
+
30
+ def _make_hadamard_matrix(size: int) -> np.ndarray:
31
+ """Generates a Hadamard matrix of the given size.
32
+
33
+ Args:
34
+ size: The size of the Hadamard matrix. Must be a power of 2. This represents
35
+ a single dimension. E.g. if size is 4, then the Hadamard matrix is a 4x4
36
+ matrix.
37
+
38
+ Returns:
39
+ The Hadamard matrix.
40
+
41
+ Raises:
42
+ ValueError: If the size is not a power of 2.
43
+ """
44
+ if size <= 0 or (size & (size - 1)) != 0:
45
+ raise ValueError("Hadamard matrix size must be a power of 2. ")
46
+ h = h2 = np.array([[1, 1], [1, -1]])
47
+ current_size = 2
48
+ while current_size < size:
49
+ h = np.kron(h, h2)
50
+ current_size *= 2
51
+ return h / np.sqrt(size)
52
+
53
+
54
+ def _rotate_with_diagonal_hadamard(
55
+ tensor_content: np.ndarray,
56
+ axis: int,
57
+ ):
58
+ """Quantizes the given float array using the diagonal Hadamard algorithm.
59
+
60
+ Args:
61
+ tensor_content: The float array to quantize.
62
+ axis: The axis of the tensor to rotate.
63
+
64
+ Returns:
65
+ A tuple containing the quantized array and the recovered array.
66
+
67
+ Raises:
68
+ ValueError: If the axis is not the last axis of tensor_content. To support
69
+ other axes, please add support to the matrix multiplication.
70
+ """
71
+ if axis != tensor_content.ndim - 1:
72
+ raise ValueError(
73
+ "Hadamard rotation is only supported for tensors with quantized"
74
+ " dimension 0 (rotate last dimension)."
75
+ )
76
+
77
+ # Use the largest power of 2 that is a factor of the dimension and then
78
+ # tile this Hadamard matrix along the diagonal. 2**30 is just a large power
79
+ # of 2 to calculate this factor.
80
+ hadamard_size = np.gcd(tensor_content.shape[axis], 2 ** 30)
81
+ diagonal_size = tensor_content.shape[axis] // hadamard_size
82
+ # Output size is the product of all dimensions except the one being rotated.
83
+ output_size = np.prod(np.delete(tensor_content.shape, axis))
84
+ random_vector = np.ones(hadamard_size, dtype=np.int8)
85
+
86
+ # Use a canonical Hadamard matrix.
87
+ hadamard = _make_hadamard_matrix(hadamard_size)
88
+ reshaped_tensor = tensor_content.reshape(
89
+ diagonal_size * output_size, hadamard_size)
90
+ w_rotated = np.matmul(hadamard, reshaped_tensor.mT).mT
91
+ return w_rotated.reshape(tensor_content.shape), hadamard_size, random_vector
92
+
93
+
94
+ def get_tensor_quant_params(
95
+ op_info: qtyping.OpInfo,
96
+ tensor_quant_config: qtyping.TensorQuantizationConfig,
97
+ tensor_content: Optional[np.ndarray] = None,
98
+ tensor_qsv: Optional[dict[str, Any]] = None,
99
+ ) -> qtyping.UniformQuantParams:
100
+ """Returns the quantization parameters for a tensor.
101
+
102
+ This function will rotate the tensor with a Hadamard matrix and then
103
+ quantize it with OCTAV.
104
+
105
+ Args:
106
+ op_info: Aggregated information about the op (e.g., quantization config).
107
+ tensor_quant_config: The quantization config for the tensor.
108
+ tensor_content: The content of the tensor. When None, it means the tensor is
109
+ not a weight tensor (e.g. static quantization).
110
+ tensor_qsv: A dictionary containing the min/max of the tensor.
111
+
112
+ Raises:
113
+ ValueError: If the blockwise quantization is requested.
114
+ ValueError: If the asymmetric quantization is requested.
115
+ ValueError: `tensor_qsv` must contain min/max values, or `tensor_content`
116
+ must be provided so that they can be inferred.
117
+ """
118
+ if tensor_content is None:
119
+ raise ValueError("Hadamard rotation is only supported for weight tensors.")
120
+
121
+ if tensor_qsv is not None:
122
+ raise ValueError(
123
+ "Hadamard rotation is not supported for static quantization."
124
+ )
125
+
126
+ if tensor_content.ndim < 2:
127
+ raise ValueError(
128
+ "Hadamard rotation is only supported for tensors with rank >= 2."
129
+ )
130
+
131
+ # Reduction axis is the last non-quantized dimension. Since we only support
132
+ # quantized_dim of 0 (or 1 for blockwise), the reduction axis is the last
133
+ # axis.
134
+ reduce_axis = tensor_content.ndim - 1
135
+
136
+ # Rotate the tensor with a Hadamard matrix.
137
+ w_rotated, hadamard_size, random_vector = _rotate_with_diagonal_hadamard(
138
+ tensor_content, axis=reduce_axis
139
+ )
140
+
141
+ # Get the quantized values of the rotated tensor.
142
+ qparams = octav.get_tensor_quant_params(
143
+ op_info, tensor_quant_config, w_rotated, tensor_qsv
144
+ )
145
+
146
+ return qtyping.UniformQuantParams(
147
+ quantized_dimension=qparams.quantized_dimension,
148
+ num_bits=qparams.num_bits,
149
+ scale=qparams.scale,
150
+ zero_point=qparams.zero_point,
151
+ symmetric=qparams.symmetric,
152
+ quantized_data=qparams.quantized_data,
153
+ block_size=qparams.block_size,
154
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
155
+ random_binary_vector=random_vector,
156
+ hadamard_size=hadamard_size,
157
+ ),
158
+ )
159
+
160
+
161
+ def _materialize_fully_connected(
162
+ op_info: qtyping.OpInfo,
163
+ graph_info: qtyping.GraphInfo,
164
+ is_decomposed: bool = False,
165
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
166
+ ) -> list[qtyping.TensorTransformationParams]:
167
+ """Materialize the fully_connected op.
168
+
169
+ Args:
170
+ op_info: Aggregated information about the op (e.g., quantization config).
171
+ graph_info: Graph information needed to perform quantization for the op.
172
+ is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
173
+ op.
174
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
175
+
176
+ Returns:
177
+ Quantization configuration for the tensors associated with the op (e.g.,
178
+ weights, bias).
179
+ """
180
+ if op_info.op_quant_config.weight_tensor_config is None:
181
+ raise ValueError(
182
+ "Weight tensor quantization config is not provided for Hadamard"
183
+ " Rotation quantization."
184
+ )
185
+
186
+ op_tensor_params = []
187
+
188
+ # Materialize weight.
189
+ weight_tensor_index = 1
190
+ weight_tensor = graph_info.subgraph_tensors[
191
+ op_info.op.inputs[weight_tensor_index]
192
+ ]
193
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
194
+ weight_tensor, graph_info.buffers
195
+ )
196
+ # quant_params contains the rotated and quantized weights done by
197
+ # get_tensor_quant_params().
198
+ quant_params = get_tensor_quant_params(
199
+ op_info,
200
+ op_info.op_quant_config.weight_tensor_config,
201
+ tensor_data,
202
+ None,
203
+ )
204
+ transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
205
+ op2tensor_params = qtyping.OpToTensorParams(
206
+ subgraph_op_id=op_info.subgraph_op_index,
207
+ parameters=quant_params,
208
+ transformations=transformations,
209
+ )
210
+ weight_transformation_params = qtyping.TensorTransformationParams(
211
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
212
+ consumers=[op2tensor_params],
213
+ )
214
+
215
+ # Materialize input. A hadamard rotation op should be inserted on the input
216
+ # tensor to do the inverse of the weight's transformation.
217
+ input_tensor_index = 0
218
+ input_tensor = graph_info.subgraph_tensors[
219
+ op_info.op.inputs[input_tensor_index]
220
+ ]
221
+ transformations = [
222
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
223
+ if is_decomposed
224
+ else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
225
+ ]
226
+ op2tensor_params = qtyping.OpToTensorParams(
227
+ subgraph_op_id=op_info.subgraph_op_index,
228
+ parameters=quant_params,
229
+ transformations=transformations,
230
+ )
231
+ input_transformation_params = qtyping.TensorTransformationParams(
232
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(input_tensor),
233
+ consumers=[op2tensor_params],
234
+ )
235
+ op_tensor_params.append(input_transformation_params)
236
+ op_tensor_params.append(weight_transformation_params)
237
+
238
+ # Materialize bias. Since static quantization is not supported, we do not
239
+ # quantize the bias tensor.
240
+ bias_tensor_index = 2
241
+ bias_tensor = graph_info.subgraph_tensors[
242
+ op_info.op.inputs[bias_tensor_index]
243
+ ]
244
+ no_quant_tensor_params = qtyping.OpToTensorParams(
245
+ subgraph_op_id=op_info.subgraph_op_index,
246
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
247
+ )
248
+ bias_transformation_params = qtyping.TensorTransformationParams(
249
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
250
+ consumers=[no_quant_tensor_params],
251
+ )
252
+ op_tensor_params.append(bias_transformation_params)
253
+
254
+ # Materialize output. Since static quantization is not supported, we do not
255
+ # quantize the output tensor.
256
+ output_tensor_index = 0
257
+ output_tensor = graph_info.subgraph_tensors[
258
+ op_info.op.outputs[output_tensor_index]
259
+ ]
260
+ no_quant_tensor_params = qtyping.OpToTensorParams(
261
+ subgraph_op_id=op_info.subgraph_op_index,
262
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
263
+ )
264
+ output_transformation_params = qtyping.TensorTransformationParams(
265
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
266
+ producer=no_quant_tensor_params,
267
+ )
268
+ op_tensor_params.append(output_transformation_params)
269
+
270
+ return op_tensor_params
271
+
272
+
273
+ def materialize_fully_connected_custom_op(
274
+ op_info: qtyping.OpInfo,
275
+ graph_info: qtyping.GraphInfo,
276
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
277
+ ) -> list[qtyping.TensorTransformationParams]:
278
+ return _materialize_fully_connected(
279
+ op_info,
280
+ graph_info,
281
+ is_decomposed=False,
282
+ tensor_name_to_qsv=tensor_name_to_qsv,
283
+ )
284
+
285
+
286
+ def materialize_fully_connected_decomposed(
287
+ op_info: qtyping.OpInfo,
288
+ graph_info: qtyping.GraphInfo,
289
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
290
+ ) -> list[qtyping.TensorTransformationParams]:
291
+ return _materialize_fully_connected(
292
+ op_info,
293
+ graph_info,
294
+ is_decomposed=True,
295
+ tensor_name_to_qsv=tensor_name_to_qsv,
296
+ )
297
+
298
+
299
+ def _materialize_embedding_lookup(
300
+ op_info: qtyping.OpInfo,
301
+ graph_info: qtyping.GraphInfo,
302
+ is_decomposed: bool = False,
303
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
304
+ ) -> list[qtyping.TensorTransformationParams]:
305
+ """Materialize the embedding_lookup op.
306
+
307
+ Args:
308
+ op_info: Aggregated information about the op (e.g., quantization config).
309
+ graph_info: Graph information needed to perform quantization for the op.
310
+ is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
311
+ op.
312
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
313
+
314
+ Returns:
315
+ Quantization configuration for the tensors associated with the op (e.g.,
316
+ weights, bias).
317
+ """
318
+ op_tensor_params = []
319
+
320
+ # Materialize lookup.
321
+ lookup_tensor_index = 0
322
+ lookup_tensor = graph_info.subgraph_tensors[
323
+ op_info.op.inputs[lookup_tensor_index]
324
+ ]
325
+ transformations = [
326
+ qtyping.QuantTransformation.NO_QUANTIZE,
327
+ ]
328
+ op2tensor_params = qtyping.OpToTensorParams(
329
+ subgraph_op_id=op_info.subgraph_op_index,
330
+ parameters=None,
331
+ transformations=transformations,
332
+ )
333
+ lookup_transformation_params = qtyping.TensorTransformationParams(
334
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(lookup_tensor),
335
+ consumers=[op2tensor_params],
336
+ )
337
+ op_tensor_params.append(lookup_transformation_params)
338
+
339
+ # Materialize embedding. The embedding table should be rotated and then
340
+ # quantized.
341
+ embedding_tensor_index = 1
342
+ embedding_tensor = graph_info.subgraph_tensors[
343
+ op_info.op.inputs[embedding_tensor_index]
344
+ ]
345
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
346
+ embedding_tensor, graph_info.buffers
347
+ )
348
+ quant_params = get_tensor_quant_params(
349
+ op_info,
350
+ op_info.op_quant_config.weight_tensor_config,
351
+ tensor_data,
352
+ None,
353
+ )
354
+ transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
355
+ op2tensor_params = qtyping.OpToTensorParams(
356
+ subgraph_op_id=op_info.subgraph_op_index,
357
+ parameters=quant_params,
358
+ transformations=transformations,
359
+ )
360
+ weight_transformation_params = qtyping.TensorTransformationParams(
361
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(embedding_tensor),
362
+ consumers=[op2tensor_params],
363
+ )
364
+ op_tensor_params.append(weight_transformation_params)
365
+
366
+ # Materialize output. A hadamard rotation op should be inserted on the output
367
+ # tensor to do the inverse of the embedding's transformation.
368
+ output_tensor_index = 0
369
+ output_tensor = graph_info.subgraph_tensors[
370
+ op_info.op.outputs[output_tensor_index]
371
+ ]
372
+ transformations = [
373
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
374
+ if is_decomposed
375
+ else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
376
+ ]
377
+ op2tensor_params = qtyping.OpToTensorParams(
378
+ subgraph_op_id=op_info.subgraph_op_index,
379
+ parameters=quant_params,
380
+ transformations=transformations,
381
+ )
382
+ output_transformation_params = qtyping.TensorTransformationParams(
383
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
384
+ producer=op2tensor_params,
385
+ )
386
+ op_tensor_params.append(output_transformation_params)
387
+
388
+ return op_tensor_params
389
+
390
+
391
+ def materialize_embedding_lookup_custom_op(
392
+ op_info: qtyping.OpInfo,
393
+ graph_info: qtyping.GraphInfo,
394
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
395
+ ) -> list[qtyping.TensorTransformationParams]:
396
+ return _materialize_embedding_lookup(
397
+ op_info,
398
+ graph_info,
399
+ is_decomposed=False,
400
+ tensor_name_to_qsv=tensor_name_to_qsv,
401
+ )
402
+
403
+
404
+ def materialize_embedding_lookup_decomposed(
405
+ op_info: qtyping.OpInfo,
406
+ graph_info: qtyping.GraphInfo,
407
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
408
+ ) -> list[qtyping.TensorTransformationParams]:
409
+ return _materialize_embedding_lookup(
410
+ op_info,
411
+ graph_info,
412
+ is_decomposed=True,
413
+ tensor_name_to_qsv=tensor_name_to_qsv,
414
+ )