ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -17,13 +17,18 @@
17
17
 
18
18
  import enum
19
19
  import functools
20
+ from immutabledict import immutabledict
20
21
  from ai_edge_quantizer import algorithm_manager_api
21
22
  from ai_edge_quantizer import default_policy
22
23
  from ai_edge_quantizer import qtyping
23
24
  from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
24
25
  from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
25
26
  from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
27
+ from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
28
+ from ai_edge_quantizer.algorithms.uniform_quantize import mse
26
29
  from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
30
+ from ai_edge_quantizer.algorithms.uniform_quantize import octav
31
+
27
32
 
28
33
  # TODO: b/399775701 - Clean up this file.
29
34
 
@@ -55,6 +60,11 @@ class AlgorithmName(str, enum.Enum):
55
60
  MIN_MAX_UNIFORM_QUANT = naive_min_max_quantize.ALGORITHM_KEY
56
61
  FLOAT_CASTING = float_casting.ALGORITHM_KEY
57
62
  DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
63
+ OCTAV = octav.ALGORITHM_KEY
64
+ HADAMARD_ROTATION = hadamard_rotation.CUSTOM_OP_ALGORITHM_KEY
65
+ DECOMPOSED_HADAMARD_ROTATION = hadamard_rotation.DECOMPOSED_ALGORITHM_KEY
66
+ MSE = mse.ALGORITHM_KEY
67
+
58
68
 
59
69
  ### MIN/MAX_UNIFORM_QUANT ###
60
70
 
@@ -96,7 +106,37 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
96
106
  _TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
97
107
  _TFLOpName.SLICE: common_quantize.materialize_slice,
98
108
  _TFLOpName.SUM: common_quantize.materialize_sum,
109
+ _TFLOpName.SELECT: common_quantize.materialize_select,
99
110
  _TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
111
+ _TFLOpName.DYNAMIC_UPDATE_SLICE: (
112
+ common_quantize.materialize_dynamic_update_slice
113
+ ),
114
+ _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
115
+ _TFLOpName.PAD: common_quantize.materialize_pad,
116
+ _TFLOpName.SQUARED_DIFFERENCE: (
117
+ common_quantize.materialize_squared_difference
118
+ ),
119
+ _TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
120
+ _TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
121
+ _TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
122
+ common_quantize.materialize_resize_nearest_neighbor
123
+ ),
124
+ _TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
125
+ _TFLOpName.PACK: common_quantize.materialize_pack,
126
+ _TFLOpName.UNPACK: common_quantize.materialize_unpack,
127
+ _TFLOpName.DIV: common_quantize.materialize_div,
128
+ _TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
129
+ _TFLOpName.SQRT: common_quantize.materialize_sqrt,
130
+ _TFLOpName.GATHER: common_quantize.materialize_gather,
131
+ _TFLOpName.HARD_SWISH: common_quantize.materialize_hard_swish,
132
+ _TFLOpName.MAXIMUM: common_quantize.materialize_maximum,
133
+ _TFLOpName.PADV2: common_quantize.materialize_padv2,
134
+ _TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
135
+ _TFLOpName.EQUAL: common_quantize.materialize_equal,
136
+ _TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
137
+ _TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
138
+ _TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
139
+ _TFLOpName.RELU: common_quantize.materialize_relu,
100
140
  }
101
141
  for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
