ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -44,33 +44,17 @@ class OctavQuantizeTest(parameterized.TestCase):
44
44
  )
45
45
  self._tensor_name_to_qsv = {}
46
46
  subgraph0 = self._test_model.subgraphs[0]
47
- subgraph_op_index = 3
48
- fc_op = subgraph0.operators[subgraph_op_index]
47
+ self._subgraph_op_index = 3
48
+ self._fc_op = subgraph0.operators[self._subgraph_op_index]
49
49
  self._fc_op_info = qtyping.OpInfo(
50
- op=fc_op,
50
+ op=self._fc_op,
51
51
  op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
52
- subgraph_op_index=subgraph_op_index,
52
+ subgraph_op_index=self._subgraph_op_index,
53
53
  op_quant_config=qtyping.OpQuantizationConfig(
54
54
  weight_tensor_config=None,
55
55
  ),
56
56
  )
57
57
 
58
- def test_get_tensor_quant_params_unsupported_granularity_assert(self):
59
- err_msg = "Unsupported granularity"
60
- test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
61
- with self.assertRaisesWithPredicateMatch(
62
- ValueError, lambda err: err_msg in str(err)
63
- ):
64
- _ = octav.get_tensor_quant_params(
65
- op_info=self._fc_op_info,
66
- tensor_quant_config=qtyping.TensorQuantizationConfig(
67
- num_bits=4,
68
- symmetric=True,
69
- granularity=qtyping.QuantGranularity.BLOCKWISE,
70
- ),
71
- tensor_content=test_data,
72
- )
73
-
74
58
  def test_get_tensor_quant_params_unsupported_symmetry(self):
75
59
  err_msg = "Unsupported symmetry"
76
60
  test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
@@ -117,13 +101,22 @@ class OctavQuantizeTest(parameterized.TestCase):
117
101
  [25, -30, 50, -75, 1e5, -125],
118
102
  [50, -60, 70, -80, 90, -100],
119
103
  ])
