ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +158 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
- ai_edge_quantizer/calibrator.py +11 -60
- ai_edge_quantizer/calibrator_test.py +4 -73
- ai_edge_quantizer/default_policy.py +61 -26
- ai_edge_quantizer/model_modifier.py +97 -7
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +31 -8
- ai_edge_quantizer/params_generator.py +17 -10
- ai_edge_quantizer/params_generator_test.py +2 -7
- ai_edge_quantizer/qtyping.py +86 -6
- ai_edge_quantizer/quantizer.py +166 -21
- ai_edge_quantizer/quantizer_test.py +284 -16
- ai_edge_quantizer/recipe.py +154 -42
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +118 -13
- ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
- ai_edge_quantizer/transformation_performer.py +55 -25
- ai_edge_quantizer/transformation_performer_test.py +127 -5
- ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
- ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
- ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
- ai_edge_quantizer/transformations/transformation_utils.py +129 -6
- ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +75 -2
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -158,7 +158,7 @@ def get_tensor_quant_params(
|
|
|
158
158
|
op_info, tensor_quant_config, tensor_content, tensor_qsv
|
|
159
159
|
)
|
|
160
160
|
|
|
161
|
-
if tensor_quant_config.granularity
|
|
161
|
+
if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
|
|
162
162
|
raise ValueError(
|
|
163
163
|
"Blockwise quantization is not supported for dequantized weight"
|
|
164
164
|
" recovery."
|
|
@@ -168,11 +168,9 @@ def get_tensor_quant_params(
|
|
|
168
168
|
"Only symmetric weights are supported for dequantized weight recovery."
|
|
169
169
|
)
|
|
170
170
|
|
|
171
|
-
quantized_dim =
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
op_info, tensor_content
|
|
175
|
-
)
|
|
171
|
+
quantized_dim = common_utils.get_weight_quantized_dim(
|
|
172
|
+
op_info, tensor_content, tensor_quant_config.granularity
|
|
173
|
+
)
|
|
176
174
|
|
|
177
175
|
zp, scale = get_zp_scale_from_dequantized_symmetric_weights(
|
|
178
176
|
dequant_vals=tensor_content,
|
|
@@ -0,0 +1,414 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Implements the Hadamard Rotation quantization."""
|
|
17
|
+
|
|
18
|
+
from typing import Any, Optional
|
|
19
|
+
import numpy as np
|
|
20
|
+
from ai_edge_quantizer import qtyping
|
|
21
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import octav
|
|
22
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
|
23
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
CUSTOM_OP_ALGORITHM_KEY = "HADAMARD_ROTATION"
|
|
27
|
+
DECOMPOSED_ALGORITHM_KEY = "DECOMPOSED_HADAMARD_ROTATION"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _make_hadamard_matrix(size: int) -> np.ndarray:
|
|
31
|
+
"""Generates a Hadamard matrix of the given size.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
size: The size of the Hadamard matrix. Must be a power of 2. This represents
|
|
35
|
+
a single dimension. E.g. if size is 4, then the Hadamard matrix is a 4x4
|
|
36
|
+
matrix.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
The Hadamard matrix.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If the size is not a power of 2.
|
|
43
|
+
"""
|
|
44
|
+
if size <= 0 or (size & (size - 1)) != 0:
|
|
45
|
+
raise ValueError("Hadamard matrix size must be a power of 2. ")
|
|
46
|
+
h = h2 = np.array([[1, 1], [1, -1]])
|
|
47
|
+
current_size = 2
|
|
48
|
+
while current_size < size:
|
|
49
|
+
h = np.kron(h, h2)
|
|
50
|
+
current_size *= 2
|
|
51
|
+
return h / np.sqrt(size)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _rotate_with_diagonal_hadamard(
|
|
55
|
+
tensor_content: np.ndarray,
|
|
56
|
+
axis: int,
|
|
57
|
+
):
|
|
58
|
+
"""Quantizes the given float array using the diagonal Hadamard algorithm.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
tensor_content: The float array to quantize.
|
|
62
|
+
axis: The axis of the tensor to rotate.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
A tuple containing the quantized array and the recovered array.
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
ValueError: If the axis is not the last axis of tensor_content. To support
|
|
69
|
+
other axes, please add support to the matrix multiplication.
|
|
70
|
+
"""
|
|
71
|
+
if axis != tensor_content.ndim - 1:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"Hadamard rotation is only supported for tensors with quantized"
|
|
74
|
+
" dimension 0 (rotate last dimension)."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Use the largest power of 2 that is a factor of the dimension and then
|
|
78
|
+
# tile this Hadamard matrix along the diagonal. 2**30 is just a large power
|
|
79
|
+
# of 2 to calculate this factor.
|
|
80
|
+
hadamard_size = np.gcd(tensor_content.shape[axis], 2 ** 30)
|
|
81
|
+
diagonal_size = tensor_content.shape[axis] // hadamard_size
|
|
82
|
+
# Output size is the product of all dimensions except the one being rotated.
|
|
83
|
+
output_size = np.prod(np.delete(tensor_content.shape, axis))
|
|
84
|
+
random_vector = np.ones(hadamard_size, dtype=np.int8)
|
|
85
|
+
|
|
86
|
+
# Use a canonical Hadamard matrix.
|
|
87
|
+
hadamard = _make_hadamard_matrix(hadamard_size)
|
|
88
|
+
reshaped_tensor = tensor_content.reshape(
|
|
89
|
+
diagonal_size * output_size, hadamard_size)
|
|
90
|
+
w_rotated = np.matmul(hadamard, reshaped_tensor.mT).mT
|
|
91
|
+
return w_rotated.reshape(tensor_content.shape), hadamard_size, random_vector
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_tensor_quant_params(
|
|
95
|
+
op_info: qtyping.OpInfo,
|
|
96
|
+
tensor_quant_config: qtyping.TensorQuantizationConfig,
|
|
97
|
+
tensor_content: Optional[np.ndarray] = None,
|
|
98
|
+
tensor_qsv: Optional[dict[str, Any]] = None,
|
|
99
|
+
) -> qtyping.UniformQuantParams:
|
|
100
|
+
"""Returns the quantization parameters for a tensor.
|
|
101
|
+
|
|
102
|
+
This function will rotate the tensor with a Hadamard matrix and then
|
|
103
|
+
quantize it with OCTAV.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
107
|
+
tensor_quant_config: The quantization config for the tensor.
|
|
108
|
+
tensor_content: The content of the tensor. When None, it means the tensor is
|
|
109
|
+
not a weight tensor (e.g. static quantization).
|
|
110
|
+
tensor_qsv: A dictionary containing the min/max of the tensor.
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: If the blockwise quantization is requested.
|
|
114
|
+
ValueError: If the asymmetric quantization is requested.
|
|
115
|
+
ValueError: `tensor_qsv` must contain min/max values, or `tensor_content`
|
|
116
|
+
must be provided so that they can be inferred.
|
|
117
|
+
"""
|
|
118
|
+
if tensor_content is None:
|
|
119
|
+
raise ValueError("Hadamard rotation is only supported for weight tensors.")
|
|
120
|
+
|
|
121
|
+
if tensor_qsv is not None:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"Hadamard rotation is not supported for static quantization."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if tensor_content.ndim < 2:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
"Hadamard rotation is only supported for tensors with rank >= 2."
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Reduction axis is the last non-quantized dimension. Since we only support
|
|
132
|
+
# quantized_dim of 0 (or 1 for blockwise), the reduction axis is the last
|
|
133
|
+
# axis.
|
|
134
|
+
reduce_axis = tensor_content.ndim - 1
|
|
135
|
+
|
|
136
|
+
# Rotate the tensor with a Hadamard matrix.
|
|
137
|
+
w_rotated, hadamard_size, random_vector = _rotate_with_diagonal_hadamard(
|
|
138
|
+
tensor_content, axis=reduce_axis
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Get the quantized values of the rotated tensor.
|
|
142
|
+
qparams = octav.get_tensor_quant_params(
|
|
143
|
+
op_info, tensor_quant_config, w_rotated, tensor_qsv
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return qtyping.UniformQuantParams(
|
|
147
|
+
quantized_dimension=qparams.quantized_dimension,
|
|
148
|
+
num_bits=qparams.num_bits,
|
|
149
|
+
scale=qparams.scale,
|
|
150
|
+
zero_point=qparams.zero_point,
|
|
151
|
+
symmetric=qparams.symmetric,
|
|
152
|
+
quantized_data=qparams.quantized_data,
|
|
153
|
+
block_size=qparams.block_size,
|
|
154
|
+
hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
|
|
155
|
+
random_binary_vector=random_vector,
|
|
156
|
+
hadamard_size=hadamard_size,
|
|
157
|
+
),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _materialize_fully_connected(
|
|
162
|
+
op_info: qtyping.OpInfo,
|
|
163
|
+
graph_info: qtyping.GraphInfo,
|
|
164
|
+
is_decomposed: bool = False,
|
|
165
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
166
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
167
|
+
"""Materialize the fully_connected op.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
171
|
+
graph_info: Graph information needed to perform quantization for the op.
|
|
172
|
+
is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
|
|
173
|
+
op.
|
|
174
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Quantization configuration for the tensors associated with the op (e.g.,
|
|
178
|
+
weights, bias).
|
|
179
|
+
"""
|
|
180
|
+
if op_info.op_quant_config.weight_tensor_config is None:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"Weight tensor quantization config is not provided for Hadamard"
|
|
183
|
+
" Rotation quantization."
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
op_tensor_params = []
|
|
187
|
+
|
|
188
|
+
# Materialize weight.
|
|
189
|
+
weight_tensor_index = 1
|
|
190
|
+
weight_tensor = graph_info.subgraph_tensors[
|
|
191
|
+
op_info.op.inputs[weight_tensor_index]
|
|
192
|
+
]
|
|
193
|
+
tensor_data = tfl_flatbuffer_utils.get_tensor_data(
|
|
194
|
+
weight_tensor, graph_info.buffers
|
|
195
|
+
)
|
|
196
|
+
# quant_params contains the rotated and quantized weights done by
|
|
197
|
+
# get_tensor_quant_params().
|
|
198
|
+
quant_params = get_tensor_quant_params(
|
|
199
|
+
op_info,
|
|
200
|
+
op_info.op_quant_config.weight_tensor_config,
|
|
201
|
+
tensor_data,
|
|
202
|
+
None,
|
|
203
|
+
)
|
|
204
|
+
transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
|
|
205
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
|
206
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
|
207
|
+
parameters=quant_params,
|
|
208
|
+
transformations=transformations,
|
|
209
|
+
)
|
|
210
|
+
weight_transformation_params = qtyping.TensorTransformationParams(
|
|
211
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
|
|
212
|
+
consumers=[op2tensor_params],
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Materialize input. A hadamard rotation op should be inserted on the input
|
|
216
|
+
# tensor to do the inverse of the weight's transformation.
|
|
217
|
+
input_tensor_index = 0
|
|
218
|
+
input_tensor = graph_info.subgraph_tensors[
|
|
219
|
+
op_info.op.inputs[input_tensor_index]
|
|
220
|
+
]
|
|
221
|
+
transformations = [
|
|
222
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
|
|
223
|
+
if is_decomposed
|
|
224
|
+
else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
|
225
|
+
]
|
|
226
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
|
227
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
|
228
|
+
parameters=quant_params,
|
|
229
|
+
transformations=transformations,
|
|
230
|
+
)
|
|
231
|
+
input_transformation_params = qtyping.TensorTransformationParams(
|
|
232
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(input_tensor),
|
|
233
|
+
consumers=[op2tensor_params],
|
|
234
|
+
)
|
|
235
|
+
op_tensor_params.append(input_transformation_params)
|
|
236
|
+
op_tensor_params.append(weight_transformation_params)
|
|
237
|
+
|
|
238
|
+
# Materialize bias. Since static quantization is not supported, we do not
|
|
239
|
+
# quantize the bias tensor.
|
|
240
|
+
bias_tensor_index = 2
|
|
241
|
+
bias_tensor = graph_info.subgraph_tensors[
|
|
242
|
+
op_info.op.inputs[bias_tensor_index]
|
|
243
|
+
]
|
|
244
|
+
no_quant_tensor_params = qtyping.OpToTensorParams(
|
|
245
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
|
246
|
+
transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
|
|
247
|
+
)
|
|
248
|
+
bias_transformation_params = qtyping.TensorTransformationParams(
|
|
249
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
|
|
250
|
+
consumers=[no_quant_tensor_params],
|
|
251
|
+
)
|
|
252
|
+
op_tensor_params.append(bias_transformation_params)
|
|
253
|
+
|
|
254
|
+
# Materialize output. Since static quantization is not supported, we do not
|
|
255
|
+
# quantize the output tensor.
|
|
256
|
+
output_tensor_index = 0
|
|
257
|
+
output_tensor = graph_info.subgraph_tensors[
|
|
258
|
+
op_info.op.outputs[output_tensor_index]
|
|
259
|
+
]
|
|
260
|
+
no_quant_tensor_params = qtyping.OpToTensorParams(
|
|
261
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
|
262
|
+
transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
|
|
263
|
+
)
|
|
264
|
+
output_transformation_params = qtyping.TensorTransformationParams(
|
|
265
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
|
|
266
|
+
producer=no_quant_tensor_params,
|
|
267
|
+
)
|
|
268
|
+
op_tensor_params.append(output_transformation_params)
|
|
269
|
+
|
|
270
|
+
return op_tensor_params
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def materialize_fully_connected_custom_op(
|
|
274
|
+
op_info: qtyping.OpInfo,
|
|
275
|
+
graph_info: qtyping.GraphInfo,
|
|
276
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
277
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
278
|
+
return _materialize_fully_connected(
|
|
279
|
+
op_info,
|
|
280
|
+
graph_info,
|
|
281
|
+
is_decomposed=False,
|
|
282
|
+
tensor_name_to_qsv=tensor_name_to_qsv,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def materialize_fully_connected_decomposed(
|
|
287
|
+
op_info: qtyping.OpInfo,
|
|
288
|
+
graph_info: qtyping.GraphInfo,
|
|
289
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
290
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
291
|
+
return _materialize_fully_connected(
|
|
292
|
+
op_info,
|
|
293
|
+
graph_info,
|
|
294
|
+
is_decomposed=True,
|
|
295
|
+
tensor_name_to_qsv=tensor_name_to_qsv,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _materialize_embedding_lookup(
|
|
300
|
+
op_info: qtyping.OpInfo,
|
|
301
|
+
graph_info: qtyping.GraphInfo,
|
|
302
|
+
is_decomposed: bool = False,
|
|
303
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
304
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
305
|
+
"""Materialize the embedding_lookup op.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
309
|
+
graph_info: Graph information needed to perform quantization for the op.
|
|
310
|
+
is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
|
|
311
|
+
op.
|
|
312
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Quantization configuration for the tensors associated with the op (e.g.,
|
|
316
|
+
weights, bias).
|
|
317
|
+
"""
|
|
318
|
+
op_tensor_params = []
|
|
319
|
+
|
|
320
|
+
# Materialize lookup.
|
|
321
|
+
lookup_tensor_index = 0
|
|
322
|
+
lookup_tensor = graph_info.subgraph_tensors[
|
|
323
|
+
op_info.op.inputs[lookup_tensor_index]
|
|
324
|
+
]
|
|
325
|
+
transformations = [
|
|
326
|
+
qtyping.QuantTransformation.NO_QUANTIZE,
|
|
327
|
+
]
|
|
328
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
|
329
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
|
330
|
+
parameters=None,
|
|
331
|
+
transformations=transformations,
|
|
332
|
+
)
|
|
333
|
+
lookup_transformation_params = qtyping.TensorTransformationParams(
|
|
334
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(lookup_tensor),
|
|
335
|
+
consumers=[op2tensor_params],
|
|
336
|
+
)
|
|
337
|
+
op_tensor_params.append(lookup_transformation_params)
|
|
338
|
+
|
|
339
|
+
# Materialize embedding. The embedding table should be rotated and then
|
|
340
|
+
# quantized.
|
|
341
|
+
embedding_tensor_index = 1
|
|
342
|
+
embedding_tensor = graph_info.subgraph_tensors[
|
|
343
|
+
op_info.op.inputs[embedding_tensor_index]
|
|
344
|
+
]
|
|
345
|
+
tensor_data = tfl_flatbuffer_utils.get_tensor_data(
|
|
346
|
+
embedding_tensor, graph_info.buffers
|
|
347
|
+
)
|
|
348
|
+
quant_params = get_tensor_quant_params(
|
|
349
|
+
op_info,
|
|
350
|
+
op_info.op_quant_config.weight_tensor_config,
|
|
351
|
+
tensor_data,
|
|
352
|
+
None,
|
|
353
|
+
)
|
|
354
|
+
transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
|
|
355
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
|
356
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
|
357
|
+
parameters=quant_params,
|
|
358
|
+
transformations=transformations,
|
|
359
|
+
)
|
|
360
|
+
weight_transformation_params = qtyping.TensorTransformationParams(
|
|
361
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(embedding_tensor),
|
|
362
|
+
consumers=[op2tensor_params],
|
|
363
|
+
)
|
|
364
|
+
op_tensor_params.append(weight_transformation_params)
|
|
365
|
+
|
|
366
|
+
# Materialize output. A hadamard rotation op should be inserted on the output
|
|
367
|
+
# tensor to do the inverse of the embedding's transformation.
|
|
368
|
+
output_tensor_index = 0
|
|
369
|
+
output_tensor = graph_info.subgraph_tensors[
|
|
370
|
+
op_info.op.outputs[output_tensor_index]
|
|
371
|
+
]
|
|
372
|
+
transformations = [
|
|
373
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
|
|
374
|
+
if is_decomposed
|
|
375
|
+
else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
|
376
|
+
]
|
|
377
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
|
378
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
|
379
|
+
parameters=quant_params,
|
|
380
|
+
transformations=transformations,
|
|
381
|
+
)
|
|
382
|
+
output_transformation_params = qtyping.TensorTransformationParams(
|
|
383
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
|
|
384
|
+
producer=op2tensor_params,
|
|
385
|
+
)
|
|
386
|
+
op_tensor_params.append(output_transformation_params)
|
|
387
|
+
|
|
388
|
+
return op_tensor_params
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def materialize_embedding_lookup_custom_op(
|
|
392
|
+
op_info: qtyping.OpInfo,
|
|
393
|
+
graph_info: qtyping.GraphInfo,
|
|
394
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
395
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
396
|
+
return _materialize_embedding_lookup(
|
|
397
|
+
op_info,
|
|
398
|
+
graph_info,
|
|
399
|
+
is_decomposed=False,
|
|
400
|
+
tensor_name_to_qsv=tensor_name_to_qsv,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def materialize_embedding_lookup_decomposed(
|
|
405
|
+
op_info: qtyping.OpInfo,
|
|
406
|
+
graph_info: qtyping.GraphInfo,
|
|
407
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
408
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
409
|
+
return _materialize_embedding_lookup(
|
|
410
|
+
op_info,
|
|
411
|
+
graph_info,
|
|
412
|
+
is_decomposed=True,
|
|
413
|
+
tensor_name_to_qsv=tensor_name_to_qsv,
|
|
414
|
+
)
|