ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ from typing import cast
18
+
19
+ from absl.testing import parameterized
20
+ import numpy as np
21
+
22
+ from tensorflow.python.platform import googletest
23
+ from ai_edge_quantizer import qtyping
24
+ from ai_edge_quantizer.algorithms.uniform_quantize import mse
25
+ from ai_edge_quantizer.utils import test_utils
26
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
27
+
28
+
29
+ class MseQuantizeTest(parameterized.TestCase):
30
+ """Tests for general functions for MSE."""
31
+
32
+ def setUp(self):
33
+ super().setUp()
34
+ np.random.seed(666)
35
+ self._test_model_path = os.path.join(
36
+ test_utils.get_path_to_datafile("../../tests/models"),
37
+ "conv_fc_mnist.tflite",
38
+ )
39
+ self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
40
+ # The test model has one subgraph for now.
41
+ self._graph_info = qtyping.GraphInfo(
42
+ subgraph_tensors=self._test_model.subgraphs[0].tensors,
43
+ buffers=self._test_model.buffers,
44
+ )
45
+ self._tensor_name_to_qsv = {}
46
+ subgraph0 = self._test_model.subgraphs[0]
47
+ self._subgraph_op_index = 3
48
+ self._fc_op = subgraph0.operators[self._subgraph_op_index]
49
+ self._fc_op_info = qtyping.OpInfo(
50
+ op=self._fc_op,
51
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
52
+ subgraph_op_index=self._subgraph_op_index,
53
+ op_quant_config=qtyping.OpQuantizationConfig(
54
+ weight_tensor_config=None,
55
+ ),
56
+ )
57
+
58
+ def test_get_tensor_quant_params_raises_error_with_unsupported_symmetry(self):
59
+ err_msg = "Unsupported symmetry"
60
+ test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
61
+ with self.assertRaisesWithPredicateMatch(
62
+ ValueError, lambda err: err_msg in str(err)
63
+ ):
64
+ _ = mse.get_tensor_quant_params(
65
+ op_info=self._fc_op_info,
66
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
67
+ num_bits=4,
68
+ symmetric=False,
69
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
70
+ ),
71
+ tensor_content=test_data,
72
+ )
73
+
74
+ def test_get_tensor_quant_params_raises_error_with_unsupported_granularity(
75
+ self,
76
+ ):
77
+ err_msg = "Blockwise quantization is not supported"
78
+ test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
79
+ with self.assertRaisesWithPredicateMatch(
80
+ ValueError, lambda err: err_msg in str(err)
81
+ ):
82
+ _ = mse.get_tensor_quant_params(
83
+ op_info=self._fc_op_info,
84
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
85
+ num_bits=4,
86
+ symmetric=True,
87
+ granularity=qtyping.QuantGranularity.BLOCKWISE_32,
88
+ ),
89
+ tensor_content=test_data,
90
+ )
91
+
92
+ def test_get_tensor_quant_params_succeeds_with_qsv(self):
93
+ # Fall back to naive_min_max_quantize.py for non-weight tensors.
94
+ tensor_quant_params = mse.get_tensor_quant_params(
95
+ op_info=self._fc_op_info,
96
+ tensor_quant_config=qtyping.TensorQuantizationConfig(
97
+ num_bits=8,
98
+ granularity=qtyping.QuantGranularity.TENSORWISE,
99
+ ),
100
+ tensor_qsv={
101
+ "min": np.array([-1]),
102
+ "max": np.array([1]),
103
+ },
104
+ )
105
+
106
+ self.assertIsNone(tensor_quant_params.quantized_dimension)
107
+ scale = tensor_quant_params.scale
108
+ self.assertEqual(scale.shape, (1,))
109
+ self.assertSequenceAlmostEqual(scale.flatten(), [1 / 127])
110
+
111
+ # Zero point should be zero for symmetric quantization.
112
+ zp = tensor_quant_params.zero_point
113
+ self.assertEqual(np.sum(zp), 0)
114
+ self.assertEqual(zp.shape, (1,))
115
+
116
+ def test_get_tensor_quant_params_succeeds_with_tensorwise_granularity(self):
117
+ test_data = np.array([
118
+ [-1e5, 25, -50, 75, -100, 125],
119
+ [25, -30, 50, -75, 1e5, -125],
120
+ [50, -60, 70, -80, 90, -100],
121
+ ])
122
+ tensor_config = qtyping.TensorQuantizationConfig(
123
+ num_bits=4,
124
+ symmetric=True,
125
+ granularity=qtyping.QuantGranularity.TENSORWISE,
126
+ )
127
+ fc_op_info = qtyping.OpInfo(
128
+ op=self._fc_op,
129
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
130
+ subgraph_op_index=self._subgraph_op_index,
131
+ op_quant_config=qtyping.OpQuantizationConfig(
132
+ weight_tensor_config=tensor_config,
133
+ ),
134
+ )
135
+ quant_params = mse.get_tensor_quant_params(
136
+ op_info=fc_op_info,
137
+ tensor_quant_config=tensor_config,
138
+ tensor_content=test_data,
139
+ )
140
+
141
+ with self.subTest(name="CheckQuantParamsShapes"):
142
+ self.assertEqual(quant_params.zero_point.shape, (1, 1))
143
+ self.assertEqual(quant_params.scale.shape, (1, 1))
144
+ self.assertIsNone(quant_params.quantized_dimension)
145
+ self.assertIsNotNone(quant_params.quantized_data)
146
+ self.assertTupleEqual(
147
+ cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
148
+ )
149
+
150
+ with self.subTest(name="CheckQuantParamsValues"):
151
+ self.assertTrue(np.all(quant_params.zero_point == 0))
152
+
153
+ def test_get_tensor_quant_params_succeeds_with_channelwise_granularity(self):
154
+ # Test that the call generates quant params that are appropriately shaped,
155
+ # have some clipping, and correct config values without checking the
156
+ # actual values numerically.
157
+ test_data = np.array([
158
+ [-1e5, 25, -50, 75, -100, 125],
159
+ [25, -30, 50, -75, 1e5, -125],
160
+ [50, -60, 70, -80, 90, -100],
161
+ ])
162
+ tensor_config = qtyping.TensorQuantizationConfig(
163
+ num_bits=4,
164
+ symmetric=True,
165
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
166
+ )
167
+ fc_op_info = qtyping.OpInfo(
168
+ op=self._fc_op,
169
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
170
+ subgraph_op_index=self._subgraph_op_index,
171
+ op_quant_config=qtyping.OpQuantizationConfig(
172
+ weight_tensor_config=tensor_config,
173
+ ),
174
+ )
175
+ quant_params = mse.get_tensor_quant_params(
176
+ op_info=fc_op_info,
177
+ tensor_quant_config=tensor_config,
178
+ tensor_content=test_data,
179
+ )
180
+
181
+ with self.subTest(name="CheckQuantParamsShapes"):
182
+ self.assertEqual(quant_params.zero_point.shape, (test_data.shape[0], 1))
183
+ self.assertEqual(quant_params.scale.shape, (test_data.shape[0], 1))
184
+ self.assertIsNotNone(quant_params.quantized_data)
185
+ self.assertTupleEqual(
186
+ cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
187
+ )
188
+
189
+ with self.subTest(name="CheckQuantParamsValues"):
190
+ self.assertTrue(np.all(quant_params.zero_point == 0))
191
+ self.assertEqual(quant_params.quantized_dimension, 0)
192
+
193
+
194
+ if __name__ == "__main__":
195
+ googletest.main()
@@ -15,8 +15,8 @@
15
15
 
