ai-edge-quantizer-nightly 0.0.1.dev20250221__py3-none-any.whl → 0.0.1.dev20250222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,11 +15,13 @@
15
15
 
16
16
  """Recovers quantized weights from dequantized weights (often from QAT)."""
17
17
 
18
- from typing import Optional
18
+ import dataclasses
19
+ from typing import Any, Optional
19
20
  import numpy as np
20
21
  from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
21
23
  from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
22
-
24
+ from ai_edge_quantizer.algorithms.utils import common_utils
23
25
 
24
26
  ALGORITHM_KEY = "dequantized_weight_recovery"
25
27
  _TFLOpName = qtyping.TFLOperationName
@@ -27,24 +29,26 @@ _QuantTransformation = qtyping.QuantTransformation
27
29
  _IntType = uniform_quantize_tensor.IntType
28
30
 
29
31
 
30
- def _validate_recovered_scale(
31
- dequant_vals: np.ndarray, scale: np.ndarray, tol: float = 1e-4
32
+ def _validate_recovered_weights(
33
+ original_vals: np.ndarray,
34
+ quant_vals: np.ndarray,
35
+ scale: np.ndarray,
36
+ tol: float = 1e-4,
32
37
  ):
33
- """Validates if the recovered quantized values match the dequantized values.
38
+ """Validates if recovered weights (from the quantized values) are close enough to the original ones.
34
39
 
35
40
  Args:
36
- dequant_vals: The dequantized weight values.
37
- scale: The scale values.
38
- tol: The tolerance for the difference between the recovered and original
39
- values.
41
+ original_vals: Original values before quantization.
42
+ quant_vals: Quantized values.
43
+ scale: Scale used for quantization.
44
+ tol: Tolerance for the difference between original and recovered values.
40
45
 
41
46
  Raises:
42
- RuntimeError: If the maximum difference between the recovered and
43
- original values exceeds the tolerance.
47
+ RuntimeError: If the maximum difference between original and recovered
48
+ values exceeds the tolerance.
44
49
  """
45
- quant_vals = np.round(dequant_vals / scale) # no need to clamp.
46
50
  recovered_vals = quant_vals * scale
47
- diff = np.abs(recovered_vals - dequant_vals).flatten()
51
+ diff = np.abs(recovered_vals - original_vals).flatten()
48
52
  max_diff = diff.max()
49
53
  if max_diff > tol:
50
54
  raise RuntimeError(
@@ -127,5 +131,120 @@ def get_zp_scale_from_2d_dequantized_symmetric_weights(
127
131
  )
128
132
 
129
133
  zero_points = np.zeros_like(scales, dtype=np.int32)
130
- _validate_recovered_scale(dequant_vals, scales)
131
134
  return zero_points, scales
135
+
136
+
137
+ def get_tensor_quant_params(
138
+ op_info: qtyping.OpInfo,
139
+ tensor_quant_config: qtyping.TensorQuantizationConfig,
140
+ tensor_content: Optional[np.ndarray] = None,
141
+ tensor_qsv: Optional[dict[str, Any]] = None,
142
+ ) -> qtyping.UniformQuantParams:
143
+ """Get the quantization parameters for a tensor.
144
+
145
+ Args:
146
+ op_info: Aggregated information about the op (e.g., quantization config).
147
+ tensor_quant_config: The quantization config for the tensor.
148
+ tensor_content: The content of the tensor.
149
+ tensor_qsv: A dictionary containing the min/max of the tensor.
150
+
151
+ Returns:
152
+ The quantization parameters for the tensor.
153
+
154
+ Raises:
155
+ ValueError: If the quantization granularity is blockwise, or if the tensor
156
+ is not a 2D symmetric weight tensor.
157
+ """
158
+ # Fallback to naive_min_max_quantize.py for non-weight tensors.
159
+ if tensor_content is None:
160
+ return naive_min_max_quantize.get_tensor_quant_params(
161
+ op_info, tensor_quant_config, tensor_content, tensor_qsv
162
+ )
163
+
164
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
165
+ raise ValueError(
166
+ "Blockwise quantization is not supported for dequantized weight"
167
+ " recovery."
168
+ )
169
+ if tensor_content.ndim != 2 or not tensor_quant_config.symmetric:
170
+ raise ValueError(
171
+ "Only 2D symmetric weights are supported for dequantized weight"
172
+ " recovery."
173
+ )
174
+
175
+ quantized_dim = None
176
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
177
+ quantized_dim = common_utils.get_weight_quantized_dim(
178
+ op_info, tensor_content
179
+ )
180
+
181
+ zp, scale = get_zp_scale_from_2d_dequantized_symmetric_weights(
182
+ dequant_vals=tensor_content,
183
+ quantized_dimension=quantized_dim,
184
+ )
185
+ quant_params = qtyping.UniformQuantParams(
186
+ scale=scale,
187
+ zero_point=zp,
188
+ num_bits=tensor_quant_config.num_bits,
189
+ symmetric=tensor_quant_config.symmetric,
190
+ quantized_dimension=quantized_dim,
191
+ )
192
+ quantized_vars = uniform_quantize_tensor.uniform_quantize(
193
+ tensor_content, quant_params
194
+ )
195
+ _validate_recovered_weights(tensor_content, quantized_vars, scale)
196
+ return dataclasses.replace(quant_params, quantized_data=quantized_vars)
197
+
198
+
199
+ def calibrate(
200
+ tfl_op: Any,
201
+ graph_info: qtyping.GraphInfo,
202
+ tensor_content_map: dict[str, np.ndarray],
203
+ inputs_to_ignore: Optional[list[int]] = None,
204
+ outputs_to_ignore: Optional[list[int]] = None,
205
+ ) -> dict[str, qtyping.QSV]:
206
+ """Collect quantization statistics variable (QSV, e.g., min/max) for the op.
207
+
208
+ Args:
209
+ tfl_op: The tfl operation.
210
+ graph_info: Graph information needed to perform quantization for the op.
211
+ tensor_content_map: A map of tensor name to tensor content.
212
+ inputs_to_ignore: Input tensor indices to ignore.
213
+ outputs_to_ignore: Output tensor indices to ignore.
214
+
215
+ Returns:
216
+ A dictionary with key as tensor name and value as the collected QSV.
217
+ """
218
+ # Reuse the min/max calibration algorithm from naive_min_max_quantize.py since
219
+ # only weights need to be handled differently.
220
+ return naive_min_max_quantize.min_max_calibrate(
221
+ tfl_op,
222
+ graph_info,
223
+ tensor_content_map,
224
+ inputs_to_ignore,
225
+ outputs_to_ignore,
226
+ )
227
+
228
+
229
+ def init_qsvs(
230
+ op_info: qtyping.OpInfo,
231
+ graph_info: qtyping.GraphInfo,
232
+ inputs_to_ignore: Optional[list[int]] = None,
233
+ outputs_to_ignore: Optional[list[int]] = None,
234
+ ) -> qtyping.QSV:
235
+ """Initialize the QSVs.
236
+
237
+ Args:
238
+ op_info: Aggregated information about the op (e.g., quantization config).
239
+ graph_info: Graph information needed to perform quantization for the op.
240
+ inputs_to_ignore: Input tensor indices to ignore.
241
+ outputs_to_ignore: Output tensor indices to ignore.
242
+
243
+ Returns:
244
+ QSVs.
245
+ """
246
+ # Reuse the min/max calibration algorithm from naive_min_max_quantize.py since
247
+ # only weights need to be handeled differently.
248
+ return naive_min_max_quantize.init_qsvs(
249
+ op_info, graph_info, inputs_to_ignore, outputs_to_ignore
250
+ )
@@ -19,7 +19,6 @@ import numpy as np
19
19
  from tensorflow.python.platform import googletest
20
20
  from ai_edge_quantizer import qtyping
21
21
  from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
22
- from ai_edge_quantizer.utils import test_utils
23
22
 
24
23
  _TFLOpName = qtyping.TFLOperationName
25
24
  _TensorQuantConfig = qtyping.TensorQuantizationConfig
@@ -31,9 +30,15 @@ class DequantizedWeightRecoveryTest(parameterized.TestCase):
31
30
  super().setUp()
32
31
  self._dummy_quantized_weights = np.array([
33
32
  [1, -2, 3, 4],
34
- [6, 7, -8, 5],
35
- [-1, 8, -7, -4],
33
+ [6, 7, -6, 5],
34
+ [2, -6, -7, -4],
36
35
  ])
36
+ self._dummy_op_info = qtyping.OpInfo(
37
+ op=None,
38
+ op_name=_TFLOpName.FULLY_CONNECTED,
39
+ subgraph_op_index=0,
40
+ op_quant_config=qtyping.OpQuantizationConfig(),
41
+ )
37
42
 
38
43
  @parameterized.named_parameters(
39
44
  dict(
@@ -96,18 +101,104 @@ class DequantizedWeightRecoveryTest(parameterized.TestCase):
96
101
 
97
102
  @parameterized.named_parameters(
98
103
  dict(
99
- testcase_name="recovery_on_wrong_dimension",
100
- quantized_dimension=1, # should be 0.
104
+ testcase_name="tensor-recovery-tensor-quant",
105
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
106
+ num_bits=4,
107
+ granularity=qtyping.QuantGranularity.TENSORWISE,
108
+ ),
109
+ scale=np.array([0.1875]).reshape(1, 1),
110
+ ),
111
+ dict(
112
+ testcase_name="channel-recovery-channel-quant",
113
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
114
+ num_bits=4,
115
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
116
+ ),
101
117
  scale=np.array([0.1875, 1e-4, 12.3]).reshape(3, 1),
102
118
  ),
103
119
  dict(
104
- testcase_name="tensor_recovery_for_channel_quantization",
105
- quantized_dimension=None, # should be 0.
120
+ testcase_name="channel-recovery-excessive-bits",
121
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
122
+ num_bits=8, # int4 is enough for the sample weights.
123
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
124
+ ),
125
+ scale=np.array([0.1875, 1e-4, 12.3]).reshape(3, 1),
126
+ ),
127
+ )
128
+ def test_get_tensor_quant_params_success_with_dequantized_weights(
129
+ self, tensor_quant_config, scale
130
+ ):
131
+ dequant_vals = scale * self._dummy_quantized_weights
132
+ tensor_quant_params = dequantized_weight_recovery.get_tensor_quant_params(
133
+ self._dummy_op_info, tensor_quant_config, dequant_vals
134
+ )
135
+
136
+ if tensor_quant_config.granularity is qtyping.QuantGranularity.TENSORWISE:
137
+ self.assertIsNone(tensor_quant_params.quantized_dimension)
138
+ else:
139
+ self.assertEqual(tensor_quant_params.quantized_dimension, 0)
140
+
141
+ recovered_scale = tensor_quant_params.scale
142
+ self.assertEqual(recovered_scale.shape, scale.shape)
143
+ self.assertSequenceAlmostEqual(recovered_scale.flatten(), scale.flatten())
144
+
145
+ # Zero point should be zero for symmetric quantization.
146
+ recovered_zp = tensor_quant_params.zero_point
147
+ self.assertEqual(np.sum(recovered_zp), 0)
148
+ self.assertEqual(recovered_zp.shape, scale.shape)
149
+
150
+ def test_get_tensor_quant_params_success_with_qsv(self):
151
+ # Fall back to naive_min_max_quantize.py for non-weight tensors.
152
+ tensor_quant_params = dequantized_weight_recovery.get_tensor_quant_params(
153
+ self._dummy_op_info,
154
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
155
+ num_bits=8,
156
+ granularity=qtyping.QuantGranularity.TENSORWISE,
157
+ ),
158
+ tensor_qsv={
159
+ "min": np.array([-1]),
160
+ "max": np.array([1]),
161
+ },
162
+ )
163
+
164
+ self.assertIsNone(tensor_quant_params.quantized_dimension)
165
+ recovered_scale = tensor_quant_params.scale
166
+ self.assertEqual(recovered_scale.shape, (1,))
167
+ self.assertSequenceAlmostEqual(recovered_scale.flatten(), [1 / 127])
168
+
169
+ # Zero point should be zero for symmetric quantization.
170
+ recovered_zp = tensor_quant_params.zero_point
171
+ self.assertEqual(np.sum(recovered_zp), 0)
172
+ self.assertEqual(recovered_zp.shape, (1,))
173
+
174
+ @parameterized.named_parameters(
175
+ dict(
176
+ testcase_name="recovery_on_wrong_dimension",
177
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
178
+ num_bits=4,
179
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
180
+ ),
106
181
  scale=np.array([0.003, 1.234, 12.65, 2.24e-4]).reshape(1, 4),
107
182
  ),
183
+ dict(
184
+ testcase_name="tensor_recovery_for_channel_quantization",
185
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
186
+ num_bits=4,
187
+ granularity=qtyping.QuantGranularity.TENSORWISE,
188
+ ),
189
+ scale=np.array([0.1875, 1e-2, 12.3]).reshape(3, 1),
190
+ ),
191
+ dict(
192
+ testcase_name="insufficient_bits",
193
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
194
+ num_bits=2,
195
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
196
+ ),
197
+ scale=np.array([0.1875, 1e-2, 12.3]).reshape(3, 1),
198
+ ),
108
199
  )
109
- def test_tensor_zp_scale_from_2d_dequantized_symmetric_weights_raises_error_big_recovery_error(
110
- self, quantized_dimension, scale
200
+ def test_get_tensor_quant_params_raises_error_big_recovery_error(
201
+ self, tensor_quant_config, scale
111
202
  ):
112
203
  dequant_vals = scale * self._dummy_quantized_weights
113
204
  with self.assertRaisesRegex(
@@ -115,8 +206,8 @@ class DequantizedWeightRecoveryTest(parameterized.TestCase):
115
206
  "Failed to recover the original quantized values from dequantized"
116
207
  " values. Max diff between recovered and original values: ",
117
208
  ):
118
- dequantized_weight_recovery.get_zp_scale_from_2d_dequantized_symmetric_weights(
119
- dequant_vals, quantized_dimension
209
+ dequantized_weight_recovery.get_tensor_quant_params(
210
+ self._dummy_op_info, tensor_quant_config, dequant_vals
120
211
  )
121
212
 
122
213
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.0.1.dev20250221
3
+ Version: 0.0.1.dev20250222
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -30,8 +30,8 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=s64
30
30
  ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
31
31
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=wPZevOuowJczG9t4Gynzv7tIeH6zhOnaKPsfr2K_fsk,21259
32
32
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=qMmKbWqxrCoVKbLKHn9WuCrGKPfHkEyU0Nmhokh8Qeo,2597
33
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=LfwZgZwkPZVZBS6VEwaskLNw3BoeymIjxAVw3ZkjjsI,4597
34
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=6FPYx4M2-W-SSV6iqQdggd5q5cnciqFI7Ci3Wo5Wyog,4566
33
+ ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=OTXjEZ3Ctq3ffYzisX-6HwgK_DuA7uos_aap5PiIUPE,8686
34
+ ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=y7BK11fkF63Ex_Jzg3fbIdy0D_Ca6HuvChVZR7Uwggc,8073
35
35
  ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=fBqSidFVKZmdO-xIFfwZPdIN1eLJjOik8mUZxZj2ljk,12149
36
36
  ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=Hok09dloSyBfD0oDM5VABdSZjM9JWSQhm_hDHNbFujA,7640
37
37
  ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py,sha256=Q_vx7YN7KMpjubsngxRdJ4bfdSIV-gmXjtVuxIkZuX4,11078
@@ -60,8 +60,8 @@ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=SM8H4i7Jq_nfdsJpImopHndN
60
60
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=Op3JxtOqlrjzmYF18jnnstL1k9xiY9kKJ8S2vklKGkc,11327
61
61
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
62
62
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
63
- ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
64
- ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info/METADATA,sha256=3so5mv89zJCWrCAs77PXQqIq3sGmyue7jkZsmIyO_mQ,1484
65
- ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
66
- ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
67
- ai_edge_quantizer_nightly-0.0.1.dev20250221.dist-info/RECORD,,
63
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
64
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/METADATA,sha256=e9r1p0vAQtBGj4RIEtBbjmiyDyUVUmdNYNU8LqfDVGk,1484
65
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
66
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
67
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/RECORD,,