ai-edge-quantizer-nightly 0.1.0.dev20250512__py3-none-any.whl → 0.1.0.dev20250514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +34 -0
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +37 -12
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +3 -5
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +357 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +265 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +7 -31
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +27 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +93 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +133 -3
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +11 -2
- ai_edge_quantizer/algorithms/utils/common_utils.py +21 -8
- ai_edge_quantizer/default_policy.py +4 -2
- ai_edge_quantizer/params_generator.py +1 -0
- ai_edge_quantizer/qtyping.py +34 -1
- ai_edge_quantizer/transformation_performer.py +5 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +209 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/utils/test_utils.py +33 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/RECORD +25 -21
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,11 @@
|
|
16
16
|
"""Uniform quantize in tensor level."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import Optional
|
19
|
+
from typing import Optional, Sequence
|
20
|
+
import ml_dtypes
|
20
21
|
import numpy as np
|
21
22
|
from ai_edge_quantizer import qtyping
|
23
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
22
24
|
|
23
25
|
|
24
26
|
@dataclasses.dataclass(frozen=True)
|
@@ -120,19 +122,127 @@ def fix_quantization_params_rank(
|
|
120
122
|
)
|
121
123
|
|
122
124
|
|
125
|
+
def _get_tensor_shape_for_blockwise(
|
126
|
+
tensor_shape: Sequence[int], quantized_dim: int, block_size: int
|
127
|
+
) -> list[int]:
|
128
|
+
"""Get the tensor shape for blockwise quantization.
|
129
|
+
|
130
|
+
This function splits the quantize dimension of the tensor into blocks and the
|
131
|
+
dim/blocks. Hence, min/max of the tensor can be calculated for each block
|
132
|
+
using existing functions.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
tensor_shape: The original shape of the tensor.
|
136
|
+
quantized_dim: The dimension to be quantized blockwise.
|
137
|
+
block_size: The size of the block.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
The new tensor shape for calculating scale and zp for blockwise
|
141
|
+
quantization.
|
142
|
+
"""
|
143
|
+
new_shape = []
|
144
|
+
for index, val in enumerate(tensor_shape):
|
145
|
+
if index == quantized_dim:
|
146
|
+
new_shape.append(int(val / block_size))
|
147
|
+
new_shape.append(block_size)
|
148
|
+
else:
|
149
|
+
new_shape.append(val)
|
150
|
+
return new_shape
|
151
|
+
|
152
|
+
|
153
|
+
def reshape_data_for_blockwise(
|
154
|
+
tensor_data: np.ndarray, op_name: qtyping.TFLOperationName, block_size: int
|
155
|
+
) -> tuple[np.ndarray, int]:
|
156
|
+
"""Reshapes data for blockwise quantization.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
tensor_data: The original tensor data.
|
160
|
+
op_name: The name of the TFL op.
|
161
|
+
block_size: The size of the block.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
A tuple containing the reshaped tensor data and the new reduce dimension.
|
165
|
+
"""
|
166
|
+
quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
|
167
|
+
op_name
|
168
|
+
]
|
169
|
+
new_shape = _get_tensor_shape_for_blockwise(
|
170
|
+
tensor_data.shape, quantized_dim, block_size
|
171
|
+
)
|
172
|
+
reshaped_data = tensor_data.reshape(new_shape)
|
173
|
+
return reshaped_data, quantized_dim + 1
|
174
|
+
|
175
|
+
|
176
|
+
def _broadcast_scale_zp_for_blockwise(
|
177
|
+
tensor_content: np.ndarray,
|
178
|
+
quant_params: qtyping.UniformQuantParams,
|
179
|
+
) -> qtyping.UniformQuantParams:
|
180
|
+
"""Broadcasts scale and zp for blockwise quantization.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
tensor_content: The original tensor data.
|
184
|
+
quant_params: The quantization parameters.
|
185
|
+
`quant_params.quantized_dimension` must be specified.
|
186
|
+
`quant_params.block_size` must be specified and positive.
|
187
|
+
|
188
|
+
Returns:
|
189
|
+
The updated quantization parameters with broadcasted scale and zp for
|
190
|
+
correct constant quantization.
|
191
|
+
"""
|
192
|
+
if quant_params.quantized_dimension is None:
|
193
|
+
raise ValueError("Quantized dimension must be specified.")
|
194
|
+
if quant_params.block_size is None or quant_params.block_size <= 0:
|
195
|
+
raise ValueError("Block size must be specified and positive.")
|
196
|
+
quantized_dim = quant_params.quantized_dimension
|
197
|
+
expanded_tensor_shape = _get_tensor_shape_for_blockwise(
|
198
|
+
tensor_content.shape, quantized_dim, quant_params.block_size
|
199
|
+
)
|
200
|
+
expanded_scale = np.reshape(
|
201
|
+
np.broadcast_to(
|
202
|
+
np.expand_dims(quant_params.scale, quantized_dim + 1),
|
203
|
+
expanded_tensor_shape,
|
204
|
+
),
|
205
|
+
tensor_content.shape,
|
206
|
+
)
|
207
|
+
expanded_zp = np.reshape(
|
208
|
+
np.broadcast_to(
|
209
|
+
np.expand_dims(quant_params.zero_point, quantized_dim + 1),
|
210
|
+
expanded_tensor_shape,
|
211
|
+
),
|
212
|
+
tensor_content.shape,
|
213
|
+
)
|
214
|
+
return qtyping.UniformQuantParams(
|
215
|
+
scale=expanded_scale,
|
216
|
+
zero_point=expanded_zp,
|
217
|
+
num_bits=quant_params.num_bits,
|
218
|
+
symmetric=quant_params.symmetric,
|
219
|
+
quantized_dimension=quantized_dim,
|
220
|
+
block_size=quant_params.block_size,
|
221
|
+
)
|
222
|
+
|
223
|
+
|
123
224
|
def uniform_quantize(
|
124
225
|
tensor_data: np.ndarray,
|
125
226
|
quantization_params: qtyping.UniformQuantParams,
|
227
|
+
is_blockwise: bool = False,
|
126
228
|
):
|
127
229
|
"""Uniform quantize a tensor.
|
128
230
|
|
129
231
|
Args:
|
130
232
|
tensor_data: The tensor to be quantized.
|
131
233
|
quantization_params: The quantization parameters.
|
234
|
+
is_blockwise: Whether the tensor is blockwise quantized.
|
132
235
|
|
133
236
|
Returns:
|
134
237
|
The quantized tensor.
|
135
238
|
"""
|
239
|
+
# The reshaping for blockwise quantization is unique hence we do this here
|
240
|
+
# to avoid unexpected broadcast behavior downstream.
|
241
|
+
if is_blockwise:
|
242
|
+
quantization_params = _broadcast_scale_zp_for_blockwise(
|
243
|
+
tensor_data, quantization_params
|
244
|
+
)
|
245
|
+
|
136
246
|
# quant params in flatbuffer is flattened, expand the rank to be the same
|
137
247
|
# as the tensor rank to avoid ambiguous broadcasting.
|
138
248
|
quantization_params = fix_quantization_params_rank(
|
@@ -242,15 +352,19 @@ def tensor_zp_scale_from_min_max(
|
|
242
352
|
max_value,
|
243
353
|
num_bits: int,
|
244
354
|
symmetric: bool,
|
355
|
+
granularity: qtyping.QuantGranularity,
|
245
356
|
clipping_values: Optional[np.ndarray] = None,
|
246
357
|
):
|
247
358
|
"""Get zero point and scale from min and max value.
|
248
359
|
|
249
360
|
Args:
|
250
|
-
min_value: The minimum value of the tensor (
|
251
|
-
|
361
|
+
min_value: The minimum value of the tensor (channelwise and blockwise
|
362
|
+
supported).
|
363
|
+
max_value: The maximum value of the tensor (channelwise and blockwise
|
364
|
+
supported).
|
252
365
|
num_bits: The number of bits of the tensor.
|
253
366
|
symmetric: Whether the tensor is symmetric.
|
367
|
+
granularity: The granularity of the tensor.
|
254
368
|
clipping_values: Absolute clipping values to apply to the tensor. This will
|
255
369
|
clip the tensors to the range [-clipping_values, clipping_values]. This
|
256
370
|
should be the same shape as min_value and max_value. If None, no clipping
|
@@ -267,6 +381,16 @@ def tensor_zp_scale_from_min_max(
|
|
267
381
|
qmin, qmax = get_quantized_range(qtype)
|
268
382
|
min_bound = 1e-4 # 1e-6 precision for int8 and 1e-8 for int16.
|
269
383
|
|
384
|
+
if granularity == qtyping.QuantGranularity.BLOCKWISE:
|
385
|
+
# Blockwise quantization uses float16 scale, with 7 bit mantissa,
|
386
|
+
# so the maximum representable value is 65280.
|
387
|
+
float16_max = np.broadcast_to(np.array(65280), min_value.shape)
|
388
|
+
clipping_values = (
|
389
|
+
float16_max
|
390
|
+
if clipping_values is None
|
391
|
+
else np.minimum(clipping_values, float16_max)
|
392
|
+
)
|
393
|
+
|
270
394
|
if symmetric:
|
271
395
|
bound = np.maximum(np.abs(min_value), np.abs(max_value))
|
272
396
|
bound = np.maximum(bound, min_bound)
|
@@ -292,6 +416,12 @@ def tensor_zp_scale_from_min_max(
|
|
292
416
|
zp = qmin - bound_min / scale
|
293
417
|
zp = np.rint(zp)
|
294
418
|
|
419
|
+
if granularity == qtyping.QuantGranularity.BLOCKWISE:
|
420
|
+
# Round the scale values to 7 bit mantissa.
|
421
|
+
scale = (
|
422
|
+
scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
|
423
|
+
)
|
424
|
+
|
295
425
|
# It's safe to cast zp to qtype without clipping because we can infer
|
296
426
|
# qmin <= zp <= qmax from bound_min <= 0 <= bound_max.
|
297
427
|
zp = assign_quantized_type(zp, qtype)
|
@@ -336,7 +336,11 @@ class TensorUtilsTest(parameterized.TestCase):
|
|
336
336
|
max_val = np.max(self._test_data, keepdims=True)
|
337
337
|
|
338
338
|
zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
|
339
|
-
min_val,
|
339
|
+
min_val,
|
340
|
+
max_val,
|
341
|
+
num_bits,
|
342
|
+
symmetric,
|
343
|
+
qtyping.QuantGranularity.TENSORWISE,
|
340
344
|
)
|
341
345
|
self.assertEqual(zp.shape, scale.shape)
|
342
346
|
max_q = 2**num_bits / 2 - 1
|
@@ -364,7 +368,12 @@ class TensorUtilsTest(parameterized.TestCase):
|
|
364
368
|
max_val = np.array([[5.0]])
|
365
369
|
clipping_values = np.array([4.0])
|
366
370
|
zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
|
367
|
-
min_val,
|
371
|
+
min_val,
|
372
|
+
max_val,
|
373
|
+
num_bits,
|
374
|
+
symmetric,
|
375
|
+
qtyping.QuantGranularity.TENSORWISE,
|
376
|
+
clipping_values,
|
368
377
|
)
|
369
378
|
expected_scale = clipping_values / quantized_bound
|
370
379
|
|
@@ -905,23 +905,36 @@ def get_tensor_transformation_params(
|
|
905
905
|
)
|
906
906
|
|
907
907
|
|
908
|
-
def get_weight_quantized_dim(
|
908
|
+
def get_weight_quantized_dim(
|
909
|
+
op_info: qtyping.OpInfo,
|
910
|
+
tensor_data: np.ndarray,
|
911
|
+
granularity: qtyping.QuantGranularity,
|
912
|
+
):
|
909
913
|
"""Get the quantized dimension for the weight tensor.
|
910
914
|
|
911
915
|
Args:
|
912
916
|
op_info: Aggregated information about the op (e.g., quantization config).
|
913
917
|
tensor_data: The weight tensor data.
|
918
|
+
granularity: The granularity of the weight tensor.
|
914
919
|
|
915
920
|
Returns:
|
916
921
|
The quantized dimension for the weight tensor.
|
917
922
|
"""
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
923
|
+
quantized_dim = None
|
924
|
+
if granularity == qtyping.QuantGranularity.CHANNELWISE:
|
925
|
+
if op_info.op_name == _TFLOpName.BATCH_MATMUL:
|
926
|
+
quantized_dim = get_bmm_weight_quantized_dim(
|
927
|
+
tensor_data, adj_y=op_info.op.builtinOptions.adjY
|
928
|
+
)
|
929
|
+
else:
|
930
|
+
quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
|
931
|
+
op_info.op_name, None
|
932
|
+
)
|
933
|
+
elif granularity == qtyping.QuantGranularity.BLOCKWISE:
|
934
|
+
quantized_dim = (
|
935
|
+
tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
|
936
|
+
op_info.op_name
|
937
|
+
]
|
925
938
|
)
|
926
939
|
return quantized_dim
|
927
940
|
|
@@ -183,7 +183,8 @@ DEFAULT_JSON_POLICY = """
|
|
183
183
|
"SELECT_V2",
|
184
184
|
"DYNAMIC_UPDATE_SLICE",
|
185
185
|
"SELECT_V2",
|
186
|
-
"STABLEHLO_COMPOSITE"
|
186
|
+
"STABLEHLO_COMPOSITE",
|
187
|
+
"PAD"
|
187
188
|
],
|
188
189
|
"static_wi8_ai8": [
|
189
190
|
"ADD",
|
@@ -214,7 +215,8 @@ DEFAULT_JSON_POLICY = """
|
|
214
215
|
"SELECT_V2",
|
215
216
|
"DYNAMIC_UPDATE_SLICE",
|
216
217
|
"SELECT_V2",
|
217
|
-
"STABLEHLO_COMPOSITE"
|
218
|
+
"STABLEHLO_COMPOSITE",
|
219
|
+
"PAD"
|
218
220
|
],
|
219
221
|
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
220
222
|
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
ai_edge_quantizer/qtyping.py
CHANGED
@@ -20,7 +20,7 @@ from collections.abc import MutableMapping
|
|
20
20
|
import copy
|
21
21
|
import dataclasses
|
22
22
|
import enum
|
23
|
-
from typing import Any, Optional, Union
|
23
|
+
from typing import Any, Callable, Optional, Union
|
24
24
|
|
25
25
|
import numpy as np
|
26
26
|
from typing_extensions import TypeAlias
|
@@ -62,6 +62,7 @@ class TFLOperationName(str, enum.Enum):
|
|
62
62
|
SELECT_V2 = 'SELECT_V2'
|
63
63
|
DYNAMIC_UPDATE_SLICE = 'DYNAMIC_UPDATE_SLICE'
|
64
64
|
STABLEHLO_COMPOSITE = 'STABLEHLO_COMPOSITE'
|
65
|
+
PAD = 'PAD'
|
65
66
|
|
66
67
|
|
67
68
|
class QuantizeMode(enum.Enum):
|
@@ -113,6 +114,8 @@ class QuantTransformation(enum.Enum):
|
|
113
114
|
DUPLICATE_BUFFER = 5
|
114
115
|
# Duplicate the tensor.
|
115
116
|
DUPLICATE_TENSOR = 6
|
117
|
+
# Insert the aeq.hadamard_rotation op.
|
118
|
+
INSERT_HADAMARD_ROTATION = 7
|
116
119
|
|
117
120
|
|
118
121
|
@dataclasses.dataclass(frozen=True)
|
@@ -128,8 +131,35 @@ class UniformQuantParams:
|
|
128
131
|
quantized_data: The quantized data.
|
129
132
|
block_size: The block size for blockwise quantization, block_size=0 meaning
|
130
133
|
no blockwise quantization.
|
134
|
+
hadamard: The Hadamard rotation parameters, if set.
|
131
135
|
"""
|
132
136
|
|
137
|
+
class HadamardRotationParams:
|
138
|
+
"""Parameters for the Hadamard rotation.
|
139
|
+
|
140
|
+
Attributes:
|
141
|
+
random_binary_vector: The random binary vector for the Hadamard rotation.
|
142
|
+
TODO(b/415392354): Randomization is an experimental feature that's
|
143
|
+
currently not implemented yet hence this is always 1. We will add
|
144
|
+
support or remove in the future.
|
145
|
+
hadamard_size: The size of the Hadamard matrix.
|
146
|
+
"""
|
147
|
+
|
148
|
+
random_binary_vector: np.ndarray
|
149
|
+
hadamard_size: int
|
150
|
+
|
151
|
+
def __init__(self, random_binary_vector: np.ndarray, hadamard_size: int):
|
152
|
+
self.random_binary_vector = random_binary_vector
|
153
|
+
self.hadamard_size = hadamard_size
|
154
|
+
|
155
|
+
def __eq__(self, other):
|
156
|
+
if other.__class__ is not self.__class__:
|
157
|
+
return NotImplemented
|
158
|
+
return (
|
159
|
+
np.array_equal(self.random_binary_vector, other.random_binary_vector)
|
160
|
+
and self.hadamard_size == other.hadamard_size
|
161
|
+
)
|
162
|
+
|
133
163
|
num_bits: int
|
134
164
|
quantized_dimension: Optional[int]
|
135
165
|
scale: np.ndarray
|
@@ -137,6 +167,7 @@ class UniformQuantParams:
|
|
137
167
|
symmetric: bool = True
|
138
168
|
quantized_data: Optional[np.ndarray] = None
|
139
169
|
block_size: int = 0
|
170
|
+
hadamard: Optional[HadamardRotationParams] = None
|
140
171
|
|
141
172
|
@classmethod
|
142
173
|
def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams':
|
@@ -180,6 +211,7 @@ class UniformQuantParams:
|
|
180
211
|
and self.symmetric == other.symmetric
|
181
212
|
and _compare_array_or_none(self.quantized_data, other.quantized_data)
|
182
213
|
and self.block_size == other.block_size
|
214
|
+
and self.hadamard == other.hadamard
|
183
215
|
)
|
184
216
|
|
185
217
|
|
@@ -492,6 +524,7 @@ class IOOperator:
|
|
492
524
|
outputs: list[int]
|
493
525
|
op_key: TFLOperationName
|
494
526
|
|
527
|
+
|
495
528
|
# The function signature for `get_tensor_quant_params_fn`.
|
496
529
|
GetTensorQuantParamsFuncSignature = Callable[
|
497
530
|
[
|
@@ -25,6 +25,7 @@ from ai_edge_quantizer.transformations import dequant_insert
|
|
25
25
|
from ai_edge_quantizer.transformations import duplicate_buffer
|
26
26
|
from ai_edge_quantizer.transformations import duplicate_tensor
|
27
27
|
from ai_edge_quantizer.transformations import emulated_subchannel
|
28
|
+
from ai_edge_quantizer.transformations import insert_hadamard_rotation
|
28
29
|
from ai_edge_quantizer.transformations import quant_insert
|
29
30
|
from ai_edge_quantizer.transformations import quantize_tensor
|
30
31
|
from ai_edge_quantizer.transformations import transformation_utils
|
@@ -80,6 +81,9 @@ class TransformationPerformer:
|
|
80
81
|
qtyping.QuantTransformation.DUPLICATE_TENSOR: (
|
81
82
|
duplicate_tensor.duplicate_tensor
|
82
83
|
),
|
84
|
+
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
|
85
|
+
insert_hadamard_rotation.insert_hadamard_rotation
|
86
|
+
),
|
83
87
|
}
|
84
88
|
# transformations are seprated in two categories:
|
85
89
|
# op_insertion_transformations are transformations that only insert ops
|
@@ -91,6 +95,7 @@ class TransformationPerformer:
|
|
91
95
|
qtyping.QuantTransformation.ADD_QUANTIZE,
|
92
96
|
qtyping.QuantTransformation.DUPLICATE_BUFFER,
|
93
97
|
qtyping.QuantTransformation.DUPLICATE_TENSOR,
|
98
|
+
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
94
99
|
])
|
95
100
|
self._op_replacement_transformations = set(
|
96
101
|
[qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
|
@@ -0,0 +1,209 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Hadamard rotation pattern transformation."""
|
17
|
+
|
18
|
+
from flatbuffers import flexbuffers
|
19
|
+
import numpy as np
|
20
|
+
from ai_edge_quantizer import qtyping
|
21
|
+
from ai_edge_quantizer.transformations import transformation_utils
|
22
|
+
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
|
23
|
+
|
24
|
+
|
25
|
+
def _to_flexbuffer(
|
26
|
+
hadamard_size: int,
|
27
|
+
random_binary_vector: list[np.int8],
|
28
|
+
) -> bytes:
|
29
|
+
"""Converts hadamard_size to flexbuffer."""
|
30
|
+
fbb = flexbuffers.Builder()
|
31
|
+
with fbb.Map():
|
32
|
+
fbb.Int('hadamard_size', hadamard_size)
|
33
|
+
fbb.VectorFromElements('random_binary_vector', random_binary_vector)
|
34
|
+
return fbb.Finish()
|
35
|
+
|
36
|
+
|
37
|
+
def _is_producer_embedding_lookup(
|
38
|
+
transformation: transformation_utils.TransformationInput,
|
39
|
+
) -> bool:
|
40
|
+
"""Checks if the tensor's producer is an embedding lookup op."""
|
41
|
+
if transformation.producer == -1:
|
42
|
+
return False
|
43
|
+
else:
|
44
|
+
return (
|
45
|
+
transformation.op_codes[
|
46
|
+
transformation.subgraph.operators[
|
47
|
+
transformation.producer
|
48
|
+
].opcodeIndex
|
49
|
+
].builtinCode
|
50
|
+
== schema_py_generated.BuiltinOperator.EMBEDDING_LOOKUP
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
def _is_fully_connected(
|
55
|
+
transformation: transformation_utils.TransformationInput, op_id: int
|
56
|
+
) -> bool:
|
57
|
+
"""Checks if the any of the tensor's consumers is a fully connected op."""
|
58
|
+
return (
|
59
|
+
transformation.op_codes[
|
60
|
+
transformation.subgraph.operators[op_id].opcodeIndex
|
61
|
+
].builtinCode
|
62
|
+
== schema_py_generated.BuiltinOperator.FULLY_CONNECTED
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
def _update_embedding_lookup_consumers(
|
67
|
+
transformation: transformation_utils.TransformationInput,
|
68
|
+
new_tensor_id: int,
|
69
|
+
) -> bool:
|
70
|
+
"""Updates the consumers of the embedding lookup op to use the new tensor.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
transformation: The transformation input to update the consumers of.
|
74
|
+
new_tensor_id: The new tensor id to use as the input to the embedding lookup
|
75
|
+
consumers.
|
76
|
+
"""
|
77
|
+
for consumer in transformation.consumers:
|
78
|
+
# If the consumer is a graph output and not an op, we can ignore it here
|
79
|
+
# since the graph output will be updated later.
|
80
|
+
if consumer == -1:
|
81
|
+
continue
|
82
|
+
consumer_op = transformation.subgraph.operators[consumer]
|
83
|
+
# Find the input that was attached to the insertion point, and replace it
|
84
|
+
# with the new tensor.
|
85
|
+
for i in range(len(consumer_op.inputs)):
|
86
|
+
if consumer_op.inputs[i] == transformation.tensor_id:
|
87
|
+
consumer_op.inputs[i] = new_tensor_id
|
88
|
+
|
89
|
+
|
90
|
+
def _update_fully_connected_consumers(
|
91
|
+
transformation: transformation_utils.TransformationInput,
|
92
|
+
new_tensor_id: int,
|
93
|
+
) -> bool:
|
94
|
+
"""Updates the fully connected op(s) to use the new tensor.
|
95
|
+
|
96
|
+
Since the new tensor is inserted to the fully_connected's input, we need to
|
97
|
+
scan each consumer (in case of multiple fully_connected ops), and update
|
98
|
+
the input tensor to the new tensor.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
transformation: The transformation input to update the consumers of.
|
102
|
+
new_tensor_id: The new tensor id to use as the input to the fully connected
|
103
|
+
consumers.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
True if the fully connected op(s) were updated to use the new tensor.
|
107
|
+
"""
|
108
|
+
updated = False
|
109
|
+
for consumer in transformation.consumers:
|
110
|
+
if _is_fully_connected(transformation, consumer):
|
111
|
+
transformation.subgraph.operators[consumer].inputs[0] = new_tensor_id
|
112
|
+
updated = True
|
113
|
+
return updated
|
114
|
+
|
115
|
+
|
116
|
+
def insert_hadamard_rotation(
|
117
|
+
transformation_input: transformation_utils.TransformationInput,
|
118
|
+
) -> qtyping.TransformationInfo:
|
119
|
+
"""Inserts a custom aeq.hadamard_rotation op on this tensor.
|
120
|
+
|
121
|
+
This function works for float32 tensors only.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
transformation_input: The transformation input to insert the custom op on.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
The transformation info of the inserted custom op.
|
128
|
+
|
129
|
+
Raises:
|
130
|
+
ValueError: If the transformation input is not a uniform quantization
|
131
|
+
transformation.
|
132
|
+
ValueError: If the Hadamard quantization params are not set.
|
133
|
+
ValueError: If the tensor is not a float32 tensor.
|
134
|
+
ValueError: If no supported ops were found as the tensor's producer or
|
135
|
+
consumers.
|
136
|
+
"""
|
137
|
+
if not isinstance(
|
138
|
+
transformation_input.quant_params, qtyping.UniformQuantParams
|
139
|
+
):
|
140
|
+
raise ValueError('Hadamard rotation supports uniform quantization only')
|
141
|
+
|
142
|
+
if transformation_input.quant_params.hadamard is None:
|
143
|
+
raise ValueError(
|
144
|
+
'Hadamard rotation quantization params are not set but op insertion is'
|
145
|
+
' requested.'
|
146
|
+
)
|
147
|
+
|
148
|
+
tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
|
149
|
+
if tensor.type != schema_py_generated.TensorType.FLOAT32:
|
150
|
+
raise ValueError(
|
151
|
+
'The Hadamard rotation op supports float32 tensors only. Got'
|
152
|
+
f' {tensor.type} tensor.'
|
153
|
+
)
|
154
|
+
|
155
|
+
# Create new custom op with the current tensor as input and a new activation
|
156
|
+
# tensor as output.
|
157
|
+
custom_op_code_idx = transformation_utils.add_op_code(
|
158
|
+
schema_py_generated.BuiltinOperator.CUSTOM,
|
159
|
+
transformation_input.op_codes,
|
160
|
+
'aeq.hadamard_rotation',
|
161
|
+
)
|
162
|
+
custom_op = schema_py_generated.OperatorT()
|
163
|
+
custom_op.opcodeIndex = custom_op_code_idx
|
164
|
+
custom_op.inputs = [transformation_input.tensor_id]
|
165
|
+
custom_op.customOptions = _to_flexbuffer(
|
166
|
+
transformation_input.quant_params.hadamard.hadamard_size,
|
167
|
+
transformation_input.quant_params.hadamard.random_binary_vector.tolist(),
|
168
|
+
)
|
169
|
+
new_tensor_id = transformation_utils.add_new_activation_tensor(
|
170
|
+
tensor.name + b'_rotated',
|
171
|
+
tensor.shapeSignature
|
172
|
+
if tensor.shapeSignature is not None
|
173
|
+
else tensor.shape,
|
174
|
+
schema_py_generated.TensorType.FLOAT32,
|
175
|
+
transformation_input.subgraph,
|
176
|
+
)
|
177
|
+
custom_op.outputs = [new_tensor_id]
|
178
|
+
|
179
|
+
# Update the users of this tensor to use the new tensor.
|
180
|
+
if _is_producer_embedding_lookup(transformation_input):
|
181
|
+
_update_embedding_lookup_consumers(transformation_input, new_tensor_id)
|
182
|
+
elif not _update_fully_connected_consumers(
|
183
|
+
transformation_input, new_tensor_id
|
184
|
+
):
|
185
|
+
raise ValueError(
|
186
|
+
'The Hadamard rotation op supports embedding lookup and fully connected'
|
187
|
+
' ops only, but no such ops were found.'
|
188
|
+
)
|
189
|
+
|
190
|
+
# If the tensor is a graph output, we need to replace the tensor with the
|
191
|
+
# new tensor.
|
192
|
+
for i, output in enumerate(transformation_input.subgraph.outputs):
|
193
|
+
if output == transformation_input.tensor_id:
|
194
|
+
transformation_input.subgraph.outputs[i] = new_tensor_id
|
195
|
+
|
196
|
+
# Find the actual insertion point. The insertion point should be after the
|
197
|
+
# producer op and before the first consumer op. The max() operation ensures
|
198
|
+
# that we're not using -1 as the insertion point.
|
199
|
+
first_consumer_id = min(transformation_input.consumers)
|
200
|
+
op_id = max(transformation_input.producer + 1, first_consumer_id)
|
201
|
+
|
202
|
+
# Insert the custom op.
|
203
|
+
transformation_input.subgraph.operators.insert(op_id, custom_op)
|
204
|
+
|
205
|
+
return qtyping.TransformationInfo(
|
206
|
+
op_id=op_id,
|
207
|
+
num_ops_added=1,
|
208
|
+
output_tensor_id=new_tensor_id,
|
209
|
+
)
|