16
16
  """Performs naive min/max uniform quantization."""
17
17
 
18
+ import dataclasses
18
19
  from typing import Any, Optional
19
- import ml_dtypes
20
20
  import numpy as np
21
21
  from ai_edge_quantizer import qtyping
22
22
  from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
@@ -75,65 +75,43 @@ def get_tensor_quant_params(
75
75
  " the ParamsGenerator."
76
76
  )
77
77
  clipping_values = None
78
- # Blockwise quantization uses float16 scale, with 7 bit mantissa,
79
- # so the maximum representable value is 65280.
80
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
81
- clipping_values = np.broadcast_to(
82
- np.array(65280), tensor_min_max["min"].shape
83
- )
84
78
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
85
79
  tensor_min_max["min"],
86
80
  tensor_min_max["max"],
87
81
  tensor_quant_config.num_bits,
88
82
  tensor_quant_config.symmetric,
83
+ tensor_quant_config.granularity,
89
84
  clipping_values,
90
85
  )
91
- # Round the scale values to 7 bit mantissa.
92
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
93
- scale = (
94
- scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
95
- )
96
- quantized_dim = None
97
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
98
- quantized_dim = common_utils.get_weight_quantized_dim(
99
- op_info, tensor_content
100
- )
101
- elif tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
102
- quantized_dim = (
103
- tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
104
- op_info.op_name
105
- ]
106
- )
86
+ quantized_dim = common_utils.get_weight_quantized_dim(
87
+ op_info, tensor_content, tensor_quant_config.granularity
88
+ )
107
89
  quant_params = qtyping.UniformQuantParams(
108
90
  scale=scale,
109
91
  zero_point=zp,
110
92
  num_bits=tensor_quant_config.num_bits,
111
93
  symmetric=tensor_quant_config.symmetric,
112
94
  quantized_dimension=quantized_dim,
113
- block_size=tensor_quant_config.block_size,
95
+ block_size=uniform_quantize_tensor.extract_block_size_from_granularity(
96
+ tensor_quant_config.granularity
97
+ ),
114
98
  )
