ai-edge-quantizer-nightly 0.0.1.dev20250220__py3-none-any.whl → 0.0.1.dev20250222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,250 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Recovers quantized weights from dequantized weights (often from QAT)."""
17
+
18
+ import dataclasses
19
+ from typing import Any, Optional
20
+ import numpy as np
21
+ from ai_edge_quantizer import qtyping
22
+ from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
23
+ from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
24
+ from ai_edge_quantizer.algorithms.utils import common_utils
25
+
26
+ ALGORITHM_KEY = "dequantized_weight_recovery"
27
+ _TFLOpName = qtyping.TFLOperationName
28
+ _QuantTransformation = qtyping.QuantTransformation
29
+ _IntType = uniform_quantize_tensor.IntType
30
+
31
+
32
+ def _validate_recovered_weights(
33
+ original_vals: np.ndarray,
34
+ quant_vals: np.ndarray,
35
+ scale: np.ndarray,
36
+ tol: float = 1e-4,
37
+ ):
38
+ """Validates if recovered weights (from the quantized values) are close enough to the original ones.
39
+
40
+ Args:
41
+ original_vals: Original values before quantization.
42
+ quant_vals: Quantized values.
43
+ scale: Scale used for quantization.
44
+ tol: Tolerance for the difference between original and recovered values.
45
+
46
+ Raises:
47
+ RuntimeError: If the maximum difference between original and recovered
48
+ values exceeds the tolerance.
49
+ """
50
+ recovered_vals = quant_vals * scale
51
+ diff = np.abs(recovered_vals - original_vals).flatten()
52
+ max_diff = diff.max()
53
+ if max_diff > tol:
54
+ raise RuntimeError(
55
+ "Failed to recover the original quantized values from dequantized"
56
+ f" values. Max diff between recovered and original values: {max_diff}"
57
+ )
58
+
59
+
60
+ def _get_scale(arr: np.ndarray, min_scale: float) -> float:
61
+ """Helper function to calculate scale from a 1D array."""
62
+ # Make sure the array includes zero (symmetric quantization).
63
+ arr = np.append(arr, 0)
64
+ unique_vals = np.unique(arr)
65
+ if unique_vals.size > 1:
66
+ diffs = np.diff(unique_vals)
67
+ return float(
68
+ np.maximum(np.min(diffs), min_scale)
69
+ ) # Cast to float to ensure return type consistency
70
+ return min_scale
71
+
72
+
73
+ def get_zp_scale_from_2d_dequantized_symmetric_weights(
74
+ dequant_vals: np.ndarray,
75
+ quantized_dimension: Optional[int] = None,
76
+ min_scale: float = 1e-9,
77
+ ) -> tuple[np.ndarray, np.ndarray]:
78
+ """Calculates scale and zero point from 2D dequantized, symmetric weights.
79
+
80
+ Handles both per-tensor and per-channel (axis) quantization.
81
+
82
+ Args:
83
+ dequant_vals: The 2D dequantized weight values (numpy array).
84
+ quantized_dimension: The dimension along which quantization was performed
85
+ (0 or 1), or None for per-tensor quantization.
86
+ min_scale: The minimum allowed scale value.
87
+
88
+ Returns:
89
+ A tuple containing:
90
+ - zero_points: Zero points (all zeros for symmetric quantization).
91
+ - scales: Scales (scalar for per-tensor, array for per-channel).
92
+
93
+ Raises:
94
+ ValueError: If `dequant_vals` is not 2D, or if
95
+ `quantized_dimension` is not 0, 1, or None.
96
+ """
97
+
98
+ if dequant_vals.ndim != 2:
99
+ raise ValueError(
100
+ f"Only 2D weights are supported. Got {dequant_vals.ndim} dimensions."
101
+ )
102
+
103
+ if quantized_dimension not in (0, 1, None):
104
+ raise ValueError(
105
+ f"quantized_dimension must be 0, 1, or None. Got {quantized_dimension}"
106
+ )
107
+
108
+ # Use absolute values for symmetric quantization.
109
+ dequant_vals = np.abs(dequant_vals)
110
+
111
+ if quantized_dimension is None:
112
+ # Per-tensor quantization: One scale for the entire tensor.
113
+ scales = _get_scale(dequant_vals.flatten(), min_scale)
114
+ scales = np.array([[scales]])
115
+
116
+ else:
117
+ # Per-channel quantization: A scale for each slice along the dimension.
118
+ scales = []
119
+ for i in range(dequant_vals.shape[quantized_dimension]):
120
+ if quantized_dimension == 0:
121
+ vec = dequant_vals[i, :]
122
+ else: # quantized_dimension == 1
123
+ vec = dequant_vals[:, i]
124
+ scales.append(_get_scale(vec, min_scale))
125
+
126
+ # Reshape for correct broadcasting.
127
+ scales = (
128
+ np.array(scales).reshape(-1, 1)
129
+ if quantized_dimension == 0
130
+ else np.array(scales).reshape(1, -1)
131
+ )
132
+
133
+ zero_points = np.zeros_like(scales, dtype=np.int32)
134
+ return zero_points, scales
135
+
136
+
137
+ def get_tensor_quant_params(
138
+ op_info: qtyping.OpInfo,
139
+ tensor_quant_config: qtyping.TensorQuantizationConfig,
140
+ tensor_content: Optional[np.ndarray] = None,
141
+ tensor_qsv: Optional[dict[str, Any]] = None,
142
+ ) -> qtyping.UniformQuantParams:
143
+ """Get the quantization parameters for a tensor.
144
+
145
+ Args:
146
+ op_info: Aggregated information about the op (e.g., quantization config).
147
+ tensor_quant_config: The quantization config for the tensor.
148
+ tensor_content: The content of the tensor.
149
+ tensor_qsv: A dictionary containing the min/max of the tensor.
150
+
151
+ Returns:
152
+ The quantization parameters for the tensor.
153
+
154
+ Raises:
155
+ ValueError: If the quantization granularity is blockwise, or if the tensor
156
+ is not a 2D symmetric weight tensor.
157
+ """
158
+ # Fallback to naive_min_max_quantize.py for non-weight tensors.
159
+ if tensor_content is None:
160
+ return naive_min_max_quantize.get_tensor_quant_params(
161
+ op_info, tensor_quant_config, tensor_content, tensor_qsv
162
+ )
163
+
164
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
165
+ raise ValueError(
166
+ "Blockwise quantization is not supported for dequantized weight"
167
+ " recovery."
168
+ )
169
+ if tensor_content.ndim != 2 or not tensor_quant_config.symmetric:
170
+ raise ValueError(
171
+ "Only 2D symmetric weights are supported for dequantized weight"
172
+ " recovery."
173
+ )
174
+
175
+ quantized_dim = None
176
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
177
+ quantized_dim = common_utils.get_weight_quantized_dim(
178
+ op_info, tensor_content
179
+ )
180
+
181
+ zp, scale = get_zp_scale_from_2d_dequantized_symmetric_weights(
182
+ dequant_vals=tensor_content,
183
+ quantized_dimension=quantized_dim,
184
+ )
185
+ quant_params = qtyping.UniformQuantParams(
186
+ scale=scale,
187
+ zero_point=zp,
188
+ num_bits=tensor_quant_config.num_bits,
189
+ symmetric=tensor_quant_config.symmetric,
190
+ quantized_dimension=quantized_dim,
191
+ )
192
+ quantized_vars = uniform_quantize_tensor.uniform_quantize(
193
+ tensor_content, quant_params
194
+ )
195
+ _validate_recovered_weights(tensor_content, quantized_vars, scale)
196
+ return dataclasses.replace(quant_params, quantized_data=quantized_vars)
197
+
198
+
199
+ def calibrate(
200
+ tfl_op: Any,
201
+ graph_info: qtyping.GraphInfo,
202
+ tensor_content_map: dict[str, np.ndarray],
203
+ inputs_to_ignore: Optional[list[int]] = None,
204
+ outputs_to_ignore: Optional[list[int]] = None,
205
+ ) -> dict[str, qtyping.QSV]:
206
+ """Collect quantization statistics variable (QSV, e.g., min/max) for the op.
207
+
208
+ Args:
209
+ tfl_op: The tfl operation.
210
+ graph_info: Graph information needed to perform quantization for the op.
211
+ tensor_content_map: A map of tensor name to tensor content.
212
+ inputs_to_ignore: Input tensor indices to ignore.
213
+ outputs_to_ignore: Output tensor indices to ignore.
214
+
215
+ Returns:
216
+ A dictionary with key as tensor name and value as the collected QSV.
217
+ """
218
+ # Reuse the min/max calibration algorithm from naive_min_max_quantize.py since
219
+ # only weights need to be handled differently.
220
+ return naive_min_max_quantize.min_max_calibrate(
221
+ tfl_op,
222
+ graph_info,
223
+ tensor_content_map,
224
+ inputs_to_ignore,
225
+ outputs_to_ignore,
226
+ )
227
+
228
+
229
+ def init_qsvs(
230
+ op_info: qtyping.OpInfo,
231
+ graph_info: qtyping.GraphInfo,
232
+ inputs_to_ignore: Optional[list[int]] = None,
233
+ outputs_to_ignore: Optional[list[int]] = None,
234
+ ) -> qtyping.QSV:
235
+ """Initialize the QSVs.
236
+
237
+ Args:
238
+ op_info: Aggregated information about the op (e.g., quantization config).
239
+ graph_info: Graph information needed to perform quantization for the op.
240
+ inputs_to_ignore: Input tensor indices to ignore.
241
+ outputs_to_ignore: Output tensor indices to ignore.
242
+
243
+ Returns:
244
+ QSVs.
245
+ """
246
+ # Reuse the min/max calibration algorithm from naive_min_max_quantize.py since
247
+ # only weights need to be handeled differently.
248
+ return naive_min_max_quantize.init_qsvs(
249
+ op_info, graph_info, inputs_to_ignore, outputs_to_ignore
250
+ )
@@ -0,0 +1,215 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from absl.testing import parameterized
17
+ import numpy as np
18
+
19
+ from tensorflow.python.platform import googletest
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.algorithms.uniform_quantize import dequantized_weight_recovery
22
+
23
+ _TFLOpName = qtyping.TFLOperationName
24
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
25
+
26
+
27
+ class DequantizedWeightRecoveryTest(parameterized.TestCase):
28
+
29
+ def setUp(self):
30
+ super().setUp()
31
+ self._dummy_quantized_weights = np.array([
32
+ [1, -2, 3, 4],
33
+ [6, 7, -6, 5],
34
+ [2, -6, -7, -4],
35
+ ])
36
+ self._dummy_op_info = qtyping.OpInfo(
37
+ op=None,
38
+ op_name=_TFLOpName.FULLY_CONNECTED,
39
+ subgraph_op_index=0,
40
+ op_quant_config=qtyping.OpQuantizationConfig(),
41
+ )
42
+
43
+ @parameterized.named_parameters(
44
+ dict(
45
+ testcase_name="per-tensor-recovery",
46
+ quantized_dimension=None,
47
+ scale=np.array([0.1875]).reshape(1, 1),
48
+ ),
49
+ dict(
50
+ testcase_name="channel0-recovery",
51
+ quantized_dimension=0,
52
+ scale=np.array([0.1875, 1e-4, 12.3]).reshape(3, 1),
53
+ ),
54
+ dict(
55
+ testcase_name="channel1-recovery",
56
+ quantized_dimension=1,
57
+ scale=np.array([0.003, 1.234, 12.65, 2.24e-4]).reshape(1, 4),
58
+ ),
59
+ )
60
+ def test_tensor_zp_scale_from_2d_dequantized_symmetric_weights_success(
61
+ self, quantized_dimension, scale
62
+ ):
63
+ dequant_vals = scale * self._dummy_quantized_weights
64
+ zp, recovered_scale = (
65
+ dequantized_weight_recovery.get_zp_scale_from_2d_dequantized_symmetric_weights(
66
+ dequant_vals, quantized_dimension
67
+ )
68
+ )
69
+ self.assertEqual(recovered_scale.shape, scale.shape)
70
+ self.assertSequenceAlmostEqual(recovered_scale.flatten(), scale.flatten())
71
+ # Zero point should be zero for symmetric quantization.
72
+ self.assertEqual(np.sum(zp), 0)
73
+ self.assertEqual(zp.shape, scale.shape)
74
+
75
+ def test_tensor_zp_scale_from_2d_dequantized_symmetric_weights_raises_error_for_non_2d_weights(
76
+ self,
77
+ ):
78
+ weights_3d = self._dummy_quantized_weights.reshape(1, 3, 4)
79
+ weights_3d = weights_3d * 1.02
80
+ with self.assertRaisesRegex(
81
+ ValueError, "Only 2D weights are supported. Got 3 dimensions."
82
+ ):
83
+ dequantized_weight_recovery.get_zp_scale_from_2d_dequantized_symmetric_weights(
84
+ weights_3d, quantized_dimension=None
85
+ )
86
+
87
+ @parameterized.named_parameters(
88
+ dict(testcase_name="negative_dimension", quantized_dimension=-1),
89
+ dict(testcase_name="too_large_dimension", quantized_dimension=2),
90
+ )
91
+ def test_tensor_zp_scale_from_2d_dequantized_symmetric_weights_raises_error_for_invalid_quantized_dimension(
92
+ self, quantized_dimension
93
+ ):
94
+ dequant_vals = self._dummy_quantized_weights * 1.02
95
+ with self.assertRaisesRegex(
96
+ ValueError, "quantized_dimension must be 0, 1, or None. Got"
97
+ ):
98
+ dequantized_weight_recovery.get_zp_scale_from_2d_dequantized_symmetric_weights(
99
+ dequant_vals, quantized_dimension
100
+ )
101
+
102
+ @parameterized.named_parameters(
103
+ dict(
104
+ testcase_name="tensor-recovery-tensor-quant",
105
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
106
+ num_bits=4,
107
+ granularity=qtyping.QuantGranularity.TENSORWISE,
108
+ ),
109
+ scale=np.array([0.1875]).reshape(1, 1),
110
+ ),
111
+ dict(
112
+ testcase_name="channel-recovery-channel-quant",
113
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
114
+ num_bits=4,
115
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
116
+ ),
117
+ scale=np.array([0.1875, 1e-4, 12.3]).reshape(3, 1),
118
+ ),
119
+ dict(
120
+ testcase_name="channel-recovery-excessive-bits",
121
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
122
+ num_bits=8, # int4 is enough for the sample weights.
123
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
124
+ ),
125
+ scale=np.array([0.1875, 1e-4, 12.3]).reshape(3, 1),
126
+ ),
127
+ )
128
+ def test_get_tensor_quant_params_success_with_dequantized_weights(
129
+ self, tensor_quant_config, scale
130
+ ):
131
+ dequant_vals = scale * self._dummy_quantized_weights
132
+ tensor_quant_params = dequantized_weight_recovery.get_tensor_quant_params(
133
+ self._dummy_op_info, tensor_quant_config, dequant_vals
134
+ )
135
+
136
+ if tensor_quant_config.granularity is qtyping.QuantGranularity.TENSORWISE:
137
+ self.assertIsNone(tensor_quant_params.quantized_dimension)
138
+ else:
139
+ self.assertEqual(tensor_quant_params.quantized_dimension, 0)
140
+
141
+ recovered_scale = tensor_quant_params.scale
142
+ self.assertEqual(recovered_scale.shape, scale.shape)
143
+ self.assertSequenceAlmostEqual(recovered_scale.flatten(), scale.flatten())
144
+
145
+ # Zero point should be zero for symmetric quantization.
146
+ recovered_zp = tensor_quant_params.zero_point
147
+ self.assertEqual(np.sum(recovered_zp), 0)
148
+ self.assertEqual(recovered_zp.shape, scale.shape)
149
+
150
+ def test_get_tensor_quant_params_success_with_qsv(self):
151
+ # Fall back to naive_min_max_quantize.py for non-weight tensors.
152
+ tensor_quant_params = dequantized_weight_recovery.get_tensor_quant_params(
153
+ self._dummy_op_info,
154
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
155
+ num_bits=8,
156
+ granularity=qtyping.QuantGranularity.TENSORWISE,
157
+ ),
158
+ tensor_qsv={
159
+ "min": np.array([-1]),
160
+ "max": np.array([1]),
161
+ },
162
+ )
163
+
164
+ self.assertIsNone(tensor_quant_params.quantized_dimension)
165
+ recovered_scale = tensor_quant_params.scale
166
+ self.assertEqual(recovered_scale.shape, (1,))
167
+ self.assertSequenceAlmostEqual(recovered_scale.flatten(), [1 / 127])
168
+
169
+ # Zero point should be zero for symmetric quantization.
170
+ recovered_zp = tensor_quant_params.zero_point
171
+ self.assertEqual(np.sum(recovered_zp), 0)
172
+ self.assertEqual(recovered_zp.shape, (1,))
173
+
174
+ @parameterized.named_parameters(
175
+ dict(
176
+ testcase_name="recovery_on_wrong_dimension",
177
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
178
+ num_bits=4,
179
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
180
+ ),
181
+ scale=np.array([0.003, 1.234, 12.65, 2.24e-4]).reshape(1, 4),
182
+ ),
183
+ dict(
184
+ testcase_name="tensor_recovery_for_channel_quantization",
185
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
186
+ num_bits=4,
187
+ granularity=qtyping.QuantGranularity.TENSORWISE,
188
+ ),
189
+ scale=np.array([0.1875, 1e-2, 12.3]).reshape(3, 1),
190
+ ),
191
+ dict(
192
+ testcase_name="insufficient_bits",
193
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
194
+ num_bits=2,
195
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
196
+ ),
197
+ scale=np.array([0.1875, 1e-2, 12.3]).reshape(3, 1),
198
+ ),
199
+ )
200
+ def test_get_tensor_quant_params_raises_error_big_recovery_error(
201
+ self, tensor_quant_config, scale
202
+ ):
203
+ dequant_vals = scale * self._dummy_quantized_weights
204
+ with self.assertRaisesRegex(
205
+ RuntimeError,
206
+ "Failed to recover the original quantized values from dequantized"
207
+ " values. Max diff between recovered and original values: ",
208
+ ):
209
+ dequantized_weight_recovery.get_tensor_quant_params(
210
+ self._dummy_op_info, tensor_quant_config, dequant_vals
211
+ )
212
+
213
+
214
+ if __name__ == "__main__":
215
+ googletest.main()
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Performs naive min/max uniform quantization."""
17
17
 
