ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -24,9 +24,12 @@ from ai_edge_quantizer import qtyping
24
24
  from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
25
25
  from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
26
26
  from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
27
+ from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
28
+ from ai_edge_quantizer.algorithms.uniform_quantize import mse
27
29
  from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
28
30
  from ai_edge_quantizer.algorithms.uniform_quantize import octav
29
31
 
32
+
30
33
  # TODO: b/399775701 - Clean up this file.
31
34
 
32
35
  _TFLOpName = qtyping.TFLOperationName
@@ -58,6 +61,10 @@ class AlgorithmName(str, enum.Enum):
58
61
  FLOAT_CASTING = float_casting.ALGORITHM_KEY
59
62
  DEQUANTIZED_WEIGHT_RECOVERY = dequantized_weight_recovery.ALGORITHM_KEY
60
63
  OCTAV = octav.ALGORITHM_KEY
64
+ HADAMARD_ROTATION = hadamard_rotation.CUSTOM_OP_ALGORITHM_KEY
65
+ DECOMPOSED_HADAMARD_ROTATION = hadamard_rotation.DECOMPOSED_ALGORITHM_KEY
66
+ MSE = mse.ALGORITHM_KEY
67
+
61
68
 
62
69
  ### MIN/MAX_UNIFORM_QUANT ###
63
70
 
@@ -99,11 +106,37 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
99
106
  _TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
100
107
  _TFLOpName.SLICE: common_quantize.materialize_slice,
101
108
  _TFLOpName.SUM: common_quantize.materialize_sum,
109
+ _TFLOpName.SELECT: common_quantize.materialize_select,
102
110
  _TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
103
111
  _TFLOpName.DYNAMIC_UPDATE_SLICE: (
104
112
  common_quantize.materialize_dynamic_update_slice
105
113
  ),
106
114
  _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
115
+ _TFLOpName.PAD: common_quantize.materialize_pad,
116
+ _TFLOpName.SQUARED_DIFFERENCE: (
117
+ common_quantize.materialize_squared_difference
118
+ ),
119
+ _TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
120
+ _TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
121
+ _TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
122
+ common_quantize.materialize_resize_nearest_neighbor
123
+ ),
124
+ _TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
125
+ _TFLOpName.PACK: common_quantize.materialize_pack,
126
+ _TFLOpName.UNPACK: common_quantize.materialize_unpack,
127
+ _TFLOpName.DIV: common_quantize.materialize_div,
128
+ _TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
129
+ _TFLOpName.SQRT: common_quantize.materialize_sqrt,
130
+ _TFLOpName.GATHER: common_quantize.materialize_gather,
131
+ _TFLOpName.HARD_SWISH: common_quantize.materialize_hard_swish,
132
+ _TFLOpName.MAXIMUM: common_quantize.materialize_maximum,
133
+ _TFLOpName.PADV2: common_quantize.materialize_padv2,
134
+ _TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
135
+ _TFLOpName.EQUAL: common_quantize.materialize_equal,
136
+ _TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
137
+ _TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
138
+ _TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
139
+ _TFLOpName.RELU: common_quantize.materialize_relu,
107
140
  }
108
141
  for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
109
142
  register_quantized_op(
@@ -232,11 +265,37 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
232
265
  _TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
233
266
  _TFLOpName.SLICE: common_quantize.materialize_slice,
234
267
  _TFLOpName.SUM: common_quantize.materialize_sum,
268
+ _TFLOpName.SELECT: common_quantize.materialize_select,
235
269
  _TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
236
270
  _TFLOpName.DYNAMIC_UPDATE_SLICE: (
237
271
  common_quantize.materialize_dynamic_update_slice
238
272
  ),
239
273
  _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
274
+ _TFLOpName.PAD: common_quantize.materialize_pad,
275
+ _TFLOpName.SQUARED_DIFFERENCE: (
276
+ common_quantize.materialize_squared_difference
277
+ ),
278
+ _TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
279
+ _TFLOpName.RESIZE_BILINEAR: common_quantize.materialize_resize_bilinear,
280
+ _TFLOpName.RESIZE_NEAREST_NEIGHBOR: (
281
+ common_quantize.materialize_resize_nearest_neighbor
282
+ ),
283
+ _TFLOpName.GATHER_ND: common_quantize.materialize_gather_nd,
284
+ _TFLOpName.PACK: common_quantize.materialize_pack,
285
+ _TFLOpName.UNPACK: common_quantize.materialize_unpack,
286
+ _TFLOpName.DIV: common_quantize.materialize_div,
287
+ _TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
288
+ _TFLOpName.SQRT: common_quantize.materialize_sqrt,
289
+ _TFLOpName.GATHER: common_quantize.materialize_gather,
290
+ _TFLOpName.HARD_SWISH: common_quantize.materialize_hard_swish,
291
+ _TFLOpName.MAXIMUM: common_quantize.materialize_maximum,
292
+ _TFLOpName.PADV2: common_quantize.materialize_padv2,
293
+ _TFLOpName.REDUCE_MIN: common_quantize.materialize_reduce_min,
294
+ _TFLOpName.EQUAL: common_quantize.materialize_equal,
295
+ _TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
296
+ _TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
297
+ _TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
298
+ _TFLOpName.RELU: common_quantize.materialize_relu,
240
299
  })