115
99
  if tensor_content is None:
116
100
  return quant_params
117
101
 
118
- # The reshaping for blockwise quantization is unique hence we do this here
119
- # to avoid unexpected broadcast behavior downstream.
120
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
121
- quant_params = common_quantize.broadcast_scale_zp_for_blockwise(
122
- tensor_content, quant_params
123
- )
124
-
125
102
  quantized_vars = uniform_quantize_tensor.uniform_quantize(
126
- tensor_content, quant_params
103
+ tensor_content,
104
+ quant_params,
105
+ uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity),
127
106
  )
128
107
  # Update with quantized values.
129
- return qtyping.UniformQuantParams(
130
- scale=scale,
131
- zero_point=zp,
132
- num_bits=tensor_quant_config.num_bits,
133
- symmetric=tensor_quant_config.symmetric,
134
- quantized_dimension=quantized_dim,
135
- quantized_data=quantized_vars,
136
- block_size=tensor_quant_config.block_size,
108
+ return dataclasses.replace(quant_params, quantized_data=quantized_vars)
109
+
110
+
111
+ def check_if_quantized(tensor: Any) -> bool:
112
+ """Checks if the tensor is quantized."""
113
+ return (
114
+ tensor.quantization is not None and tensor.quantization.scale is not None
137
115
  )
138
116
 
139
117
 
@@ -158,6 +136,13 @@ def init_qsvs(
158
136
  op_qsvs = {}
159
137
 
160
138
  inputs_to_ignore = inputs_to_ignore or []
139
+ quantized_inputs_to_ignore = [
140
+ opr_idx
141
+ for opr_idx, tensor_idx in enumerate(op_info.op.inputs)
142
+ if check_if_quantized(graph_info.subgraph_tensors[tensor_idx])
143
+ ]
144
+ inputs_to_ignore.extend(quantized_inputs_to_ignore)
145
+
161
146
  outputs_to_ignore = outputs_to_ignore or []
162
147
  for opr_idx, tensor_idx in enumerate(op_info.op.inputs):
163
148
  if tensor_idx != -1 and opr_idx not in inputs_to_ignore:
@@ -190,6 +175,7 @@ def min_max_calibrate(
190
175
  tensor_content_map: dict[str, np.ndarray],
191
176
  inputs_to_ignore: Optional[list[int]] = None,
192
177
  outputs_to_ignore: Optional[list[int]] = None,
178
+ valid_range: tuple[float, float] = (-3e38, 3e38),
193
179
  ) -> dict[str, qtyping.QSV]:
194
180
  """Collect quantization statistics variable (QSV, e.g., min/max) for the op.
195
181
 
@@ -199,11 +185,18 @@ def min_max_calibrate(
199
185
  tensor_content_map: A map of tensor name to tensor content.
200
186
  inputs_to_ignore: Input tensor indices to ignore.
201
187
  outputs_to_ignore: Output tensor indices to ignore.
188
+ valid_range: The valid range for tensor content, excluding the boundaries.
189
+ Tensor values outside this range are ignored during calibration. Defaults
190
+ to an approximate bfloat16 range. This range is chosen to address issues
191
+ with `padv2` where a bfloat16 -inf padding constant can cause problems.
192
+ Values exceeding this range can lead to quantization issues and are
193
+ therefore excluded from min/max calibration.
202
194
 
203
195
  Returns:
204
196
  A dictionary with key as tensor name and value as the collected QSV.
205
197
  """
206
198
  op_qsvs = {}
199
+ min_val, max_val = valid_range
207
200
 
208
201
  def _collect_activation_tensor_min_max(tensor_idx):
209
202
  tensor = graph_info.subgraph_tensors[tensor_idx]
@@ -215,12 +208,25 @@ def min_max_calibrate(
215
208
  return
216
209
  tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
217
210
  tensor_content = tensor_content_map[tensor_name]
211
+ qsv_shape = (1,) * tensor_content.ndim
212
+ filter_mask = (tensor_content > min_val) & (tensor_content < max_val)
213
+ if np.any(filter_mask):
214
+ tensor_content = tensor_content[filter_mask]
215
+ # Reshape is needed to ensure the scalar min/max have the same number of
216
+ # dimensions as the input tensor array, for compatibility with subsequent
217
+ # operations.
218
218
  op_qsvs[tensor_name] = {
219
- "min": np.min(tensor_content, axis=None, keepdims=True),
220
- "max": np.max(tensor_content, axis=None, keepdims=True),
219
+ "min": np.min(tensor_content, axis=None).reshape(qsv_shape),
220
+ "max": np.max(tensor_content, axis=None).reshape(qsv_shape),
221
221
  }
222
222
 
223
223
  inputs_to_ignore = inputs_to_ignore or []
224
+ quantized_inputs_to_ignore = [
225
+ opr_idx
226
+ for opr_idx, tensor_idx in enumerate(tfl_op.inputs)
227
+ if check_if_quantized(graph_info.subgraph_tensors[tensor_idx])
228
+ ]
229
+ inputs_to_ignore.extend(quantized_inputs_to_ignore)
224
230
  outputs_to_ignore = outputs_to_ignore or []
225
231
  for i, tensor_idx in enumerate(tfl_op.inputs):
226
232
  if tensor_idx != -1 and i not in inputs_to_ignore:
@@ -17,6 +17,7 @@ import os
17
17
  from typing import cast
18
18
 
19
19
  from absl.testing import parameterized
20
+ import ml_dtypes
20
21
  import numpy as np
21
22
 
22
23
  from tensorflow.python.platform import googletest
@@ -165,8 +166,7 @@ class NaiveMinMaxQuantizeTest(parameterized.TestCase):
165
166
  weight_tensor_config = _TensorQuantConfig(
166
167
  num_bits=4,
167
168
  symmetric=True,
168
- granularity=qtyping.QuantGranularity.BLOCKWISE,
169
- block_size=2,
169
+ granularity=qtyping.QuantGranularity.BLOCKWISE_32,
170
170
  )
171
171
  op_info = qtyping.OpInfo(
172
172
  op=fc_op,
@@ -176,30 +176,69 @@ class NaiveMinMaxQuantizeTest(parameterized.TestCase):
176
176
  weight_tensor_config=weight_tensor_config,
177
177
  ),
178
178
  )
179
- test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
179
+ test_data = np.random.uniform(low=-10, high=10, size=(4, 32)).astype(
180
+ np.float32
181
+ )
180
182
  quant_params = naive_min_max_quantize.get_tensor_quant_params(
181
183
  op_info=op_info,
182
184
  tensor_quant_config=weight_tensor_config,
183
185
  tensor_content=test_data,
184
186
  )
185
- scale = quant_params.scale
186
187
  zp = quant_params.zero_point
187
- expected_scale = np.array([
188
- [1],
189
- [0.5703125],
190
- [0.5703125],
191
- [1],
192
- ])
193
- expected_zp = np.zeros([4, 1])
194
- self.assertTrue(np.array_equal(zp, expected_zp))
195
- self.assertTrue(np.array_equal(scale, expected_scale))
188
+ self.assertEqual(zp.shape, (4, 1))
189
+ self.assertTrue(np.array_equal(zp, np.zeros([4, 1])))
190
+
191
+ self.assertEqual(quant_params.scale.shape, (4, 1))
192
+ expected_scales = np.max(np.abs(test_data), axis=1, keepdims=True) / 7.0
193
+ expected_scales = (
194
+ expected_scales.astype(ml_dtypes.bfloat16)
195
+ .astype(np.float16)
196
+ .astype(np.float32)
197
+ )
198
+ self.assertTrue(np.allclose(quant_params.scale, expected_scales, atol=1e-5))
199
+
196
200
  self.assertIsNotNone(quant_params.quantized_data)
197
201
  self.assertTupleEqual(
198
202
  cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
199
203
  )
200
- self.assertEqual(quant_params.block_size, 2)
204
+ self.assertEqual(quant_params.block_size, 32)
201
205
  self.assertEqual(quant_params.quantized_dimension, 1)
202
206
 
207
+ def test_calibrate_ignores_inf_min_max(self):
208
+ """Tests that calibration ignores infinity values."""
209
+ # Sample input/output data for the fc op.
210
+ input_tensor_name = "sequential/flatten/Reshape"
211
+ output_tensor_name = (
212
+ "sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd"
213
+ )
214
+ bloat16_inf = 3.39e38
215
+ tensor_content_map = {
216
+ input_tensor_name: np.array(
217
+ [[-np.inf, 1.0, 5.0, np.inf, bloat16_inf]], dtype=np.float32
218
+ ),
219
+ output_tensor_name: np.array(
220
+ [[6.0, 7.0, -bloat16_inf, 9.0, np.inf]], dtype=np.float32
221
+ ),
222
+ }
223
+ # Read from Model Explorer.
224
+ subgraph0 = self._test_model.subgraphs[0]
225
+ fc_op = subgraph0.operators[3]
226
+ op_qsvs = naive_min_max_quantize.min_max_calibrate(
227
+ fc_op,
228
+ self._graph_info,
229
+ tensor_content_map,
230
+ inputs_to_ignore=[1, 2], # Ignore weight and bias.
231
+ outputs_to_ignore=[],
232
+ )
233
+
234
+ self.assertIn(input_tensor_name, op_qsvs)
235
+ self.assertEqual(op_qsvs[input_tensor_name]["min"], 1.0)
236
+ self.assertEqual(op_qsvs[input_tensor_name]["max"], 5.0)
237
+
238
+ self.assertIn(output_tensor_name, op_qsvs)
239
+ self.assertEqual(op_qsvs[output_tensor_name]["min"], 6.0)
240
+ self.assertEqual(op_qsvs[output_tensor_name]["max"], 9.0)
241
+
203
242
 
204
243
  if __name__ == "__main__":
205
244
  googletest.main()
@@ -102,21 +102,13 @@ def get_tensor_quant_params(
102
102
  op_info, tensor_quant_config, tensor_content, tensor_qsv
103
103
  )
104
104
 
105
- if (
106
- tensor_quant_config.granularity != qtyping.QuantGranularity.CHANNELWISE
107
- and tensor_quant_config.granularity != qtyping.QuantGranularity.TENSORWISE
108
- ):
109
- raise ValueError(
110
- f"Unsupported granularity: {tensor_quant_config.granularity}."
111
- )
112
-
113
105
  if not tensor_quant_config.symmetric:
114
106
  raise ValueError(
115
107
  f"Unsupported symmetry: {tensor_quant_config.symmetric}. OCTAV"
116
108
  " supports symmetric quantization only for now."
117
109
  )
118
110
 
119
- if tensor_qsv is None:
111
+ if not tensor_qsv:
120
112
  # We need min/max to calculate quantization parameters, which
121
113
  # should be collected during the calibration process. However,
122
114
  # weight-only and DRQ do not require calibration, thus it is
@@ -136,25 +128,41 @@ def get_tensor_quant_params(
136
128
  " the ParamsGenerator."
137
129
  )
138
130
 
139
- quantized_dim = None
140
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
141
- quantized_dim = common_utils.get_weight_quantized_dim(
142
- op_info, tensor_content
131
+ quantized_dim = common_utils.get_weight_quantized_dim(
132
+ op_info, tensor_content, tensor_quant_config.granularity
133
+ )
134
+ if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
135
+ reshaped_data, reduce_dims = (
136
+ uniform_quantize_tensor.reshape_data_for_blockwise(
137
+ tensor_content,
138
+ op_info.op_name,
139
+ tensor_quant_config.granularity,
140
+ )
141
+ )
142
+ else:
143
+ reshaped_data = tensor_content
144
+ reduce_dims = common_utils.get_reduce_dims(
145
+ quantized_dim, tensor_content.shape
143
146
  )
144
-
145
147
  clipping_constants = _guess_clipping_with_octav(
146
- tensor_content,
148
+ reshaped_data,
147
149
  tensor_quant_config.num_bits,
148
- common_utils.get_reduce_dims(quantized_dim, tensor_content.shape),
150
+ reduce_dims,
149
151
  max_iterations=10,
150
152
  exponent_divisor=3.0 if tensor_quant_config.symmetric else 12.0,
151
153
  )
154
+ # We created a new dimension in order to reduce properly for blockwise
155
+ # quantization, so we need to reshape the clipping constants back to the
156
+ # min/max shape for the next step.
157
+ if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
158
+ clipping_constants = clipping_constants.reshape(tensor_min_max["min"].shape)
152
159
 
153
160
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
154
161
  tensor_min_max["min"],
155
162
  tensor_min_max["max"],
156
163
  tensor_quant_config.num_bits,
157
164
  tensor_quant_config.symmetric,
165
+ tensor_quant_config.granularity,
158
166
  clipping_constants,
159
167
  )
160
168
 
@@ -164,11 +172,17 @@ def get_tensor_quant_params(
164
172
  num_bits=tensor_quant_config.num_bits,
165
173
  symmetric=tensor_quant_config.symmetric,
166
174
  quantized_dimension=quantized_dim,
167
- block_size=tensor_quant_config.block_size,
175
+ block_size=uniform_quantize_tensor.extract_block_size_from_granularity(
176
+ tensor_quant_config.granularity
177
+ ),
168
178
  )
169
179
 
170
180
  quantized_vars = uniform_quantize_tensor.uniform_quantize(
171
- tensor_content, quant_params
181
+ tensor_content,
182
+ quant_params,
183
+ is_blockwise_quant=uniform_quantize_tensor.is_blockwise(
184
+ tensor_quant_config.granularity
185
+ ),
172
186
  )
173
187
 
174
188
  return dataclasses.replace(quant_params, quantized_data=quantized_vars)