ai-edge-quantizer-nightly 0.1.0.dev20250513__py3-none-any.whl → 0.1.0.dev20250515__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -726,18 +726,29 @@ def _get_tensor_shape_for_blockwise(
726
726
 
727
727
 
728
728
  def _reshape_data_for_blockwise(
729
- tensor_data: np.ndarray, quantized_dim: int, block_size: int
729
+ tensor_data: np.ndarray,
730
+ quantized_dim: int,
731
+ block_size: int,
730
732
  ) -> tuple[np.ndarray, int]:
731
733
  """Reshapes data for blockwise quantization.
732
734
 
733
735
  Args:
734
736
  tensor_data: The original tensor data.
735
737
  quantized_dim: The dimension to be quantized blockwise.
736
- block_size: The size of the block.
738
+ block_size: The size of the block. `block_size must be a multiple of 32. `
739
+ `The tensor quantized dimension shape must be divisible by block_size.
737
740
 
738
741
  Returns:
739
742
  A tuple containing the reshaped tensor data and the new reduce dimension.
740
743
  """
744
+
745
+ # TODO: b/417508018 - create AEQ specific error class instead of
746
+ # using generic ValueError.
747
+ if tensor_data.shape[quantized_dim] % block_size != 0:
748
+ raise ValueError(
749
+ "Tensor quantization dimension must be divisible by block size for"
750
+ " blockwise quantization."
751
+ )
741
752
  new_shape = _get_tensor_shape_for_blockwise(
742
753
  tensor_data.shape, quantized_dim, block_size
743
754
  )
@@ -818,22 +829,19 @@ def init_tensor_min_max(
818
829
  weight_tensor_config.granularity == qtyping.QuantGranularity.CHANNELWISE
819
830
  ):
820
831
  quantized_dim = common_utils.get_weight_quantized_dim(
821
- op_info, tensor_data
832
+ op_info, tensor_data, weight_tensor_config.granularity
822
833
  )
823
834
  if (
824
835
  weight_tensor_config is not None
825
836
  and weight_tensor_config.granularity
826
837
  == qtyping.QuantGranularity.BLOCKWISE
827
838
  ):
828
- quantized_dim = (
829
- tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
830
- op_info.op_name
831
- ]
832
- )
833
- reshaped_data, reduce_dims = _reshape_data_for_blockwise(
834
- tensor_data,
835
- quantized_dim,
836
- weight_tensor_config.block_size,
839
+ reshaped_data, reduce_dims = (
840
+ uniform_quantize_tensor.reshape_data_for_blockwise(
841
+ tensor_data,
842
+ op_info.op_name,
843
+ weight_tensor_config.block_size,
844
+ )
837
845
  )
838
846
  return {
839
847
  "min": np.min(reshaped_data, axis=reduce_dims, keepdims=False),
@@ -31,8 +31,7 @@ _TensorQuantConfig = qtyping.TensorQuantizationConfig
31
31
 
32
32
 
33
33
  class CommonQuantizeTest(parameterized.TestCase):
34
- """Tests for general quantize functions.
35
- """
34
+ """Tests for general quantize functions."""
36
35
 
37
36
  def setUp(self):
38
37
  super().setUp()
@@ -69,6 +68,34 @@ class CommonQuantizeTest(parameterized.TestCase):
69
68
  default_policy.DEFAULT_CONFIG_CHECK_POLICY,
70
69
  )
71
70
 
71
+ def test_reshape_data_for_blockwise_raises_error_when_quantized_dim_not_divisible_by_block_size(
72
+ self,
73
+ ):
74
+ tensor_data = np.ones((24, 128), dtype=np.float32)
75
+ block_size = 256
76
+ quantized_dim = 1
77
+ with self.assertRaisesWithPredicateMatch(
78
+ ValueError,
79
+ lambda err: (
80
+ "Tensor quantization dimension must be divisible by block"
81
+ " size for blockwise quantization."
82
+ )
83
+ in str(err),
84
+ ):
85
+ common_quantize._reshape_data_for_blockwise(
86
+ tensor_data, quantized_dim, block_size
87
+ )
88
+
89
+ def test_reshape_data_for_blockwise_returns_correct_values(self):
90
+ tensor_data = np.ones((24, 128), dtype=np.float32)
91
+ block_size = 32
92
+ quantized_dim = 1
93
+ new_tensor_data, reduce_dim = common_quantize._reshape_data_for_blockwise(
94
+ tensor_data, quantized_dim, block_size
95
+ )
96
+ self.assertEqual(new_tensor_data.shape, (24, 4, 32))
97
+ self.assertEqual(reduce_dim, 2)
98
+
72
99
 
73
100
  if __name__ == "__main__":
74
101
  googletest.main()
@@ -168,11 +168,9 @@ def get_tensor_quant_params(
168
168
  "Only symmetric weights are supported for dequantized weight recovery."
169
169
  )
170
170
 
171
- quantized_dim = None
172
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
173
- quantized_dim = common_utils.get_weight_quantized_dim(
174
- op_info, tensor_content
175
- )
171
+ quantized_dim = common_utils.get_weight_quantized_dim(
172
+ op_info, tensor_content, tensor_quant_config.granularity
173
+ )
176
174
 
177
175
  zp, scale = get_zp_scale_from_dequantized_symmetric_weights(
178
176
  dequant_vals=tensor_content,
@@ -78,13 +78,16 @@ def _rotate_with_diagonal_hadamard(
78
78
  # of 2 to calculate this factor.
79
79
  hadamard_size = np.gcd(tensor_content.shape[axis], 2 ** 30)
80
80
  diagonal_size = tensor_content.shape[axis] // hadamard_size
81
+ output_size = tensor_content.shape[1 - axis]
81
82
  random_vector = np.ones(hadamard_size, dtype=np.int8)
82
83
 
83
84
  # Use a canonical Hadamard matrix.
84
85
  hadamard = _make_hadamard_matrix(hadamard_size)
85
- hadamard_diagonal = np.kron(np.eye(diagonal_size), hadamard)
86
- w_rotated = np.einsum("ij,aj->ai", hadamard_diagonal, tensor_content)
87
- return w_rotated, hadamard_size, random_vector
86
+ reshaped_tensor = tensor_content.reshape(
87
+ diagonal_size, output_size, hadamard_size
88
+ )
89
+ w_rotated = np.einsum("jk,ilk->ilj", hadamard, reshaped_tensor)
90
+ return w_rotated.reshape(tensor_content.shape), hadamard_size, random_vector
88
91
 
89
92
 
90
93
  def get_tensor_quant_params(
@@ -128,7 +131,9 @@ def get_tensor_quant_params(
128
131
  f" {tensor_quant_config.granularity} granularity."
129
132
  )
130
133
 
131
- quantized_dim = common_utils.get_weight_quantized_dim(op_info, tensor_content)
134
+ quantized_dim = common_utils.get_weight_quantized_dim(
135
+ op_info, tensor_content, tensor_quant_config.granularity
136
+ )
132
137
  if quantized_dim != 0:
133
138
  raise ValueError(
134
139
  f"Unsupported quantized dimension: {quantized_dim}. Only 0 is"
@@ -119,6 +119,55 @@ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
119
119
  if qparams.hadamard is not None:
120
120
  self.assertEqual(qparams.hadamard.hadamard_size, 32)
121
121
 
122
+ def test_get_tensor_quant_params_golden_1(self):
123
+ test_data = np.ones((6, 6))
124
+ # expected:
125
+ # [[127 0 127 0 127 0]
126
+ # [127 0 127 0 127 0]
127
+ # [127 0 127 0 127 0]
128
+ # [127 0 127 0 127 0]
129
+ # [127 0 127 0 127 0]
130
+ # [127 0 127 0 127 0]]
131
+ expected = np.tile([127, 0], [6, 3])
132
+ qparams = hadamard_rotation.get_tensor_quant_params(
133
+ self._op_info,
134
+ self._op_info.op_quant_config.weight_tensor_config,
135
+ test_data,
136
+ self._tensor_name_to_qsv,
137
+ )
138
+ self.assertIsNotNone(qparams.quantized_data)
139
+ np.testing.assert_array_equal(
140
+ np.array(qparams.quantized_data), expected
141
+ )
142
+
143
+ def test_get_tensor_quant_params_golden_2(self):
144
+ # test_data:
145
+ # [[1 2 1 2 1 2]
146
+ # [3 4 3 4 3 4]
147
+ # [1 2 1 2 1 2]
148
+ # [3 4 3 4 3 4]
149
+ # [1 2 1 2 1 2]
150
+ # [3 4 3 4 3 4]]
151
+ test_data = np.tile([[1, 2], [3, 4]], [3, 3])
152
+ # expected:
153
+ # [[127 -42 127 -42 127 -42]
154
+ # [127 -18 127 -18 127 -18]
155
+ # [127 -42 127 -42 127 -42]
156
+ # [127 -18 127 -18 127 -18]
157
+ # [127 -42 127 -42 127 -42]
158
+ # [127 -18 127 -18 127 -18]]
159
+ expected = np.tile([[127, -42], [127, -18]], [3, 3])
160
+ qparams = hadamard_rotation.get_tensor_quant_params(
161
+ self._op_info,
162
+ self._op_info.op_quant_config.weight_tensor_config,
163
+ test_data,
164
+ self._tensor_name_to_qsv,
165
+ )
166
+ self.assertIsNotNone(qparams.quantized_data)
167
+ np.testing.assert_array_equal(
168
+ np.array(qparams.quantized_data), expected
169
+ )
170
+
122
171
  def test_raise_missing_tensor_content(self):
123
172
  with self.assertRaisesWithPredicateMatch(
124
173
  ValueError, lambda err: "weight tensor" in str(err)
@@ -16,7 +16,6 @@
16
16
  """Performs naive min/max uniform quantization."""
17
17
 
18
18
  from typing import Any, Optional
19
- import ml_dtypes
20
19
  import numpy as np
21
20
  from ai_edge_quantizer import qtyping
22
21
  from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
@@ -75,35 +74,17 @@ def get_tensor_quant_params(
75
74
  " the ParamsGenerator."
76
75
  )
77
76
  clipping_values = None
78
- # Blockwise quantization uses float16 scale, with 7 bit mantissa,
79
- # so the maximum representable value is 65280.
80
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
81
- clipping_values = np.broadcast_to(
82
- np.array(65280), tensor_min_max["min"].shape
83
- )
84
77
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
85
78
  tensor_min_max["min"],
86
79
  tensor_min_max["max"],
87
80
  tensor_quant_config.num_bits,
88
81
  tensor_quant_config.symmetric,
82
+ tensor_quant_config.granularity,
89
83
  clipping_values,
90
84
  )
91
- # Round the scale values to 7 bit mantissa.
92
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
93
- scale = (
94
- scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
95
- )
96
- quantized_dim = None
97
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
98
- quantized_dim = common_utils.get_weight_quantized_dim(
99
- op_info, tensor_content
100
- )
101
- elif tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
102
- quantized_dim = (
103
- tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
104
- op_info.op_name
105
- ]
106
- )
85
+ quantized_dim = common_utils.get_weight_quantized_dim(
86
+ op_info, tensor_content, tensor_quant_config.granularity
87
+ )
107
88
  quant_params = qtyping.UniformQuantParams(
108
89
  scale=scale,
109
90
  zero_point=zp,
@@ -115,15 +96,10 @@ def get_tensor_quant_params(
115
96
  if tensor_content is None:
116
97
  return quant_params
117
98
 
118
- # The reshaping for blockwise quantization is unique hence we do this here
119
- # to avoid unexpected broadcast behavior downstream.
120
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
121
- quant_params = common_quantize.broadcast_scale_zp_for_blockwise(
122
- tensor_content, quant_params
123
- )
124
-
125
99
  quantized_vars = uniform_quantize_tensor.uniform_quantize(
126
- tensor_content, quant_params
100
+ tensor_content,
101
+ quant_params,
102
+ tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
127
103
  )
128
104
  # Update with quantized values.
129
105
  return qtyping.UniformQuantParams(
@@ -102,21 +102,13 @@ def get_tensor_quant_params(
102
102
  op_info, tensor_quant_config, tensor_content, tensor_qsv
103
103
  )
104
104
 
105
- if (
106
- tensor_quant_config.granularity != qtyping.QuantGranularity.CHANNELWISE
107
- and tensor_quant_config.granularity != qtyping.QuantGranularity.TENSORWISE
108
- ):
109
- raise ValueError(
110
- f"Unsupported granularity: {tensor_quant_config.granularity}."
111
- )
112
-
113
105
  if not tensor_quant_config.symmetric:
114
106
  raise ValueError(
115
107
  f"Unsupported symmetry: {tensor_quant_config.symmetric}. OCTAV"
116
108
  " supports symmetric quantization only for now."
117
109
  )
118
110
 
119
- if tensor_qsv is None:
111
+ if not tensor_qsv:
120
112
  # We need min/max to calculate quantization parameters, which
121
113
  # should be collected during the calibration process. However,
122
114
  # weight-only and DRQ do not require calibration, thus it is
@@ -136,25 +128,41 @@ def get_tensor_quant_params(
136
128
  " the ParamsGenerator."
137
129
  )
138
130
 
139
- quantized_dim = None
140
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
141
- quantized_dim = common_utils.get_weight_quantized_dim(
142
- op_info, tensor_content
131
+ quantized_dim = common_utils.get_weight_quantized_dim(
132
+ op_info, tensor_content, tensor_quant_config.granularity
133
+ )
134
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
135
+ reshaped_data, reduce_dims = (
136
+ uniform_quantize_tensor.reshape_data_for_blockwise(
137
+ tensor_content,
138
+ op_info.op_name,
139
+ tensor_quant_config.block_size,
140
+ )
141
+ )
142
+ else:
143
+ reshaped_data = tensor_content
144
+ reduce_dims = common_utils.get_reduce_dims(
145
+ quantized_dim, tensor_content.shape
143
146
  )
144
-
145
147
  clipping_constants = _guess_clipping_with_octav(
146
- tensor_content,
148
+ reshaped_data,
147
149
  tensor_quant_config.num_bits,
148
- common_utils.get_reduce_dims(quantized_dim, tensor_content.shape),
150
+ reduce_dims,
149
151
  max_iterations=10,
150
152
  exponent_divisor=3.0 if tensor_quant_config.symmetric else 12.0,
151
153
  )
154
+ # We created a new dimension in order to reduce properly for blockwise
155
+ # quantization, so we need to reshape the clipping constants back to the
156
+ # min/max shape for the next step.
157
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
158
+ clipping_constants = clipping_constants.reshape(tensor_min_max["min"].shape)
152
159
 
153
160
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
154
161
  tensor_min_max["min"],
155
162
  tensor_min_max["max"],
156
163
  tensor_quant_config.num_bits,
157
164
  tensor_quant_config.symmetric,
165
+ tensor_quant_config.granularity,
158
166
  clipping_constants,
159
167
  )
160
168
 
@@ -168,7 +176,9 @@ def get_tensor_quant_params(
168
176
  )
169
177
 
170
178
  quantized_vars = uniform_quantize_tensor.uniform_quantize(
171
- tensor_content, quant_params
179
+ tensor_content,
180
+ quant_params,
181
+ tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
172
182
  )
173
183
 
174
184
  return dataclasses.replace(quant_params, quantized_data=quantized_vars)
@@ -44,33 +44,17 @@ class OctavQuantizeTest(parameterized.TestCase):
44
44
  )
45
45
  self._tensor_name_to_qsv = {}
46
46
  subgraph0 = self._test_model.subgraphs[0]
47
- subgraph_op_index = 3
48
- fc_op = subgraph0.operators[subgraph_op_index]
47
+ self._subgraph_op_index = 3
48
+ self._fc_op = subgraph0.operators[self._subgraph_op_index]
49
49
  self._fc_op_info = qtyping.OpInfo(
50
- op=fc_op,
50
+ op=self._fc_op,
51
51
  op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
52
- subgraph_op_index=subgraph_op_index,
52
+ subgraph_op_index=self._subgraph_op_index,
53
53
  op_quant_config=qtyping.OpQuantizationConfig(
54
54
  weight_tensor_config=None,
55
55
  ),
56
56
  )
57
57
 
58
- def test_get_tensor_quant_params_unsupported_granularity_assert(self):
59
- err_msg = "Unsupported granularity"
60
- test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
61
- with self.assertRaisesWithPredicateMatch(
62
- ValueError, lambda err: err_msg in str(err)
63
- ):
64
- _ = octav.get_tensor_quant_params(
65
- op_info=self._fc_op_info,
66
- tensor_quant_config=qtyping.TensorQuantizationConfig(
67
- num_bits=4,
68
- symmetric=True,
69
- granularity=qtyping.QuantGranularity.BLOCKWISE,
70
- ),
71
- tensor_content=test_data,
72
- )
73
-
74
58
  def test_get_tensor_quant_params_unsupported_symmetry(self):
75
59
  err_msg = "Unsupported symmetry"
76
60
  test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
@@ -117,13 +101,22 @@ class OctavQuantizeTest(parameterized.TestCase):
117
101
  [25, -30, 50, -75, 1e5, -125],
118
102
  [50, -60, 70, -80, 90, -100],
119
103
  ])
120
- quant_params = octav.get_tensor_quant_params(
121
- op_info=self._fc_op_info,
122
- tensor_quant_config=qtyping.TensorQuantizationConfig(
123
- num_bits=4,
124
- symmetric=True,
125
- granularity=qtyping.QuantGranularity.TENSORWISE,
104
+ tensor_config = qtyping.TensorQuantizationConfig(
105
+ num_bits=4,
106
+ symmetric=True,
107
+ granularity=qtyping.QuantGranularity.TENSORWISE,
108
+ )
109
+ fc_op_info = qtyping.OpInfo(
110
+ op=self._fc_op,
111
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
112
+ subgraph_op_index=self._subgraph_op_index,
113
+ op_quant_config=qtyping.OpQuantizationConfig(
114
+ weight_tensor_config=tensor_config,
126
115
  ),
116
+ )
117
+ quant_params = octav.get_tensor_quant_params(
118
+ op_info=fc_op_info,
119
+ tensor_quant_config=tensor_config,
127
120
  tensor_content=test_data,
128
121
  )
129
122
  adjusted_test_data = quant_params.quantized_data * quant_params.scale
@@ -131,10 +124,10 @@ class OctavQuantizeTest(parameterized.TestCase):
131
124
  adjusted_max = np.max(np.abs(adjusted_test_data))
132
125
 
133
126
  # Check that some clipping occurred.
134
- with self.subTest(name="SanityCheckClipping"):
127
+ with self.subTest(name="CheckClipping"):
135
128
  self.assertLess(adjusted_max, real_max)
136
129
 
137
- with self.subTest(name="SanityCheckQuantParamsShapes"):
130
+ with self.subTest(name="CheckQuantParamsShapes"):
138
131
  self.assertEqual(quant_params.zero_point.shape, (1, 1))
139
132
  self.assertEqual(quant_params.scale.shape, (1, 1))
140
133
  self.assertIsNone(quant_params.quantized_dimension)
@@ -143,33 +136,47 @@ class OctavQuantizeTest(parameterized.TestCase):
143
136
  cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
144
137
  )
145
138
 
146
- with self.subTest(name="SanityCheckQuantParamsValues"):
139
+ with self.subTest(name="CheckQuantParamsValues"):
147
140
  self.assertTrue(np.all(quant_params.zero_point == 0))
148
141
 
149
142
  def test_get_tensor_quant_params_sanity_channelwise(self):
143
+ # Test that the call generates quant params that are appropriately shaped,
144
+ # have some clipping, and correct config values without checking the
145
+ # actual values numerically.
150
146
  test_data = np.array([
151
147
  [-1e5, 25, -50, 75, -100, 125],
152
148
  [25, -30, 50, -75, 1e5, -125],
153
149
  [50, -60, 70, -80, 90, -100],
154
150
  ])
155
- quant_params = octav.get_tensor_quant_params(
156
- op_info=self._fc_op_info,
157
- tensor_quant_config=qtyping.TensorQuantizationConfig(
158
- num_bits=4,
159
- symmetric=True,
160
- granularity=qtyping.QuantGranularity.CHANNELWISE,
151
+ tensor_config = qtyping.TensorQuantizationConfig(
152
+ num_bits=4,
153
+ symmetric=True,
154
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
155
+ )
156
+ fc_op_info = qtyping.OpInfo(
157
+ op=self._fc_op,
158
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
159
+ subgraph_op_index=self._subgraph_op_index,
160
+ op_quant_config=qtyping.OpQuantizationConfig(
161
+ weight_tensor_config=tensor_config,
161
162
  ),
163
+ )
164
+ quant_params = octav.get_tensor_quant_params(
165
+ op_info=fc_op_info,
166
+ tensor_quant_config=tensor_config,
162
167
  tensor_content=test_data,
163
168
  )
169
+ # Dequantize output to compare with the original test data.
164
170
  adjusted_test_data = quant_params.quantized_data * quant_params.scale
171
+
165
172
  for i, row in enumerate(test_data):
166
173
  real_max = np.max(np.abs(row))
167
174
  adjusted_max = np.max(np.abs(adjusted_test_data[i]))
168
175
  # Check that some clipping occurred.
169
- with self.subTest(name="SanityCheckClipping"):
176
+ with self.subTest(name="CheckClipping"):
170
177
  self.assertLess(adjusted_max, real_max)
171
178
 
172
- with self.subTest(name="SanityCheckQuantParamsShapes"):
179
+ with self.subTest(name="CheckQuantParamsShapes"):
173
180
  self.assertEqual(quant_params.zero_point.shape, (test_data.shape[0], 1))
174
181
  self.assertEqual(quant_params.scale.shape, (test_data.shape[0], 1))
175
182
  self.assertIsNotNone(quant_params.quantized_data)
@@ -177,10 +184,58 @@ class OctavQuantizeTest(parameterized.TestCase):
177
184
  cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
178
185
  )
179
186
 
180
- with self.subTest(name="SanityCheckQuantParamsValues"):
187
+ with self.subTest(name="CheckQuantParamsValues"):
181
188
  self.assertTrue(np.all(quant_params.zero_point == 0))
182
189
  self.assertEqual(quant_params.quantized_dimension, 0)
183
190
 
191
+ def test_get_tensor_quant_params_sanity_blockwise(self):
192
+ # Test that the call generates quant params that are appropriately shaped,
193
+ # have some clipping, and correct config values without checking the
194
+ # actual values numerically.
195
+ test_data = np.random.randint(0, 1024, size=(32, 128))
196
+ tensor_config = qtyping.TensorQuantizationConfig(
197
+ num_bits=4,
198
+ symmetric=True,
199
+ granularity=qtyping.QuantGranularity.BLOCKWISE,
200
+ block_size=32,
201
+ )
202
+ fc_op_info = qtyping.OpInfo(
203
+ op=self._fc_op,
204
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
205
+ subgraph_op_index=self._subgraph_op_index,
206
+ op_quant_config=qtyping.OpQuantizationConfig(
207
+ weight_tensor_config=tensor_config,
208
+ ),
209
+ )
210
+ quant_params = octav.get_tensor_quant_params(
211
+ op_info=fc_op_info,
212
+ tensor_quant_config=tensor_config,
213
+ tensor_content=test_data,
214
+ )
215
+
216
+ with self.subTest(name="CheckQuantParamsShapes"):
217
+ # Check that quant params have appropriate shapes.
218
+ self.assertEqual(quant_params.zero_point.shape, (32, 4))
219
+ self.assertEqual(quant_params.scale.shape, (32, 4))
220
+ self.assertIsNotNone(quant_params.quantized_data)
221
+ self.assertTupleEqual(
222
+ cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
223
+ )
224
+
225
+ scales = np.repeat(quant_params.scale, 32, axis=1)
226
+ adjusted_test_data = quant_params.quantized_data * scales
227
+ for i, row in enumerate(test_data):
228
+ real_max = np.max(np.abs(row))
229
+ adjusted_max = np.max(np.abs(adjusted_test_data[i]))
230
+ # Check that some clipping occurred.
231
+ with self.subTest(name="CheckClipping"):
232
+ self.assertLess(adjusted_max, real_max)
233
+
234
+ with self.subTest(name="CheckQuantParamsValues"):
235
+ self.assertTrue(np.all(quant_params.zero_point == 0))
236
+ # See TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM.
237
+ self.assertEqual(quant_params.quantized_dimension, 1)
238
+
184
239
 
185
240
  if __name__ == "__main__":
186
241
  googletest.main()
@@ -16,9 +16,11 @@
16
16
  """Uniform quantize in tensor level."""
17
17
 
18
18
  import dataclasses
19
- from typing import Optional
19
+ from typing import Optional, Sequence
20
+ import ml_dtypes
20
21
  import numpy as np
21
22
  from ai_edge_quantizer import qtyping
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
22
24
 
23
25
 
24
26
  @dataclasses.dataclass(frozen=True)
@@ -120,19 +122,127 @@ def fix_quantization_params_rank(
120
122
  )
121
123
 
122
124
 
125
+ def _get_tensor_shape_for_blockwise(
126
+ tensor_shape: Sequence[int], quantized_dim: int, block_size: int
127
+ ) -> list[int]:
128
+ """Get the tensor shape for blockwise quantization.
129
+
130
+ This function splits the quantize dimension of the tensor into blocks and the
131
+ dim/blocks. Hence, min/max of the tensor can be calculated for each block
132
+ using existing functions.
133
+
134
+ Args:
135
+ tensor_shape: The original shape of the tensor.
136
+ quantized_dim: The dimension to be quantized blockwise.
137
+ block_size: The size of the block.
138
+
139
+ Returns:
140
+ The new tensor shape for calculating scale and zp for blockwise
141
+ quantization.
142
+ """
143
+ new_shape = []
144
+ for index, val in enumerate(tensor_shape):
145
+ if index == quantized_dim:
146
+ new_shape.append(int(val / block_size))
147
+ new_shape.append(block_size)
148
+ else:
149
+ new_shape.append(val)
150
+ return new_shape
151
+
152
+
153
+ def reshape_data_for_blockwise(
154
+ tensor_data: np.ndarray, op_name: qtyping.TFLOperationName, block_size: int
155
+ ) -> tuple[np.ndarray, int]:
156
+ """Reshapes data for blockwise quantization.
157
+
158
+ Args:
159
+ tensor_data: The original tensor data.
160
+ op_name: The name of the TFL op.
161
+ block_size: The size of the block.
162
+
163
+ Returns:
164
+ A tuple containing the reshaped tensor data and the new reduce dimension.
165
+ """
166
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
167
+ op_name
168
+ ]
169
+ new_shape = _get_tensor_shape_for_blockwise(
170
+ tensor_data.shape, quantized_dim, block_size
171
+ )
172
+ reshaped_data = tensor_data.reshape(new_shape)
173
+ return reshaped_data, quantized_dim + 1
174
+
175
+
176
+ def _broadcast_scale_zp_for_blockwise(
177
+ tensor_content: np.ndarray,
178
+ quant_params: qtyping.UniformQuantParams,
179
+ ) -> qtyping.UniformQuantParams:
180
+ """Broadcasts scale and zp for blockwise quantization.
181
+
182
+ Args:
183
+ tensor_content: The original tensor data.
184
+ quant_params: The quantization parameters.
185
+ `quant_params.quantized_dimension` must be specified.
186
+ `quant_params.block_size` must be specified and positive.
187
+
188
+ Returns:
189
+ The updated quantization parameters with broadcasted scale and zp for
190
+ correct constant quantization.
191
+ """
192
+ if quant_params.quantized_dimension is None:
193
+ raise ValueError("Quantized dimension must be specified.")
194
+ if quant_params.block_size is None or quant_params.block_size <= 0:
195
+ raise ValueError("Block size must be specified and positive.")
196
+ quantized_dim = quant_params.quantized_dimension
197
+ expanded_tensor_shape = _get_tensor_shape_for_blockwise(
198
+ tensor_content.shape, quantized_dim, quant_params.block_size
199
+ )
200
+ expanded_scale = np.reshape(
201
+ np.broadcast_to(
202
+ np.expand_dims(quant_params.scale, quantized_dim + 1),
203
+ expanded_tensor_shape,
204
+ ),
205
+ tensor_content.shape,
206
+ )
207
+ expanded_zp = np.reshape(
208
+ np.broadcast_to(
209
+ np.expand_dims(quant_params.zero_point, quantized_dim + 1),
210
+ expanded_tensor_shape,
211
+ ),
212
+ tensor_content.shape,
213
+ )
214
+ return qtyping.UniformQuantParams(
215
+ scale=expanded_scale,
216
+ zero_point=expanded_zp,
217
+ num_bits=quant_params.num_bits,
218
+ symmetric=quant_params.symmetric,
219
+ quantized_dimension=quantized_dim,
220
+ block_size=quant_params.block_size,
221
+ )
222
+
223
+
123
224
  def uniform_quantize(
124
225
  tensor_data: np.ndarray,
125
226
  quantization_params: qtyping.UniformQuantParams,
227
+ is_blockwise: bool = False,
126
228
  ):
127
229
  """Uniform quantize a tensor.
128
230
 
129
231
  Args:
130
232
  tensor_data: The tensor to be quantized.
131
233
  quantization_params: The quantization parameters.
234
+ is_blockwise: Whether the tensor is blockwise quantized.
132
235
 
133
236
  Returns:
134
237
  The quantized tensor.
135
238
  """
239
+ # The reshaping for blockwise quantization is unique hence we do this here
240
+ # to avoid unexpected broadcast behavior downstream.
241
+ if is_blockwise:
242
+ quantization_params = _broadcast_scale_zp_for_blockwise(
243
+ tensor_data, quantization_params
244
+ )
245
+
136
246
  # quant params in flatbuffer is flattened, expand the rank to be the same
137
247
  # as the tensor rank to avoid ambiguous broadcasting.
138
248
  quantization_params = fix_quantization_params_rank(
@@ -242,15 +352,19 @@ def tensor_zp_scale_from_min_max(
242
352
  max_value,
243
353
  num_bits: int,
244
354
  symmetric: bool,
355
+ granularity: qtyping.QuantGranularity,
245
356
  clipping_values: Optional[np.ndarray] = None,
246
357
  ):
247
358
  """Get zero point and scale from min and max value.
248
359
 
249
360
  Args:
250
- min_value: The minimum value of the tensor (channel-wise supported).
251
- max_value: The maximum value of the tensor (channel-wise supported).
361
+ min_value: The minimum value of the tensor (channelwise and blockwise
362
+ supported).
363
+ max_value: The maximum value of the tensor (channelwise and blockwise
364
+ supported).
252
365
  num_bits: The number of bits of the tensor.
253
366
  symmetric: Whether the tensor is symmetric.
367
+ granularity: The granularity of the tensor.
254
368
  clipping_values: Absolute clipping values to apply to the tensor. This will
255
369
  clip the tensors to the range [-clipping_values, clipping_values]. This
256
370
  should be the same shape as min_value and max_value. If None, no clipping
@@ -267,6 +381,16 @@ def tensor_zp_scale_from_min_max(
267
381
  qmin, qmax = get_quantized_range(qtype)
268
382
  min_bound = 1e-4 # 1e-6 precision for int8 and 1e-8 for int16.
269
383
 
384
+ if granularity == qtyping.QuantGranularity.BLOCKWISE:
385
+ # Blockwise quantization uses float16 scale, with 7 bit mantissa,
386
+ # so the maximum representable value is 65280.
387
+ float16_max = np.broadcast_to(np.array(65280), min_value.shape)
388
+ clipping_values = (
389
+ float16_max
390
+ if clipping_values is None
391
+ else np.minimum(clipping_values, float16_max)
392
+ )
393
+
270
394
  if symmetric:
271
395
  bound = np.maximum(np.abs(min_value), np.abs(max_value))
272
396
  bound = np.maximum(bound, min_bound)
@@ -292,6 +416,12 @@ def tensor_zp_scale_from_min_max(
292
416
  zp = qmin - bound_min / scale
293
417
  zp = np.rint(zp)
294
418
 
419
+ if granularity == qtyping.QuantGranularity.BLOCKWISE:
420
+ # Round the scale values to 7 bit mantissa.
421
+ scale = (
422
+ scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
423
+ )
424
+
295
425
  # It's safe to cast zp to qtype without clipping because we can infer
296
426
  # qmin <= zp <= qmax from bound_min <= 0 <= bound_max.
297
427
  zp = assign_quantized_type(zp, qtype)
@@ -336,7 +336,11 @@ class TensorUtilsTest(parameterized.TestCase):
336
336
  max_val = np.max(self._test_data, keepdims=True)
337
337
 
338
338
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
339
- min_val, max_val, num_bits, symmetric
339
+ min_val,
340
+ max_val,
341
+ num_bits,
342
+ symmetric,
343
+ qtyping.QuantGranularity.TENSORWISE,
340
344
  )
341
345
  self.assertEqual(zp.shape, scale.shape)
342
346
  max_q = 2**num_bits / 2 - 1
@@ -364,7 +368,12 @@ class TensorUtilsTest(parameterized.TestCase):
364
368
  max_val = np.array([[5.0]])
365
369
  clipping_values = np.array([4.0])
366
370
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
367
- min_val, max_val, num_bits, symmetric, clipping_values
371
+ min_val,
372
+ max_val,
373
+ num_bits,
374
+ symmetric,
375
+ qtyping.QuantGranularity.TENSORWISE,
376
+ clipping_values,
368
377
  )
369
378
  expected_scale = clipping_values / quantized_bound
370
379
 
@@ -905,23 +905,36 @@ def get_tensor_transformation_params(
905
905
  )
906
906
 
907
907
 
908
- def get_weight_quantized_dim(op_info: qtyping.OpInfo, tensor_data: np.ndarray):
908
+ def get_weight_quantized_dim(
909
+ op_info: qtyping.OpInfo,
910
+ tensor_data: np.ndarray,
911
+ granularity: qtyping.QuantGranularity,
912
+ ):
909
913
  """Get the quantized dimension for the weight tensor.
910
914
 
911
915
  Args:
912
916
  op_info: Aggregated information about the op (e.g., quantization config).
913
917
  tensor_data: The weight tensor data.
918
+ granularity: The granularity of the weight tensor.
914
919
 
915
920
  Returns:
916
921
  The quantized dimension for the weight tensor.
917
922
  """
918
- if op_info.op_name == _TFLOpName.BATCH_MATMUL:
919
- quantized_dim = get_bmm_weight_quantized_dim(
920
- tensor_data, adj_y=op_info.op.builtinOptions.adjY
921
- )
922
- else:
923
- quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
924
- op_info.op_name, None
923
+ quantized_dim = None
924
+ if granularity == qtyping.QuantGranularity.CHANNELWISE:
925
+ if op_info.op_name == _TFLOpName.BATCH_MATMUL:
926
+ quantized_dim = get_bmm_weight_quantized_dim(
927
+ tensor_data, adj_y=op_info.op.builtinOptions.adjY
928
+ )
929
+ else:
930
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
931
+ op_info.op_name, None
932
+ )
933
+ elif granularity == qtyping.QuantGranularity.BLOCKWISE:
934
+ quantized_dim = (
935
+ tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
936
+ op_info.op_name
937
+ ]
925
938
  )
926
939
  return quantized_dim
927
940
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.1.0.dev20250513
3
+ Version: 0.1.0.dev20250515
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -28,20 +28,20 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCP
28
28
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
29
29
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
30
30
  ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
31
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=-ugXQ4cZoVMrgOVs4m73ozI-49CRyT0YuKrLS5begW8,28297
32
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=qMmKbWqxrCoVKbLKHn9WuCrGKPfHkEyU0Nmhokh8Qeo,2597
33
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=Fk3s9Qy2A_hjUepFOUmTwIZ_wKYVPbdDX4eoP-eoAQU,8726
31
+ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=NpZ-JvZt2OhpTqH7Z81YYVjzOX_pHoDCt8rr3VIXJUY,28665
32
+ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
33
+ ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=BDdn_uBZakfHyzdMJPKadsOqxqyC-s6W2ZzFH99L4fE,8652
34
34
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
35
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py,sha256=f9HhFCAavbrdYkQQH37ivbKRuRXC1g1TO2FmILMApN8,12389
36
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py,sha256=kN9aCPt1yTleiDBiH4g2RZ1vMBm7WAf5pmVFjmYCH-0,7617
37
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=Divlsn3NjNGtH0vlvE91wxL-VHb4q1nUE0JTDGiEtYc,8572
35
+ ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py,sha256=pN4hwggrdI4eBdqvsdwnFagFxpd4D8LkWK0o4HG_xxk,12536
36
+ ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py,sha256=MajG6DqpP4HvVzcZwgiKojWL3RBxCpkU3u2mKyeB0hA,9191
37
+ ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=8_tNLTbOWTKId4DfHBjkOR9RvELUyIpxlGxKu7tv5Ko,7556
38
38
  ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=zoF_EHjYqsKkuev8wfuutIITEmp_maa70IpJI_Df3ck,7431
39
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py,sha256=e5wYtki-vl739gSVAZHAKcs2hA87GvFUjVoSUPlnkyM,6433
40
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py,sha256=IcTOaJ1pxtqsitqxOEP9LROVEP_19VFutHalqNied4I,6940
41
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py,sha256=WmZzKQlzfu9gFr9SbUDoPY3rFqTl363om8-0rTLwotw,11629
42
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py,sha256=G2PFpHhF-6OOuAwQ1lei63QEIm7uzIZJ62qpgA02qTM,12288
39
+ ai_edge_quantizer/algorithms/uniform_quantize/octav.py,sha256=Umxh4kJyeHddZf-Wd4aXE5MTI1XWFa5KRuM17uYU714,6922
40
+ ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py,sha256=sha1d99Xk87bI87tgz0g5LeDC-EeE4WMfM5rRC98-m4,9140
41
+ ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py,sha256=W2QbXP96xeleAmA7qFwco1iq_bOtArGDK6Qj_g6kNl8,15986
42
+ ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py,sha256=MgG7Qh2_z4I6InBqEEDSVlaR0q48aMz4xqAlxeG2EMk,12436
43
43
  ai_edge_quantizer/algorithms/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
44
- ai_edge_quantizer/algorithms/utils/common_utils.py,sha256=4qSlVNx3-91kJufnnJV1RdVRXBPapylZkrAp2nywoao,34581
44
+ ai_edge_quantizer/algorithms/utils/common_utils.py,sha256=UoZxeAQmZk3b3hK51KFwq6XfdbeduXVjdYIxAxlAzB8,34982
45
45
  ai_edge_quantizer/algorithms/utils/common_utils_test.py,sha256=zqapGEfYhjQWe9cNGPLmdbwtEUUYQRhlO_kNe0cXX6E,18104
46
46
  ai_edge_quantizer/transformations/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
47
47
  ai_edge_quantizer/transformations/dequant_insert.py,sha256=sL1LHFVzBDSd9jgrzlHz38LWU0bwmVX7iBkaNcui0ts,3566
@@ -70,8 +70,8 @@ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=WoewyiZpaua80oP0tpgyrw5W
70
70
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
71
71
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
72
72
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
73
- ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
- ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/METADATA,sha256=zL_JxmjzCHEwIUmLkDGzI6B7IACt6YnVQSpaxaNUujY,1528
75
- ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
- ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
- ai_edge_quantizer_nightly-0.1.0.dev20250513.dist-info/RECORD,,
73
+ ai_edge_quantizer_nightly-0.1.0.dev20250515.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
+ ai_edge_quantizer_nightly-0.1.0.dev20250515.dist-info/METADATA,sha256=Rwa9ls9ryiTwntWB8-SCfO_uYjWMj3bqPTjEhIiQMyo,1528
75
+ ai_edge_quantizer_nightly-0.1.0.dev20250515.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
+ ai_edge_quantizer_nightly-0.1.0.dev20250515.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
+ ai_edge_quantizer_nightly-0.1.0.dev20250515.dist-info/RECORD,,