ai-edge-quantizer-nightly 0.0.1.dev20250211__py3-none-any.whl → 0.0.1.dev20250212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,10 +16,12 @@
16
16
  """Quantizer Algorithm Manager Interface."""
17
17
 
18
18
  import enum
19
+ import functools
19
20
  from ai_edge_quantizer import algorithm_manager_api
20
21
  from ai_edge_quantizer import default_policy
21
22
  from ai_edge_quantizer import qtyping
22
23
  from ai_edge_quantizer.algorithms.nonlinear_quantize import float_casting
24
+ from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
23
25
  from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
24
26
 
25
27
  _TFLOpName = qtyping.TFLOperationName
@@ -54,7 +56,7 @@ class AlgorithmName(str, enum.Enum):
54
56
  # Register MIN_MAX_UNIFORM_QUANT algorithm.
55
57
  register_op_quant_config_validation_func(
56
58
  AlgorithmName.MIN_MAX_UNIFORM_QUANT,
57
- naive_min_max_quantize.check_op_quantization_config,
59
+ common_quantize.check_op_quantization_config,
58
60
  )
59
61
 
60
62
  # Register a config check policy for MIN_MAX_UNIFORM_QUANT algorithm.
@@ -63,71 +65,48 @@ register_config_check_policy_func(
63
65
  default_policy.DEFAULT_CONFIG_CHECK_POLICY,
64
66
  )
65
67
 
66
-
67
- for op_name, materialize_func in zip(
68
- (
69
- _TFLOpName.INPUT,
70
- _TFLOpName.OUTPUT,
71
- _TFLOpName.FULLY_CONNECTED,
72
- _TFLOpName.BATCH_MATMUL,
73
- _TFLOpName.CONV_2D,
74
- _TFLOpName.DEPTHWISE_CONV_2D,
75
- _TFLOpName.CONV_2D_TRANSPOSE,
76
- _TFLOpName.RESHAPE,
77
- _TFLOpName.AVERAGE_POOL_2D,
78
- _TFLOpName.EMBEDDING_LOOKUP,
79
- _TFLOpName.SOFTMAX,
80
- _TFLOpName.TANH,
81
- _TFLOpName.TRANSPOSE,
82
- _TFLOpName.GELU,
83
- _TFLOpName.ADD,
84
- _TFLOpName.SUB,
85
- _TFLOpName.MUL,
86
- _TFLOpName.MEAN,
87
- _TFLOpName.RSQRT,
88
- _TFLOpName.CONCATENATION,
89
- _TFLOpName.STRIDED_SLICE,
90
- _TFLOpName.SPLIT,
91
- _TFLOpName.LOGISTIC, # Sigmoid
92
- _TFLOpName.SLICE,
93
- _TFLOpName.SUM,
94
- _TFLOpName.SELECT_V2,
95
- ),
96
- (
97
- naive_min_max_quantize.materialize_input,
98
- naive_min_max_quantize.materialize_output,
99
- naive_min_max_quantize.materialize_fc_conv,
100
- naive_min_max_quantize.materialize_batch_matmul,
101
- naive_min_max_quantize.materialize_fc_conv,
102
- naive_min_max_quantize.materialize_fc_conv,
103
- naive_min_max_quantize.materialize_conv2d_transpose,
104
- naive_min_max_quantize.materialize_reshape,
105
- naive_min_max_quantize.materialize_average_pool_2d,
106
- naive_min_max_quantize.materialize_embedding_lookup,
107
- naive_min_max_quantize.materialize_softmax_and_logistic,
108
- naive_min_max_quantize.materialize_tanh,
109
- naive_min_max_quantize.materialize_transpose,
110
- naive_min_max_quantize.materialize_gelu,
111
- naive_min_max_quantize.materialize_add,
112
- naive_min_max_quantize.materialize_sub,
113
- naive_min_max_quantize.materialize_mul,
114
- naive_min_max_quantize.materialize_mean,
115
- naive_min_max_quantize.materialize_rsqrt,
116
- naive_min_max_quantize.materialize_concatenation,
117
- naive_min_max_quantize.materialize_strided_slice,
118
- naive_min_max_quantize.materialize_split,
119
- naive_min_max_quantize.materialize_softmax_and_logistic,
120
- naive_min_max_quantize.materialize_slice,
121
- naive_min_max_quantize.materialize_sum,
122
- naive_min_max_quantize.materialize_select_v2,
123
- ),
124
- ):
68
+ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
69
+ _TFLOpName.INPUT: common_quantize.materialize_input,
70
+ _TFLOpName.OUTPUT: common_quantize.materialize_output,
71
+ _TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
72
+ _TFLOpName.BATCH_MATMUL: common_quantize.materialize_batch_matmul,
73
+ _TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
74
+ _TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv,
75
+ _TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose,
76
+ _TFLOpName.RESHAPE: common_quantize.materialize_reshape,
77
+ _TFLOpName.AVERAGE_POOL_2D: common_quantize.materialize_average_pool_2d,
78
+ _TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
79
+ _TFLOpName.SOFTMAX: common_quantize.materialize_softmax_and_logistic,
80
+ _TFLOpName.TANH: common_quantize.materialize_tanh,
81
+ _TFLOpName.TRANSPOSE: common_quantize.materialize_transpose,
82
+ _TFLOpName.GELU: common_quantize.materialize_gelu,
83
+ _TFLOpName.ADD: common_quantize.materialize_add,
84
+ _TFLOpName.SUB: common_quantize.materialize_sub,
85
+ _TFLOpName.MUL: common_quantize.materialize_mul,
86
+ _TFLOpName.MEAN: common_quantize.materialize_mean,
87
+ _TFLOpName.RSQRT: common_quantize.materialize_rsqrt,
88
+ _TFLOpName.CONCATENATION: common_quantize.materialize_concatenation,
89
+ _TFLOpName.STRIDED_SLICE: common_quantize.materialize_strided_slice,
90
+ _TFLOpName.SPLIT: common_quantize.materialize_split,
91
+ _TFLOpName.LOGISTIC: common_quantize.materialize_softmax_and_logistic,
92
+ _TFLOpName.SLICE: common_quantize.materialize_slice,
93
+ _TFLOpName.SUM: common_quantize.materialize_sum,
94
+ _TFLOpName.SELECT_V2: common_quantize.materialize_select_v2,
95
+ }
96
+ for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
125
97
  register_quantized_op(
126
98
  AlgorithmName.MIN_MAX_UNIFORM_QUANT,
127
99
  op_name,
128
100
  naive_min_max_quantize.init_qsvs,
129
101
  calibration_func=naive_min_max_quantize.min_max_calibrate,
130
- materialize_func=materialize_func,
102
+ # Most of the materialize op functions are common for all algorithms
103
+ # except for the function to get scale and zero point, i.e.,
104
+ # get_tensor_quant_params. So we use functools.partial here to pass in the
105
+ # common utility function and thealgorithm-specific function.
106
+ materialize_func=functools.partial(
107
+ materialize_func,
108
+ naive_min_max_quantize.get_tensor_quant_params,
109
+ ),
131
110
  )
132
111
 
133
112
  # Register FLOAT_CASTING algorithm.