18
+ from collections.abc import Sequence
18
19
  from typing import Any, Optional
19
20
  import numpy as np
20
21
  from ai_edge_quantizer import qtyping
@@ -36,55 +37,133 @@ def _init_tensor_min_max(
36
37
  if tensor_data is None:
37
38
  return {}
38
39
  else:
40
+ weight_tensor_config = op_info.op_quant_config.weight_tensor_config
39
41
  quantized_dim = None
42
+ if weight_tensor_config is not None and (
43
+ weight_tensor_config.granularity == qtyping.QuantGranularity.CHANNELWISE
44
+ or weight_tensor_config.granularity
45
+ == qtyping.QuantGranularity.BLOCKWISE
46
+ ):
47
+ quantized_dim = common_utils.get_weight_quantized_dim(
48
+ op_info, tensor_data
49
+ )
40
50
  if (
41
- op_info.op_quant_config.weight_tensor_config is not None
42
- and op_info.op_quant_config.weight_tensor_config.granularity
51
+ weight_tensor_config is not None
52
+ and weight_tensor_config.granularity
43
53
  == qtyping.QuantGranularity.BLOCKWISE
44
54
  ):
45
- # TODO(b/346612503): emulate subchannel only supports fully connected,
46
- # will skip special handling. Once we have a spec, we can change this.
47
- block_size = op_info.op_quant_config.weight_tensor_config.block_size
48
- # assuming tensor is 2D, which is correct for FULLY_CONNECTED
49
- transposed_tensor_data = np.transpose(tensor_data, (1, 0))
50
- if transposed_tensor_data.shape[0] % block_size:
51
- raise ValueError(
52
- f"Block size {block_size} does not divide channel dimension"
53
- f" {transposed_tensor_data.shape[0]}."
54
- )
55
- reshaped_tensor_data = np.reshape(
56
- transposed_tensor_data,
57
- (
58
- 1,
59
- int(transposed_tensor_data.shape[0] / block_size),
60
- block_size,
61
- transposed_tensor_data.shape[1],
62
- ),
55
+ reshaped_data, reduce_dims = _reshape_data_for_blockwise(
56
+ tensor_data,
57
+ quantized_dim,
58
+ weight_tensor_config.block_size,
63
59
  )
64
60
  return {
65
- "min": np.min(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
66
- "max": np.max(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
61
+ "min": np.min(reshaped_data, axis=reduce_dims, keepdims=False),
62
+ "max": np.max(reshaped_data, axis=reduce_dims, keepdims=False),
67
63
  }
68
- if (
69
- op_info.op_quant_config.weight_tensor_config is not None
70
- and op_info.op_quant_config.weight_tensor_config.granularity
71
- == qtyping.QuantGranularity.CHANNELWISE
72
- ):
73
- if op_info.op_name == _TFLOpName.BATCH_MATMUL:
74
- quantized_dim = common_utils.get_bmm_weight_quantized_dim(
75
- tensor_data, adj_y=op_info.op.builtinOptions.adjY
76
- )
77
- else:
78
- quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
79
- op_info.op_name, None
80
- )
81
- reduce_dims = common_utils.get_reduce_dims(
82
- quantized_dim, list(tensor_data.shape)
83
- )
84
- return {
85
- "min": np.min(tensor_data, axis=reduce_dims, keepdims=True),
86
- "max": np.max(tensor_data, axis=reduce_dims, keepdims=True),
87
- }
64
+
65
+ else:
66
+ reduce_dims = common_utils.get_reduce_dims(
67
+ quantized_dim, tensor_data.shape
68
+ )
69
+ return {
70
+ "min": np.min(tensor_data, axis=reduce_dims, keepdims=True),
71
+ "max": np.max(tensor_data, axis=reduce_dims, keepdims=True),
72
+ }
73
+
74
+
75
+ def _get_tensor_shape_for_blockwise(
76
+ tensor_shape: Sequence[int], quantized_dim: int, block_size: int
77
+ ) -> list[int]:
78
+ """Get the tensor shape for blockwise quantization.
79
+
80
+ This function splits the quantize dimension of the tensor into blocks and the
81
+ dim/blocks. Hence, min/max of the tensor can be calculated for each block
82
+ using existing functions.
83
+
84
+ Args:
85
+ tensor_shape: The original shape of the tensor.
86
+ quantized_dim: The dimension to be quantized blockwise.
87
+ block_size: The size of the block.
88
+
89
+ Returns:
90
+ The new tensor shape for calculating scale and zp for blockwise
91
+ quantization.
92
+ """
93
+ new_shape = []
94
+ for index, val in enumerate(tensor_shape):
95
+ if index == quantized_dim:
96
+ new_shape.append(int(val / block_size))
97
+ new_shape.append(block_size)
98
+ else:
99
+ new_shape.append(val)
100
+ return new_shape
101
+
102
+
103
+ def _reshape_data_for_blockwise(
104
+ tensor_data: np.ndarray, quantized_dim: int, block_size: int
105
+ ) -> tuple[np.ndarray, int]:
106
+ """Reshapes data for blockwise quantization.
107
+
108
+ Args:
109
+ tensor_data: The original tensor data.
110
+ quantized_dim: The dimension to be quantized blockwise.
111
+ block_size: The size of the block.
112
+
113
+ Returns:
114
+ A tuple containing the reshaped tensor data and the new reduce dimension.
115
+ """
116
+ new_shape = _get_tensor_shape_for_blockwise(
117
+ tensor_data.shape, quantized_dim, block_size
118
+ )
119
+ reshaped_data = tensor_data.reshape(new_shape)
120
+ return reshaped_data, quantized_dim + 1
121
+
122
+
123
+ def _broadcast_scale_zp_for_blockwise(
124
+ tensor_content: np.ndarray,
125
+ quant_params: qtyping.UniformQuantParams,
126
+ ) -> qtyping.UniformQuantParams:
127
+ """Broadcasts scale and zp for blockwise quantization.
128
+
129
+ Args:
130
+ tensor_content: The original tensor data.
131
+ quant_params: The quantization parameters.
132
+
133
+ Returns:
134
+ The updated quantization parameters with broadcasted scale and zp for
135
+ correct constant quantization.
136
+ """
137
+ if quant_params.quantized_dimension is None:
138
+ raise ValueError("Quantized dimension must be specified.")
139
+ if quant_params.block_size is None or quant_params.block_size <= 0:
140
+ raise ValueError("Block size must be specified and positive.")
141
+ quantized_dim = quant_params.quantized_dimension
142
+ expanded_tensor_shape = _get_tensor_shape_for_blockwise(
143
+ tensor_content.shape, quantized_dim, quant_params.block_size
144
+ )
145
+ expanded_scale = np.reshape(
146
+ np.broadcast_to(
147
+ np.expand_dims(quant_params.scale, quantized_dim + 1),
148
+ expanded_tensor_shape,
149
+ ),
150
+ tensor_content.shape,
151
+ )
152
+ expanded_zp = np.reshape(
153
+ np.broadcast_to(
154
+ np.expand_dims(quant_params.zero_point, quantized_dim + 1),
155
+ expanded_tensor_shape,
156
+ ),
157
+ tensor_content.shape,
158
+ )
159
+ return qtyping.UniformQuantParams(
160
+ scale=expanded_scale,
161
+ zero_point=expanded_zp,
162
+ num_bits=quant_params.num_bits,
163
+ symmetric=quant_params.symmetric,
164
+ quantized_dimension=quantized_dim,
165
+ block_size=quant_params.block_size,
166
+ )
88
167
 
89
168
 
90
169
  def get_tensor_quant_params(
@@ -138,34 +217,34 @@ def get_tensor_quant_params(
138
217
  tensor_quant_config.symmetric,
139
218
  )
140
219
  quantized_dim = None
141
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
142
- if op_info.op_name == _TFLOpName.BATCH_MATMUL:
143
- quantized_dim = common_utils.get_bmm_weight_quantized_dim(
144
- tensor_content, adj_y=op_info.op.builtinOptions.adjY
145
- )
146
- else:
147
- quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM[
148
- op_info.op_name
149
- ]
220
+ if (
221
+ tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE
222
+ or tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE
223
+ ):
224
+ quantized_dim = common_utils.get_weight_quantized_dim(
225
+ op_info, tensor_content
226
+ )
150
227
  quant_params = qtyping.UniformQuantParams(
151
228
  scale=scale,
152
229
  zero_point=zp,
153
230
  num_bits=tensor_quant_config.num_bits,
154
231
  symmetric=tensor_quant_config.symmetric,
155
232
  quantized_dimension=quantized_dim,
233
+ block_size=tensor_quant_config.block_size,
156
234
  )
157
235
  if tensor_content is None:
158
236
  return quant_params
237
+
238
+ # The reshaping for blockwise quantization is unique hence we do this here
239
+ # to avoid unexpected broadcast behavior downstream.
159
240
  if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
160
- quantized_vars = (
161
- uniform_quantize_tensor.uniform_quantize_for_emulated_subchannel(
162
- tensor_content, quant_params, tensor_quant_config.block_size
163
- )
164
- )
165
- else:
166
- quantized_vars = uniform_quantize_tensor.uniform_quantize(
241
+ quant_params = _broadcast_scale_zp_for_blockwise(
167
242
  tensor_content, quant_params
168
243
  )
244
+
245
+ quantized_vars = uniform_quantize_tensor.uniform_quantize(
246
+ tensor_content, quant_params
247
+ )
169
248
  # Update with quantized values.
170
249
  return qtyping.UniformQuantParams(
171
250
  scale=scale,
@@ -174,6 +253,7 @@ def get_tensor_quant_params(
174
253
  symmetric=tensor_quant_config.symmetric,
175
254
  quantized_dimension=quantized_dim,
176
255
  quantized_data=quantized_vars,
256
+ block_size=tensor_quant_config.block_size,
177
257
  )
178
258
 
179
259
 
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import os
17
+ from typing import cast
17
18
 
18
19
  from absl.testing import parameterized
19
20
  import numpy as np
@@ -21,6 +22,7 @@ import numpy as np
21
22
  from tensorflow.python.platform import googletest
22
23
  from ai_edge_quantizer import qtyping
23
24
  from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
25
+ from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
24
26
  from ai_edge_quantizer.utils import test_utils
25
27
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
26
28
 
@@ -157,6 +159,49 @@ class NaiveMinMaxQuantizeTest(parameterized.TestCase):
157
159
  self.assertNotIn("arith.constant1", op_qsvs)
158
160
  self.assertNotIn("arith.constant2", op_qsvs)
159
161
 
162
+ def test_get_tensor_quant_params_for_blockwise_weight(self):
163
+ subgraph0 = self._test_model.subgraphs[0]
164
+ subgraph_op_index = 3
165
+ fc_op = subgraph0.operators[subgraph_op_index]
166
+ weight_tensor_config = _TensorQuantConfig(
167
+ num_bits=4,
168
+ symmetric=True,
169
+ granularity=qtyping.QuantGranularity.BLOCKWISE,
170
+ block_size=2,
171
+ )
172
+ op_info = qtyping.OpInfo(
173
+ op=fc_op,
174
+ op_name=_TFLOpName.FULLY_CONNECTED,
175
+ subgraph_op_index=subgraph_op_index,
176
+ op_quant_config=qtyping.OpQuantizationConfig(
177
+ weight_tensor_config=weight_tensor_config,
178
+ ),
179
+ )
180
+ test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
181
+ quant_params = naive_min_max_quantize.get_tensor_quant_params(
182
+ op_info=op_info,
183
+ tensor_quant_config=weight_tensor_config,
184
+ tensor_content=test_data,
185
+ )
186
+ scale = quant_params.scale
187
+ zp = quant_params.zero_point
188
+ expected_zp, expected_scale = (
189
+ uniform_quantize_tensor.tensor_zp_scale_from_min_max(
190
+ min_value=np.array([[-7, 4], [-4, -4]]),
191
+ max_value=np.array([[4, 7], [7, 7]]),
192
+ num_bits=4,
193
+ symmetric=True,
194
+ )
195
+ )
196
+ self.assertTrue(np.array_equal(zp, expected_zp))
197
+ self.assertTrue(np.array_equal(scale, expected_scale))
198
+ self.assertIsNotNone(quant_params.quantized_data)
199
+ self.assertTupleEqual(
200
+ cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
201
+ )
202
+ self.assertEqual(quant_params.block_size, 2)
203
+ self.assertEqual(quant_params.quantized_dimension, 0)
204
+
160
205
 
161
206
  if __name__ == "__main__":
162
207
  googletest.main()
@@ -119,55 +119,6 @@ def fix_quantization_params_rank(
119
119
  )
120
120
 
121
121
 
122
- def uniform_quantize_for_emulated_subchannel(
123
- tensor_data: np.ndarray,
124
- quantization_params: qtyping.UniformQuantParams,
125
- block_size: int,
126
- ) -> np.ndarray:
127
- """Uniform quantize a tensor for emulated subchannel.
128
-
129
- emulation involves reshaping the tensor and quantizing value on a different
130
- axes. Hence, we use a different quantization function.
131
-
132
- Args:
133
- tensor_data: The tensor to be quantized.
134
- quantization_params: The quantization parameters.
135
- block_size: The block size of the emulated subchannel.
136
-
137
- Returns:
138
- The quantized tensor.
139
- """
140
- scales, zero_points = (
141
- quantization_params.scale,
142
- quantization_params.zero_point,
143
- )
144
- transposed_and_reshaped_tensor = np.reshape(
145
- np.transpose(tensor_data, (1, 0)),
146
- (
147
- 1,
148
- int(tensor_data.shape[1] / block_size),
149
- block_size,
150
- tensor_data.shape[0],
151
- ),
152
- )
153
- inverse_scales = 1.0 / scales
154
- qtype = IntType(quantization_params.num_bits, signed=True)
155
- # Symmetric means narrow range (e.g., -127 to 127)
156
- narrow_range = quantization_params.symmetric
157
- required_dtype = np.signedinteger if qtype.signed else np.unsignedinteger
158
- if not np.issubdtype(zero_points.dtype, required_dtype):
159
- raise ValueError(
160
- f"zero_points need to be {required_dtype}."
161
- f" But the actual type is {zero_points.dtype}."
162
- )
163
- ret = (
164
- np.multiply(transposed_and_reshaped_tensor, inverse_scales) + zero_points
165
- )
166
- ret = _round_and_clip(ret, qtype, narrow_range)
167
- ret = assign_quantized_type(ret, qtype)
168
- return ret
169
-
170
-
171
122
  def uniform_quantize(
172
123
  tensor_data: np.ndarray,
173
124
  quantization_params: qtyping.UniformQuantParams,
@@ -369,3 +320,14 @@ def _is_valid_quantization_params(
369
320
  f" ({zero_point_rank}) must be the same as the tensor rank"
370
321
  f" ({tensor_rank})."
371
322
  )
323
+ if (
324
+ quantization_params.block_size != 0
325
+ and tensor_data.shape[quantization_params.quantized_dimension]
326
+ % quantization_params.block_size
327
+ != 0
328
+ ):
329
+ raise ValueError(
330
+ "Tensor dimension must be divisible by block size. Got dimension:"
331
+ f" {tensor_data.shape[quantization_params.quantized_dimension]} and"
332
+ f" block size: {quantization_params.block_size}"
333
+ )
@@ -906,9 +906,30 @@ def get_tensor_transformation_params(
906
906
  )
907
907
 
908
908
 
909
+ def get_weight_quantized_dim(op_info: qtyping.OpInfo, tensor_data: np.ndarray):
910
+ """Get the quantized dimension for the weight tensor.
911
+
912
+ Args:
913
+ op_info: Aggregated information about the op (e.g., quantization config).
914
+ tensor_data: The weight tensor data.
915
+
916
+ Returns:
917
+ The quantized dimension for the weight tensor.
918
+ """
919
+ if op_info.op_name == _TFLOpName.BATCH_MATMUL:
920
+ quantized_dim = get_bmm_weight_quantized_dim(
921
+ tensor_data, adj_y=op_info.op.builtinOptions.adjY
922
+ )
923
+ else:
924
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
925
+ op_info.op_name, None
926
+ )
927
+ return quantized_dim
928
+
929
+
909
930
  def get_reduce_dims(
910
931
  quantized_dim: Optional[int],
911
- tensor_shape: list[int],
932
+ tensor_shape: Sequence[int],
912
933
  ) -> Optional[tuple[int, ...]]:
913
934
  """Get the reduce dims of a tensor for the given quantized dimension."""
914
935
  if quantized_dim is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.0.1.dev20250220
3
+ Version: 0.0.1.dev20250222
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -30,12 +30,14 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=s64
30
30
  ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
31
31
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=wPZevOuowJczG9t4Gynzv7tIeH6zhOnaKPsfr2K_fsk,21259
32
32
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=qMmKbWqxrCoVKbLKHn9WuCrGKPfHkEyU0Nmhokh8Qeo,2597
33
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=WaN6w-DqQkSwNl8xsbsSPPY97oKohHpo-5Ng_5yAerw,9958
34
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=3UV1I-to_u6NE_yKoXOVUOQgil-tMY6VQ_L273lMfqQ,5949
35
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py,sha256=B-s1KMfb9tqvaDhHJV-M2zRR078z5Mwv-P9h77S3Mis,12229
33
+ ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=OTXjEZ3Ctq3ffYzisX-6HwgK_DuA7uos_aap5PiIUPE,8686
34
+ ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=y7BK11fkF63Ex_Jzg3fbIdy0D_Ca6HuvChVZR7Uwggc,8073
35
+ ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=fBqSidFVKZmdO-xIFfwZPdIN1eLJjOik8mUZxZj2ljk,12149
36
+ ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=Hok09dloSyBfD0oDM5VABdSZjM9JWSQhm_hDHNbFujA,7640
37
+ ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py,sha256=Q_vx7YN7KMpjubsngxRdJ4bfdSIV-gmXjtVuxIkZuX4,11078
36
38
  ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py,sha256=WZ4_bvbG999nOtCIqn7mrMnpRdoJOdiyzxhsL_QiPHA,11395
37
39
  ai_edge_quantizer/algorithms/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
38
- ai_edge_quantizer/algorithms/utils/common_utils.py,sha256=Z2ziBeADwov8rRN4pRX6Qr2L_agu8RRAbOKw0_yLG7E,33936
40
+ ai_edge_quantizer/algorithms/utils/common_utils.py,sha256=nlLpUY1LTO9ZC3X0FjQ0EArCZekGUnv2-IF0AUu5zNM,34582
39
41
  ai_edge_quantizer/algorithms/utils/common_utils_test.py,sha256=zqapGEfYhjQWe9cNGPLmdbwtEUUYQRhlO_kNe0cXX6E,18104
40
42
  ai_edge_quantizer/transformations/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
41
43
  ai_edge_quantizer/transformations/dequant_insert.py,sha256=sL1LHFVzBDSd9jgrzlHz38LWU0bwmVX7iBkaNcui0ts,3566
@@ -58,8 +60,8 @@ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=SM8H4i7Jq_nfdsJpImopHndN
58
60
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=Op3JxtOqlrjzmYF18jnnstL1k9xiY9kKJ8S2vklKGkc,11327
59
61
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
60
62
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
61
- ai_edge_quantizer_nightly-0.0.1.dev20250220.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
62
- ai_edge_quantizer_nightly-0.0.1.dev20250220.dist-info/METADATA,sha256=wc1t3VKLcToSVZ6MOwmrcNhWcy967d9mAaQlFF6w50s,1484
63
- ai_edge_quantizer_nightly-0.0.1.dev20250220.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
64
- ai_edge_quantizer_nightly-0.0.1.dev20250220.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
65
- ai_edge_quantizer_nightly-0.0.1.dev20250220.dist-info/RECORD,,
63
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
64
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/METADATA,sha256=e9r1p0vAQtBGj4RIEtBbjmiyDyUVUmdNYNU8LqfDVGk,1484
65
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
66
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
67
+ ai_edge_quantizer_nightly-0.0.1.dev20250222.dist-info/RECORD,,