102
142
  register_quantized_op(
@@ -163,6 +203,7 @@ register_config_check_policy_func(
163
203
 
164
204
  DEQUANTIZED_WEIGHT_RECOVERY_OP_NAME_MATERIALIZE_FUNC_DICT = {
165
205
  _TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
206
+ _TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
166
207
  _TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
167
208
  }
168
209
 
@@ -184,3 +225,186 @@ for (
184
225
  dequantized_weight_recovery.get_tensor_quant_params,
185
226
  ),
186
227
  )
228
+
229
+
230
+ # Register OCTAV algorithm.
231
+ register_op_quant_config_validation_func(
232
+ AlgorithmName.OCTAV,
233
+ common_quantize.check_op_quantization_config,
234
+ )
235
+
236
+ # Register a config check policy for OCTAV algorithm.
237
+ register_config_check_policy_func(
238
+ AlgorithmName.OCTAV,
239
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
240
+ )
241
+
242
+ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
243
+ _TFLOpName.INPUT: common_quantize.materialize_input,
244
+ _TFLOpName.OUTPUT: common_quantize.materialize_output,
245
+ _TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
246
+ _TFLOpName.BATCH_MATMUL: common_quantize.materialize_batch_matmul,
247
+ _TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
248
+ _TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv,
249
+ _TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose,
250
+ _TFLOpName.RESHAPE: common_quantize.materialize_reshape,
251
+ _TFLOpName.AVERAGE_POOL_2D: common_quantize.materialize_average_pool_2d,
252
+ _TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
253
+ _TFLOpName.SOFTMAX: common_quantize.materialize_softmax_and_logistic,
254
+ _TFLOpName.TANH: common_quantize.materialize_tanh,
255
+ _TFLOpName.TRANSPOSE: common_quantize.materialize_transpose,
256
+ _TFLOpName.GELU: common_quantize.materialize_gelu,
257
+ _TFLOpName.ADD: common_quantize.materialize_add,
258
+ _TFLOpName.SUB: common_quantize.materialize_sub,
259
+ _TFLOpName.MUL: common_quantize.materialize_mul,
260
+ _TFLOpName.MEAN: common_quantize.materialize_mean,
261
+ _TFLOpName.RSQRT: common_quantize.materialize_rsqrt,
262
+ _TFLOpName.CONCATENATION: common_quantize.materialize_concatenation,
263
+ _TFLOpName.STRIDED_SLICE: common_quantize.materialize_strided_slice,
264
+ _TFLOpName.SPLIT: common_quantize.materialize_split,
265
+ _TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
266
+ _TFLOpName.SLICE: common_quantize.materialize_slice,
267
+ _TFLOpName.SUM: common_quantize.materialize_sum,
268
+ _TFLOpName.SELECT: common_quantize.materialize_select,
269
+ _TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
270
+ _TFLOpName.DYNAMIC_UPDATE_SLICE: (
271
+ common_quantize.materialize_dynamic_update_slice
272
+ ),
273
+ _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
274
+ _TFLOpName.PAD: common_quantize.materialize_pad,
275
+ _TFLOpName.SQUARED_DIFFERENCE: (
276
+ common_quantize.materialize_squared_difference
277
+ ),
278
+ _TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
279
+ _TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
280
+ _TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
281
+ common_quantize.materialize_resize_nearest_neighbor
282
+ ),
283
+ _TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
284
+ _TFLOpName.PACK: common_quantize.materialize_pack,
285
+ _TFLOpName.UNPACK: common_quantize.materialize_unpack,
286
+ _TFLOpName.DIV: common_quantize.materialize_div,
287
+ _TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
288
+ _TFLOpName.SQRT: common_quantize.materialize_sqrt,
289
+ _TFLOpName.GATHER: common_quantize.materialize_gather,
290
+ _TFLOpName.HARD_SWISH: common_quantize.materialize_hard_swish,
291
+ _TFLOpName.MAXIMUM: common_quantize.materialize_maximum,
292
+ _TFLOpName.PADV2: common_quantize.materialize_padv2,
293
+ _TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
294
+ _TFLOpName.EQUAL: common_quantize.materialize_equal,
295
+ _TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
296
+ _TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
297
+ _TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
298
+ _TFLOpName.RELU: common_quantize.materialize_relu,
299
+ })
300
+
301
+ for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
302
+ register_quantized_op(
303
+ AlgorithmName.OCTAV,
304
+ op_name,
305
+ naive_min_max_quantize.init_qsvs,
306
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
307
+ materialize_func=functools.partial(
308
+ materialize_func,
309
+ octav.get_tensor_quant_params,
310
+ ),
311
+ )
312
+
313
+ # Register the Hadamard Rotation algorithm.
314
+ register_op_quant_config_validation_func(
315
+ AlgorithmName.HADAMARD_ROTATION,
316
+ common_quantize.check_op_quantization_config,
317
+ )
318
+
319
+ # Register a config check policy for the Hadamard Rotation algorithm.
320
+ register_config_check_policy_func(
321
+ AlgorithmName.HADAMARD_ROTATION,
322
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
323
+ )
324
+
325
+ # Register specialized hadamard rotation materialize functions.
326
+ _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
327
+ _TFLOpName.FULLY_CONNECTED: (
328
+ hadamard_rotation.materialize_fully_connected_custom_op
329
+ ),
330
+ _TFLOpName.EMBEDDING_LOOKUP: (
331
+ hadamard_rotation.materialize_embedding_lookup_custom_op
332
+ ),
333
+ })
334
+ for (
335
+ op_name,
336
+ materialize_func,
337
+ ) in _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
338
+ register_quantized_op(
339
+ AlgorithmName.HADAMARD_ROTATION,
340
+ op_name,
341
+ naive_min_max_quantize.init_qsvs,
342
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
343
+ materialize_func=materialize_func,
344
+ )
345
+
346
+ register_op_quant_config_validation_func(
347
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
348
+ common_quantize.check_op_quantization_config,
349
+ )
350
+
351
+ register_config_check_policy_func(
352
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
353
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
354
+ )
355
+
356
+ _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
357
+ _TFLOpName.FULLY_CONNECTED: (
358
+ hadamard_rotation.materialize_fully_connected_decomposed
359
+ ),
360
+ _TFLOpName.EMBEDDING_LOOKUP: (
361
+ hadamard_rotation.materialize_embedding_lookup_decomposed
362
+ ),
363
+ })
364
+ for (
365
+ op_name,
366
+ materialize_func,
367
+ ) in _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
368
+ register_quantized_op(
369
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
370
+ op_name,
371
+ naive_min_max_quantize.init_qsvs,
372
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
373
+ materialize_func=materialize_func,
374
+ )
375
+
376
+
377
+ # Register the MSE algorithm.
378
+ register_op_quant_config_validation_func(
379
+ AlgorithmName.MSE,
380
+ common_quantize.check_op_quantization_config,
381
+ )
382
+
383
+ # Register a config check policy for the MSE algorithm.
384
+ register_config_check_policy_func(
385
+ AlgorithmName.MSE,
386
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
387
+ )
388
+
389
+ # Register specialized MSE materialize functions.
390
+ _MSE_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
391
+ _TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
392
+ _TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
393
+ _TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
394
+ _TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv,
395
+ _TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose,
396
+ })
397
+ for (
398
+ op_name,
399
+ materialize_func,
400
+ ) in _MSE_OP_NAME_MATERIALIZE_FUNC_DICT.items():
401
+ register_quantized_op(
402
+ AlgorithmName.MSE,
403
+ op_name,
404
+ naive_min_max_quantize.init_qsvs,
405
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
406
+ materialize_func=functools.partial(
407
+ materialize_func,
408
+ mse.get_tensor_quant_params,
409
+ ),
410
+ )
@@ -18,6 +18,7 @@
18
18
  from absl.testing import parameterized
