ai-edge-quantizer-nightly 0.1.0.dev20250512__py3-none-any.whl → 0.1.0.dev20250514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +34 -0
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +37 -12
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +3 -5
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +357 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +265 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +7 -31
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +27 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +93 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +133 -3
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +11 -2
- ai_edge_quantizer/algorithms/utils/common_utils.py +21 -8
- ai_edge_quantizer/default_policy.py +4 -2
- ai_edge_quantizer/params_generator.py +1 -0
- ai_edge_quantizer/qtyping.py +34 -1
- ai_edge_quantizer/transformation_performer.py +5 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +209 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/utils/test_utils.py +33 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/RECORD +25 -21
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ from ai_edge_quantizer import qtyping
|
|
24
24
|
from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
|
25
25
|
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
|
26
26
|
from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
|
27
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
|
27
28
|
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
|
28
29
|
from ai_edge_quantizer.algorithms.uniform_quantize import octav
|
29
30
|
|
@@ -58,6 +59,8 @@ class AlgorithmName(str, enum.Enum):
|
|
58
59
|
FLOAT_CASTING = float_casting.ALGORITHM_KEY
|
59
60
|
DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
|
60
61
|
OCTAV = octav.ALGORITHM_KEY
|
62
|
+
HADAMARD_ROTATION = hadamard_rotation.ALGORITHM_KEY
|
63
|
+
|
61
64
|
|
62
65
|
### MIN/MAX_UNIFORM_QUANT ###
|
63
66
|
|
@@ -104,6 +107,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
|
|
104
107
|
common_quantize.materialize_dynamic_update_slice
|
105
108
|
),
|
106
109
|
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
|
110
|
+
_TFLOpName.PAD: common_quantize.materialize_pad,
|
107
111
|
}
|
108
112
|
for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
109
113
|
register_quantized_op(
|
@@ -237,6 +241,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
237
241
|
common_quantize.materialize_dynamic_update_slice
|
238
242
|
),
|
239
243
|
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
|
244
|
+
_TFLOpName.PAD: common_quantize.materialize_pad,
|
240
245
|
})
|
241
246
|
|
242
247
|
for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
@@ -250,3 +255,32 @@ for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
250
255
|
octav.get_tensor_quant_params,
|
251
256
|
),
|
252
257
|
)
|
258
|
+
|
259
|
+
# Register the Hadamard Rotation algorithm.
|
260
|
+
register_op_quant_config_validation_func(
|
261
|
+
AlgorithmName.HADAMARD_ROTATION,
|
262
|
+
common_quantize.check_op_quantization_config,
|
263
|
+
)
|
264
|
+
|
265
|
+
# Register a config check policy for the Hadamard Rotation algorithm.
|
266
|
+
register_config_check_policy_func(
|
267
|
+
AlgorithmName.HADAMARD_ROTATION,
|
268
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
269
|
+
)
|
270
|
+
|
271
|
+
# Register specialized hadamard rotation materialize functions.
|
272
|
+
_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
273
|
+
_TFLOpName.FULLY_CONNECTED: hadamard_rotation.materialize_fully_connected,
|
274
|
+
_TFLOpName.EMBEDDING_LOOKUP: hadamard_rotation.materialize_embedding_lookup,
|
275
|
+
})
|
276
|
+
for (
|
277
|
+
op_name,
|
278
|
+
materialize_func,
|
279
|
+
) in _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
280
|
+
register_quantized_op(
|
281
|
+
AlgorithmName.HADAMARD_ROTATION,
|
282
|
+
op_name,
|
283
|
+
naive_min_max_quantize.init_qsvs,
|
284
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
285
|
+
materialize_func=materialize_func,
|
286
|
+
)
|
@@ -680,6 +680,23 @@ def materialize_split(
|
|
680
680
|
)
|
681
681
|
|
682
682
|
|
683
|
+
def materialize_pad(
|
684
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
685
|
+
op_info: qtyping.OpInfo,
|
686
|
+
graph_info: qtyping.GraphInfo,
|
687
|
+
tensor_name_to_qsv: dict[str, Any],
|
688
|
+
) -> list[qtyping.TensorTransformationParams]:
|
689
|
+
"""Materialize tensors in tfl.pad."""
|
690
|
+
return common_utils.materialize_standard_op(
|
691
|
+
op_info,
|
692
|
+
graph_info,
|
693
|
+
tensor_name_to_qsv,
|
694
|
+
get_tensor_quant_params_fn,
|
695
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
696
|
+
inputs_to_ignore=[1], # Padding value does not need to be quantized.
|
697
|
+
)
|
698
|
+
|
699
|
+
|
683
700
|
def _get_tensor_shape_for_blockwise(
|
684
701
|
tensor_shape: Sequence[int], quantized_dim: int, block_size: int
|
685
702
|
) -> list[int]:
|
@@ -709,18 +726,29 @@ def _get_tensor_shape_for_blockwise(
|
|
709
726
|
|
710
727
|
|
711
728
|
def _reshape_data_for_blockwise(
|
712
|
-
tensor_data: np.ndarray,
|
729
|
+
tensor_data: np.ndarray,
|
730
|
+
quantized_dim: int,
|
731
|
+
block_size: int,
|
713
732
|
) -> tuple[np.ndarray, int]:
|
714
733
|
"""Reshapes data for blockwise quantization.
|
715
734
|
|
716
735
|
Args:
|
717
736
|
tensor_data: The original tensor data.
|
718
737
|
quantized_dim: The dimension to be quantized blockwise.
|
719
|
-
block_size: The size of the block.
|
738
|
+
block_size: The size of the block. `block_size must be a multiple of 32. `
|
739
|
+
`The tensor quantized dimension shape must be divisible by block_size.
|
720
740
|
|
721
741
|
Returns:
|
722
742
|
A tuple containing the reshaped tensor data and the new reduce dimension.
|
723
743
|
"""
|
744
|
+
|
745
|
+
# TODO: b/417508018 - create AEQ specific error class instead of
|
746
|
+
# using generic ValueError.
|
747
|
+
if tensor_data.shape[quantized_dim] % block_size != 0:
|
748
|
+
raise ValueError(
|
749
|
+
"Tensor quantization dimension must be divisible by block size for"
|
750
|
+
" blockwise quantization."
|
751
|
+
)
|
724
752
|
new_shape = _get_tensor_shape_for_blockwise(
|
725
753
|
tensor_data.shape, quantized_dim, block_size
|
726
754
|
)
|
@@ -801,22 +829,19 @@ def init_tensor_min_max(
|
|
801
829
|
weight_tensor_config.granularity == qtyping.QuantGranularity.CHANNELWISE
|
802
830
|
):
|
803
831
|
quantized_dim = common_utils.get_weight_quantized_dim(
|
804
|
-
op_info, tensor_data
|
832
|
+
op_info, tensor_data, weight_tensor_config.granularity
|
805
833
|
)
|
806
834
|
if (
|
807
835
|
weight_tensor_config is not None
|
808
836
|
and weight_tensor_config.granularity
|
809
837
|
== qtyping.QuantGranularity.BLOCKWISE
|
810
838
|
):
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
tensor_data,
|
818
|
-
quantized_dim,
|
819
|
-
weight_tensor_config.block_size,
|
839
|
+
reshaped_data, reduce_dims = (
|
840
|
+
uniform_quantize_tensor.reshape_data_for_blockwise(
|
841
|
+
tensor_data,
|
842
|
+
op_info.op_name,
|
843
|
+
weight_tensor_config.block_size,
|
844
|
+
)
|
820
845
|
)
|
821
846
|
return {
|
822
847
|
"min": np.min(reshaped_data, axis=reduce_dims, keepdims=False),
|
@@ -31,8 +31,7 @@ _TensorQuantConfig = qtyping.TensorQuantizationConfig
|
|
31
31
|
|
32
32
|
|
33
33
|
class CommonQuantizeTest(parameterized.TestCase):
|
34
|
-
"""Tests for general quantize functions.
|
35
|
-
"""
|
34
|
+
"""Tests for general quantize functions."""
|
36
35
|
|
37
36
|
def setUp(self):
|
38
37
|
super().setUp()
|
@@ -69,6 +68,34 @@ class CommonQuantizeTest(parameterized.TestCase):
|
|
69
68
|
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
70
69
|
)
|
71
70
|
|
71
|
+
def test_reshape_data_for_blockwise_raises_error_when_quantized_dim_not_divisible_by_block_size(
|
72
|
+
self,
|
73
|
+
):
|
74
|
+
tensor_data = np.ones((24, 128), dtype=np.float32)
|
75
|
+
block_size = 256
|
76
|
+
quantized_dim = 1
|
77
|
+
with self.assertRaisesWithPredicateMatch(
|
78
|
+
ValueError,
|
79
|
+
lambda err: (
|
80
|
+
"Tensor quantization dimension must be divisible by block"
|
81
|
+
" size for blockwise quantization."
|
82
|
+
)
|
83
|
+
in str(err),
|
84
|
+
):
|
85
|
+
common_quantize._reshape_data_for_blockwise(
|
86
|
+
tensor_data, quantized_dim, block_size
|
87
|
+
)
|
88
|
+
|
89
|
+
def test_reshape_data_for_blockwise_returns_correct_values(self):
|
90
|
+
tensor_data = np.ones((24, 128), dtype=np.float32)
|
91
|
+
block_size = 32
|
92
|
+
quantized_dim = 1
|
93
|
+
new_tensor_data, reduce_dim = common_quantize._reshape_data_for_blockwise(
|
94
|
+
tensor_data, quantized_dim, block_size
|
95
|
+
)
|
96
|
+
self.assertEqual(new_tensor_data.shape, (24, 4, 32))
|
97
|
+
self.assertEqual(reduce_dim, 2)
|
98
|
+
|
72
99
|
|
73
100
|
if __name__ == "__main__":
|
74
101
|
googletest.main()
|
@@ -168,11 +168,9 @@ def get_tensor_quant_params(
|
|
168
168
|
"Only symmetric weights are supported for dequantized weight recovery."
|
169
169
|
)
|
170
170
|
|
171
|
-
quantized_dim =
|
172
|
-
|
173
|
-
|
174
|
-
op_info, tensor_content
|
175
|
-
)
|
171
|
+
quantized_dim = common_utils.get_weight_quantized_dim(
|
172
|
+
op_info, tensor_content, tensor_quant_config.granularity
|
173
|
+
)
|
176
174
|
|
177
175
|
zp, scale = get_zp_scale_from_dequantized_symmetric_weights(
|
178
176
|
dequant_vals=tensor_content,
|
@@ -0,0 +1,357 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Implements the Hadamard Rotation quantization."""
|
17
|
+
|
18
|
+
from typing import Any, Optional
|
19
|
+
import numpy as np
|
20
|
+
from ai_edge_quantizer import qtyping
|
21
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import octav
|
22
|
+
from ai_edge_quantizer.algorithms.utils import common_utils
|
23
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
24
|
+
|
25
|
+
|
26
|
+
ALGORITHM_KEY = "HADAMARD_ROTATION"
|
27
|
+
|
28
|
+
|
29
|
+
def _make_hadamard_matrix(size: int) -> np.ndarray:
|
30
|
+
"""Generates a Hadamard matrix of the given size.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
size: The size of the Hadamard matrix. Must be a power of 2. This
|
34
|
+
represents a single dimension. E.g. if size is 4, then the Hadamard matrix
|
35
|
+
is a 4x4 matrix.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
The Hadamard matrix.
|
39
|
+
|
40
|
+
Raises:
|
41
|
+
ValueError: If the size is not a power of 2.
|
42
|
+
"""
|
43
|
+
if size <= 0 or (size & (size - 1)) != 0:
|
44
|
+
raise ValueError("Hadamard matrix size must be a power of 2. ")
|
45
|
+
h = h2 = np.array([[1, 1], [1, -1]])
|
46
|
+
current_size = 2
|
47
|
+
while current_size < size:
|
48
|
+
h = np.kron(h, h2)
|
49
|
+
current_size *= 2
|
50
|
+
return h / np.sqrt(size)
|
51
|
+
|
52
|
+
|
53
|
+
def _rotate_with_diagonal_hadamard(
|
54
|
+
tensor_content: np.ndarray,
|
55
|
+
axis: int,
|
56
|
+
):
|
57
|
+
"""Quantizes the given float array using the diagonal Hadamard algorithm.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
tensor_content: The float array to quantize.
|
61
|
+
axis: The axis of the tensor to quantize.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
A tuple containing the quantized array and the recovered array.
|
65
|
+
|
66
|
+
Raises:
|
67
|
+
ValueError: If the axis is not 1. To support other axes, please add
|
68
|
+
support to the matrix multiplication.
|
69
|
+
"""
|
70
|
+
if axis != 1:
|
71
|
+
raise ValueError(
|
72
|
+
"Hadamard rotation is only supported for 2D tensors with quantized"
|
73
|
+
" dimension 0."
|
74
|
+
)
|
75
|
+
|
76
|
+
# Use the largest power of 2 that is a factor of the dimension and then
|
77
|
+
# tile this Hadamard matrix along the diagonal. 2**30 is just a large power
|
78
|
+
# of 2 to calculate this factor.
|
79
|
+
hadamard_size = np.gcd(tensor_content.shape[axis], 2 ** 30)
|
80
|
+
diagonal_size = tensor_content.shape[axis] // hadamard_size
|
81
|
+
output_size = tensor_content.shape[1 - axis]
|
82
|
+
random_vector = np.ones(hadamard_size, dtype=np.int8)
|
83
|
+
|
84
|
+
# Use a canonical Hadamard matrix.
|
85
|
+
hadamard = _make_hadamard_matrix(hadamard_size)
|
86
|
+
reshaped_tensor = tensor_content.reshape(
|
87
|
+
diagonal_size, output_size, hadamard_size
|
88
|
+
)
|
89
|
+
w_rotated = np.einsum("jk,ilk->ilj", hadamard, reshaped_tensor)
|
90
|
+
return w_rotated.reshape(tensor_content.shape), hadamard_size, random_vector
|
91
|
+
|
92
|
+
|
93
|
+
def get_tensor_quant_params(
|
94
|
+
op_info: qtyping.OpInfo,
|
95
|
+
tensor_quant_config: qtyping.TensorQuantizationConfig,
|
96
|
+
tensor_content: Optional[np.ndarray] = None,
|
97
|
+
tensor_qsv: Optional[dict[str, Any]] = None,
|
98
|
+
) -> qtyping.UniformQuantParams:
|
99
|
+
"""Returns the quantization parameters for a tensor.
|
100
|
+
|
101
|
+
This function will rotate the tensor with a Hadamard matrix and then
|
102
|
+
quantize it with OCTAV.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
106
|
+
tensor_quant_config: The quantization config for the tensor.
|
107
|
+
tensor_content: The content of the tensor. When None, it means the tensor is
|
108
|
+
not a weight tensor (e.g. static quantization).
|
109
|
+
tensor_qsv: A dictionary containing the min/max of the tensor.
|
110
|
+
|
111
|
+
Raises:
|
112
|
+
ValueError: If the blockwise quantization is requested.
|
113
|
+
ValueError: If the asymmetric quantization is requested.
|
114
|
+
ValueError: `tensor_qsv` must contain min/max values, or `tensor_content`
|
115
|
+
must be provided so that they can be inferred.
|
116
|
+
"""
|
117
|
+
if tensor_content is None:
|
118
|
+
raise ValueError("Hadamard rotation is only supported for weight tensors.")
|
119
|
+
|
120
|
+
if tensor_qsv is not None:
|
121
|
+
raise ValueError(
|
122
|
+
"Hadamard rotation is not supported for static quantization."
|
123
|
+
)
|
124
|
+
|
125
|
+
if tensor_content.ndim != 2:
|
126
|
+
raise ValueError("Hadamard rotation is only supported for 2D tensors.")
|
127
|
+
|
128
|
+
if tensor_quant_config.granularity != qtyping.QuantGranularity.CHANNELWISE:
|
129
|
+
raise ValueError(
|
130
|
+
"Hadamard rotation is not supported for"
|
131
|
+
f" {tensor_quant_config.granularity} granularity."
|
132
|
+
)
|
133
|
+
|
134
|
+
quantized_dim = common_utils.get_weight_quantized_dim(
|
135
|
+
op_info, tensor_content, tensor_quant_config.granularity
|
136
|
+
)
|
137
|
+
if quantized_dim != 0:
|
138
|
+
raise ValueError(
|
139
|
+
f"Unsupported quantized dimension: {quantized_dim}. Only 0 is"
|
140
|
+
" supported."
|
141
|
+
)
|
142
|
+
|
143
|
+
# Reduction axis is the non-quantized dimension. Since we only support 2D
|
144
|
+
# tensors and quantized_dim of 0, the reduction axis is 1.
|
145
|
+
reduce_axis = 1
|
146
|
+
|
147
|
+
# Rotate the tensor with a Hadamard matrix.
|
148
|
+
w_rotated, hadamard_size, random_vector = _rotate_with_diagonal_hadamard(
|
149
|
+
tensor_content, axis=reduce_axis
|
150
|
+
)
|
151
|
+
|
152
|
+
# Get the quantized values of the rotated tensor.
|
153
|
+
qparams = octav.get_tensor_quant_params(
|
154
|
+
op_info, tensor_quant_config, w_rotated, tensor_qsv
|
155
|
+
)
|
156
|
+
|
157
|
+
return qtyping.UniformQuantParams(
|
158
|
+
quantized_dimension=qparams.quantized_dimension,
|
159
|
+
num_bits=qparams.num_bits,
|
160
|
+
scale=qparams.scale,
|
161
|
+
zero_point=qparams.zero_point,
|
162
|
+
symmetric=qparams.symmetric,
|
163
|
+
quantized_data=qparams.quantized_data,
|
164
|
+
block_size=qparams.block_size,
|
165
|
+
hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
|
166
|
+
random_binary_vector=random_vector,
|
167
|
+
hadamard_size=hadamard_size,
|
168
|
+
),
|
169
|
+
)
|
170
|
+
|
171
|
+
|
172
|
+
def materialize_fully_connected(
|
173
|
+
op_info: qtyping.OpInfo,
|
174
|
+
graph_info: qtyping.GraphInfo,
|
175
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
176
|
+
) -> list[qtyping.TensorTransformationParams]:
|
177
|
+
"""Materialize the fully_connected op.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
181
|
+
graph_info: Graph information needed to perform quantization for the op.
|
182
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
Quantization configuration for the tensors associated with the op (e.g.,
|
186
|
+
weights, bias).
|
187
|
+
"""
|
188
|
+
op_tensor_params = []
|
189
|
+
|
190
|
+
# Materialize weight.
|
191
|
+
weight_tensor_index = 1
|
192
|
+
weight_tensor = graph_info.subgraph_tensors[
|
193
|
+
op_info.op.inputs[weight_tensor_index]
|
194
|
+
]
|
195
|
+
tensor_data = tfl_flatbuffer_utils.get_tensor_data(
|
196
|
+
weight_tensor, graph_info.buffers
|
197
|
+
)
|
198
|
+
# quant_params contains the rotated and quantized weights done by
|
199
|
+
# get_tensor_quant_params().
|
200
|
+
quant_params = get_tensor_quant_params(
|
201
|
+
op_info,
|
202
|
+
op_info.op_quant_config.weight_tensor_config,
|
203
|
+
tensor_data,
|
204
|
+
None,
|
205
|
+
)
|
206
|
+
transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
|
207
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
208
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
209
|
+
parameters=quant_params,
|
210
|
+
transformations=transformations,
|
211
|
+
)
|
212
|
+
weight_transformation_params = qtyping.TensorTransformationParams(
|
213
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
|
214
|
+
consumers=[op2tensor_params],
|
215
|
+
)
|
216
|
+
|
217
|
+
# Materialize input. A hadamard rotation op should be inserted on the input
|
218
|
+
# tensor to do the inverse of the weight's transformation.
|
219
|
+
input_tensor_index = 0
|
220
|
+
input_tensor = graph_info.subgraph_tensors[
|
221
|
+
op_info.op.inputs[input_tensor_index]
|
222
|
+
]
|
223
|
+
transformations = [
|
224
|
+
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
225
|
+
]
|
226
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
227
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
228
|
+
parameters=quant_params,
|
229
|
+
transformations=transformations,
|
230
|
+
)
|
231
|
+
input_transformation_params = qtyping.TensorTransformationParams(
|
232
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(input_tensor),
|
233
|
+
consumers=[op2tensor_params],
|
234
|
+
)
|
235
|
+
op_tensor_params.append(input_transformation_params)
|
236
|
+
op_tensor_params.append(weight_transformation_params)
|
237
|
+
|
238
|
+
# Materialize bias. Since static quantization is not supported, we do not
|
239
|
+
# quantize the bias tensor.
|
240
|
+
bias_tensor_index = 2
|
241
|
+
bias_tensor = graph_info.subgraph_tensors[
|
242
|
+
op_info.op.inputs[bias_tensor_index]
|
243
|
+
]
|
244
|
+
no_quant_tensor_params = qtyping.OpToTensorParams(
|
245
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
246
|
+
transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
|
247
|
+
)
|
248
|
+
bias_transformation_params = qtyping.TensorTransformationParams(
|
249
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
|
250
|
+
consumers=[no_quant_tensor_params],
|
251
|
+
)
|
252
|
+
op_tensor_params.append(bias_transformation_params)
|
253
|
+
|
254
|
+
# Materialize output. Since static quantization is not supported, we do not
|
255
|
+
# quantize the output tensor.
|
256
|
+
output_tensor_index = 0
|
257
|
+
output_tensor = graph_info.subgraph_tensors[
|
258
|
+
op_info.op.outputs[output_tensor_index]
|
259
|
+
]
|
260
|
+
no_quant_tensor_params = qtyping.OpToTensorParams(
|
261
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
262
|
+
transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
|
263
|
+
)
|
264
|
+
output_transformation_params = qtyping.TensorTransformationParams(
|
265
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
|
266
|
+
producer=no_quant_tensor_params,
|
267
|
+
)
|
268
|
+
op_tensor_params.append(output_transformation_params)
|
269
|
+
|
270
|
+
return op_tensor_params
|
271
|
+
|
272
|
+
|
273
|
+
def materialize_embedding_lookup(
|
274
|
+
op_info: qtyping.OpInfo,
|
275
|
+
graph_info: qtyping.GraphInfo,
|
276
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
277
|
+
) -> list[qtyping.TensorTransformationParams]:
|
278
|
+
"""Materialize the embedding_lookup op.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
282
|
+
graph_info: Graph information needed to perform quantization for the op.
|
283
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
Quantization configuration for the tensors associated with the op (e.g.,
|
287
|
+
weights, bias).
|
288
|
+
"""
|
289
|
+
op_tensor_params = []
|
290
|
+
|
291
|
+
# Materialize lookup.
|
292
|
+
lookup_tensor_index = 0
|
293
|
+
lookup_tensor = graph_info.subgraph_tensors[
|
294
|
+
op_info.op.inputs[lookup_tensor_index]
|
295
|
+
]
|
296
|
+
transformations = [
|
297
|
+
qtyping.QuantTransformation.NO_QUANTIZE,
|
298
|
+
]
|
299
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
300
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
301
|
+
parameters=None,
|
302
|
+
transformations=transformations,
|
303
|
+
)
|
304
|
+
lookup_transformation_params = qtyping.TensorTransformationParams(
|
305
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(lookup_tensor),
|
306
|
+
consumers=[op2tensor_params],
|
307
|
+
)
|
308
|
+
op_tensor_params.append(lookup_transformation_params)
|
309
|
+
|
310
|
+
# Materialize embedding. The embedding table should be rotated and then
|
311
|
+
# quantized.
|
312
|
+
embedding_tensor_index = 1
|
313
|
+
embedding_tensor = graph_info.subgraph_tensors[
|
314
|
+
op_info.op.inputs[embedding_tensor_index]
|
315
|
+
]
|
316
|
+
tensor_data = tfl_flatbuffer_utils.get_tensor_data(
|
317
|
+
embedding_tensor, graph_info.buffers
|
318
|
+
)
|
319
|
+
quant_params = get_tensor_quant_params(
|
320
|
+
op_info,
|
321
|
+
op_info.op_quant_config.weight_tensor_config,
|
322
|
+
tensor_data,
|
323
|
+
None,
|
324
|
+
)
|
325
|
+
transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
|
326
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
327
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
328
|
+
parameters=quant_params,
|
329
|
+
transformations=transformations,
|
330
|
+
)
|
331
|
+
weight_transformation_params = qtyping.TensorTransformationParams(
|
332
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(embedding_tensor),
|
333
|
+
consumers=[op2tensor_params],
|
334
|
+
)
|
335
|
+
op_tensor_params.append(weight_transformation_params)
|
336
|
+
|
337
|
+
# Materialize output. A hadamard rotation op should be inserted on the output
|
338
|
+
# tensor to do the inverse of the embedding's transformation.
|
339
|
+
output_tensor_index = 0
|
340
|
+
output_tensor = graph_info.subgraph_tensors[
|
341
|
+
op_info.op.outputs[output_tensor_index]
|
342
|
+
]
|
343
|
+
transformations = [
|
344
|
+
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
345
|
+
]
|
346
|
+
op2tensor_params = qtyping.OpToTensorParams(
|
347
|
+
subgraph_op_id=op_info.subgraph_op_index,
|
348
|
+
parameters=quant_params,
|
349
|
+
transformations=transformations,
|
350
|
+
)
|
351
|
+
output_transformation_params = qtyping.TensorTransformationParams(
|
352
|
+
tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
|
353
|
+
producer=op2tensor_params,
|
354
|
+
)
|
355
|
+
op_tensor_params.append(output_transformation_params)
|
356
|
+
|
357
|
+
return op_tensor_params
|