ai-edge-quantizer-nightly 0.4.0.dev20251008__py3-none-any.whl → 0.5.0.dev20251121__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. ai_edge_quantizer/algorithm_manager.py +5 -0
  2. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +49 -25
  3. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +1 -1
  4. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +1 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +5 -3
  6. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +1 -1
  7. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +6 -11
  8. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +18 -14
  9. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +9 -5
  10. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +1 -2
  11. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +40 -13
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +5 -2
  13. ai_edge_quantizer/algorithms/utils/common_utils.py +46 -33
  14. ai_edge_quantizer/calibrator.py +1 -50
  15. ai_edge_quantizer/calibrator_test.py +2 -67
  16. ai_edge_quantizer/default_policy.py +9 -18
  17. ai_edge_quantizer/qtyping.py +25 -3
  18. ai_edge_quantizer/quantizer.py +25 -2
  19. ai_edge_quantizer/quantizer_test.py +56 -6
  20. ai_edge_quantizer/recipe_manager_test.py +0 -6
  21. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +8 -0
  22. ai_edge_quantizer/utils/constrained_ops_utils_test.py +1 -1
  23. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
  24. ai_edge_quantizer/utils/validation_utils.py +80 -5
  25. ai_edge_quantizer/utils/validation_utils_test.py +56 -0
  26. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info}/METADATA +11 -2
  27. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info}/RECORD +30 -30
  28. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info}/WHEEL +1 -1
  29. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info/licenses}/LICENSE +0 -0
  30. {ai_edge_quantizer_nightly-0.4.0.dev20251008.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20251121.dist-info}/top_level.txt +0 -0
@@ -132,6 +132,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
132
132
  _TFLOpName.EQUAL: common_quantize.materialize_equal,
133
133
  _TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
134
134
  _TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
135
+ _TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
135
136
  }
136
137
  for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
137
138
  register_quantized_op(
@@ -286,6 +287,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
286
287
  _TFLOpName.EQUAL: common_quantize.materialize_equal,
287
288
  _TFLOpName.NOT_EQUAL: common_quantize.materialize_not_equal,
288
289
  _TFLOpName.MIRROR_PAD: common_quantize.materialize_mirror_pad,
290
+ _TFLOpName.SPACE_TO_DEPTH: common_quantize.materialize_space_to_depth,
289
291
  })
290
292
 
291
293
  for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
@@ -380,6 +382,9 @@ register_config_check_policy_func(
380
382
  _MSE_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
381
383
  _TFLOpName.FULLY_CONNECTED: common_quantize.materialize_fc_conv,
382
384
  _TFLOpName.EMBEDDING_LOOKUP: common_quantize.materialize_embedding_lookup,
385
+ _TFLOpName.CONV_2D: common_quantize.materialize_fc_conv,
386
+ _TFLOpName.DEPTHWISE_CONV_2D: common_quantize.materialize_fc_conv,
387
+ _TFLOpName.CONV_2D_TRANSPOSE: common_quantize.materialize_conv2d_transpose,
383
388
  })