120
- quant_params = octav.get_tensor_quant_params(
121
- op_info=self._fc_op_info,
122
- tensor_quant_config=qtyping.TensorQuantizationConfig(
123
- num_bits=4,
124
- symmetric=True,
125
- granularity=qtyping.QuantGranularity.TENSORWISE,
104
+ tensor_config = qtyping.TensorQuantizationConfig(
105
+ num_bits=4,
106
+ symmetric=True,
107
+ granularity=qtyping.QuantGranularity.TENSORWISE,
108
+ )
109
+ fc_op_info = qtyping.OpInfo(
110
+ op=self._fc_op,
111
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
112
+ subgraph_op_index=self._subgraph_op_index,
113
+ op_quant_config=qtyping.OpQuantizationConfig(
114
+ weight_tensor_config=tensor_config,
126
115
  ),
116
+ )
117
+ quant_params = octav.get_tensor_quant_params(
118
+ op_info=fc_op_info,
119
+ tensor_quant_config=tensor_config,
127
120
  tensor_content=test_data,
128
121
  )
129
122
  adjusted_test_data = quant_params.quantized_data * quant_params.scale
@@ -131,10 +124,10 @@ class OctavQuantizeTest(parameterized.TestCase):
131
124
  adjusted_max = np.max(np.abs(adjusted_test_data))
132
125
 
133
126
  # Check that some clipping occurred.
134
- with self.subTest(name="SanityCheckClipping"):
127
+ with self.subTest(name="CheckClipping"):
135
128
  self.assertLess(adjusted_max, real_max)
136
129
 
137
- with self.subTest(name="SanityCheckQuantParamsShapes"):
130
+ with self.subTest(name="CheckQuantParamsShapes"):
138
131
  self.assertEqual(quant_params.zero_point.shape, (1, 1))
139
132
  self.assertEqual(quant_params.scale.shape, (1, 1))
140
133
  self.assertIsNone(quant_params.quantized_dimension)
@@ -143,33 +136,47 @@ class OctavQuantizeTest(parameterized.TestCase):
143
136
  cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
144
137
  )
145
138
 
146
- with self.subTest(name="SanityCheckQuantParamsValues"):
139
+ with self.subTest(name="CheckQuantParamsValues"):
147
140
  self.assertTrue(np.all(quant_params.zero_point == 0))
148
141
 
149
142
  def test_get_tensor_quant_params_sanity_channelwise(self):
143
+ # Test that the call generates quant params that are appropriately shaped,
144
+ # have some clipping, and correct config values without checking the
145
+ # actual values numerically.
150
146
  test_data = np.array([
151
147
  [-1e5, 25, -50, 75, -100, 125],
152
148
  [25, -30, 50, -75, 1e5, -125],
153
149
  [50, -60, 70, -80, 90, -100],
154
150
  ])
155
- quant_params = octav.get_tensor_quant_params(
156
- op_info=self._fc_op_info,
157
- tensor_quant_config=qtyping.TensorQuantizationConfig(
158
- num_bits=4,
159
- symmetric=True,
160
- granularity=qtyping.QuantGranularity.CHANNELWISE,
151
+ tensor_config = qtyping.TensorQuantizationConfig(
152
+ num_bits=4,
153
+ symmetric=True,
154
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
155
+ )
156
+ fc_op_info = qtyping.OpInfo(
157
+ op=self._fc_op,
158
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
159
+ subgraph_op_index=self._subgraph_op_index,
160
+ op_quant_config=qtyping.OpQuantizationConfig(
161
+ weight_tensor_config=tensor_config,
161
162
  ),
163
+ )
164
+ quant_params = octav.get_tensor_quant_params(
165
+ op_info=fc_op_info,
166
+ tensor_quant_config=tensor_config,
162
167
  tensor_content=test_data,
163
168
  )
169
+ # Dequantize output to compare with the original test data.
164
170
  adjusted_test_data = quant_params.quantized_data * quant_params.scale
171
+
165
172
  for i, row in enumerate(test_data):
166
173
  real_max = np.max(np.abs(row))
167
174
  adjusted_max = np.max(np.abs(adjusted_test_data[i]))
168
175
  # Check that some clipping occurred.
169
- with self.subTest(name="SanityCheckClipping"):
176
+ with self.subTest(name="CheckClipping"):
170
177
  self.assertLess(adjusted_max, real_max)
171
178
 
172
- with self.subTest(name="SanityCheckQuantParamsShapes"):
179
+ with self.subTest(name="CheckQuantParamsShapes"):
173
180
  self.assertEqual(quant_params.zero_point.shape, (test_data.shape[0], 1))
174
181
  self.assertEqual(quant_params.scale.shape, (test_data.shape[0], 1))
175
182
  self.assertIsNotNone(quant_params.quantized_data)
@@ -177,10 +184,57 @@ class OctavQuantizeTest(parameterized.TestCase):
177
184
  cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
178
185
  )
179
186
 
180
- with self.subTest(name="SanityCheckQuantParamsValues"):
187
+ with self.subTest(name="CheckQuantParamsValues"):
181
188
  self.assertTrue(np.all(quant_params.zero_point == 0))
182
189
  self.assertEqual(quant_params.quantized_dimension, 0)
183
190
 
191
+ def test_get_tensor_quant_params_sanity_blockwise(self):
192
+ # Test that the call generates quant params that are appropriately shaped,
193
+ # have some clipping, and correct config values without checking the
194
+ # actual values numerically.
195
+ test_data = np.random.randint(0, 1024, size=(32, 128))
196
+ tensor_config = qtyping.TensorQuantizationConfig(
197
+ num_bits=4,
198
+ symmetric=True,
199
+ granularity=qtyping.QuantGranularity.BLOCKWISE_32,
200
+ )
201
+ fc_op_info = qtyping.OpInfo(
202
+ op=self._fc_op,
203
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
204
+ subgraph_op_index=self._subgraph_op_index,
205
+ op_quant_config=qtyping.OpQuantizationConfig(
206
+ weight_tensor_config=tensor_config,
207
+ ),
208
+ )
209
+ quant_params = octav.get_tensor_quant_params(
210
+ op_info=fc_op_info,
211
+ tensor_quant_config=tensor_config,
212
+ tensor_content=test_data,
213
+ )
214
+
215
+ with self.subTest(name="CheckQuantParamsShapes"):
216
+ # Check that quant params have appropriate shapes.
217
+ self.assertEqual(quant_params.zero_point.shape, (32, 4))
218
+ self.assertEqual(quant_params.scale.shape, (32, 4))
219
+ self.assertIsNotNone(quant_params.quantized_data)
220
+ self.assertTupleEqual(
221
+ cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
222
+ )
223
+
224
+ scales = np.repeat(quant_params.scale, 32, axis=1)
225
+ adjusted_test_data = quant_params.quantized_data * scales
226
+ for i, row in enumerate(test_data):
227
+ real_max = np.max(np.abs(row))
228
+ adjusted_max = np.max(np.abs(adjusted_test_data[i]))
229
+ # Check that some clipping occurred.
230
+ with self.subTest(name="CheckClipping"):
231
+ self.assertLess(adjusted_max, real_max)
232
+
233
+ with self.subTest(name="CheckQuantParamsValues"):
234
+ self.assertTrue(np.all(quant_params.zero_point == 0))
235
+ # See TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM.
236
+ self.assertEqual(quant_params.quantized_dimension, 1)
237
+
184
238
 
185
239
  if __name__ == "__main__":
186
240
  googletest.main()
@@ -16,9 +16,11 @@
16
16
  """Uniform quantize in tensor level."""
17
17
 
18
18
  import dataclasses
19
- from typing import Optional
19
+ from typing import Optional, Sequence
20
+ import ml_dtypes
20
21
  import numpy as np
21
22
  from ai_edge_quantizer import qtyping
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
22
24
 
23
25
 
24
26
  @dataclasses.dataclass(frozen=True)
@@ -27,6 +29,11 @@ class IntType:
27
29
  signed: bool
28
30
 
29
31
 
32
+ def is_blockwise(granularity: qtyping.QuantGranularity) -> bool:
33
+ """Checks if the quantization granularity is blockwise."""
34
+ return "BLOCKWISE" in str(granularity)
35
+
36
+
30
37
  def get_quantized_range(qtype: IntType) -> tuple[float, float]:
31
38
  """Calculates range of the quantized type."""
32
39
  if qtype.signed:
@@ -38,6 +45,22 @@ def get_quantized_range(qtype: IntType) -> tuple[float, float]:
38
45
  return float(qmin), float(qmax)
39
46
 
40
47
 
48
+ def extract_block_size_from_granularity(
49
+ granularity: qtyping.QuantGranularity,
50
+ ) -> int:
51
+ """Get the block size for blockwise quantization."""
52
+ if granularity == qtyping.QuantGranularity.BLOCKWISE_32:
53
+ return 32
54
+ elif granularity == qtyping.QuantGranularity.BLOCKWISE_64:
55
+ return 64
56
+ elif granularity == qtyping.QuantGranularity.BLOCKWISE_128:
57
+ return 128
58
+ elif granularity == qtyping.QuantGranularity.BLOCKWISE_256:
59
+ return 256
60
+ else:
61
+ return 0
62
+
63
+
41
64
  def _round_and_clip(
42
65
  tensor: np.ndarray, qtype: IntType, narrow: bool
43
66
  ) -> np.ndarray:
@@ -117,22 +140,141 @@ def fix_quantization_params_rank(
117
140
  symmetric=quantization_params.symmetric,
118
141
  quantized_dimension=quantization_params.quantized_dimension,
119
142
  quantized_data=quantization_params.quantized_data,
143
+ block_size=quantization_params.block_size,
144
+ )
145
+
146
+
147
+ def _get_tensor_shape_for_blockwise(
148
+ tensor_shape: Sequence[int], quantized_dim: int, block_size: int
149
+ ) -> list[int]:
150
+ """Get the tensor shape for blockwise quantization.
151
+
152
+ This function splits the quantize dimension of the tensor into blocks and the
153
+ dim/blocks. Hence, min/max of the tensor can be calculated for each block
154
+ using existing functions.
155
+
156
+ Args:
157
+ tensor_shape: The original shape of the tensor.
158
+ quantized_dim: The dimension to be quantized blockwise.
159
+ block_size: The size of the block.
160
+
161
+ Returns:
162
+ The new tensor shape for calculating scale and zp for blockwise
163
+ quantization.
164
+ """
165
+ new_shape = []
166
+ for index, val in enumerate(tensor_shape):
167
+ if index == quantized_dim:
168
+ if val % block_size != 0:
169
+ raise ValueError(
170
+ f"Quantized dimension {val} in tensor shape {tensor_shape} is not"
171
+ f" divisible by block size {block_size}."
172
+ )
173
+ new_shape.append(int(val / block_size))
174
+ new_shape.append(block_size)
175
+ else:
176
+ new_shape.append(val)
177
+ return new_shape
178
+
179
+
180
+ def reshape_data_for_blockwise(
181
+ tensor_data: np.ndarray,
182
+ op_name: qtyping.TFLOperationName,
183
+ granularity: qtyping.QuantGranularity,
184
+ ) -> tuple[np.ndarray, int]:
185
+ """Reshapes data for blockwise quantization.
186
+
187
+ Args:
188
+ tensor_data: The original tensor data.
189
+ op_name: The name of the TFL op.
190
+ granularity: The quantization granularity for the tensor.
191
+
192
+ Returns:
193
+ A tuple containing the reshaped tensor data and the new reduce dimension.
194
+ """
195
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
196
+ op_name
197
+ ]
198
+ block_size = extract_block_size_from_granularity(granularity)
199
+ new_shape = _get_tensor_shape_for_blockwise(
200
+ tensor_data.shape, quantized_dim, block_size
201
+ )
202
+ return tensor_data.reshape(new_shape), quantized_dim + 1
203
+
204
+
205
+ def _broadcast_scale_zp_for_blockwise(
206
+ tensor_content: np.ndarray,
207
+ quant_params: qtyping.UniformQuantParams,
208
+ ) -> qtyping.UniformQuantParams:
209
+ """Broadcasts scale and zp for blockwise quantization.
210
+
211
+ Args:
212
+ tensor_content: The original tensor data.
213
+ quant_params: The quantization parameters.
214
+ `quant_params.quantized_dimension` must be specified.
215
+ `quant_params.block_size` must be specified and positive.
216
+
217
+ Returns:
218
+ The updated quantization parameters with broadcasted scale and zp for
219
+ correct constant quantization.
220
+ """
221
+ if quant_params.quantized_dimension is None:
222
+ raise ValueError("Quantized dimension must be specified.")
223
+ if quant_params.block_size is None or quant_params.block_size <= 0:
224
+ raise ValueError("Block size must be specified and positive.")
225
+ quantized_dim = quant_params.quantized_dimension
226
+ expanded_tensor_shape = _get_tensor_shape_for_blockwise(
227
+ tensor_content.shape, quantized_dim, quant_params.block_size
228
+ )
229
+ expanded_scale = np.reshape(
230
+ np.broadcast_to(
231
+ np.expand_dims(quant_params.scale, quantized_dim + 1),
232
+ expanded_tensor_shape,
233
+ ),
234
+ tensor_content.shape,
235
+ )
236
+ if quant_params.zero_point is None or quant_params.zero_point.size == 0:
237
+ expanded_zp = np.zeros_like(tensor_content, dtype=np.int32)
238
+ else:
239
+ expanded_zp = np.reshape(
240
+ np.broadcast_to(
241
+ np.expand_dims(quant_params.zero_point, quantized_dim + 1),
242
+ expanded_tensor_shape,
243
+ ),
244
+ tensor_content.shape,
245
+ )
246
+ return qtyping.UniformQuantParams(
247
+ scale=expanded_scale,
248
+ zero_point=expanded_zp,
249
+ num_bits=quant_params.num_bits,
250
+ symmetric=quant_params.symmetric,
251
+ quantized_dimension=quantized_dim,
252
+ block_size=quant_params.block_size,
120
253
  )
121
254
 
122
255
 
123
256
  def uniform_quantize(
124
257
  tensor_data: np.ndarray,
125
258
  quantization_params: qtyping.UniformQuantParams,
259
+ is_blockwise_quant: bool = False,
126
260
  ):
127
261
  """Uniform quantize a tensor.
128
262
 
129
263
  Args:
130
264
  tensor_data: The tensor to be quantized.
131
265
  quantization_params: The quantization parameters.
266
+ is_blockwise_quant: Whether the tensor is blockwise quantized.
132
267
 
133
268
  Returns:
134
269
  The quantized tensor.
135
270
  """
271
+ # The reshaping for blockwise quantization is unique hence we do this here
272
+ # to avoid unexpected broadcast behavior downstream.
273
+ if is_blockwise_quant:
274
+ quantization_params = _broadcast_scale_zp_for_blockwise(
275
+ tensor_data, quantization_params
276
+ )
277
+
136
278
  # quant params in flatbuffer is flattened, expand the rank to be the same
137
279
  # as the tensor rank to avoid ambiguous broadcasting.
138
280
  quantization_params = fix_quantization_params_rank(
@@ -146,8 +288,15 @@ def uniform_quantize(
146
288
  inverse_scales = 1.0 / scales
147
289
  # TODO: b/332574603 - support unsigned data type.
148
290
  qtype = IntType(quantization_params.num_bits, signed=True)
149
- # Symmetric means narrow range (e.g., -127 to 127)
150
- narrow_range = quantization_params.symmetric
291
+ # For quantization with more than 8 bits, symmetric narrow-range quantization
292
+ # is required due to assumptions made by legacy TFLite kernels. However, this
293
+ # method is not ideal for low-bit quantization (e.g., 2-bit quantization,
294
+ # which only has 4 bins), as it wastes a bin and there are no kernel
295
+ # requirements for a narrow range when < 8 bits because the data is unpacked
296
+ # to int8 before being used in the kernel.
297
+ narrow_range = (
298
+ quantization_params.symmetric and quantization_params.num_bits >= 8
299
+ )
151
300
  required_dtype = np.signedinteger if qtype.signed else np.unsignedinteger
152
301
  if not np.issubdtype(zero_points.dtype, required_dtype):
153
302
  raise ValueError(
@@ -173,6 +322,26 @@ def uniform_dequantize(
173
322
  Returns:
174
323
  The dequantized tensor.
175
324
  """
325
+ if quantization_params.block_size != 0:
326
+ # b/443830202: The quantized dimension is currently increased by 1 because
327
+ # AEQ expects 1 and XNNPack expects 0.
328
+ quantization_params = dataclasses.replace(
329
+ quantization_params,
330
+ quantized_dimension=quantization_params.quantized_dimension + 1,
331
+ )
332
+ scale_shape = list(tensor_data.shape)
333
+ scale_shape[quantization_params.quantized_dimension] = (
334
+ scale_shape[quantization_params.quantized_dimension]
335
+ // quantization_params.block_size
336
+ )
337
+ quantization_params = dataclasses.replace(
338
+ quantization_params,
339
+ scale=quantization_params.scale.reshape(scale_shape),
340
+ )
341
+ quantization_params = _broadcast_scale_zp_for_blockwise(
342
+ tensor_data, quantization_params
343
+ )
344
+
176
345
  # quant params in flatbuffer is flattened, expand the rank to be the same
177
346
  # as the tensor rank to avoid ambiguous broadcasting.
178
347
  quantization_params = fix_quantization_params_rank(
@@ -188,6 +357,7 @@ def symmetric_quantize_bias_tensor(
188
357
  bias_content: np.ndarray,
189
358
  input_tensor_quant_params: qtyping.UniformQuantParams,
190
359
  weight_tensor_quant_params: qtyping.UniformQuantParams,
360
+ check_error: bool = True,
191
361
  ) -> qtyping.UniformQuantParams:
192
362
  """Quantize bias tensor (symmetrically, i.e., zero_point = 0).
193
363
 
@@ -199,6 +369,12 @@ def symmetric_quantize_bias_tensor(
199
369
  bias_content: The bias content.
200
370
  input_tensor_quant_params: The quantization parameters of input tensor.
201
371
  weight_tensor_quant_params: The quantization parameters of weight tensor.
372
+ check_error: Whether to check if the quantization error (the difference
373
+ between the original and dequantized bias) is larger than the quantization
374
+ scale. This check is important because bias quantization parameters are
375
+ fixed (bias_scale = input_scale * weight_scale), which can lead to large
376
+ quantization errors. Raising an error when the quantization error is
377
+ larger than the scale helps to identify unexpected numerical issues.
202
378
 
203
379
  Returns:
204
380
  The quantized bias tensor.
@@ -213,7 +389,8 @@ def symmetric_quantize_bias_tensor(
213
389
 
214
390
  # symmetric
215
391
  bias_zp = np.zeros_like(effective_output_scale, dtype=np.int32)
216
- bias_number_bits = 64 if input_tensor_quant_params.num_bits == 16 else 32
392
+ # Fixed to 32 bits since most of the accelerators use int32 accumulator.
393
+ bias_number_bits = 32
217
394
  symmetric = True
218
395
  quantized_dimension = None if len(effective_output_scale) == 1 else 0
219
396
  bias_quant_params = qtyping.UniformQuantParams(
@@ -225,6 +402,24 @@ def symmetric_quantize_bias_tensor(
225
402
  )
226
403
 
227
404
  quantized_vars = uniform_quantize(bias_content, bias_quant_params)
405
+ if check_error:
406
+ dequantized_bias = uniform_dequantize(quantized_vars, bias_quant_params)
407
+ max_quant_error = np.max(np.abs(dequantized_bias - bias_content))
408
+ error_tolerance = np.maximum(1e-6, np.max(effective_output_scale))
409
+ if max_quant_error > error_tolerance:
410
+ raise ValueError(
411
+ "Quantization error is too large for bias tensor quantization. Max"
412
+ f" quantization error is {max_quant_error}, which exceed"
413
+ f" the threshold {error_tolerance}"
414
+ )
415
+
416
+ # Save the int32 quantized bias as int64 if the input tensor is quantized to
417
+ # 16 bits. This is to assume the matmul is using int64 accumulator (safe from
418
+ # overflow). For accelerators with int32 accumulator, it is safe to cast int64
419
+ # back to int32.
420
+ if input_tensor_quant_params.num_bits == 16:
421
+ quantized_vars = quantized_vars.astype(np.int64)
422
+ bias_number_bits = 64
228
423
 
229
424
  # UniformQuantParams is frozen dataclass, need to recreate.
230
425
  return qtyping.UniformQuantParams(
@@ -242,15 +437,19 @@ def tensor_zp_scale_from_min_max(
242
437
  max_value,
243
438
  num_bits: int,
244
439
  symmetric: bool,
440
+ granularity: qtyping.QuantGranularity,
245
441
  clipping_values: Optional[np.ndarray] = None,
246
442
  ):
247
443
  """Get zero point and scale from min and max value.
248
444
 
249
445
  Args:
250
- min_value: The minimum value of the tensor (channel-wise supported).
251
- max_value: The maximum value of the tensor (channel-wise supported).
446
+ min_value: The minimum value of the tensor (channelwise and blockwise
447
+ supported).
448
+ max_value: The maximum value of the tensor (channelwise and blockwise
449
+ supported).
252
450
  num_bits: The number of bits of the tensor.
253
451
  symmetric: Whether the tensor is symmetric.
452
+ granularity: The granularity of the tensor.
254
453
  clipping_values: Absolute clipping values to apply to the tensor. This will
255
454
  clip the tensors to the range [-clipping_values, clipping_values]. This
256
455
  should be the same shape as min_value and max_value. If None, no clipping
@@ -259,19 +458,45 @@ def tensor_zp_scale_from_min_max(
259
458
  Returns:
260
459
  The zero point and scale of the tensor.
261
460
  """
461
+
262
462
  # TODO: b/332574603 - support unsigned data type.
263
463
  qtype = IntType(
264
464
  num_bits,
265
465
  signed=True,
266
466
  )
267
467
  qmin, qmax = get_quantized_range(qtype)
268
- min_bound = 1e-4 # 1e-6 precision for int8 and 1e-8 for int16.
468
+ min_bound = 1e-9 # Avoid zero scale.
469
+ pos_clipping_values = None if clipping_values is None else clipping_values
470
+ neg_clipping_values = None if clipping_values is None else -clipping_values
471
+
472
+ if is_blockwise(granularity):
473
+ # Blockwise quantization uses float16 scale,
474
+ # with 7 bit mantissa, so the maximum scale value is 65280 and maximum
475
+ # representable range is [-65280 * (2 ** num_bits),
476
+ # 65280 * (2 ** num_bits - 1)].
477
+ # Note that we have one extra value on the negative side.
478
+ float16_max = np.broadcast_to(
479
+ np.array(65280) * (2**num_bits - 1), max_value.shape
480
+ )
481
+ float16_min = np.broadcast_to(
482
+ np.array(-65280) * (2**num_bits), min_value.shape
483
+ )
484
+ pos_clipping_values = (
485
+ float16_max
486
+ if pos_clipping_values is None
487
+ else np.minimum(pos_clipping_values, float16_max)
488
+ )
489
+ neg_clipping_values = (
490
+ float16_min
491
+ if neg_clipping_values is None
492
+ else np.maximum(neg_clipping_values, float16_min)
493
+ )
269
494
 
270
495
  if symmetric:
271
496
  bound = np.maximum(np.abs(min_value), np.abs(max_value))
272
497
  bound = np.maximum(bound, min_bound)
273
498
  if clipping_values is not None:
274
- bound = np.clip(bound, -clipping_values, clipping_values)
499
+ bound = np.clip(bound, neg_clipping_values, pos_clipping_values)
275
500
  if not qtype.signed:
276
501
  half_q = (qmax - 1) / 2
277
502
  scale = bound / half_q
@@ -292,6 +517,12 @@ def tensor_zp_scale_from_min_max(
292
517
  zp = qmin - bound_min / scale
293
518
  zp = np.rint(zp)
294
519
 
520
+ if is_blockwise(granularity):
521
+ # Round the scale values to 7 bit mantissa.
522
+ scale = (
523
+ scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
524
+ )
525
+
295
526
  # It's safe to cast zp to qtype without clipping because we can infer
296
527
  # qmin <= zp <= qmax from bound_min <= 0 <= bound_max.
297
528
  zp = assign_quantized_type(zp, qtype)
@@ -305,7 +536,8 @@ def _is_valid_quantization_params(
305
536
  """Checks if the quantization parameters are valid.
306
537
 
307
538
  A valid quantization params requires:
308
- 1. scale and zero point have the same shape (TFL Runtime requirement).
539
+ 1. scale and zero point either have the same shape or the zero point is a
540
+ scalar.
309
541
  2. scale and zero point have the same rank as the tensor content (avoid
310
542
  ambiguous broadcasting).
311
543
 
@@ -316,17 +548,20 @@ def _is_valid_quantization_params(
316
548
  Returns:
317
549
  True if the quantization parameters are valid.
318
550
  """
319
- if quantization_params.scale.shape != quantization_params.zero_point.shape:
551
+ if (
552
+ quantization_params.scale.shape != quantization_params.zero_point.shape
553
+ and quantization_params.zero_point.size != 1
554
+ ):
320
555
  raise ValueError(
321
- "scale and zero_point must have the same shape. Got"
322
- f" {quantization_params.scale.shape} and"
556
+ "scale and zero_point must have the same shape or zero_point must have"
557
+ f" only one element. Got {quantization_params.scale.shape} and"
323
558
  f" {quantization_params.zero_point.shape}"
324
559
  )
325
560
 
326
561
  tensor_rank = tensor_data.ndim
327
562
  scale_rank = quantization_params.scale.ndim
328
563
  zero_point_rank = quantization_params.zero_point.ndim
329
- if (tensor_rank != scale_rank) or (tensor_rank != zero_point_rank):
564
+ if tensor_rank != scale_rank or (tensor_rank != zero_point_rank):
330
565
  raise ValueError(
331
566
  f"Ranks of scales ({scale_rank}) and zps"
332
567
  f" ({zero_point_rank}) must be the same as the tensor rank"