ai-edge-quantizer-nightly 0.4.0.dev20250930__py3-none-any.whl → 0.4.0.dev20251002__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +40 -3
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +28 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +77 -8
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +69 -4
- ai_edge_quantizer/default_policy.py +4 -2
- ai_edge_quantizer/params_generator.py +1 -0
- ai_edge_quantizer/qtyping.py +5 -0
- ai_edge_quantizer/transformation_performer.py +5 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +291 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +8 -31
- ai_edge_quantizer/transformations/quantize_tensor.py +11 -31
- ai_edge_quantizer/transformations/transformation_utils.py +66 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +1 -1
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
- ai_edge_quantizer/utils/validation_utils.py +29 -0
- ai_edge_quantizer/utils/validation_utils_test.py +24 -0
- {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/RECORD +22 -20
- {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.4.0.dev20250930.dist-info → ai_edge_quantizer_nightly-0.4.0.dev20251002.dist-info}/top_level.txt +0 -0
|
@@ -61,7 +61,8 @@ class AlgorithmName(str, enum.Enum):
|
|
|
61
61
|
FLOAT_CASTING = float_casting.ALGORITHM_KEY
|
|
62
62
|
DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
|
|
63
63
|
OCTAV = octav.ALGORITHM_KEY
|
|
64
|
-
HADAMARD_ROTATION = hadamard_rotation.
|
|
64
|
+
HADAMARD_ROTATION = hadamard_rotation.CUSTOM_OP_ALGORITHM_KEY
|
|
65
|
+
DECOMPOSED_HADAMARD_ROTATION = hadamard_rotation.DECOMPOSED_ALGORITHM_KEY
|
|
65
66
|
MSE = mse.ALGORITHM_KEY
|
|
66
67
|
|
|
67
68
|
|
|
@@ -130,6 +131,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
|
|
|
130
131
|
_TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
|
|
131
132
|
_TFLOpName.EQUAL: common_quantize.materialize_equal,
|
|
132
133
|
_TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
|
|
134
|
+
_TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
|
|
133
135
|
}
|
|
134
136
|
for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
135
137
|
register_quantized_op(
|
|
@@ -283,6 +285,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
|
283
285
|
_TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
|
|
284
286
|
_TFLOpName.EQUAL: common_quantize.materialize_equal,
|
|
285
287
|
_TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
|
|
288
|
+
_TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
|
|
286
289
|
})
|
|
287
290
|
|
|
288
291
|
for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
@@ -311,8 +314,12 @@ register_config_check_policy_func(
|
|
|
311
314
|
|
|
312
315
|
# Register specialized hadamard rotation materialize functions.
|
|
313
316
|
_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
314
|
-
_TFLOpName.FULLY_CONNECTED:
|
|
315
|
-
|
|
317
|
+
_TFLOpName.FULLY_CONNECTED: (
|
|
318
|
+
hadamard_rotation.materialize_fully_connected_custom_op
|
|
319
|
+
),
|
|
320
|
+
_TFLOpName.EMBEDDING_LOOKUP: (
|
|
321
|
+
hadamard_rotation.materialize_embedding_lookup_custom_op
|
|
322
|
+
),
|
|
316
323
|
})
|
|
317
324
|
for (
|
|
318
325
|
op_name,
|
|
@@ -326,6 +333,36 @@ for (
|
|
|
326
333
|
materialize_func=materialize_func,
|
|
327
334
|
)
|
|
328
335
|
|
|
336
|
+
register_op_quant_config_validation_func(
|
|
337
|
+
AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
|
|
338
|
+
common_quantize.check_op_quantization_config,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
register_config_check_policy_func(
|
|
342
|
+
AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
|
|
343
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
_DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
347
|
+
_TFLOpName.FULLY_CONNECTED: (
|
|
348
|
+
hadamard_rotation.materialize_fully_connected_decomposed
|
|
349
|
+
),
|
|
350
|
+
_TFLOpName.EMBEDDING_LOOKUP: (
|
|
351
|
+
hadamard_rotation.materialize_embedding_lookup_decomposed
|
|
352
|
+
),
|
|
353
|
+
})
|
|
354
|
+
for (
|
|
355
|
+
op_name,
|
|
356
|
+
materialize_func,
|
|
357
|
+
) in _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
358
|
+
register_quantized_op(
|
|
359
|
+
AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
|
|
360
|
+
op_name,
|
|
361
|
+
naive_min_max_quantize.init_qsvs,
|
|
362
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
|
363
|
+
materialize_func=materialize_func,
|
|
364
|
+
)
|
|
365
|
+
|
|
329
366
|
|
|
330
367
|
# Register the MSE algorithm.
|
|
331
368
|
register_op_quant_config_validation_func(
|
|
@@ -748,6 +748,34 @@ def materialize_padv2(
|
|
|
748
748
|
)
|
|
749
749
|
|
|
750
750
|
|
|
751
|
+
def materialize_mirror_pad(
|
|
752
|
+
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
753
|
+
op_info: qtyping.OpInfo,
|
|
754
|
+
graph_info: qtyping.GraphInfo,
|
|
755
|
+
tensor_name_to_qsv: dict[str, Any],
|
|
756
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
757
|
+
"""Materialize tensors in tfl.mirror_pad.
|
|
758
|
+
|
|
759
|
+
Args:
|
|
760
|
+
get_tensor_quant_params_fn: Function to get quantization parameters for the
|
|
761
|
+
tensor.
|
|
762
|
+
op_info: Aggregated information about the op (e.g., quantization config).
|
|
763
|
+
graph_info: Graph information needed to perform quantization for the op.
|
|
764
|
+
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
765
|
+
|
|
766
|
+
Returns:
|
|
767
|
+
A list of `qtyping.TensorTransformationParams` for the tensors in the op.
|
|
768
|
+
"""
|
|
769
|
+
return common_utils.materialize_standard_op(
|
|
770
|
+
op_info,
|
|
771
|
+
graph_info,
|
|
772
|
+
tensor_name_to_qsv,
|
|
773
|
+
get_tensor_quant_params_fn,
|
|
774
|
+
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
|
|
775
|
+
inputs_to_ignore=[1], # Paddings tensor does not need to be quantized.
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
|
|
751
779
|
def materialize_squared_difference(
|
|
752
780
|
get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
|
|
753
781
|
op_info: qtyping.OpInfo,
|
|
@@ -23,16 +23,17 @@ from ai_edge_quantizer.algorithms.utils import common_utils
|
|
|
23
23
|
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
|
|
26
|
+
CUSTOM_OP_ALGORITHM_KEY = "HADAMARD_ROTATION"
|
|
27
|
+
DECOMPOSED_ALGORITHM_KEY = "DECOMPOSED_HADAMARD_ROTATION"
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
def _make_hadamard_matrix(size: int) -> np.ndarray:
|
|
30
31
|
"""Generates a Hadamard matrix of the given size.
|
|
31
32
|
|
|
32
33
|
Args:
|
|
33
|
-
size: The size of the Hadamard matrix. Must be a power of 2. This
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
size: The size of the Hadamard matrix. Must be a power of 2. This represents
|
|
35
|
+
a single dimension. E.g. if size is 4, then the Hadamard matrix is a 4x4
|
|
36
|
+
matrix.
|
|
36
37
|
|
|
37
38
|
Returns:
|
|
38
39
|
The Hadamard matrix.
|
|
@@ -157,9 +158,10 @@ def get_tensor_quant_params(
|
|
|
157
158
|
)
|
|
158
159
|
|
|
159
160
|
|
|
160
|
-
def
|
|
161
|
+
def _materialize_fully_connected(
|
|
161
162
|
op_info: qtyping.OpInfo,
|
|
162
163
|
graph_info: qtyping.GraphInfo,
|
|
164
|
+
is_decomposed: bool = False,
|
|
163
165
|
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
164
166
|
) -> list[qtyping.TensorTransformationParams]:
|
|
165
167
|
"""Materialize the fully_connected op.
|
|
@@ -167,12 +169,20 @@ def materialize_fully_connected(
|
|
|
167
169
|
Args:
|
|
168
170
|
op_info: Aggregated information about the op (e.g., quantization config).
|
|
169
171
|
graph_info: Graph information needed to perform quantization for the op.
|
|
172
|
+
is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
|
|
173
|
+
op.
|
|
170
174
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
171
175
|
|
|
172
176
|
Returns:
|
|
173
177
|
Quantization configuration for the tensors associated with the op (e.g.,
|
|
174
178
|
weights, bias).
|
|
175
179
|
"""
|
|
180
|
+
if op_info.op_quant_config.weight_tensor_config is None:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"Weight tensor quantization config is not provided for Hadamard"
|
|
183
|
+
" Rotation quantization."
|
|
184
|
+
)
|
|
185
|
+
|
|
176
186
|
op_tensor_params = []
|
|
177
187
|
|
|
178
188
|
# Materialize weight.
|
|
@@ -209,7 +219,9 @@ def materialize_fully_connected(
|
|
|
209
219
|
op_info.op.inputs[input_tensor_index]
|
|
210
220
|
]
|
|
211
221
|
transformations = [
|
|
212
|
-
qtyping.QuantTransformation.
|
|
222
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
|
|
223
|
+
if is_decomposed
|
|
224
|
+
else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
|
213
225
|
]
|
|
214
226
|
op2tensor_params = qtyping.OpToTensorParams(
|
|
215
227
|
subgraph_op_id=op_info.subgraph_op_index,
|
|
@@ -258,16 +270,45 @@ def materialize_fully_connected(
|
|
|
258
270
|
return op_tensor_params
|
|
259
271
|
|
|
260
272
|
|
|
261
|
-
def
|
|
273
|
+
def materialize_fully_connected_custom_op(
|
|
274
|
+
op_info: qtyping.OpInfo,
|
|
275
|
+
graph_info: qtyping.GraphInfo,
|
|
276
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
277
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
278
|
+
return _materialize_fully_connected(
|
|
279
|
+
op_info,
|
|
280
|
+
graph_info,
|
|
281
|
+
is_decomposed=False,
|
|
282
|
+
tensor_name_to_qsv=tensor_name_to_qsv,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def materialize_fully_connected_decomposed(
|
|
262
287
|
op_info: qtyping.OpInfo,
|
|
263
288
|
graph_info: qtyping.GraphInfo,
|
|
264
289
|
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
290
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
291
|
+
return _materialize_fully_connected(
|
|
292
|
+
op_info,
|
|
293
|
+
graph_info,
|
|
294
|
+
is_decomposed=True,
|
|
295
|
+
tensor_name_to_qsv=tensor_name_to_qsv,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _materialize_embedding_lookup(
|
|
300
|
+
op_info: qtyping.OpInfo,
|
|
301
|
+
graph_info: qtyping.GraphInfo,
|
|
302
|
+
is_decomposed: bool = False,
|
|
303
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
265
304
|
) -> list[qtyping.TensorTransformationParams]:
|
|
266
305
|
"""Materialize the embedding_lookup op.
|
|
267
306
|
|
|
268
307
|
Args:
|
|
269
308
|
op_info: Aggregated information about the op (e.g., quantization config).
|
|
270
309
|
graph_info: Graph information needed to perform quantization for the op.
|
|
310
|
+
is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
|
|
311
|
+
op.
|
|
271
312
|
tensor_name_to_qsv: A map of tensor name to quantization parameters.
|
|
272
313
|
|
|
273
314
|
Returns:
|
|
@@ -329,7 +370,9 @@ def materialize_embedding_lookup(
|
|
|
329
370
|
op_info.op.outputs[output_tensor_index]
|
|
330
371
|
]
|
|
331
372
|
transformations = [
|
|
332
|
-
qtyping.QuantTransformation.
|
|
373
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
|
|
374
|
+
if is_decomposed
|
|
375
|
+
else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
|
333
376
|
]
|
|
334
377
|
op2tensor_params = qtyping.OpToTensorParams(
|
|
335
378
|
subgraph_op_id=op_info.subgraph_op_index,
|
|
@@ -343,3 +386,29 @@ def materialize_embedding_lookup(
|
|
|
343
386
|
op_tensor_params.append(output_transformation_params)
|
|
344
387
|
|
|
345
388
|
return op_tensor_params
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def materialize_embedding_lookup_custom_op(
|
|
392
|
+
op_info: qtyping.OpInfo,
|
|
393
|
+
graph_info: qtyping.GraphInfo,
|
|
394
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
395
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
396
|
+
return _materialize_embedding_lookup(
|
|
397
|
+
op_info,
|
|
398
|
+
graph_info,
|
|
399
|
+
is_decomposed=False,
|
|
400
|
+
tensor_name_to_qsv=tensor_name_to_qsv,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def materialize_embedding_lookup_decomposed(
|
|
405
|
+
op_info: qtyping.OpInfo,
|
|
406
|
+
graph_info: qtyping.GraphInfo,
|
|
407
|
+
tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
|
|
408
|
+
) -> list[qtyping.TensorTransformationParams]:
|
|
409
|
+
return _materialize_embedding_lookup(
|
|
410
|
+
op_info,
|
|
411
|
+
graph_info,
|
|
412
|
+
is_decomposed=True,
|
|
413
|
+
tensor_name_to_qsv=tensor_name_to_qsv,
|
|
414
|
+
)
|
|
@@ -63,7 +63,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
|
|
|
63
63
|
)
|
|
64
64
|
|
|
65
65
|
def test_materialize_fully_connected_basic(self):
|
|
66
|
-
params = hadamard_rotation.
|
|
66
|
+
params = hadamard_rotation.materialize_fully_connected_custom_op(
|
|
67
67
|
self._op_info, self._graph_info, self._tensor_name_to_qsv
|
|
68
68
|
)
|
|
69
69
|
fc_input = params[0]
|
|
@@ -111,7 +111,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
|
|
|
111
111
|
),
|
|
112
112
|
),
|
|
113
113
|
)
|
|
114
|
-
params = hadamard_rotation.
|
|
114
|
+
params = hadamard_rotation.materialize_fully_connected_custom_op(
|
|
115
115
|
self._op_info, self._graph_info, self._tensor_name_to_qsv
|
|
116
116
|
)
|
|
117
117
|
self.assertLen(params, 4)
|
|
@@ -152,7 +152,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
|
|
|
152
152
|
),
|
|
153
153
|
),
|
|
154
154
|
)
|
|
155
|
-
params = hadamard_rotation.
|
|
155
|
+
params = hadamard_rotation.materialize_fully_connected_custom_op(
|
|
156
156
|
self._op_info, self._graph_info, self._tensor_name_to_qsv
|
|
157
157
|
)
|
|
158
158
|
self.assertLen(params, 4)
|
|
@@ -179,6 +179,34 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
|
|
|
179
179
|
):
|
|
180
180
|
self.assertEqual(weight.consumers[0].parameters.quantized_dimension, 1)
|
|
181
181
|
|
|
182
|
+
def test_materialize_fully_connected_decomposed(self):
|
|
183
|
+
params = hadamard_rotation.materialize_fully_connected_decomposed(
|
|
184
|
+
self._op_info, self._graph_info, self._tensor_name_to_qsv
|
|
185
|
+
)
|
|
186
|
+
fc_input = params[0]
|
|
187
|
+
weight = params[1]
|
|
188
|
+
bias = params[2]
|
|
189
|
+
output = params[3]
|
|
190
|
+
|
|
191
|
+
self.assertLen(params, 4)
|
|
192
|
+
self.assertEqual(
|
|
193
|
+
fc_input.consumers[0].transformations,
|
|
194
|
+
[qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION],
|
|
195
|
+
)
|
|
196
|
+
self.assertEqual(
|
|
197
|
+
weight.consumers[0].transformations,
|
|
198
|
+
[qtyping.QuantTransformation.QUANTIZE_TENSOR],
|
|
199
|
+
)
|
|
200
|
+
self.assertEqual(
|
|
201
|
+
bias.consumers[0].transformations,
|
|
202
|
+
[qtyping.QuantTransformation.NO_QUANTIZE],
|
|
203
|
+
)
|
|
204
|
+
if output.producer is not None:
|
|
205
|
+
self.assertEqual(
|
|
206
|
+
output.producer.transformations,
|
|
207
|
+
[qtyping.QuantTransformation.NO_QUANTIZE],
|
|
208
|
+
)
|
|
209
|
+
|
|
182
210
|
def test_get_tensor_quant_params_basic(self):
|
|
183
211
|
input_tensor = self._subgraph.tensors[self._fc_op.inputs[1]]
|
|
184
212
|
buffer = self._graph_info.buffers[self._fc_buffer_id]
|
|
@@ -344,7 +372,7 @@ class HadamardRotationEmbeddingLookupTest(parameterized.TestCase):
|
|
|
344
372
|
),
|
|
345
373
|
),
|
|
346
374
|
)
|
|
347
|
-
params = hadamard_rotation.
|
|
375
|
+
params = hadamard_rotation.materialize_embedding_lookup_custom_op(
|
|
348
376
|
op_info, self._graph_info, self._tensor_name_to_qsv
|
|
349
377
|
)
|
|
350
378
|
self.assertLen(params, 3)
|
|
@@ -371,6 +399,43 @@ class HadamardRotationEmbeddingLookupTest(parameterized.TestCase):
|
|
|
371
399
|
[qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION],
|
|
372
400
|
)
|
|
373
401
|
|
|
402
|
+
def test_materialize_embedding_lookup_decomposed(self):
|
|
403
|
+
subgraph = self._test_model.subgraphs[0]
|
|
404
|
+
embedding_subgraph_op_index = 0
|
|
405
|
+
embedding_op = subgraph.operators[embedding_subgraph_op_index]
|
|
406
|
+
op_info = qtyping.OpInfo(
|
|
407
|
+
op=embedding_op,
|
|
408
|
+
op_name=_TFLOpName.EMBEDDING_LOOKUP,
|
|
409
|
+
subgraph_op_index=embedding_subgraph_op_index,
|
|
410
|
+
op_quant_config=qtyping.OpQuantizationConfig(
|
|
411
|
+
weight_tensor_config=_TensorQuantConfig(
|
|
412
|
+
num_bits=8,
|
|
413
|
+
symmetric=True,
|
|
414
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
|
415
|
+
),
|
|
416
|
+
),
|
|
417
|
+
)
|
|
418
|
+
params = hadamard_rotation.materialize_embedding_lookup_decomposed(
|
|
419
|
+
op_info, self._graph_info, self._tensor_name_to_qsv
|
|
420
|
+
)
|
|
421
|
+
self.assertLen(params, 3)
|
|
422
|
+
lookup = params[0]
|
|
423
|
+
value = params[1]
|
|
424
|
+
output = params[2]
|
|
425
|
+
self.assertEqual(
|
|
426
|
+
lookup.consumers[0].transformations,
|
|
427
|
+
[qtyping.QuantTransformation.NO_QUANTIZE],
|
|
428
|
+
)
|
|
429
|
+
self.assertEqual(
|
|
430
|
+
value.consumers[0].transformations,
|
|
431
|
+
[qtyping.QuantTransformation.QUANTIZE_TENSOR],
|
|
432
|
+
)
|
|
433
|
+
if output.producer is not None:
|
|
434
|
+
self.assertEqual(
|
|
435
|
+
output.producer.transformations,
|
|
436
|
+
[qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION],
|
|
437
|
+
)
|
|
438
|
+
|
|
374
439
|
|
|
375
440
|
if __name__ == "__main__":
|
|
376
441
|
googletest.main()
|
|
@@ -199,7 +199,8 @@ DEFAULT_JSON_POLICY = """
|
|
|
199
199
|
"PADV2",
|
|
200
200
|
"REDUCE_MIN",
|
|
201
201
|
"EQUAL",
|
|
202
|
-
"NOT_EQUAL"
|
|
202
|
+
"NOT_EQUAL",
|
|
203
|
+
"MIRROR_PAD"
|
|
203
204
|
],
|
|
204
205
|
"static_wi8_ai8": [
|
|
205
206
|
"ADD",
|
|
@@ -248,7 +249,8 @@ DEFAULT_JSON_POLICY = """
|
|
|
248
249
|
"PADV2",
|
|
249
250
|
"REDUCE_MIN",
|
|
250
251
|
"EQUAL",
|
|
251
|
-
"NOT_EQUAL"
|
|
252
|
+
"NOT_EQUAL",
|
|
253
|
+
"MIRROR_PAD"
|
|
252
254
|
],
|
|
253
255
|
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
|
254
256
|
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
|
|
@@ -510,6 +510,7 @@ def _compatible_tensor_params(
|
|
|
510
510
|
_QuantTrans.ADD_QUANTIZE,
|
|
511
511
|
_QuantTrans.NO_QUANTIZE,
|
|
512
512
|
_QuantTrans.INSERT_HADAMARD_ROTATION,
|
|
513
|
+
_QuantTrans.INSERT_DECOMPOSED_HADAMARD_ROTATION,
|
|
513
514
|
]
|
|
514
515
|
quantized_source_transformations = [
|
|
515
516
|
_QuantTrans.QUANTIZE_TENSOR,
|
ai_edge_quantizer/qtyping.py
CHANGED
|
@@ -80,6 +80,7 @@ class TFLOperationName(str, enum.Enum):
|
|
|
80
80
|
REDUCE_MIN = 'REDUCE_MIN'
|
|
81
81
|
EQUAL = 'EQUAL'
|
|
82
82
|
NOT_EQUAL = 'NOT_EQUAL'
|
|
83
|
+
MIRROR_PAD = 'MIRROR_PAD'
|
|
83
84
|
|
|
84
85
|
|
|
85
86
|
class QuantizeMode(enum.Enum):
|
|
@@ -133,6 +134,9 @@ class QuantTransformation(enum.Enum):
|
|
|
133
134
|
DUPLICATE_TENSOR = 6
|
|
134
135
|
# Insert the aeq.hadamard_rotation op.
|
|
135
136
|
INSERT_HADAMARD_ROTATION = 7
|
|
137
|
+
# Insert decomposed Hadamard rotation ops. This expresses the Hadamard
|
|
138
|
+
# rotation as matrix multiplication with Hadamard matrices.
|
|
139
|
+
INSERT_DECOMPOSED_HADAMARD_ROTATION = 8
|
|
136
140
|
|
|
137
141
|
|
|
138
142
|
@dataclasses.dataclass(frozen=True)
|
|
@@ -305,6 +309,7 @@ class TensorQuantizationConfig:
|
|
|
305
309
|
quantization.
|
|
306
310
|
dtype: The data type of the tensor.
|
|
307
311
|
block_size: The block size for blockwise quantization, ignored otherwise.
|
|
312
|
+
algorithm_key: The algorithm key to use for quantization.
|
|
308
313
|
"""
|
|
309
314
|
|
|
310
315
|
num_bits: int
|
|
@@ -24,6 +24,7 @@ from ai_edge_quantizer import qtyping
|
|
|
24
24
|
from ai_edge_quantizer.transformations import dequant_insert
|
|
25
25
|
from ai_edge_quantizer.transformations import duplicate_buffer
|
|
26
26
|
from ai_edge_quantizer.transformations import duplicate_tensor
|
|
27
|
+
from ai_edge_quantizer.transformations import insert_decomposed_hadamard_rotation
|
|
27
28
|
from ai_edge_quantizer.transformations import insert_hadamard_rotation
|
|
28
29
|
from ai_edge_quantizer.transformations import quant_insert
|
|
29
30
|
from ai_edge_quantizer.transformations import quantize_tensor
|
|
@@ -83,6 +84,9 @@ class TransformationPerformer:
|
|
|
83
84
|
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION: (
|
|
84
85
|
insert_hadamard_rotation.insert_hadamard_rotation
|
|
85
86
|
),
|
|
87
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION: (
|
|
88
|
+
insert_decomposed_hadamard_rotation.insert_decomposed_hadamard_rotation
|
|
89
|
+
),
|
|
86
90
|
}
|
|
87
91
|
# transformations are seprated in two categories:
|
|
88
92
|
# op_insertion_transformations are transformations that only insert ops
|
|
@@ -95,6 +99,7 @@ class TransformationPerformer:
|
|
|
95
99
|
qtyping.QuantTransformation.DUPLICATE_BUFFER,
|
|
96
100
|
qtyping.QuantTransformation.DUPLICATE_TENSOR,
|
|
97
101
|
qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
|
|
102
|
+
qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION,
|
|
98
103
|
])
|
|
99
104
|
self._op_replacement_transformations = set(
|
|
100
105
|
[qtyping.QuantTransformation.EMULATED_SUBCHANNEL]
|