ai-edge-quantizer-nightly 0.1.0.dev20250512__py3-none-any.whl → 0.1.0.dev20250514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/algorithm_manager.py +34 -0
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +37 -12
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
- ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +3 -5
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +357 -0
- ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +265 -0
- ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +7 -31
- ai_edge_quantizer/algorithms/uniform_quantize/octav.py +27 -17
- ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +93 -38
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +133 -3
- ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +11 -2
- ai_edge_quantizer/algorithms/utils/common_utils.py +21 -8
- ai_edge_quantizer/default_policy.py +4 -2
- ai_edge_quantizer/params_generator.py +1 -0
- ai_edge_quantizer/qtyping.py +34 -1
- ai_edge_quantizer/transformation_performer.py +5 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation.py +209 -0
- ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
- ai_edge_quantizer/utils/test_utils.py +33 -0
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +1 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/RECORD +25 -21
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.1.0.dev20250512.dist-info → ai_edge_quantizer_nightly-0.1.0.dev20250514.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,265 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Quantizer Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Test Hadamard rotation materialization."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
|
20
|
+
from absl.testing import parameterized
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
from tensorflow.python.platform import googletest
|
24
|
+
from ai_edge_quantizer import qtyping
|
25
|
+
from ai_edge_quantizer.algorithms.uniform_quantize import hadamard_rotation
|
26
|
+
from ai_edge_quantizer.utils import test_utils
|
27
|
+
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
|
28
|
+
|
29
|
+
_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../../tests/models")
|
30
|
+
_TFLOpName = qtyping.TFLOperationName
|
31
|
+
_TensorQuantConfig = qtyping.TensorQuantizationConfig
|
32
|
+
|
33
|
+
|
34
|
+
class HadamardRotationFullyConnectedTest(parameterized.TestCase):
|
35
|
+
|
36
|
+
def setUp(self):
|
37
|
+
super().setUp()
|
38
|
+
np.random.seed(888)
|
39
|
+
self._test_model_path = os.path.join(
|
40
|
+
_TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
|
41
|
+
)
|
42
|
+
self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
|
43
|
+
self._graph_info = qtyping.GraphInfo(
|
44
|
+
subgraph_tensors=self._test_model.subgraphs[0].tensors,
|
45
|
+
buffers=self._test_model.buffers,
|
46
|
+
)
|
47
|
+
self._tensor_name_to_qsv = None
|
48
|
+
self._subgraph = self._test_model.subgraphs[0]
|
49
|
+
fc_subgraph_op_index = 3
|
50
|
+
self._fc_op = self._subgraph.operators[fc_subgraph_op_index]
|
51
|
+
self._fc_buffer_id = self._subgraph.tensors[self._fc_op.inputs[1]].buffer
|
52
|
+
self._op_info = qtyping.OpInfo(
|
53
|
+
op=self._fc_op,
|
54
|
+
op_name=_TFLOpName.FULLY_CONNECTED,
|
55
|
+
subgraph_op_index=fc_subgraph_op_index,
|
56
|
+
op_quant_config=qtyping.OpQuantizationConfig(
|
57
|
+
weight_tensor_config=_TensorQuantConfig(
|
58
|
+
num_bits=8,
|
59
|
+
symmetric=True,
|
60
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
61
|
+
),
|
62
|
+
),
|
63
|
+
)
|
64
|
+
|
65
|
+
def test_materialize_fully_connected_basic(self):
|
66
|
+
params = hadamard_rotation.materialize_fully_connected(
|
67
|
+
self._op_info, self._graph_info, self._tensor_name_to_qsv
|
68
|
+
)
|
69
|
+
fc_input = params[0]
|
70
|
+
weight = params[1]
|
71
|
+
bias = params[2]
|
72
|
+
output = params[3]
|
73
|
+
|
74
|
+
self.assertLen(params, 4)
|
75
|
+
self.assertIsNone(fc_input.producer)
|
76
|
+
self.assertIsNotNone(fc_input.consumers)
|
77
|
+
self.assertIsNone(weight.producer)
|
78
|
+
self.assertIsNotNone(weight.consumers)
|
79
|
+
self.assertIsNone(bias.producer)
|
80
|
+
self.assertIsNotNone(bias.consumers)
|
81
|
+
self.assertIsNotNone(output.producer)
|
82
|
+
self.assertIsNone(output.consumers)
|
83
|
+
self.assertEqual(
|
84
|
+
fc_input.consumers[0].transformations,
|
85
|
+
[qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION],
|
86
|
+
)
|
87
|
+
self.assertEqual(
|
88
|
+
weight.consumers[0].transformations,
|
89
|
+
[qtyping.QuantTransformation.QUANTIZE_TENSOR],
|
90
|
+
)
|
91
|
+
self.assertEqual(
|
92
|
+
bias.consumers[0].transformations,
|
93
|
+
[qtyping.QuantTransformation.NO_QUANTIZE],
|
94
|
+
)
|
95
|
+
if output.producer is not None:
|
96
|
+
self.assertEqual(
|
97
|
+
output.producer.transformations,
|
98
|
+
[qtyping.QuantTransformation.NO_QUANTIZE],
|
99
|
+
)
|
100
|
+
|
101
|
+
def test_get_tensor_quant_params_basic(self):
|
102
|
+
input_tensor = self._subgraph.tensors[self._fc_op.inputs[1]]
|
103
|
+
buffer = self._graph_info.buffers[self._fc_buffer_id]
|
104
|
+
np_buffer = np.frombuffer(buffer.data, dtype=np.float32).reshape(
|
105
|
+
input_tensor.shape
|
106
|
+
)
|
107
|
+
qparams = hadamard_rotation.get_tensor_quant_params(
|
108
|
+
self._op_info,
|
109
|
+
self._op_info.op_quant_config.weight_tensor_config,
|
110
|
+
np_buffer,
|
111
|
+
self._tensor_name_to_qsv,
|
112
|
+
)
|
113
|
+
self.assertEqual(qparams.num_bits, 8)
|
114
|
+
self.assertEqual(qparams.zero_point.all(), 0)
|
115
|
+
self.assertEqual(qparams.symmetric, True)
|
116
|
+
self.assertIsNotNone(qparams.quantized_data)
|
117
|
+
self.assertEqual(qparams.block_size, 0)
|
118
|
+
self.assertIsNotNone(qparams.hadamard)
|
119
|
+
if qparams.hadamard is not None:
|
120
|
+
self.assertEqual(qparams.hadamard.hadamard_size, 32)
|
121
|
+
|
122
|
+
def test_get_tensor_quant_params_golden_1(self):
|
123
|
+
test_data = np.ones((6, 6))
|
124
|
+
# expected:
|
125
|
+
# [[127 0 127 0 127 0]
|
126
|
+
# [127 0 127 0 127 0]
|
127
|
+
# [127 0 127 0 127 0]
|
128
|
+
# [127 0 127 0 127 0]
|
129
|
+
# [127 0 127 0 127 0]
|
130
|
+
# [127 0 127 0 127 0]]
|
131
|
+
expected = np.tile([127, 0], [6, 3])
|
132
|
+
qparams = hadamard_rotation.get_tensor_quant_params(
|
133
|
+
self._op_info,
|
134
|
+
self._op_info.op_quant_config.weight_tensor_config,
|
135
|
+
test_data,
|
136
|
+
self._tensor_name_to_qsv,
|
137
|
+
)
|
138
|
+
self.assertIsNotNone(qparams.quantized_data)
|
139
|
+
np.testing.assert_array_equal(
|
140
|
+
np.array(qparams.quantized_data), expected
|
141
|
+
)
|
142
|
+
|
143
|
+
def test_get_tensor_quant_params_golden_2(self):
|
144
|
+
# test_data:
|
145
|
+
# [[1 2 1 2 1 2]
|
146
|
+
# [3 4 3 4 3 4]
|
147
|
+
# [1 2 1 2 1 2]
|
148
|
+
# [3 4 3 4 3 4]
|
149
|
+
# [1 2 1 2 1 2]
|
150
|
+
# [3 4 3 4 3 4]]
|
151
|
+
test_data = np.tile([[1, 2], [3, 4]], [3, 3])
|
152
|
+
# expected:
|
153
|
+
# [[127 -42 127 -42 127 -42]
|
154
|
+
# [127 -18 127 -18 127 -18]
|
155
|
+
# [127 -42 127 -42 127 -42]
|
156
|
+
# [127 -18 127 -18 127 -18]
|
157
|
+
# [127 -42 127 -42 127 -42]
|
158
|
+
# [127 -18 127 -18 127 -18]]
|
159
|
+
expected = np.tile([[127, -42], [127, -18]], [3, 3])
|
160
|
+
qparams = hadamard_rotation.get_tensor_quant_params(
|
161
|
+
self._op_info,
|
162
|
+
self._op_info.op_quant_config.weight_tensor_config,
|
163
|
+
test_data,
|
164
|
+
self._tensor_name_to_qsv,
|
165
|
+
)
|
166
|
+
self.assertIsNotNone(qparams.quantized_data)
|
167
|
+
np.testing.assert_array_equal(
|
168
|
+
np.array(qparams.quantized_data), expected
|
169
|
+
)
|
170
|
+
|
171
|
+
def test_raise_missing_tensor_content(self):
|
172
|
+
with self.assertRaisesWithPredicateMatch(
|
173
|
+
ValueError, lambda err: "weight tensor" in str(err)
|
174
|
+
):
|
175
|
+
hadamard_rotation.get_tensor_quant_params(
|
176
|
+
self._op_info,
|
177
|
+
self._op_info.op_quant_config.weight_tensor_config,
|
178
|
+
None,
|
179
|
+
self._tensor_name_to_qsv,
|
180
|
+
)
|
181
|
+
|
182
|
+
def test_raise_qsv_set(self):
|
183
|
+
with self.assertRaisesWithPredicateMatch(
|
184
|
+
ValueError, lambda err: "static quantization" in str(err)
|
185
|
+
):
|
186
|
+
hadamard_rotation.get_tensor_quant_params(
|
187
|
+
self._op_info,
|
188
|
+
self._op_info.op_quant_config.weight_tensor_config,
|
189
|
+
self._graph_info.buffers[self._fc_buffer_id],
|
190
|
+
self._graph_info.buffers[self._fc_buffer_id],
|
191
|
+
)
|
192
|
+
|
193
|
+
def test_raise_non_2d_constant(self):
|
194
|
+
with self.assertRaisesWithPredicateMatch(
|
195
|
+
ValueError, lambda err: "2D tensors" in str(err)
|
196
|
+
):
|
197
|
+
hadamard_rotation.get_tensor_quant_params(
|
198
|
+
self._op_info,
|
199
|
+
self._op_info.op_quant_config.weight_tensor_config,
|
200
|
+
np.array([1.0, 2.0, 3.0]),
|
201
|
+
self._tensor_name_to_qsv,
|
202
|
+
)
|
203
|
+
|
204
|
+
|
205
|
+
class HadamardRotationEmbeddingLookupTest(parameterized.TestCase):
|
206
|
+
|
207
|
+
def setUp(self):
|
208
|
+
super().setUp()
|
209
|
+
np.random.seed(888)
|
210
|
+
self._test_model_path = os.path.join(
|
211
|
+
_TEST_DATA_PREFIX_PATH, "embedding_lookup.tflite"
|
212
|
+
)
|
213
|
+
self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
|
214
|
+
self._graph_info = qtyping.GraphInfo(
|
215
|
+
subgraph_tensors=self._test_model.subgraphs[0].tensors,
|
216
|
+
buffers=self._test_model.buffers,
|
217
|
+
)
|
218
|
+
self._tensor_name_to_qsv = None
|
219
|
+
|
220
|
+
def test_materialize_embedding_lookup_basic(self):
|
221
|
+
subgraph = self._test_model.subgraphs[0]
|
222
|
+
embedding_subgraph_op_index = 0
|
223
|
+
embedding_op = subgraph.operators[embedding_subgraph_op_index]
|
224
|
+
op_info = qtyping.OpInfo(
|
225
|
+
op=embedding_op,
|
226
|
+
op_name=_TFLOpName.EMBEDDING_LOOKUP,
|
227
|
+
subgraph_op_index=embedding_subgraph_op_index,
|
228
|
+
op_quant_config=qtyping.OpQuantizationConfig(
|
229
|
+
weight_tensor_config=_TensorQuantConfig(
|
230
|
+
num_bits=8,
|
231
|
+
symmetric=True,
|
232
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
233
|
+
),
|
234
|
+
),
|
235
|
+
)
|
236
|
+
params = hadamard_rotation.materialize_embedding_lookup(
|
237
|
+
op_info, self._graph_info, self._tensor_name_to_qsv
|
238
|
+
)
|
239
|
+
self.assertLen(params, 3)
|
240
|
+
lookup = params[0]
|
241
|
+
value = params[1]
|
242
|
+
output = params[2]
|
243
|
+
self.assertIsNone(lookup.producer)
|
244
|
+
self.assertIsNotNone(lookup.consumers)
|
245
|
+
self.assertIsNone(value.producer)
|
246
|
+
self.assertIsNotNone(value.consumers)
|
247
|
+
self.assertIsNotNone(output.producer)
|
248
|
+
self.assertIsNone(output.consumers)
|
249
|
+
self.assertEqual(
|
250
|
+
lookup.consumers[0].transformations,
|
251
|
+
[qtyping.QuantTransformation.NO_QUANTIZE],
|
252
|
+
)
|
253
|
+
self.assertEqual(
|
254
|
+
value.consumers[0].transformations,
|
255
|
+
[qtyping.QuantTransformation.QUANTIZE_TENSOR],
|
256
|
+
)
|
257
|
+
if output.producer is not None:
|
258
|
+
self.assertEqual(
|
259
|
+
output.producer.transformations,
|
260
|
+
[qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION],
|
261
|
+
)
|
262
|
+
|
263
|
+
|
264
|
+
if __name__ == "__main__":
|
265
|
+
googletest.main()
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""Performs naive min/max uniform quantization."""
|
17
17
|
|
18
18
|
from typing import Any, Optional
|
19
|
-
import ml_dtypes
|
20
19
|
import numpy as np
|
21
20
|
from ai_edge_quantizer import qtyping
|
22
21
|
from ai_edge_quantizer.algorithms.uniform_quantize import common_quantize
|
@@ -75,35 +74,17 @@ def get_tensor_quant_params(
|
|
75
74
|
" the ParamsGenerator."
|
76
75
|
)
|
77
76
|
clipping_values = None
|
78
|
-
# Blockwise quantization uses float16 scale, with 7 bit mantissa,
|
79
|
-
# so the maximum representable value is 65280.
|
80
|
-
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
|
81
|
-
clipping_values = np.broadcast_to(
|
82
|
-
np.array(65280), tensor_min_max["min"].shape
|
83
|
-
)
|
84
77
|
zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
|
85
78
|
tensor_min_max["min"],
|
86
79
|
tensor_min_max["max"],
|
87
80
|
tensor_quant_config.num_bits,
|
88
81
|
tensor_quant_config.symmetric,
|
82
|
+
tensor_quant_config.granularity,
|
89
83
|
clipping_values,
|
90
84
|
)
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
scale.astype(ml_dtypes.bfloat16).astype(np.float16).astype(np.float32)
|
95
|
-
)
|
96
|
-
quantized_dim = None
|
97
|
-
if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
|
98
|
-
quantized_dim = common_utils.get_weight_quantized_dim(
|
99
|
-
op_info, tensor_content
|
100
|
-
)
|
101
|
-
elif tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
|
102
|
-
quantized_dim = (
|
103
|
-
tfl_flatbuffer_utils.TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM[
|
104
|
-
op_info.op_name
|
105
|
-
]
|
106
|
-
)
|
85
|
+
quantized_dim = common_utils.get_weight_quantized_dim(
|
86
|
+
op_info, tensor_content, tensor_quant_config.granularity
|
87
|
+
)
|
107
88
|
quant_params = qtyping.UniformQuantParams(
|
108
89
|
scale=scale,
|
109
90
|
zero_point=zp,
|
@@ -115,15 +96,10 @@ def get_tensor_quant_params(
|
|
115
96
|
if tensor_content is None:
|
116
97
|
return quant_params
|
117
98
|
|
118
|
-
# The reshaping for blockwise quantization is unique hence we do this here
|
119
|
-
# to avoid unexpected broadcast behavior downstream.
|
120
|
-
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
|
121
|
-
quant_params = common_quantize.broadcast_scale_zp_for_blockwise(
|
122
|
-
tensor_content, quant_params
|
123
|
-
)
|
124
|
-
|
125
99
|
quantized_vars = uniform_quantize_tensor.uniform_quantize(
|
126
|
-
tensor_content,
|
100
|
+
tensor_content,
|
101
|
+
quant_params,
|
102
|
+
tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
|
127
103
|
)
|
128
104
|
# Update with quantized values.
|
129
105
|
return qtyping.UniformQuantParams(
|
@@ -102,21 +102,13 @@ def get_tensor_quant_params(
|
|
102
102
|
op_info, tensor_quant_config, tensor_content, tensor_qsv
|
103
103
|
)
|
104
104
|
|
105
|
-
if (
|
106
|
-
tensor_quant_config.granularity != qtyping.QuantGranularity.CHANNELWISE
|
107
|
-
and tensor_quant_config.granularity != qtyping.QuantGranularity.TENSORWISE
|
108
|
-
):
|
109
|
-
raise ValueError(
|
110
|
-
f"Unsupported granularity: {tensor_quant_config.granularity}."
|
111
|
-
)
|
112
|
-
|
113
105
|
if not tensor_quant_config.symmetric:
|
114
106
|
raise ValueError(
|
115
107
|
f"Unsupported symmetry: {tensor_quant_config.symmetric}. OCTAV"
|
116
108
|
" supports symmetric quantization only for now."
|
117
109
|
)
|
118
110
|
|
119
|
-
if tensor_qsv
|
111
|
+
if not tensor_qsv:
|
120
112
|
# We need min/max to calculate quantization parameters, which
|
121
113
|
# should be collected during the calibration process. However,
|
122
114
|
# weight-only and DRQ do not require calibration, thus it is
|
@@ -136,25 +128,41 @@ def get_tensor_quant_params(
|
|
136
128
|
" the ParamsGenerator."
|
137
129
|
)
|
138
130
|
|
139
|
-
quantized_dim =
|
140
|
-
|
141
|
-
|
142
|
-
|
131
|
+
quantized_dim = common_utils.get_weight_quantized_dim(
|
132
|
+
op_info, tensor_content, tensor_quant_config.granularity
|
133
|
+
)
|
134
|
+
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
|
135
|
+
reshaped_data, reduce_dims = (
|
136
|
+
uniform_quantize_tensor.reshape_data_for_blockwise(
|
137
|
+
tensor_content,
|
138
|
+
op_info.op_name,
|
139
|
+
tensor_quant_config.block_size,
|
140
|
+
)
|
141
|
+
)
|
142
|
+
else:
|
143
|
+
reshaped_data = tensor_content
|
144
|
+
reduce_dims = common_utils.get_reduce_dims(
|
145
|
+
quantized_dim, tensor_content.shape
|
143
146
|
)
|
144
|
-
|
145
147
|
clipping_constants = _guess_clipping_with_octav(
|
146
|
-
|
148
|
+
reshaped_data,
|
147
149
|
tensor_quant_config.num_bits,
|
148
|
-
|
150
|
+
reduce_dims,
|
149
151
|
max_iterations=10,
|
150
152
|
exponent_divisor=3.0 if tensor_quant_config.symmetric else 12.0,
|
151
153
|
)
|
154
|
+
# We created a new dimension in order to reduce properly for blockwise
|
155
|
+
# quantization, so we need to reshape the clipping constants back to the
|
156
|
+
# min/max shape for the next step.
|
157
|
+
if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
|
158
|
+
clipping_constants = clipping_constants.reshape(tensor_min_max["min"].shape)
|
152
159
|
|
153
160
|
zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
|
154
161
|
tensor_min_max["min"],
|
155
162
|
tensor_min_max["max"],
|
156
163
|
tensor_quant_config.num_bits,
|
157
164
|
tensor_quant_config.symmetric,
|
165
|
+
tensor_quant_config.granularity,
|
158
166
|
clipping_constants,
|
159
167
|
)
|
160
168
|
|
@@ -168,7 +176,9 @@ def get_tensor_quant_params(
|
|
168
176
|
)
|
169
177
|
|
170
178
|
quantized_vars = uniform_quantize_tensor.uniform_quantize(
|
171
|
-
tensor_content,
|
179
|
+
tensor_content,
|
180
|
+
quant_params,
|
181
|
+
tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE,
|
172
182
|
)
|
173
183
|
|
174
184
|
return dataclasses.replace(quant_params, quantized_data=quantized_vars)
|
@@ -44,33 +44,17 @@ class OctavQuantizeTest(parameterized.TestCase):
|
|
44
44
|
)
|
45
45
|
self._tensor_name_to_qsv = {}
|
46
46
|
subgraph0 = self._test_model.subgraphs[0]
|
47
|
-
|
48
|
-
|
47
|
+
self._subgraph_op_index = 3
|
48
|
+
self._fc_op = subgraph0.operators[self._subgraph_op_index]
|
49
49
|
self._fc_op_info = qtyping.OpInfo(
|
50
|
-
op=
|
50
|
+
op=self._fc_op,
|
51
51
|
op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
52
|
-
subgraph_op_index=
|
52
|
+
subgraph_op_index=self._subgraph_op_index,
|
53
53
|
op_quant_config=qtyping.OpQuantizationConfig(
|
54
54
|
weight_tensor_config=None,
|
55
55
|
),
|
56
56
|
)
|
57
57
|
|
58
|
-
def test_get_tensor_quant_params_unsupported_granularity_assert(self):
|
59
|
-
err_msg = "Unsupported granularity"
|
60
|
-
test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
|
61
|
-
with self.assertRaisesWithPredicateMatch(
|
62
|
-
ValueError, lambda err: err_msg in str(err)
|
63
|
-
):
|
64
|
-
_ = octav.get_tensor_quant_params(
|
65
|
-
op_info=self._fc_op_info,
|
66
|
-
tensor_quant_config=qtyping.TensorQuantizationConfig(
|
67
|
-
num_bits=4,
|
68
|
-
symmetric=True,
|
69
|
-
granularity=qtyping.QuantGranularity.BLOCKWISE,
|
70
|
-
),
|
71
|
-
tensor_content=test_data,
|
72
|
-
)
|
73
|
-
|
74
58
|
def test_get_tensor_quant_params_unsupported_symmetry(self):
|
75
59
|
err_msg = "Unsupported symmetry"
|
76
60
|
test_data = np.array([[-7, 7], [4, -4], [4, -4], [7, 7]])
|
@@ -117,13 +101,22 @@ class OctavQuantizeTest(parameterized.TestCase):
|
|
117
101
|
[25, -30, 50, -75, 1e5, -125],
|
118
102
|
[50, -60, 70, -80, 90, -100],
|
119
103
|
])
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
104
|
+
tensor_config = qtyping.TensorQuantizationConfig(
|
105
|
+
num_bits=4,
|
106
|
+
symmetric=True,
|
107
|
+
granularity=qtyping.QuantGranularity.TENSORWISE,
|
108
|
+
)
|
109
|
+
fc_op_info = qtyping.OpInfo(
|
110
|
+
op=self._fc_op,
|
111
|
+
op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
112
|
+
subgraph_op_index=self._subgraph_op_index,
|
113
|
+
op_quant_config=qtyping.OpQuantizationConfig(
|
114
|
+
weight_tensor_config=tensor_config,
|
126
115
|
),
|
116
|
+
)
|
117
|
+
quant_params = octav.get_tensor_quant_params(
|
118
|
+
op_info=fc_op_info,
|
119
|
+
tensor_quant_config=tensor_config,
|
127
120
|
tensor_content=test_data,
|
128
121
|
)
|
129
122
|
adjusted_test_data = quant_params.quantized_data * quant_params.scale
|
@@ -131,10 +124,10 @@ class OctavQuantizeTest(parameterized.TestCase):
|
|
131
124
|
adjusted_max = np.max(np.abs(adjusted_test_data))
|
132
125
|
|
133
126
|
# Check that some clipping occurred.
|
134
|
-
with self.subTest(name="
|
127
|
+
with self.subTest(name="CheckClipping"):
|
135
128
|
self.assertLess(adjusted_max, real_max)
|
136
129
|
|
137
|
-
with self.subTest(name="
|
130
|
+
with self.subTest(name="CheckQuantParamsShapes"):
|
138
131
|
self.assertEqual(quant_params.zero_point.shape, (1, 1))
|
139
132
|
self.assertEqual(quant_params.scale.shape, (1, 1))
|
140
133
|
self.assertIsNone(quant_params.quantized_dimension)
|
@@ -143,33 +136,47 @@ class OctavQuantizeTest(parameterized.TestCase):
|
|
143
136
|
cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
|
144
137
|
)
|
145
138
|
|
146
|
-
with self.subTest(name="
|
139
|
+
with self.subTest(name="CheckQuantParamsValues"):
|
147
140
|
self.assertTrue(np.all(quant_params.zero_point == 0))
|
148
141
|
|
149
142
|
def test_get_tensor_quant_params_sanity_channelwise(self):
|
143
|
+
# Test that the call generates quant params that are appropriately shaped,
|
144
|
+
# have some clipping, and correct config values without checking the
|
145
|
+
# actual values numerically.
|
150
146
|
test_data = np.array([
|
151
147
|
[-1e5, 25, -50, 75, -100, 125],
|
152
148
|
[25, -30, 50, -75, 1e5, -125],
|
153
149
|
[50, -60, 70, -80, 90, -100],
|
154
150
|
])
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
151
|
+
tensor_config = qtyping.TensorQuantizationConfig(
|
152
|
+
num_bits=4,
|
153
|
+
symmetric=True,
|
154
|
+
granularity=qtyping.QuantGranularity.CHANNELWISE,
|
155
|
+
)
|
156
|
+
fc_op_info = qtyping.OpInfo(
|
157
|
+
op=self._fc_op,
|
158
|
+
op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
159
|
+
subgraph_op_index=self._subgraph_op_index,
|
160
|
+
op_quant_config=qtyping.OpQuantizationConfig(
|
161
|
+
weight_tensor_config=tensor_config,
|
161
162
|
),
|
163
|
+
)
|
164
|
+
quant_params = octav.get_tensor_quant_params(
|
165
|
+
op_info=fc_op_info,
|
166
|
+
tensor_quant_config=tensor_config,
|
162
167
|
tensor_content=test_data,
|
163
168
|
)
|
169
|
+
# Dequantize output to compare with the original test data.
|
164
170
|
adjusted_test_data = quant_params.quantized_data * quant_params.scale
|
171
|
+
|
165
172
|
for i, row in enumerate(test_data):
|
166
173
|
real_max = np.max(np.abs(row))
|
167
174
|
adjusted_max = np.max(np.abs(adjusted_test_data[i]))
|
168
175
|
# Check that some clipping occurred.
|
169
|
-
with self.subTest(name="
|
176
|
+
with self.subTest(name="CheckClipping"):
|
170
177
|
self.assertLess(adjusted_max, real_max)
|
171
178
|
|
172
|
-
with self.subTest(name="
|
179
|
+
with self.subTest(name="CheckQuantParamsShapes"):
|
173
180
|
self.assertEqual(quant_params.zero_point.shape, (test_data.shape[0], 1))
|
174
181
|
self.assertEqual(quant_params.scale.shape, (test_data.shape[0], 1))
|
175
182
|
self.assertIsNotNone(quant_params.quantized_data)
|
@@ -177,10 +184,58 @@ class OctavQuantizeTest(parameterized.TestCase):
|
|
177
184
|
cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
|
178
185
|
)
|
179
186
|
|
180
|
-
with self.subTest(name="
|
187
|
+
with self.subTest(name="CheckQuantParamsValues"):
|
181
188
|
self.assertTrue(np.all(quant_params.zero_point == 0))
|
182
189
|
self.assertEqual(quant_params.quantized_dimension, 0)
|
183
190
|
|
191
|
+
def test_get_tensor_quant_params_sanity_blockwise(self):
|
192
|
+
# Test that the call generates quant params that are appropriately shaped,
|
193
|
+
# have some clipping, and correct config values without checking the
|
194
|
+
# actual values numerically.
|
195
|
+
test_data = np.random.randint(0, 1024, size=(32, 128))
|
196
|
+
tensor_config = qtyping.TensorQuantizationConfig(
|
197
|
+
num_bits=4,
|
198
|
+
symmetric=True,
|
199
|
+
granularity=qtyping.QuantGranularity.BLOCKWISE,
|
200
|
+
block_size=32,
|
201
|
+
)
|
202
|
+
fc_op_info = qtyping.OpInfo(
|
203
|
+
op=self._fc_op,
|
204
|
+
op_name=qtyping.TFLOperationName.FULLY_CONNECTED,
|
205
|
+
subgraph_op_index=self._subgraph_op_index,
|
206
|
+
op_quant_config=qtyping.OpQuantizationConfig(
|
207
|
+
weight_tensor_config=tensor_config,
|
208
|
+
),
|
209
|
+
)
|
210
|
+
quant_params = octav.get_tensor_quant_params(
|
211
|
+
op_info=fc_op_info,
|
212
|
+
tensor_quant_config=tensor_config,
|
213
|
+
tensor_content=test_data,
|
214
|
+
)
|
215
|
+
|
216
|
+
with self.subTest(name="CheckQuantParamsShapes"):
|
217
|
+
# Check that quant params have appropriate shapes.
|
218
|
+
self.assertEqual(quant_params.zero_point.shape, (32, 4))
|
219
|
+
self.assertEqual(quant_params.scale.shape, (32, 4))
|
220
|
+
self.assertIsNotNone(quant_params.quantized_data)
|
221
|
+
self.assertTupleEqual(
|
222
|
+
cast(np.ndarray, quant_params.quantized_data).shape, test_data.shape
|
223
|
+
)
|
224
|
+
|
225
|
+
scales = np.repeat(quant_params.scale, 32, axis=1)
|
226
|
+
adjusted_test_data = quant_params.quantized_data * scales
|
227
|
+
for i, row in enumerate(test_data):
|
228
|
+
real_max = np.max(np.abs(row))
|
229
|
+
adjusted_max = np.max(np.abs(adjusted_test_data[i]))
|
230
|
+
# Check that some clipping occurred.
|
231
|
+
with self.subTest(name="CheckClipping"):
|
232
|
+
self.assertLess(adjusted_max, real_max)
|
233
|
+
|
234
|
+
with self.subTest(name="CheckQuantParamsValues"):
|
235
|
+
self.assertTrue(np.all(quant_params.zero_point == 0))
|
236
|
+
# See TFL_OP_TO_BLOCKWISE_WEIGHT_QUANTIZED_DIM.
|
237
|
+
self.assertEqual(quant_params.quantized_dimension, 1)
|
238
|
+
|
184
239
|
|
185
240
|
if __name__ == "__main__":
|
186
241
|
googletest.main()
|