19
19
  from tensorflow.python.platform import googletest
20
20
  from ai_edge_quantizer import algorithm_manager_api
21
+ from ai_edge_quantizer import default_policy
21
22
  from ai_edge_quantizer import qtyping
22
23
 
23
24
  _TFLOpName = qtyping.TFLOperationName
@@ -205,6 +206,12 @@ class AlgorithmManagerApiTest(parameterized.TestCase):
205
206
  self._alg_manager._config_check_policy_registry[test_algorithm_name]
206
207
  )
207
208
 
209
+ def test_default_policy_not_empty(self):
210
+ """Tests that the default policy is not empty & no empty policy is generated."""
211
+ self.assertNotEmpty(default_policy.DEFAULT_CONFIG_CHECK_POLICY)
212
+ for policy in default_policy.DEFAULT_CONFIG_CHECK_POLICY.values():
213
+ self.assertNotEmpty(policy)
214
+
208
215
 
209
216
  if __name__ == "__main__":
210
217
  googletest.main()
@@ -531,9 +531,9 @@ class Fp16QuantizeTest(parameterized.TestCase):
531
531
 
532
532
  op_tensor_names = {}
533
533
  op_tensor_names["weight"] = (
534
- "jax2tf_export_func_/...y_yz-_...z/pjit__einsum_/MatMul;jax2tf_export_func_/pjit__one_hot_/Equal;jax2tf_export_func_/pjit__one_hot_/Cast_1"
534
+ "jit(export_func)/jit(main)/...y,yz->...z/dot_general;jit(export_func)/jit(main)/jit(_one_hot)/eq;jit(export_func)/jit(main)/jit(_one_hot)/convert_element_type"
535
535
  )
536
- op_tensor_names["input"] = "inputs"
536
+ op_tensor_names["input"] = "lookup"
537
537
  op_tensor_names["output"] = "Identity_1"
538
538
 
539
539
  # TODO: b/335913710 - Rename the test function.