ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +224 -0
- ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
- ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
- ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
- ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
- ai_edge_quantizer/calibrator.py +58 -94
- ai_edge_quantizer/calibrator_test.py +5 -74
- ai_edge_quantizer/default_policy.py +108 -16
- ai_edge_quantizer/model_modifier.py +132 -8
- ai_edge_quantizer/model_modifier_test.py +81 -1
- ai_edge_quantizer/model_validator.py +38 -10
- ai_edge_quantizer/model_validator_test.py +2 -1
- ai_edge_quantizer/params_generator.py +230 -47
- ai_edge_quantizer/params_generator_test.py +366 -261
- ai_edge_quantizer/qtyping.py +92 -6
- ai_edge_quantizer/quantizer.py +167 -23
- ai_edge_quantizer/quantizer_test.py +288 -26
- ai_edge_quantizer/recipe.py +156 -21
- ai_edge_quantizer/recipe_manager.py +158 -1
- ai_edge_quantizer/recipe_manager_test.py +146 -32
- ai_edge_quantizer/recipe_test.py +93 -17
- ai_edge_quantizer/transformation_instruction_generator.py +313 -46
- ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
- ai_edge_quantizer/transformation_performer.py +112 -58
- ai_edge_quantizer/transformation_performer_test.py +176 -4
- ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
- ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
- ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
- ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
- ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
- ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
- ai_edge_quantizer/transformations/transformation_utils.py +157 -11
- ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
- ai_edge_quantizer/utils/calibration_utils.py +263 -1
- ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
- ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
- ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
- ai_edge_quantizer/utils/test_utils.py +191 -58
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
- ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
- ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
- ai_edge_quantizer/utils/validation_utils.py +114 -4
- ai_edge_quantizer/utils/validation_utils_test.py +80 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
- ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
- ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
- ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
- ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
|
@@ -17,13 +17,18 @@
|
|
|
17
17
|
|
|
18
18
|
import enum
|
|
19
19
|
import functools
|
|
20
|
+
from immutabledict import immutabledict
|
|
20
21
|
from ai_edge_quantizer import algorithm_manager_api
|
|
21
22
|
from ai_edge_quantizer import default_policy
|
|
22
23
|
from ai_edge_quantizer import qtyping
|
|
23
24
|
from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
|
|
24
25
|
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
|
|
25
26
|
from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
|
|
27
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
|
|
28
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import mse
|
|
26
29
|
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
|
|
30
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import octav
|
|
31
|
+
|
|
27
32
|
|
|
28
33
|
# TODO: b/399775701 - Clean up this file.
|
|
29
34
|
|
|
@@ -55,6 +60,11 @@ class AlgorithmName(str, enum.Enum):
|
|
|
55
60
|
MIN_MAX_UNIFORM_QUANT = naive_min_max_quantize.ALGORITHM_KEY
|
|
56
61
|
FLOAT_CASTING = float_casting.ALGORITHM_KEY
|
|
57
62
|
DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
|
|
63
|
+
OCTAV = octav.ALGORITHM_KEY
|
|
64
|
+
HADAMARD_ROTATION = hadamard_rotation.CUSTOM_OP_ALGORITHM_KEY
|
|
65
|
+
DECOMPOSED_HADAMARD_ROTATION = hadamard_rotation.DECOMPOSED_ALGORITHM_KEY
|
|
66
|
+
MSE = mse.ALGORITHM_KEY
|
|
67
|
+
|
|
58
68
|
|
|
59
69
|
### MIN/MAX_UNIFORM_QUANT ###
|
|
60
70
|
|
|
@@ -96,7 +106,37 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
|
|
|
96
106
|
_TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
|
|
97
107
|
_TFLOpName.SLICE: common_quantize.materialize_slice,
|
|
98
108
|
_TFLOpName.SUM: common_quantize.materialize_sum,
|
|
109
|
+
_TFLOpName.SELECT: common_quantize.materialize_select,
|
|
99
110
|
_TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
|
|
111
|
+
_TFLOpName.DYNAMIC_UPDATE_SLICE: (
|
|
112
|
+
common_quantize.materialize_dynamic_update_slice
|
|
113
|
+
),
|
|
114
|
+
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
|
|
115
|
+
_TFLOpName.PAD: common_quantize.materialize_pad,
|
|
116
|
+
_TFLOpName.SQUARED_DIFFERENCE: (
|
|
117
|
+
common_quantize.materialize_squared_difference
|
|
118
|
+
),
|
|
119
|
+
_TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
|
|
120
|
+
_TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
|
|
121
|
+
_TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
|
|
122
|
+
common_quantize.materialize_resize_nearest_neighbor
|
|
123
|
+
),
|
|
124
|
+
_TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
|
|
125
|
+
_TFLOpName.PACK: common_quantize.materialize_pack,
|
|
126
|
+
_TFLOpName.UNPACK: common_quantize.materialize_unpack,
|
|
127
|
+
_TFLOpName.DIV: common_quantize.materialize_div,
|
|
128
|
+
_TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
|
|
129
|
+
_TFLOpName.SQRT: common_quantize.materialize_sqrt,
|
|
130
|
+
_TFLOpName.GATHER: common_quantize.materialize_gather,
|
|
131
|
+
_TFLOpName.HARD_SWISH: common_quantize.materialize_hard_swish,
|
|
132
|
+
_TFLOpName.MAXIMUM: common_quantize.materialize_maximum,
|
|
133
|
+
_TFLOpName.PADV2: common_quantize.materialize_padv2,
|
|
134
|
+
_TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
|
|
135
|
+
_TFLOpName.EQUAL: common_quantize.materialize_equal,
|
|
136
|
+
_TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
|
|
137
|
+
_TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
|
|
138
|
+
_TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
|
|
139
|
+
_TFLOpName.RELU: common_quantize.materialize_relu,
|
|
100
140
|
}
|
|
101
141
|
for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
102
142
|
register_quantized_op(
|
|
@@ -163,6 +203,7 @@ register_config_check_policy_func(
|
|
|
163
203
|
|
|
164
204
|
DEQUANTIZED_WEIGHT_RECOVERY_OP_NAME_MATERIALIZE_FUNC_DICT = {
|
|
165
205
|
_TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
|
|
206
|
+
_TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
|
|
166
207
|
_TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
|
|
167
208
|
}
|
|
168
209
|
|
|
@@ -184,3 +225,186 @@ for (
|
|
|
184
225
|
dequantized_weight_recovery.get_tensor_quant_params,
|
|
185
226
|
),
|
|
186
227
|
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# Register OCTAV algorithm.
|
|
231
|
+
register_op_quant_config_validation_func(
|
|
232
|
+
AlgorithmName.OCTAV,
|
|
233
|
+
common_quantize.check_op_quantization_config,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Register a config check policy for OCTAV algorithm.
|
|
237
|
+
register_config_check_policy_func(
|
|
238
|
+
AlgorithmName.OCTAV,
|
|
239
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
_OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
243
|
+
_TFLOpName.INPUT: common_quantize.materialize_input,
|
|
244
|
+
_TFLOpName.OUTPUT: common_quantize.materialize_output,
|
|
245
|
+
_TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
|
|
246
|
+
_TFLOpName.BATCH_MATMUL: common_quantize.materialize_batch_matmul,
|
|
247
|
+
_TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
|
|
248
|
+
_TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv,
|
|
249
|
+
_TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose,
|
|
250
|
+
_TFLOpName.RESHAPE: common_quantize.materialize_reshape,
|
|
251
|
+
_TFLOpName.AVERAGE_POOL_2D: common_quantize.materialize_average_pool_2d,
|
|
252
|
+
_TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
|
|
253
|
+
_TFLOpName.SOFTMAX: common_quantize.materialize_softmax_and_logistic,
|
|
254
|
+
_TFLOpName.TANH: common_quantize.materialize_tanh,
|
|
255
|
+
_TFLOpName.TRANSPOSE: common_quantize.materialize_transpose,
|
|
256
|
+
_TFLOpName.GELU: common_quantize.materialize_gelu,
|
|
257
|
+
_TFLOpName.ADD: common_quantize.materialize_add,
|
|
258
|
+
_TFLOpName.SUB: common_quantize.materialize_sub,
|
|
259
|
+
_TFLOpName.MUL: common_quantize.materialize_mul,
|
|
260
|
+
_TFLOpName.MEAN: common_quantize.materialize_mean,
|
|
261
|
+
_TFLOpName.RSQRT: common_quantize.materialize_rsqrt,
|
|
262
|
+
_TFLOpName.CONCATENATION: common_quantize.materialize_concatenation,
|
|
263
|
+
_TFLOpName.STRIDED_SLICE: common_quantize.materialize_strided_slice,
|
|
264
|
+
_TFLOpName.SPLIT: common_quantize.materialize_split,
|
|
265
|
+
_TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
|
|
266
|
+
_TFLOpName.SLICE: common_quantize.materialize_slice,
|
|
267
|
+
_TFLOpName.SUM: common_quantize.materialize_sum,
|
|
268
|
+
_TFLOpName.SELECT: common_quantize.materialize_select,
|
|
269
|
+
_TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
|
|
270
|
+
_TFLOpName.DYNAMIC_UPDATE_SLICE: (
|
|
271
|
+
common_quantize.materialize_dynamic_update_slice
|
|
272
|
+
),
|
|
273
|
+
_TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
|
|
274
|
+
_TFLOpName.PAD: common_quantize.materialize_pad,
|
|
275
|
+
_TFLOpName.SQUARED_DIFFERENCE: (
|
|
276
|
+
common_quantize.materialize_squared_difference
|
|
277
|
+
),
|
|
278
|
+
_TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
|
|
279
|
+
_TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
|
|
280
|
+
_TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
|
|
281
|
+
common_quantize.materialize_resize_nearest_neighbor
|
|
282
|
+
),
|
|
283
|
+
_TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
|
|
284
|
+
_TFLOpName.PACK: common_quantize.materialize_pack,
|
|
285
|
+
_TFLOpName.UNPACK: common_quantize.materialize_unpack,
|
|
286
|
+
_TFLOpName.DIV: common_quantize.materialize_div,
|
|
287
|
+
_TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
|
|
288
|
+
_TFLOpName.SQRT: common_quantize.materialize_sqrt,
|
|
289
|
+
_TFLOpName.GATHER: common_quantize.materialize_gather,
|
|
290
|
+
_TFLOpName.HARD_SWISH: common_quantize.materialize_hard_swish,
|
|
291
|
+
_TFLOpName.MAXIMUM: common_quantize.materialize_maximum,
|
|
292
|
+
_TFLOpName.PADV2: common_quantize.materialize_padv2,
|
|
293
|
+
_TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
|
|
294
|
+
_TFLOpName.EQUAL: common_quantize.materialize_equal,
|
|
295
|
+
_TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
|
|
296
|
+
_TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
|
|
297
|
+
_TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
|
|
298
|
+
_TFLOpName.RELU: common_quantize.materialize_relu,
|
|
299
|
+
})
|
|
300
|
+
|
|
301
|
+
for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
302
|
+
register_quantized_op(
|
|
303
|
+
AlgorithmName.OCTAV,
|
|
304
|
+
op_name,
|
|
305
|
+
naive_min_max_quantize.init_qsvs,
|
|
306
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
|
307
|
+
materialize_func=functools.partial(
|
|
308
|
+
materialize_func,
|
|
309
|
+
octav.get_tensor_quant_params,
|
|
310
|
+
),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Register the Hadamard Rotation algorithm.
|
|
314
|
+
register_op_quant_config_validation_func(
|
|
315
|
+
AlgorithmName.HADAMARD_ROTATION,
|
|
316
|
+
common_quantize.check_op_quantization_config,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Register a config check policy for the Hadamard Rotation algorithm.
|
|
320
|
+
register_config_check_policy_func(
|
|
321
|
+
AlgorithmName.HADAMARD_ROTATION,
|
|
322
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Register specialized hadamard rotation materialize functions.
|
|
326
|
+
_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
327
|
+
_TFLOpName.FULLY_CONNECTED: (
|
|
328
|
+
hadamard_rotation.materialize_fully_connected_custom_op
|
|
329
|
+
),
|
|
330
|
+
_TFLOpName.EMBEDDING_LOOKUP: (
|
|
331
|
+
hadamard_rotation.materialize_embedding_lookup_custom_op
|
|
332
|
+
),
|
|
333
|
+
})
|
|
334
|
+
for (
|
|
335
|
+
op_name,
|
|
336
|
+
materialize_func,
|
|
337
|
+
) in _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
338
|
+
register_quantized_op(
|
|
339
|
+
AlgorithmName.HADAMARD_ROTATION,
|
|
340
|
+
op_name,
|
|
341
|
+
naive_min_max_quantize.init_qsvs,
|
|
342
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
|
343
|
+
materialize_func=materialize_func,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
register_op_quant_config_validation_func(
|
|
347
|
+
AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
|
|
348
|
+
common_quantize.check_op_quantization_config,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
register_config_check_policy_func(
|
|
352
|
+
AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
|
|
353
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
_DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
357
|
+
_TFLOpName.FULLY_CONNECTED: (
|
|
358
|
+
hadamard_rotation.materialize_fully_connected_decomposed
|
|
359
|
+
),
|
|
360
|
+
_TFLOpName.EMBEDDING_LOOKUP: (
|
|
361
|
+
hadamard_rotation.materialize_embedding_lookup_decomposed
|
|
362
|
+
),
|
|
363
|
+
})
|
|
364
|
+
for (
|
|
365
|
+
op_name,
|
|
366
|
+
materialize_func,
|
|
367
|
+
) in _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
368
|
+
register_quantized_op(
|
|
369
|
+
AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
|
|
370
|
+
op_name,
|
|
371
|
+
naive_min_max_quantize.init_qsvs,
|
|
372
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
|
373
|
+
materialize_func=materialize_func,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
# Register the MSE algorithm.
|
|
378
|
+
register_op_quant_config_validation_func(
|
|
379
|
+
AlgorithmName.MSE,
|
|
380
|
+
common_quantize.check_op_quantization_config,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
# Register a config check policy for the MSE algorithm.
|
|
384
|
+
register_config_check_policy_func(
|
|
385
|
+
AlgorithmName.MSE,
|
|
386
|
+
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Register specialized MSE materialize functions.
|
|
390
|
+
_MSE_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
|
|
391
|
+
_TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
|
|
392
|
+
_TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
|
|
393
|
+
_TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
|
|
394
|
+
_TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv,
|
|
395
|
+
_TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose,
|
|
396
|
+
})
|
|
397
|
+
for (
|
|
398
|
+
op_name,
|
|
399
|
+
materialize_func,
|
|
400
|
+
) in _MSE_OP_NAME_MATERIALIZE_FUNC_DICT.items():
|
|
401
|
+
register_quantized_op(
|
|
402
|
+
AlgorithmName.MSE,
|
|
403
|
+
op_name,
|
|
404
|
+
naive_min_max_quantize.init_qsvs,
|
|
405
|
+
calibration_func=naive_min_max_quantize.min_max_calibrate,
|
|
406
|
+
materialize_func=functools.partial(
|
|
407
|
+
materialize_func,
|
|
408
|
+
mse.get_tensor_quant_params,
|
|
409
|
+
),
|
|
410
|
+
)
|
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
from absl.testing import parameterized
|
|
19
19
|
from tensorflow.python.platform import googletest
|
|
20
20
|
from ai_edge_quantizer import algorithm_manager_api
|
|
21
|
+
from ai_edge_quantizer import default_policy
|
|
21
22
|
from ai_edge_quantizer import qtyping
|
|
22
23
|
|
|
23
24
|
_TFLOpName = qtyping.TFLOperationName
|
|
@@ -205,6 +206,12 @@ class AlgorithmManagerApiTest(parameterized.TestCase):
|
|
|
205
206
|
self._alg_manager._config_check_policy_registry[test_algorithm_name]
|
|
206
207
|
)
|
|
207
208
|
|
|
209
|
+
def test_default_policy_not_empty(self):
|
|
210
|
+
"""Tests that the default policy is not empty & no empty policy is generated."""
|
|
211
|
+
self.assertNotEmpty(default_policy.DEFAULT_CONFIG_CHECK_POLICY)
|
|
212
|
+
for policy in default_policy.DEFAULT_CONFIG_CHECK_POLICY.values():
|
|
213
|
+
self.assertNotEmpty(policy)
|
|
214
|
+
|
|
208
215
|
|
|
209
216
|
if __name__ == "__main__":
|
|
210
217
|
googletest.main()
|
|
@@ -531,9 +531,9 @@ class Fp16QuantizeTest(parameterized.TestCase):
|
|
|
531
531
|
|
|
532
532
|
op_tensor_names = {}
|
|
533
533
|
op_tensor_names["weight"] = (
|
|
534
|
-
"
|
|
534
|
+
"jit(export_func)/jit(main)/...y,yz->...z/dot_general;jit(export_func)/jit(main)/jit(_one_hot)/eq;jit(export_func)/jit(main)/jit(_one_hot)/convert_element_type"
|
|
535
535
|
)
|
|
536
|
-
op_tensor_names["input"] = "
|
|
536
|
+
op_tensor_names["input"] = "lookup"
|
|
537
537
|
op_tensor_names["output"] = "Identity_1"
|
|
538
538
|
|
|
539
539
|
# TODO: b/335913710 - Rename the test function.
|