ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +158 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
- ai_edge_quantizer/calibrator.py +11 -60
- ai_edge_quantizer/calibrator_test.py +4 -73
- ai_edge_quantizer/default_policy.py +61 -26
- ai_edge_quantizer/model_modifier.py +97 -7
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +31 -8
- ai_edge_quantizer/params_generator.py +17 -10
- ai_edge_quantizer/params_generator_test.py +2 -7
- ai_edge_quantizer/qtyping.py +86 -6
- ai_edge_quantizer/quantizer.py +166 -21
- ai_edge_quantizer/quantizer_test.py +284 -16
- ai_edge_quantizer/recipe.py +154 -42
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +118 -13
- ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
- ai_edge_quantizer/transformation_performer.py +55 -25
- ai_edge_quantizer/transformation_performer_test.py +127 -5
- ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
- ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
- ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
- ai_edge_quantizer/transformations/transformation_utils.py +129 -6
- ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +75 -2
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -24,9 +24,12 @@ from ai_edge_quantizer import qtyping
|
|
|
24
24
|
from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
|
|
25
25
|
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
|
|
26
26
|
from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
|
|
27
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
|
|
28
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import mse
|
|
27
29
|
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
|
|
28
30
|
from ai_edge_quantizer.algorithms.uniform_quantize import octav
|
|
29
31
|
|
|
32
|
+
|
|
30
33
|
# TODO: b/399775701 - Clean up this file.
|
|
31
34
|
|
|
32
35
|
_TFLOpName = qtyping.TFLOperationName
|
|
@@ -58,6 +61,10 @@ class AlgorithmName(str, enum.Enum):
|
|
|
58
61
|
FLOAT_CASTING = float_casting.ALGORITHM_KEY
|
|
59
62
|
DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
|
|
60
63
|
OCTAV = octav.ALGORITHM_KEY
|
|
64
|
+
HADAMARD_ROTATION = hadamard_rotation.CUSTOM_OP_ALGORITHM_KEY
|
|
65
|
+
DECOMPOSED_HADAMARD_ROTATION = hadamard_rotation.DECOMPOSED_ALGORITHM_KEY
|
|
66
|
+
MSE = mse.ALGORITHM_KEY
|
|
67
|
+
|
|
61
68
|
|
|
62
69
|
### MIN/MAX_UNIFORM_QUANT ###
|
|
63
70
|
|
|
@@ -99,11 +106,37 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
|
|
|
99
106
|
_TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
|
|
100
107
|
_TFLOpName.SLICE: common_quantize.materialize_slice,
|
|
101
108
|
_TFLOpName.SUM: common_quantize.materialize_sum,
|
|
109
|
+
_TFLOpName.SELECT: common_quantize.materialize_select,
|
|
102
110
|
_TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
|
|
103
111
|
_TFLOpName.DYNAMIC_UPDATE_SLICE: (
|
|
104
112
|
common_quantize.materialize_dynamic_update_slice
|
|
105
113
|
),
|
|
106
114
|
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
|
|
115
|
+
_TFLOpName.PAD: common_quantize.materialize_pad,
|
|
116
|
+
_TFLOpName.SQUARED_DIFFERENCE: (
|
|
117
|
+
common_quantize.materialize_squared_difference
|
|
118
|
+
),
|
|
119
|
+
_TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
|
|
120
|
+
_TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
|
|
121
|
+
_TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
|
|
122
|
+
common_quantize.materialize_resize_nearest_neighbor
|
|
123
|
+
),
|
|
124
|
+
_TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
|
|
125
|
+
_TFLOpName.PACK: common_quantize.materialize_pack,
|
|
126
|
+
_TFLOpName.UNPACK: common_quantize.materialize_unpack,
|
|
127
|
+
_TFLOpName.DIV: common_quantize.materialize_div,
|
|
128
|
+
_TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
|
|
129
|
+
_TFLOpName.SQRT: common_quantize.materialize_sqrt,
|
|
130
|
+
_TFLOpName.GATHER: common_quantize.materialize_gather,
|
|
131
|
+
_TFLOpName.HARD_SWISH: common_quantize.materialize_hard_swish,
|
|
132
|
+
_TFLOpName.MAXIMUM: common_quantize.materialize_maximum,
|
|
133
|
+
_TFLOpName.PADV2: common_quantize.materialize_padv2,
|
|
134
|
+
_TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
|
|
135
|
+
_TFLOpName.EQUAL: common_quantize.materialize_equal,
|
|
136
|
+
_TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
|
|
137
|
+
_TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
|
|
138
|
+
_TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
|
|
139
|
+
_TFLOpName.RELU: common_quantize.materialize_relu,
|
|
107
140
|
}
|
|
108
141
|
for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
109
142
|
register_quantized_op(
|
|
@@ -232,11 +265,37 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
|
232
265
|
_TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
|
|
233
266
|
_TFLOpName.SLICE: common_quantize.materialize_slice,
|
|
234
267
|
_TFLOpName.SUM: common_quantize.materialize_sum,
|
|
268
|
+
_TFLOpName.SELECT: common_quantize.materialize_select,
|
|
235
269
|
_TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
|
|
236
270
|
_TFLOpName.DYNAMIC_UPDATE_SLICE: (
|
|
237
271
|
common_quantize.materialize_dynamic_update_slice
|
|
238
272
|
),
|
|
239
273
|
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
|
|
274
|
+
_TFLOpName.PAD: common_quantize.materialize_pad,
|
|
275
|
+
_TFLOpName.SQUARED_DIFFERENCE: (
|
|
276
|
+
common_quantize.materialize_squared_difference
|
|
277
|
+
),
|
|
278
|
+
_TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
|
|
279
|
+
_TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
|
|
280
|
+
_TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
|
|
281
|
+
common_quantize.materialize_resize_nearest_neighbor
|
|
282
|
+
),
|
|
283
|
+
_TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
|
|
284
|
+
_TFLOpName.PACK: common_quantize.materialize_pack,
|
|
285
|
+
_TFLOpName.UNPACK: common_quantize.materialize_unpack,
|
|
286
|
+
_TFLOpName.DIV: common_quantize.materialize_div,
|
|
287
|
+
_TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
|
|
288
|
+
_TFLOpName.SQRT: common_quantize.materialize_sqrt,
|
|
289
|
+
_TFLOpName.GATHER: common_quantize.materialize_gather,
|
|
290
|
+
_TFLOpName.HARD_SWISH: common_quantize.materialize_hard_swish,
|
|
291
|
+
_TFLOpName.MAXIMUM: common_quantize.materialize_maximum,
|
|
292
|
+
_TFLOpName.PADV2: common_quantize.materialize_padv2,
|
|
293
|
+
_TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
|
|
294
|
+
_TFLOpName.EQUAL: common_quantize.materialize_equal,
|
|
295
|
+
_TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
|
|
296
|
+
_TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
|
|
297
|
+
_TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
|
|
298
|
+
_TFLOpName.RELU: common_quantize.materialize_relu,
|
|
240
299
|
})
|
|
241
300
|
|
|
242
301
|
for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
@@ -250,3 +309,102 @@ for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
|
250
309
|
octav.get_tensor_quant_params,
|
|
251
310
|
),
|
|
252
311
|
)
|
|
312
|
+
|
|
313
|
+
# Register the Hadamard Rotation algorithm.
|
|
314
|
+
register_op_quant_config_validation_func(
|
|
315
|
+
AlgorithmName.HADAMARD_ROTATION,
|
|
316
|
+
common_quantize.check_op_quantization_config,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Register a config check policy for the Hadamard Rotation algorithm.
|
|
320
|
+
register_config_check_policy_func(
|
|
321
|
+
AlgorithmName.HADAMARD_ROTATION,
|
|
322
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Register specialized hadamard rotation materialize functions.
|
|
326
|
+
_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
327
|
+
_TFLOpName.FULLY_CONNECTED: (
|
|
328
|
+
hadamard_rotation.materialize_fully_connected_custom_op
|
|
329
|
+
),
|
|
330
|
+
_TFLOpName.EMBEDDING_LOOKUP: (
|
|
331
|
+
hadamard_rotation.materialize_embedding_lookup_custom_op
|
|
332
|
+
),
|
|
333
|
+
})
|
|
334
|
+
for (
|
|
335
|
+
op_name,
|
|
336
|
+
materialize_func,
|
|
337
|
+
) in _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
338
|
+
register_quantized_op(
|
|
339
|
+
AlgorithmName.HADAMARD_ROTATION,
|
|
340
|
+
op_name,
|
|
341
|
+
naive_min_max_quantize.init_qsvs,
|
|
342
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
|
343
|
+
materialize_func=materialize_func,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
register_op_quant_config_validation_func(
|
|
347
|
+
AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
|
|
348
|
+
common_quantize.check_op_quantization_config,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
register_config_check_policy_func(
|
|
352
|
+
AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
|
|
353
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
_DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
357
|
+
_TFLOpName.FULLY_CONNECTED: (
|
|
358
|
+
hadamard_rotation.materialize_fully_connected_decomposed
|
|
359
|
+
),
|
|
360
|
+
_TFLOpName.EMBEDDING_LOOKUP: (
|
|
361
|
+
hadamard_rotation.materialize_embedding_lookup_decomposed
|
|
362
|
+
),
|
|
363
|
+
})
|
|
364
|
+
for (
|
|
365
|
+
op_name,
|
|
366
|
+
materialize_func,
|
|
367
|
+
) in _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
368
|
+
register_quantized_op(
|
|
369
|
+
AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
|
|
370
|
+
op_name,
|
|
371
|
+
naive_min_max_quantize.init_qsvs,
|
|
372
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
|
373
|
+
materialize_func=materialize_func,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
# Register the MSE algorithm.
|
|
378
|
+
register_op_quant_config_validation_func(
|
|
379
|
+
AlgorithmName.MSE,
|
|
380
|
+
common_quantize.check_op_quantization_config,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
# Register a config check policy for the MSE algorithm.
|
|
384
|
+
register_config_check_policy_func(
|
|
385
|
+
AlgorithmName.MSE,
|
|
386
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Register specialized MSE materialize functions.
|
|
390
|
+
_MSE_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
391
|
+
_TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
|
|
392
|
+
_TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
|
|
393
|
+
_TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
|
|
394
|
+
_TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv,
|
|
395
|
+
_TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose,
|
|
396
|
+
})
|
|
397
|
+
for (
|
|
398
|
+
op_name,
|
|
399
|
+
materialize_func,
|
|
400
|
+
) in _MSE_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
401
|
+
register_quantized_op(
|
|
402
|
+
AlgorithmName.MSE,
|
|
403
|
+
op_name,
|
|
404
|
+
naive_min_max_quantize.init_qsvs,
|
|
405
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
|
406
|
+
materialize_func=functools.partial(
|
|
407
|
+
materialize_func,
|
|
408
|
+
mse.get_tensor_quant_params,
|
|
409
|
+
),
|
|
410
|
+
)
|
|
@@ -531,9 +531,9 @@ class Fp16QuantizeTest(parameterized.TestCase):
|
|
|
531
531
|
|
|
532
532
|
op_tensor_names = {}
|
|
533
533
|
op_tensor_names["weight"] = (
|
|
534
|
-
"
|
|
534
|
+
"jit(export_func)/jit(main)/...y,yz->...z/dot_general;jit(export_func)/jit(main)/jit(_one_hot)/eq;jit(export_func)/jit(main)/jit(_one_hot)/convert_element_type"
|
|
535
535
|
)
|
|
536
|
-
op_tensor_names["input"] = "
|
|
536
|
+
op_tensor_names["input"] = "lookup"
|
|
537
537
|
op_tensor_names["output"] = "Identity_1"
|
|
538
538
|
|
|
539
539
|
# TODO: b/335913710 - Rename the test function.
|