384
389
  for (
385
390
  op_name,
@@ -776,6 +776,33 @@ def materialize_mirror_pad(
776
776
  )
777
777
 
778
778
 
779
+ def materialize_space_to_depth(
780
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
781
+ op_info: qtyping.OpInfo,
782
+ graph_info: qtyping.GraphInfo,
783
+ tensor_name_to_qsv: dict[str, Any],
784
+ ) -> list[qtyping.TensorTransformationParams]:
785
+ """Materialize tensors in tfl.space_to_depth.
786
+
787
+ Args:
788
+ get_tensor_quant_params_fn: Function to get quantization parameters for the
789
+ tensor.
790
+ op_info: Aggregated information about the op (e.g., quantization config).
791
+ graph_info: Graph information needed to perform quantization for the op.
792
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
793
+
794
+ Returns:
795
+ A list of `qtyping.TensorTransformationParams` for the tensors in the op.
796
+ """
797
+ return common_utils.materialize_standard_op(
798
+ op_info,
799
+ graph_info,
800
+ tensor_name_to_qsv,
801
+ get_tensor_quant_params_fn,
802
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
803
+ )
804
+
805
+
779
806
  def materialize_squared_difference(
780
807
  get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
781
808
  op_info: qtyping.OpInfo,
@@ -1138,39 +1165,36 @@ def init_tensor_min_max(
1138
1165
  A dictionary containing the min/max values for the tensor, or an empty
1139
1166
  dictionary if the tensor data is None.
1140
1167
  """
1141
- if tensor_data is None:
1168
+ weight_tensor_config = op_info.op_quant_config.weight_tensor_config
1169
+ if tensor_data is None or weight_tensor_config is None:
1142
1170
  return {}
1143
1171
  else:
1144
- weight_tensor_config = op_info.op_quant_config.weight_tensor_config
1145
- quantized_dim = None
1146
- if weight_tensor_config is not None and (
1147
- weight_tensor_config.granularity == qtyping.QuantGranularity.CHANNELWISE
1148
- ):
1172
+ # Get reduce dimension for min/max calculation based on quantization
1173
+ # granularity.
1174
+ granularity = weight_tensor_config.granularity
1175
+ if granularity == qtyping.QuantGranularity.TENSORWISE:
1176
+ reduce_dims = None
1177
+ keep_dims = True
1178
+ elif granularity == qtyping.QuantGranularity.CHANNELWISE:
1149
1179
  quantized_dim = common_utils.get_weight_quantized_dim(
1150
1180
  op_info, tensor_data, weight_tensor_config.granularity
1151
1181
  )
1152
- if (
1153
- weight_tensor_config is not None
1154
- and weight_tensor_config.granularity
1155
- == qtyping.QuantGranularity.BLOCKWISE
1156
- ):
1157
- reshaped_data, reduce_dims = (
1182
+ reduce_dims = common_utils.get_reduce_dims(
1183
+ quantized_dim, tensor_data.shape
1184
+ )
1185
+ keep_dims = True
1186
+ elif uniform_quantize_tensor.is_blockwise(granularity):
1187
+ tensor_data, reduce_dims = (
1158
1188
  uniform_quantize_tensor.reshape_data_for_blockwise(
1159
1189
  tensor_data,
1160
1190
  op_info.op_name,
1161
- weight_tensor_config.block_size,
1191
+ granularity,
1162
1192
  )
1163
1193
  )
1164
- return {
1165
- "min": np.min(reshaped_data, axis=reduce_dims, keepdims=False),
1166
- "max": np.max(reshaped_data, axis=reduce_dims, keepdims=False),
1167
- }
1168
-
1194
+ keep_dims = False
1169
1195
  else:
1170
- reduce_dims = common_utils.get_reduce_dims(
1171
- quantized_dim, tensor_data.shape
1172
- )
1173
- return {
1174
- "min": np.min(tensor_data, axis=reduce_dims, keepdims=True),
1175
- "max": np.max(tensor_data, axis=reduce_dims, keepdims=True),
1176
- }
1196
+ raise ValueError(f"Unsupported granularity: {granularity}")
1197
+ return {
1198
+ "min": np.min(tensor_data, axis=reduce_dims, keepdims=keep_dims),
1199
+ "max": np.max(tensor_data, axis=reduce_dims, keepdims=keep_dims),
1200
+ }
@@ -158,7 +158,7 @@ def get_tensor_quant_params(
158
158
  op_info, tensor_quant_config, tensor_content, tensor_qsv
159
159
  )
160
160
 
161
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
161
+ if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
162
162
  raise ValueError(
163
163
  "Blockwise quantization is not supported for dequantized weight"
164
164
  " recovery."
@@ -147,8 +147,7 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
147
147
  weight_tensor_config=_TensorQuantConfig(
148
148
  num_bits=8,
149
149
  symmetric=True,
150
- granularity=qtyping.QuantGranularity.BLOCKWISE,
151
- block_size=32,
150
+ granularity=qtyping.QuantGranularity.BLOCKWISE_32,
152
151
  ),
153
152
  ),
154
153
  )
@@ -55,7 +55,7 @@ def get_tensor_quant_params(
55
55
  ValueError: `tensor_qsv` must contain min/max values, or `tensor_content`
56
56
  must be provided so that they can be inferred.
57
57
  """
58
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
58
+ if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
59
59
  raise ValueError(
60
60
  "Blockwise quantization is not supported for MSE quantization."
61
61
  )
@@ -113,13 +113,15 @@ def get_tensor_quant_params(
113
113
  num_bits=tensor_quant_config.num_bits,
114
114
  symmetric=tensor_quant_config.symmetric,
115
115
  quantized_dimension=quantized_dim,
116
- block_size=tensor_quant_config.block_size,
116
+ block_size=uniform_quantize_tensor.extract_block_size_from_granularity(
117
+ tensor_quant_config.granularity
118
+ ),
117
119
  )
118
120
 
119
121
  quantized_vars = uniform_quantize_tensor.uniform_quantize(
120
122
  tensor_content,
121
123
  quant_params,
122
- tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
124
+ uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity),
123
125
  )
124
126
 
125
127
  return dataclasses.replace(quant_params, quantized_data=quantized_vars)
@@ -84,7 +84,7 @@ class MseQuantizeTest(parameterized.TestCase):
84
84
  tensor_quant_config=qtyping.TensorQuantizationConfig(
85
85
  num_bits=4,
86
86
  symmetric=True,
87
- granularity=qtyping.QuantGranularity.BLOCKWISE,
87
+ granularity=qtyping.QuantGranularity.BLOCKWISE_32,
88
88
  ),
89
89
  tensor_content=test_data,
90
90
  )
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Performs naive min/max uniform quantization."""
17
17
 
18
+ import dataclasses
18
19
  from typing import Any, Optional
19
20
  import numpy as np
20
21
  from ai_edge_quantizer import qtyping
@@ -91,7 +92,9 @@ def get_tensor_quant_params(
91
92
  num_bits=tensor_quant_config.num_bits,
92
93
  symmetric=tensor_quant_config.symmetric,
93
94
  quantized_dimension=quantized_dim,
94
- block_size=tensor_quant_config.block_size,
95
+ block_size=uniform_quantize_tensor.extract_block_size_from_granularity(
96
+ tensor_quant_config.granularity
97
+ ),
95
98
  )
96
99
  if tensor_content is None:
97
100
  return quant_params
@@ -99,18 +102,10 @@ def get_tensor_quant_params(
99
102
  quantized_vars = uniform_quantize_tensor.uniform_quantize(
100
103
  tensor_content,
101
104
  quant_params,
102
- tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
105
+ uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity),
103
106
  )
104
107
  # Update with quantized values.
105
- return qtyping.UniformQuantParams(
106
- scale=scale,
107
- zero_point=zp,
108
- num_bits=tensor_quant_config.num_bits,
109
- symmetric=tensor_quant_config.symmetric,
110
- quantized_dimension=quantized_dim,
111
- quantized_data=quantized_vars,
112
- block_size=tensor_quant_config.block_size,
113
- )
108
+ return dataclasses.replace(quant_params, quantized_data=quantized_vars)
114
109
 
115
110
 
116
111
  # TODO: b/333731147 - Use named tuple to store min/max.
@@ -17,6 +17,7 @@ import os
17
17
  from typing import cast
18
18
 
19
19
  from absl.testing import parameterized
20
+ import ml_dtypes
20
21
  import numpy as np
21
22
 
22
23
  from tensorflow.python.platform import googletest
@@ -165,8 +166,7 @@ class NaiveMinMaxQuantizeTest(parameterized.TestCase):
165
166
  weight_tensor_config = _TensorQuantConfig(
166
167
  num_bits=4,
167
168
  symmetric=True,
168
- granularity=qtyping.QuantGranularity.BLOCKWISE,
169
- block_size=2,
169
+ granularity=qtyping.QuantGranularity.BLOCKWISE_32,
170
170
  )
171
171
  op_info = qtyping.OpInfo(
172
172
  op=fc_op,
@@ -176,28 +176,32 @@ class NaiveMinMaxQuantizeTest(parameterized.TestCase):
176
176
  weight_tensor_config=weight_tensor_config,
177
177
  ),
178
178
  )
179
- test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
179
+ test_data = np.random.uniform(low=-10, high=10, size=(4, 32)).astype(
180
+ np.float32
181
+ )
180
182
  quant_params = naive_min_max_quantize.get_tensor_quant_params(
181
183
  op_info=op_info,
182
184
  tensor_quant_config=weight_tensor_config,
183
185
  tensor_content=test_data,
184
186
  )
185
- scale = quant_params.scale
186
187
  zp = quant_params.zero_point
187
- expected_scale = np.array([
188
- [1],
189
- [0.5703125],
190
- [0.5703125],
191
- [1],
192
- ])
193
- expected_zp = np.zeros([4, 1])
194
- self.assertTrue(np.array_equal(zp, expected_zp))
195
- self.assertTrue(np.array_equal(scale, expected_scale))
188
+ self.assertEqual(zp.shape, (4, 1))
189
+ self.assertTrue(np.array_equal(zp, np.zeros([4, 1])))
190
+
191
+ self.assertEqual(quant_params.scale.shape, (4, 1))
192
+ expected_scales = np.max(np.abs(test_data), axis=1, keepdims=True) / 7.0
193
+ expected_scales = (
194
+ expected_scales.astype(ml_dtypes.bfloat16)
195
+ .astype(np.float16)
196
+ .astype(np.float32)
197
+ )
198
+ self.assertTrue(np.allclose(quant_params.scale, expected_scales, atol=1e-5))
199
+
196
200
  self.assertIsNotNone(quant_params.quantized_data)
197
201
  self.assertTupleEqual(
198
202
  cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
199
203
  )
200
- self.assertEqual(quant_params.block_size, 2)
204
+ self.assertEqual(quant_params.block_size, 32)
201
205
  self.assertEqual(quant_params.quantized_dimension, 1)
202
206
 
203
207
  def test_calibrate_ignores_inf_min_max(self):
@@ -131,12 +131,12 @@ def get_tensor_quant_params(
131
131
  quantized_dim = common_utils.get_weight_quantized_dim(
132
132
  op_info, tensor_content, tensor_quant_config.granularity
133
133
  )
134
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
134
+ if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
135
135
  reshaped_data, reduce_dims = (
136
136
  uniform_quantize_tensor.reshape_data_for_blockwise(
137
137
  tensor_content,
138
138
  op_info.op_name,
139
- tensor_quant_config.block_size,
139
+ tensor_quant_config.granularity,
140
140
  )
141
141
  )
142
142
  else:
@@ -154,7 +154,7 @@ def get_tensor_quant_params(
154
154
  # We created a new dimension in order to reduce properly for blockwise
155
155
  # quantization, so we need to reshape the clipping constants back to the
156
156
  # min/max shape for the next step.
157
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
157
+ if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
158
158
  clipping_constants = clipping_constants.reshape(tensor_min_max["min"].shape)
159
159
 
160
160
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
@@ -172,13 +172,17 @@ def get_tensor_quant_params(
172
172
  num_bits=tensor_quant_config.num_bits,
173
173
  symmetric=tensor_quant_config.symmetric,
174
174
  quantized_dimension=quantized_dim,
175
- block_size=tensor_quant_config.block_size,
175
+ block_size=uniform_quantize_tensor.extract_block_size_from_granularity(
176
+ tensor_quant_config.granularity
177
+ ),
176
178
  )
177
179
 
178
180
  quantized_vars = uniform_quantize_tensor.uniform_quantize(
179
181
  tensor_content,
180
182
  quant_params,
181
- tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
183
+ is_blockwise_quant=uniform_quantize_tensor.is_blockwise(
184
+ tensor_quant_config.granularity
185
+ ),
182
186
  )
183
187
 
184
188
  return dataclasses.replace(quant_params, quantized_data=quantized_vars)
@@ -196,8 +196,7 @@ class OctavQuantizeTest(parameterized.TestCase):
196
196
  tensor_config = qtyping.TensorQuantizationConfig(
197
197
  num_bits=4,
198
198
  symmetric=True,
199
- granularity=qtyping.QuantGranularity.BLOCKWISE,
200
- block_size=32,
199
+ granularity=qtyping.QuantGranularity.BLOCKWISE_32,
201
200
  )
202
201
  fc_op_info = qtyping.OpInfo(
203
202
  op=self._fc_op,
@@ -29,6 +29,11 @@ class IntType:
29
29
  signed: bool
30
30
 
31
31
 
32
+ def is_blockwise(granularity: qtyping.QuantGranularity) -> bool:
33
+ """Checks if the quantization granularity is blockwise."""
34
+ return "BLOCKWISE" in str(granularity)
35
+
36
+
32
37
  def get_quantized_range(qtype: IntType) -> tuple[float, float]:
33
38
  """Calculates range of the quantized type."""
34
39
  if qtype.signed:
@@ -40,6 +45,22 @@ def get_quantized_range(qtype: IntType) -> tuple[float, float]:
40
45
  return float(qmin), float(qmax)
41
46
 
42
47
 
48
+ def extract_block_size_from_granularity(
49
+ granularity: qtyping.QuantGranularity,
50
+ ) -> int:
51
+ """Get the block size for blockwise quantization."""
52
+ if granularity == qtyping.QuantGranularity.BLOCKWISE_32:
53
+ return 32
54
+ elif granularity == qtyping.QuantGranularity.BLOCKWISE_64:
55
+ return 64
56
+ elif granularity == qtyping.QuantGranularity.BLOCKWISE_128:
57
+ return 128
58
+ elif granularity == qtyping.QuantGranularity.BLOCKWISE_256:
59
+ return 256
60
+ else:
61
+ return 0
62
+
63
+
43
64
  def _round_and_clip(
44
65
  tensor: np.ndarray, qtype: IntType, narrow: bool
45
66
  ) -> np.ndarray:
@@ -157,14 +178,16 @@ def _get_tensor_shape_for_blockwise(
157
178
 
158
179
 
159
180
  def reshape_data_for_blockwise(
160
- tensor_data: np.ndarray, op_name: qtyping.TFLOperationName, block_size: int
181
+ tensor_data: np.ndarray,
182
+ op_name: qtyping.TFLOperationName,
183
+ granularity: qtyping.QuantGranularity,
161
184
  ) -> tuple[np.ndarray, int]:
162
185
  """Reshapes data for blockwise quantization.
163
186
 
164
187
  Args:
165
188
  tensor_data: The original tensor data.
166
189
  op_name: The name of the TFL op.
167
- block_size: The size of the block.
190
+ granularity: The quantization granularity for the tensor.
168
191
 
169
192
  Returns:
170
193
  A tuple containing the reshaped tensor data and the new reduce dimension.
@@ -172,11 +195,11 @@ def reshape_data_for_blockwise(
172
195
  quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
173
196
  op_name
174
197
  ]
198
+ block_size = extract_block_size_from_granularity(granularity)
175
199
  new_shape = _get_tensor_shape_for_blockwise(
176
200
  tensor_data.shape, quantized_dim, block_size
177
201
  )
178
- reshaped_data = tensor_data.reshape(new_shape)
179
- return reshaped_data, quantized_dim + 1
202
+ return tensor_data.reshape(new_shape), quantized_dim + 1
180
203
 
181
204
 
182
205
  def _broadcast_scale_zp_for_blockwise(
@@ -233,21 +256,21 @@ def _broadcast_scale_zp_for_blockwise(
233
256
  def uniform_quantize(
234
257
  tensor_data: np.ndarray,
235
258
  quantization_params: qtyping.UniformQuantParams,
236
- is_blockwise: bool = False,
259
+ is_blockwise_quant: bool = False,
237
260
  ):
238
261
  """Uniform quantize a tensor.
239
262
 
240
263
  Args:
241
264
  tensor_data: The tensor to be quantized.
242
265
  quantization_params: The quantization parameters.
243
- is_blockwise: Whether the tensor is blockwise quantized.
266
+ is_blockwise_quant: Whether the tensor is blockwise quantized.
244
267
 
245
268
  Returns:
246
269
  The quantized tensor.
247
270
  """
248
271
  # The reshaping for blockwise quantization is unique hence we do this here
249
272
  # to avoid unexpected broadcast behavior downstream.
250
- if is_blockwise:
273
+ if is_blockwise_quant:
251
274
  quantization_params = _broadcast_scale_zp_for_blockwise(
252
275
  tensor_data, quantization_params
253
276
  )
@@ -381,10 +404,13 @@ def symmetric_quantize_bias_tensor(
381
404
  quantized_vars = uniform_quantize(bias_content, bias_quant_params)
382
405
  if check_error:
383
406
  dequantized_bias = uniform_dequantize(quantized_vars, bias_quant_params)
384
- quantization_error = np.abs(dequantized_bias - bias_content)
385
- if np.any(quantization_error > effective_output_scale):
407
+ max_quant_error = np.max(np.abs(dequantized_bias - bias_content))
408
+ error_tolerance = np.maximum(1e-6, np.max(effective_output_scale))
409
+ if max_quant_error > error_tolerance:
386
410
  raise ValueError(
387
- "Quantization error is too large for bias tensor quantization."
411
+ "Quantization error is too large for bias tensor quantization. Max"
412
+ f" quantization error is {max_quant_error}, which exceed"
413
+ f" the threshold {error_tolerance}"
388
414
  )
389
415
 
390
416
  # Save the int32 quantized bias as int64 if the input tensor is quantized to
@@ -432,17 +458,18 @@ def tensor_zp_scale_from_min_max(
432
458
  Returns:
433
459
  The zero point and scale of the tensor.
434
460
  """
461
+
435
462
  # TODO: b/332574603 - support unsigned data type.
436
463
  qtype = IntType(
437
464
  num_bits,
438
465
  signed=True,
439
466
  )
440
467
  qmin, qmax = get_quantized_range(qtype)
441
- min_bound = 1e-4 # 1e-6 precision for int8 and 1e-8 for int16.
468
+ min_bound = 1e-9 # Avoid zero scale.
442
469
  pos_clipping_values = None if clipping_values is None else clipping_values
443
470
  neg_clipping_values = None if clipping_values is None else -clipping_values
444
471
 
445
- if granularity == qtyping.QuantGranularity.BLOCKWISE:
472
+ if is_blockwise(granularity):
446
473
  # Blockwise quantization uses float16 scale,
447
474
  # with 7 bit mantissa, so the maximum scale value is 65280 and maximum
448
475
  # representable range is [-65280 * (2 ** num_bits),
@@ -490,7 +517,7 @@ def tensor_zp_scale_from_min_max(
490
517
  zp = qmin - bound_min / scale
491
518
  zp = np.rint(zp)
492
519
 
493
- if granularity == qtyping.QuantGranularity.BLOCKWISE:
520
+ if is_blockwise(granularity):
494
521
  # Round the scale values to 7 bit mantissa.
495
522
  scale = (
496
523
  scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
@@ -222,7 +222,7 @@ class TensorUtilsTest(parameterized.TestCase):
222
222
  zero_point=np.array([-6]),
223
223
  symmetric=True,
224
224
  ),
225
- is_blockwise=True,
225
+ is_blockwise_quant=True,
226
226
  )
227
227
 
228
228
  @parameterized.parameters(
@@ -431,7 +431,10 @@ class TensorUtilsTest(parameterized.TestCase):
431
431
  )
432
432
  # This will result in quantized bias of 3e9, which is larger than int32 max.
433
433
  bias_tensor_data = np.array([3e7])
434
- with self.assertRaises(ValueError):
434
+ with self.assertRaisesRegex(
435
+ ValueError,
436
+ "Quantization error is too large for bias tensor quantization.",
437
+ ):
435
438
  uniform_quantize_tensor.symmetric_quantize_bias_tensor(
436
439
  bias_tensor_data,
437
440
  input_quant_config,
@@ -51,8 +51,9 @@ def check_subchannel_config(
51
51
  """Checks the op quantization config for subchannel quantization."""
52
52
  if (
53
53
  op_quant_config.weight_tensor_config is not None
54
- and op_quant_config.weight_tensor_config.granularity
55
- == qtyping.QuantGranularity.BLOCKWISE
54
+ and uniform_quantize_tensor.is_blockwise(
55
+ op_quant_config.weight_tensor_config.granularity
56
+ )
56
57
  ):
57
58
  if op_name not in _SUPPORTED_SUBCHANNEL_OPS:
58
59
  raise ValueError(f"Unsupported op for blockwise quantization: {op_name}.")
@@ -66,10 +67,6 @@ def check_subchannel_config(
66
67
  "Blockwise quantization does not support for asymmetric weight"
67
68
  " quantization."
68
69
  )
69
- if op_quant_config.weight_tensor_config.block_size <= 0:
70
- raise ValueError(
71
- "Blockwise quantization must have a non-zero block size."
72
- )
73
70
 
74
71
 
75
72
  def check_if_valid_op_config(
@@ -369,11 +366,28 @@ def _materialize_standard_op_with_same_as_input_scale(
369
366
 
370
367
  # Change output qsv to be the same as input qsv. This is safe since TFL
371
368
  # subgraph is acyclic.
372
- input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name]
373
- for output_tensor in output_tensors:
374
- tensor_name_to_qsv[tfl_flatbuffer_utils.get_tensor_name(output_tensor)] = (
375
- input_tensor_qsv
369
+ input_tensor_qsv = tensor_name_to_qsv.get(
370
+ input_tensor_params.tensor_name, None
371
+ )
372
+ if input_tensor_qsv is None:
373
+ input_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
374
+ input_tensors[0], graph_info.buffers
376
375
  )
376
+ # If the input tensor is a constant tensor without qsv, compute qsv from
377
+ # its quant params.
378
+ if input_tensor_data is None:
379
+ # If the only input to an op that needs to match input to
380
+ # output has no qsv and is not a constant tensor, then this is an error.
381
+ raise ValueError(
382
+ "Input tensor qsv is None for tensor"
383
+ f" {input_tensor_params.tensor_name}."
384
+ )
385
+ min_val, max_val = _get_min_max_from_quant_params(input_quant_params)
386
+ input_tensor_qsv = {"min": min_val, "max": max_val}
387
+ for output_tensor in output_tensors:
388
+ tensor_name_to_qsv[
389
+ tfl_flatbuffer_utils.get_tensor_name(output_tensor)
390
+ ] = input_tensor_qsv
377
391
 
378
392
  return op_tensor_params
379
393
 
@@ -697,6 +711,26 @@ def _add_non_match_tensors_to_ignored_lists(
697
711
  return inputs_to_ignore, outputs_to_ignore
698
712
 
699
713
 
714
+ def _get_min_max_from_quant_params(
715
+ quant_params: qtyping.UniformQuantParams,
716
+ ) -> tuple[np.ndarray, np.ndarray]:
717
+ """Recalculate min/max from tensor quantization params."""
718
+ q_min, q_max = uniform_quantize_tensor.get_quantized_range(
719
+ _IntType(quant_params.num_bits, True)
720
+ )
721
+ float_min = uniform_quantize_tensor.uniform_dequantize(
722
+ np.array(q_min), quant_params
723
+ )
724
+ float_max = uniform_quantize_tensor.uniform_dequantize(
725
+ np.array(q_max), quant_params
726
+ )
727
+ # We use qmax values to compute scale for symmetric quantization (see
728
+ # uniform_quantize_tensor.tensor_zp_scale_from_min_max).
729
+ if quant_params.symmetric:
730
+ float_min = -float_max
731
+ return float_min, float_max
732
+
733
+
700
734
  def materialize_standard_op(
701
735
  op_info: qtyping.OpInfo,
702
736
  graph_info: qtyping.GraphInfo,
@@ -863,8 +897,6 @@ def materialize_op_with_output_activation_constraint(
863
897
  output_tensor_params.producer = op_tensor_params
864
898
  # Update the tensor_name_to_qsv map using the output activation constraints.
865
899
  min_val, max_val = _get_min_max_from_quant_params(
866
- activation_num_bits,
867
- activation_tensor_config.symmetric,
868
900
  fixed_quant_params,
869
901
  )
870
902
  tensor_name_to_qsv[output_tensor_params.tensor_name]["min"] = min_val
@@ -993,7 +1025,7 @@ def get_weight_quantized_dim(
993
1025
  quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
994
1026
  op_info.op_name, None
995
1027
  )
996
- elif granularity == qtyping.QuantGranularity.BLOCKWISE:
1028
+ elif uniform_quantize_tensor.is_blockwise(granularity):
997
1029
  quantized_dim = (
998
1030
  tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
999
1031
  op_info.op_name
@@ -1027,23 +1059,4 @@ def get_bmm_weight_quantized_dim(
1027
1059
  return rank - 1
1028
1060
 
1029
1061
 
1030
- def _get_min_max_from_quant_params(
1031
- num_bits: int,
1032
- symmetric: bool,
1033
- tensor_params: qtyping.UniformQuantParams,
1034
- ) -> tuple[float, float]:
1035
- """Recalculate min/max from tensor quantization params."""
1036
- q_min, q_max = uniform_quantize_tensor.get_quantized_range(
1037
- _IntType(num_bits, True)
1038
- )
1039
- float_min = uniform_quantize_tensor.uniform_dequantize(
1040
- np.array(q_min), tensor_params
1041
- )
1042
- float_max = uniform_quantize_tensor.uniform_dequantize(
1043
- np.array(q_max), tensor_params
1044
- )
1045
- # We use qmax values to compute scale for symmetric quantization (see
1046
- # uniform_quantize_tensor.tensor_zp_scale_from_min_max).
1047
- if symmetric:
1048
- float_min = -float_max
1049
- return (float_min, float_max)
1062
+