241
300
 
242
301
  for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
@@ -250,3 +309,102 @@ for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
250
309
  octav.get_tensor_quant_params,
251
310
  ),
252
311
  )
312
+
313
+ # Register the Hadamard Rotation algorithm.
314
+ register_op_quant_config_validation_func(
315
+ AlgorithmName.HADAMARD_ROTATION,
316
+ common_quantize.check_op_quantization_config,
317
+ )
318
+
319
+ # Register a config check policy for the Hadamard Rotation algorithm.
320
+ register_config_check_policy_func(
321
+ AlgorithmName.HADAMARD_ROTATION,
322
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
323
+ )
324
+
325
+ # Register specialized hadamard rotation materialize functions.
326
+ _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
327
+ _TFLOpName.FULLY_CONNECTED: (
328
+ hadamard_rotation.materialize_fully_connected_custom_op
329
+ ),
330
+ _TFLOpName.EMBEDDING_LOOKUP: (
331
+ hadamard_rotation.materialize_embedding_lookup_custom_op
332
+ ),
333
+ })
334
+ for (
335
+ op_name,
336
+ materialize_func,
337
+ ) in _HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
338
+ register_quantized_op(
339
+ AlgorithmName.HADAMARD_ROTATION,
340
+ op_name,
341
+ naive_min_max_quantize.init_qsvs,
342
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
343
+ materialize_func=materialize_func,
344
+ )
345
+
346
+ register_op_quant_config_validation_func(
347
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
348
+ common_quantize.check_op_quantization_config,
349
+ )
350
+
351
+ register_config_check_policy_func(
352
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
353
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
354
+ )
355
+
356
+ _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
357
+ _TFLOpName.FULLY_CONNECTED: (
358
+ hadamard_rotation.materialize_fully_connected_decomposed
359
+ ),
360
+ _TFLOpName.EMBEDDING_LOOKUP: (
361
+ hadamard_rotation.materialize_embedding_lookup_decomposed
362
+ ),
363
+ })
364
+ for (
365
+ op_name,
366
+ materialize_func,
367
+ ) in _DECOMPOSED_HADAMARD_ROTATION_OP_NAME_MATERIALIZE_FUNC_DICT.items():
368
+ register_quantized_op(
369
+ AlgorithmName.DECOMPOSED_HADAMARD_ROTATION,
370
+ op_name,
371
+ naive_min_max_quantize.init_qsvs,
372
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
373
+ materialize_func=materialize_func,
374
+ )
375
+
376
+
377
+ # Register the MSE algorithm.
378
+ register_op_quant_config_validation_func(
379
+ AlgorithmName.MSE,
380
+ common_quantize.check_op_quantization_config,
381
+ )
382
+
383
+ # Register a config check policy for the MSE algorithm.
384
+ register_config_check_policy_func(
385
+ AlgorithmName.MSE,
386
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
387
+ )
388
+
389
+ # Register specialized MSE materialize functions.
390
+ _MSE_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
391
+ _TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
392
+ _TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
393
+ _TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
394
+ _TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv,
395
+ _TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose,
396
+ })
397
+ for (
398
+ op_name,
399
+ materialize_func,
400
+ ) in _MSE_OP_NAME_MATERIALIZE_FUNC_DICT.items():
401
+ register_quantized_op(
402
+ AlgorithmName.MSE,
403
+ op_name,
404
+ naive_min_max_quantize.init_qsvs,
405
+ calibration_func=naive_min_max_quantize.min_max_calibrate,
406
+ materialize_func=functools.partial(
407
+ materialize_func,
408
+ mse.get_tensor_quant_params,
409
+ ),
410
+ )
@@ -531,9 +531,9 @@ class Fp16QuantizeTest(parameterized.TestCase):
531
531
 
532
532
  op_tensor_names = {}
533
533
  op_tensor_names["weight"] = (
534
- "jax2tf_export_func_/...y_yz-_...z/pjit__einsum_/MatMul;jax2tf_export_func_/pjit__one_hot_/Equal;jax2tf_export_func_/pjit__one_hot_/Cast_1"
534
+ "jit(export_func)/jit(main)/...y,yz->...z/dot_general;jit(export_func)/jit(main)/jit(_one_hot)/eq;jit(export_func)/jit(main)/jit(_one_hot)/convert_element_type"
535
535
  )
536
- op_tensor_names["input"] = "inputs"
536
+ op_tensor_names["input"] = "lookup"
537
537
  op_tensor_names["output"] = "Identity_1"
538
538
 
539
539
  # TODO: b/335913710 - Rename the test function.