ai-edge-quantizer-nightly 0.1.0.dev20250512__py3-none-any.whl → 0.1.0.dev20250514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_edge_quantizer/algorithm_manager.py +34 -0
  2. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +37 -12
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +3 -5
  5. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +357 -0
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +265 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +7 -31
  8. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +27 -17
  9. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +93 -38
  10. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +133 -3
  11. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +11 -2
  12. ai_edge_quantizer/algorithms/utils/common_utils.py +21 -8
  13. ai_edge_quantizer/default_policy.py +4 -2
  14. ai_edge_quantizer/params_generator.py +1 -0
  15. ai_edge_quantizer/qtyping.py +34 -1
  16. ai_edge_quantizer/transformation_performer.py +5 -0
  17. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +209 -0
  18. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  19. ai_edge_quantizer/utils/test_utils.py +33 -0
  20. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
  21. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/METADATA +1 -1
  22. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/RECORD +25 -21
  23. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/LICENSE +0 -0
  24. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/WHEEL +0 -0
  25. {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,265 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Test Hadamard rotation materialization."""
17
+
18
+ import os
19
+
20
+ from absl.testing import parameterized
21
+ import numpy as np
22
+
23
+ from tensorflow.python.platform import googletest
24
+ from ai_edge_quantizer import qtyping
25
+ from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
26
+ from ai_edge_quantizer.utils import test_utils
27
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
28
+
29
+ _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../../tests/models")
30
+ _TFLOpName = qtyping.TFLOperationName
31
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
32
+
33
+
34
+ class HadamardRotationFullyConnectedTest(parameterized.TestCase):
35
+
36
+ def setUp(self):
37
+ super().setUp()
38
+ np.random.seed(888)
39
+ self._test_model_path = os.path.join(
40
+ _TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
41
+ )
42
+ self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
43
+ self._graph_info = qtyping.GraphInfo(
44
+ subgraph_tensors=self._test_model.subgraphs[0].tensors,
45
+ buffers=self._test_model.buffers,
46
+ )
47
+ self._tensor_name_to_qsv = None
48
+ self._subgraph = self._test_model.subgraphs[0]
49
+ fc_subgraph_op_index = 3
50
+ self._fc_op = self._subgraph.operators[fc_subgraph_op_index]
51
+ self._fc_buffer_id = self._subgraph.tensors[self._fc_op.inputs[1]].buffer
52
+ self._op_info = qtyping.OpInfo(
53
+ op=self._fc_op,
54
+ op_name=_TFLOpName.FULLY_CONNECTED,
55
+ subgraph_op_index=fc_subgraph_op_index,
56
+ op_quant_config=qtyping.OpQuantizationConfig(
57
+ weight_tensor_config=_TensorQuantConfig(
58
+ num_bits=8,
59
+ symmetric=True,
60
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
61
+ ),
62
+ ),
63
+ )
64
+
65
+ def test_materialize_fully_connected_basic(self):
66
+ params = hadamard_rotation.materialize_fully_connected(
67
+ self._op_info, self._graph_info, self._tensor_name_to_qsv
68
+ )
69
+ fc_input = params[0]
70
+ weight = params[1]
71
+ bias = params[2]
72
+ output = params[3]
73
+
74
+ self.assertLen(params, 4)
75
+ self.assertIsNone(fc_input.producer)
76
+ self.assertIsNotNone(fc_input.consumers)
77
+ self.assertIsNone(weight.producer)
78
+ self.assertIsNotNone(weight.consumers)
79
+ self.assertIsNone(bias.producer)
80
+ self.assertIsNotNone(bias.consumers)
81
+ self.assertIsNotNone(output.producer)
82
+ self.assertIsNone(output.consumers)
83
+ self.assertEqual(
84
+ fc_input.consumers[0].transformations,
85
+ [qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION],
86
+ )
87
+ self.assertEqual(
88
+ weight.consumers[0].transformations,
89
+ [qtyping.QuantTransformation.QUANTIZE_TENSOR],
90
+ )
91
+ self.assertEqual(
92
+ bias.consumers[0].transformations,
93
+ [qtyping.QuantTransformation.NO_QUANTIZE],
94
+ )
95
+ if output.producer is not None:
96
+ self.assertEqual(
97
+ output.producer.transformations,
98
+ [qtyping.QuantTransformation.NO_QUANTIZE],
99
+ )
100
+
101
+ def test_get_tensor_quant_params_basic(self):
102
+ input_tensor = self._subgraph.tensors[self._fc_op.inputs[1]]
103
+ buffer = self._graph_info.buffers[self._fc_buffer_id]
104
+ np_buffer = np.frombuffer(buffer.data, dtype=np.float32).reshape(
105
+ input_tensor.shape
106
+ )
107
+ qparams = hadamard_rotation.get_tensor_quant_params(
108
+ self._op_info,
109
+ self._op_info.op_quant_config.weight_tensor_config,
110
+ np_buffer,
111
+ self._tensor_name_to_qsv,
112
+ )
113
+ self.assertEqual(qparams.num_bits, 8)
114
+ self.assertEqual(qparams.zero_point.all(), 0)
115
+ self.assertEqual(qparams.symmetric, True)
116
+ self.assertIsNotNone(qparams.quantized_data)
117
+ self.assertEqual(qparams.block_size, 0)
118
+ self.assertIsNotNone(qparams.hadamard)
119
+ if qparams.hadamard is not None:
120
+ self.assertEqual(qparams.hadamard.hadamard_size, 32)
121
+
122
+ def test_get_tensor_quant_params_golden_1(self):
123
+ test_data = np.ones((6, 6))
124
+ # expected:
125
+ # [[127 0 127 0 127 0]
126
+ # [127 0 127 0 127 0]
127
+ # [127 0 127 0 127 0]
128
+ # [127 0 127 0 127 0]
129
+ # [127 0 127 0 127 0]
130
+ # [127 0 127 0 127 0]]
131
+ expected = np.tile([127, 0], [6, 3])
132
+ qparams = hadamard_rotation.get_tensor_quant_params(
133
+ self._op_info,
134
+ self._op_info.op_quant_config.weight_tensor_config,
135
+ test_data,
136
+ self._tensor_name_to_qsv,
137
+ )
138
+ self.assertIsNotNone(qparams.quantized_data)
139
+ np.testing.assert_array_equal(
140
+ np.array(qparams.quantized_data), expected
141
+ )
142
+
143
+ def test_get_tensor_quant_params_golden_2(self):
144
+ # test_data:
145
+ # [[1 2 1 2 1 2]
146
+ # [3 4 3 4 3 4]
147
+ # [1 2 1 2 1 2]
148
+ # [3 4 3 4 3 4]
149
+ # [1 2 1 2 1 2]
150
+ # [3 4 3 4 3 4]]
151
+ test_data = np.tile([[1, 2], [3, 4]], [3, 3])
152
+ # expected:
153
+ # [[127 -42 127 -42 127 -42]
154
+ # [127 -18 127 -18 127 -18]
155
+ # [127 -42 127 -42 127 -42]
156
+ # [127 -18 127 -18 127 -18]
157
+ # [127 -42 127 -42 127 -42]
158
+ # [127 -18 127 -18 127 -18]]
159
+ expected = np.tile([[127, -42], [127, -18]], [3, 3])
160
+ qparams = hadamard_rotation.get_tensor_quant_params(
161
+ self._op_info,
162
+ self._op_info.op_quant_config.weight_tensor_config,
163
+ test_data,
164
+ self._tensor_name_to_qsv,
165
+ )
166
+ self.assertIsNotNone(qparams.quantized_data)
167
+ np.testing.assert_array_equal(
168
+ np.array(qparams.quantized_data), expected
169
+ )
170
+
171
+ def test_raise_missing_tensor_content(self):
172
+ with self.assertRaisesWithPredicateMatch(
173
+ ValueError, lambda err: "weight tensor" in str(err)
174
+ ):
175
+ hadamard_rotation.get_tensor_quant_params(
176
+ self._op_info,
177
+ self._op_info.op_quant_config.weight_tensor_config,
178
+ None,
179
+ self._tensor_name_to_qsv,
180
+ )
181
+
182
+ def test_raise_qsv_set(self):
183
+ with self.assertRaisesWithPredicateMatch(
184
+ ValueError, lambda err: "static quantization" in str(err)
185
+ ):
186
+ hadamard_rotation.get_tensor_quant_params(
187
+ self._op_info,
188
+ self._op_info.op_quant_config.weight_tensor_config,
189
+ self._graph_info.buffers[self._fc_buffer_id],
190
+ self._graph_info.buffers[self._fc_buffer_id],
191
+ )
192
+
193
+ def test_raise_non_2d_constant(self):
194
+ with self.assertRaisesWithPredicateMatch(
195
+ ValueError, lambda err: "2D tensors" in str(err)
196
+ ):
197
+ hadamard_rotation.get_tensor_quant_params(
198
+ self._op_info,
199
+ self._op_info.op_quant_config.weight_tensor_config,
200
+ np.array([1.0, 2.0, 3.0]),
201
+ self._tensor_name_to_qsv,
202
+ )
203
+
204
+
205
+ class HadamardRotationEmbeddingLookupTest(parameterized.TestCase):
206
+
207
+ def setUp(self):
208
+ super().setUp()
209
+ np.random.seed(888)
210
+ self._test_model_path = os.path.join(
211
+ _TEST_DATA_PREFIX_PATH, "embedding_lookup.tflite"
212
+ )
213
+ self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
214
+ self._graph_info = qtyping.GraphInfo(
215
+ subgraph_tensors=self._test_model.subgraphs[0].tensors,
216
+ buffers=self._test_model.buffers,
217
+ )
218
+ self._tensor_name_to_qsv = None
219
+
220
+ def test_materialize_embedding_lookup_basic(self):
221
+ subgraph = self._test_model.subgraphs[0]
222
+ embedding_subgraph_op_index = 0
223
+ embedding_op = subgraph.operators[embedding_subgraph_op_index]
224
+ op_info = qtyping.OpInfo(
225
+ op=embedding_op,
226
+ op_name=_TFLOpName.EMBEDDING_LOOKUP,
227
+ subgraph_op_index=embedding_subgraph_op_index,
228
+ op_quant_config=qtyping.OpQuantizationConfig(
229
+ weight_tensor_config=_TensorQuantConfig(
230
+ num_bits=8,
231
+ symmetric=True,
232
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
233
+ ),
234
+ ),
235
+ )
236
+ params = hadamard_rotation.materialize_embedding_lookup(
237
+ op_info, self._graph_info, self._tensor_name_to_qsv
238
+ )
239
+ self.assertLen(params, 3)
240
+ lookup = params[0]
241
+ value = params[1]
242
+ output = params[2]
243
+ self.assertIsNone(lookup.producer)
244
+ self.assertIsNotNone(lookup.consumers)
245
+ self.assertIsNone(value.producer)
246
+ self.assertIsNotNone(value.consumers)
247
+ self.assertIsNotNone(output.producer)
248
+ self.assertIsNone(output.consumers)
249
+ self.assertEqual(
250
+ lookup.consumers[0].transformations,
251
+ [qtyping.QuantTransformation.NO_QUANTIZE],
252
+ )
253
+ self.assertEqual(
254
+ value.consumers[0].transformations,
255
+ [qtyping.QuantTransformation.QUANTIZE_TENSOR],
256
+ )
257
+ if output.producer is not None:
258
+ self.assertEqual(
259
+ output.producer.transformations,
260
+ [qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION],
261
+ )
262
+
263
+
264
+ if __name__ == "__main__":
265
+ googletest.main()
@@ -16,7 +16,6 @@
16
16
  """Performs naive min/max uniform quantization."""
17
17
 
18
18
  from typing import Any, Optional
19
- import ml_dtypes
20
19
  import numpy as np
21
20
  from ai_edge_quantizer import qtyping
22
21
  from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
@@ -75,35 +74,17 @@ def get_tensor_quant_params(
75
74
  " the ParamsGenerator."
76
75
  )
77
76
  clipping_values = None
78
- # Blockwise quantization uses float16 scale, with 7 bit mantissa,
79
- # so the maximum representable value is 65280.
80
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
81
- clipping_values = np.broadcast_to(
82
- np.array(65280), tensor_min_max["min"].shape
83
- )
84
77
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
85
78
  tensor_min_max["min"],
86
79
  tensor_min_max["max"],
87
80
  tensor_quant_config.num_bits,
88
81
  tensor_quant_config.symmetric,
82
+ tensor_quant_config.granularity,
89
83
  clipping_values,
90
84
  )
91
- # Round the scale values to 7 bit mantissa.
92
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
93
- scale = (
94
- scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
95
- )
96
- quantized_dim = None
97
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
98
- quantized_dim = common_utils.get_weight_quantized_dim(
99
- op_info, tensor_content
100
- )
101
- elif tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
102
- quantized_dim = (
103
- tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
104
- op_info.op_name
105
- ]
106
- )
85
+ quantized_dim = common_utils.get_weight_quantized_dim(
86
+ op_info, tensor_content, tensor_quant_config.granularity
87
+ )
107
88
  quant_params = qtyping.UniformQuantParams(
108
89
  scale=scale,
109
90
  zero_point=zp,
@@ -115,15 +96,10 @@ def get_tensor_quant_params(
115
96
  if tensor_content is None:
116
97
  return quant_params
117
98
 
118
- # The reshaping for blockwise quantization is unique hence we do this here
119
- # to avoid unexpected broadcast behavior downstream.
120
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
121
- quant_params = common_quantize.broadcast_scale_zp_for_blockwise(
122
- tensor_content, quant_params
123
- )
124
-
125
99
  quantized_vars = uniform_quantize_tensor.uniform_quantize(
126
- tensor_content, quant_params
100
+ tensor_content,
101
+ quant_params,
102
+ tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
127
103
  )
128
104
  # Update with quantized values.
129
105
  return qtyping.UniformQuantParams(
@@ -102,21 +102,13 @@ def get_tensor_quant_params(
102
102
  op_info, tensor_quant_config, tensor_content, tensor_qsv
103
103
  )
104
104
 
105
- if (
106
- tensor_quant_config.granularity != qtyping.QuantGranularity.CHANNELWISE
107
- and tensor_quant_config.granularity != qtyping.QuantGranularity.TENSORWISE
108
- ):
109
- raise ValueError(
110
- f"Unsupported granularity: {tensor_quant_config.granularity}."
111
- )
112
-
113
105
  if not tensor_quant_config.symmetric:
114
106
  raise ValueError(
115
107
  f"Unsupported symmetry: {tensor_quant_config.symmetric}. OCTAV"
116
108
  " supports symmetric quantization only for now."
117
109
  )
118
110
 
119
- if tensor_qsv is None:
111
+ if not tensor_qsv:
120
112
  # We need min/max to calculate quantization parameters, which
121
113
  # should be collected during the calibration process. However,
122
114
  # weight-only and DRQ do not require calibration, thus it is
@@ -136,25 +128,41 @@ def get_tensor_quant_params(
136
128
  " the ParamsGenerator."
137
129
  )
138
130
 
139
- quantized_dim = None
140
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
141
- quantized_dim = common_utils.get_weight_quantized_dim(
142
- op_info, tensor_content
131
+ quantized_dim = common_utils.get_weight_quantized_dim(
132
+ op_info, tensor_content, tensor_quant_config.granularity
133
+ )
134
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
135
+ reshaped_data, reduce_dims = (
136
+ uniform_quantize_tensor.reshape_data_for_blockwise(
137
+ tensor_content,
138
+ op_info.op_name,
139
+ tensor_quant_config.block_size,
140
+ )
141
+ )
142
+ else:
143
+ reshaped_data = tensor_content
144
+ reduce_dims = common_utils.get_reduce_dims(
145
+ quantized_dim, tensor_content.shape
143
146
  )
144
-
145
147
  clipping_constants = _guess_clipping_with_octav(
146
- tensor_content,
148
+ reshaped_data,
147
149
  tensor_quant_config.num_bits,
148
- common_utils.get_reduce_dims(quantized_dim, tensor_content.shape),
150
+ reduce_dims,
149
151
  max_iterations=10,
150
152
  exponent_divisor=3.0 if tensor_quant_config.symmetric else 12.0,
151
153
  )
154
+ # We created a new dimension in order to reduce properly for blockwise
155
+ # quantization, so we need to reshape the clipping constants back to the
156
+ # min/max shape for the next step.
157
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
158
+ clipping_constants = clipping_constants.reshape(tensor_min_max["min"].shape)
152
159
 
153
160
  zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
154
161
  tensor_min_max["min"],
155
162
  tensor_min_max["max"],
156
163
  tensor_quant_config.num_bits,
157
164
  tensor_quant_config.symmetric,
165
+ tensor_quant_config.granularity,
158
166
  clipping_constants,
159
167
  )
160
168
 
@@ -168,7 +176,9 @@ def get_tensor_quant_params(
168
176
  )
169
177
 
170
178
  quantized_vars = uniform_quantize_tensor.uniform_quantize(
171
- tensor_content, quant_params
179
+ tensor_content,
180
+ quant_params,
181
+ tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
172
182
  )
173
183
 
174
184
  return dataclasses.replace(quant_params, quantized_data=quantized_vars)
@@ -44,33 +44,17 @@ class OctavQuantizeTest(parameterized.TestCase):
44
44
  )
45
45
  self._tensor_name_to_qsv = {}
46
46
  subgraph0 = self._test_model.subgraphs[0]
47
- subgraph_op_index = 3
48
- fc_op = subgraph0.operators[subgraph_op_index]
47
+ self._subgraph_op_index = 3
48
+ self._fc_op = subgraph0.operators[self._subgraph_op_index]
49
49
  self._fc_op_info = qtyping.OpInfo(
50
- op=fc_op,
50
+ op=self._fc_op,
51
51
  op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
52
- subgraph_op_index=subgraph_op_index,
52
+ subgraph_op_index=self._subgraph_op_index,
53
53
  op_quant_config=qtyping.OpQuantizationConfig(
54
54
  weight_tensor_config=None,
55
55
  ),
56
56
  )
57
57
 
58
- def test_get_tensor_quant_params_unsupported_granularity_assert(self):
59
- err_msg = "Unsupported granularity"
60
- test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
61
- with self.assertRaisesWithPredicateMatch(
62
- ValueError, lambda err: err_msg in str(err)
63
- ):
64
- _ = octav.get_tensor_quant_params(
65
- op_info=self._fc_op_info,
66
- tensor_quant_config=qtyping.TensorQuantizationConfig(
67
- num_bits=4,
68
- symmetric=True,
69
- granularity=qtyping.QuantGranularity.BLOCKWISE,
70
- ),
71
- tensor_content=test_data,
72
- )
73
-
74
58
  def test_get_tensor_quant_params_unsupported_symmetry(self):
75
59
  err_msg = "Unsupported symmetry"
76
60
  test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
@@ -117,13 +101,22 @@ class OctavQuantizeTest(parameterized.TestCase):
117
101
  [25, -30, 50, -75, 1e5, -125],
118
102
  [50, -60, 70, -80, 90, -100],
119
103
  ])
120
- quant_params = octav.get_tensor_quant_params(
121
- op_info=self._fc_op_info,
122
- tensor_quant_config=qtyping.TensorQuantizationConfig(
123
- num_bits=4,
124
- symmetric=True,
125
- granularity=qtyping.QuantGranularity.TENSORWISE,
104
+ tensor_config = qtyping.TensorQuantizationConfig(
105
+ num_bits=4,
106
+ symmetric=True,
107
+ granularity=qtyping.QuantGranularity.TENSORWISE,
108
+ )
109
+ fc_op_info = qtyping.OpInfo(
110
+ op=self._fc_op,
111
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
112
+ subgraph_op_index=self._subgraph_op_index,
113
+ op_quant_config=qtyping.OpQuantizationConfig(
114
+ weight_tensor_config=tensor_config,
126
115
  ),
116
+ )
117
+ quant_params = octav.get_tensor_quant_params(
118
+ op_info=fc_op_info,
119
+ tensor_quant_config=tensor_config,
127
120
  tensor_content=test_data,
128
121
  )
129
122
  adjusted_test_data = quant_params.quantized_data * quant_params.scale
@@ -131,10 +124,10 @@ class OctavQuantizeTest(parameterized.TestCase):
131
124
  adjusted_max = np.max(np.abs(adjusted_test_data))
132
125
 
133
126
  # Check that some clipping occurred.
134
- with self.subTest(name="SanityCheckClipping"):
127
+ with self.subTest(name="CheckClipping"):
135
128
  self.assertLess(adjusted_max, real_max)
136
129
 
137
- with self.subTest(name="SanityCheckQuantParamsShapes"):
130
+ with self.subTest(name="CheckQuantParamsShapes"):
138
131
  self.assertEqual(quant_params.zero_point.shape, (1, 1))
139
132
  self.assertEqual(quant_params.scale.shape, (1, 1))
140
133
  self.assertIsNone(quant_params.quantized_dimension)
@@ -143,33 +136,47 @@ class OctavQuantizeTest(parameterized.TestCase):
143
136
  cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
144
137
  )
145
138
 
146
- with self.subTest(name="SanityCheckQuantParamsValues"):
139
+ with self.subTest(name="CheckQuantParamsValues"):
147
140
  self.assertTrue(np.all(quant_params.zero_point == 0))
148
141
 
149
142
  def test_get_tensor_quant_params_sanity_channelwise(self):
143
+ # Test that the call generates quant params that are appropriately shaped,
144
+ # have some clipping, and correct config values without checking the
145
+ # actual values numerically.
150
146
  test_data = np.array([
151
147
  [-1e5, 25, -50, 75, -100, 125],
152
148
  [25, -30, 50, -75, 1e5, -125],
153
149
  [50, -60, 70, -80, 90, -100],
154
150
  ])
155
- quant_params = octav.get_tensor_quant_params(
156
- op_info=self._fc_op_info,
157
- tensor_quant_config=qtyping.TensorQuantizationConfig(
158
- num_bits=4,
159
- symmetric=True,
160
- granularity=qtyping.QuantGranularity.CHANNELWISE,
151
+ tensor_config = qtyping.TensorQuantizationConfig(
152
+ num_bits=4,
153
+ symmetric=True,
154
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
155
+ )
156
+ fc_op_info = qtyping.OpInfo(
157
+ op=self._fc_op,
158
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
159
+ subgraph_op_index=self._subgraph_op_index,
160
+ op_quant_config=qtyping.OpQuantizationConfig(
161
+ weight_tensor_config=tensor_config,
161
162
  ),
163
+ )
164
+ quant_params = octav.get_tensor_quant_params(
165
+ op_info=fc_op_info,
166
+ tensor_quant_config=tensor_config,
162
167
  tensor_content=test_data,
163
168
  )
169
+ # Dequantize output to compare with the original test data.
164
170
  adjusted_test_data = quant_params.quantized_data * quant_params.scale
171
+
165
172
  for i, row in enumerate(test_data):
166
173
  real_max = np.max(np.abs(row))
167
174
  adjusted_max = np.max(np.abs(adjusted_test_data[i]))
168
175
  # Check that some clipping occurred.
169
- with self.subTest(name="SanityCheckClipping"):
176
+ with self.subTest(name="CheckClipping"):
170
177
  self.assertLess(adjusted_max, real_max)
171
178
 
172
- with self.subTest(name="SanityCheckQuantParamsShapes"):
179
+ with self.subTest(name="CheckQuantParamsShapes"):
173
180
  self.assertEqual(quant_params.zero_point.shape, (test_data.shape[0], 1))
174
181
  self.assertEqual(quant_params.scale.shape, (test_data.shape[0], 1))
175
182
  self.assertIsNotNone(quant_params.quantized_data)
@@ -177,10 +184,58 @@ class OctavQuantizeTest(parameterized.TestCase):
177
184
  cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
178
185
  )
179
186
 
180
- with self.subTest(name="SanityCheckQuantParamsValues"):
187
+ with self.subTest(name="CheckQuantParamsValues"):
181
188
  self.assertTrue(np.all(quant_params.zero_point == 0))
182
189
  self.assertEqual(quant_params.quantized_dimension, 0)
183
190
 
191
+ def test_get_tensor_quant_params_sanity_blockwise(self):
192
+ # Test that the call generates quant params that are appropriately shaped,
193
+ # have some clipping, and correct config values without checking the
194
+ # actual values numerically.
195
+ test_data = np.random.randint(0, 1024, size=(32, 128))
196
+ tensor_config = qtyping.TensorQuantizationConfig(
197
+ num_bits=4,
198
+ symmetric=True,
199
+ granularity=qtyping.QuantGranularity.BLOCKWISE,
200
+ block_size=32,
201
+ )
202
+ fc_op_info = qtyping.OpInfo(
203
+ op=self._fc_op,
204
+ op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
205
+ subgraph_op_index=self._subgraph_op_index,
206
+ op_quant_config=qtyping.OpQuantizationConfig(
207
+ weight_tensor_config=tensor_config,
208
+ ),
209
+ )
210
+ quant_params = octav.get_tensor_quant_params(
211
+ op_info=fc_op_info,
212
+ tensor_quant_config=tensor_config,
213
+ tensor_content=test_data,
214
+ )
215
+
216
+ with self.subTest(name="CheckQuantParamsShapes"):
217
+ # Check that quant params have appropriate shapes.
218
+ self.assertEqual(quant_params.zero_point.shape, (32, 4))
219
+ self.assertEqual(quant_params.scale.shape, (32, 4))
220
+ self.assertIsNotNone(quant_params.quantized_data)
221
+ self.assertTupleEqual(
222
+ cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
223
+ )
224
+
225
+ scales = np.repeat(quant_params.scale, 32, axis=1)
226
+ adjusted_test_data = quant_params.quantized_data * scales
227
+ for i, row in enumerate(test_data):
228
+ real_max = np.max(np.abs(row))
229
+ adjusted_max = np.max(np.abs(adjusted_test_data[i]))
230
+ # Check that some clipping occurred.
231
+ with self.subTest(name="CheckClipping"):
232
+ self.assertLess(adjusted_max, real_max)
233
+
234
+ with self.subTest(name="CheckQuantParamsValues"):
235
+ self.assertTrue(np.all(quant_params.zero_point == 0))
236
+ # See TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM.
237
+ self.assertEqual(quant_params.quantized_dimension, 1)
238
+
184
239
 
185
240
  if __name__ == "__main__":
186
241
  googletest.main()