ai-edge-quantizer-nightly 0.0.1.dev20250115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/__init__.py +19 -0
  2. ai_edge_quantizer/algorithm_manager.py +167 -0
  3. ai_edge_quantizer/algorithm_manager_api.py +271 -0
  4. ai_edge_quantizer/algorithm_manager_api_test.py +210 -0
  5. ai_edge_quantizer/algorithms/__init__.py +15 -0
  6. ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py +15 -0
  7. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py +273 -0
  8. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +664 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/__init__.py +15 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +666 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +184 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +371 -0
  13. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +357 -0
  14. ai_edge_quantizer/algorithms/utils/__init__.py +15 -0
  15. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +1067 -0
  16. ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +512 -0
  17. ai_edge_quantizer/calibrator.py +288 -0
  18. ai_edge_quantizer/calibrator_test.py +297 -0
  19. ai_edge_quantizer/conftest.py +22 -0
  20. ai_edge_quantizer/default_policy.py +310 -0
  21. ai_edge_quantizer/model_modifier.py +176 -0
  22. ai_edge_quantizer/model_modifier_test.py +130 -0
  23. ai_edge_quantizer/model_validator.py +357 -0
  24. ai_edge_quantizer/model_validator_test.py +354 -0
  25. ai_edge_quantizer/params_generator.py +361 -0
  26. ai_edge_quantizer/params_generator_test.py +1041 -0
  27. ai_edge_quantizer/qtyping.py +483 -0
  28. ai_edge_quantizer/quantizer.py +372 -0
  29. ai_edge_quantizer/quantizer_test.py +532 -0
  30. ai_edge_quantizer/recipe.py +67 -0
  31. ai_edge_quantizer/recipe_manager.py +245 -0
  32. ai_edge_quantizer/recipe_manager_test.py +815 -0
  33. ai_edge_quantizer/recipe_test.py +97 -0
  34. ai_edge_quantizer/transformation_instruction_generator.py +584 -0
  35. ai_edge_quantizer/transformation_instruction_generator_test.py +1082 -0
  36. ai_edge_quantizer/transformation_performer.py +278 -0
  37. ai_edge_quantizer/transformation_performer_test.py +344 -0
  38. ai_edge_quantizer/transformations/__init__.py +15 -0
  39. ai_edge_quantizer/transformations/dequant_insert.py +87 -0
  40. ai_edge_quantizer/transformations/dequant_insert_test.py +304 -0
  41. ai_edge_quantizer/transformations/emulated_subchannel.py +363 -0
  42. ai_edge_quantizer/transformations/emulated_subchannel_test.py +212 -0
  43. ai_edge_quantizer/transformations/quant_insert.py +100 -0
  44. ai_edge_quantizer/transformations/quant_insert_test.py +284 -0
  45. ai_edge_quantizer/transformations/quantize_tensor.py +156 -0
  46. ai_edge_quantizer/transformations/quantize_tensor_test.py +227 -0
  47. ai_edge_quantizer/transformations/transformation_utils.py +132 -0
  48. ai_edge_quantizer/transformations/transformation_utils_test.py +162 -0
  49. ai_edge_quantizer/utils/__init__.py +15 -0
  50. ai_edge_quantizer/utils/calibration_utils.py +86 -0
  51. ai_edge_quantizer/utils/calibration_utils_test.py +77 -0
  52. ai_edge_quantizer/utils/test_utils.py +107 -0
  53. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +317 -0
  54. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +200 -0
  55. ai_edge_quantizer/utils/tfl_interpreter_utils.py +312 -0
  56. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +332 -0
  57. ai_edge_quantizer/utils/validation_utils.py +125 -0
  58. ai_edge_quantizer/utils/validation_utils_test.py +87 -0
  59. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/LICENSE +201 -0
  60. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/METADATA +32 -0
  61. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/RECORD +63 -0
  62. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/WHEEL +5 -0
  63. ai_edge_quantizer_nightly-0.0.1.dev20250115.dist-info/top_level.txt +1 -0
@@ -0,0 +1,184 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+
18
+ from absl.testing import parameterized
19
+ import numpy as np
20
+
21
+ from tensorflow.python.platform import googletest
22
+ from ai_edge_quantizer import default_policy
23
+ from ai_edge_quantizer import qtyping
24
+ from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
25
+ from ai_edge_quantizer.utils import test_utils
26
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
27
+
28
+ _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../../tests/models")
29
+ _TFLOpName = qtyping.TFLOperationName
30
+ _TensorQuantConfig = qtyping.TensorQuantizationConfig
31
+
32
+
33
+ class NaiveMinMaxQuantizeTest(parameterized.TestCase):
34
+ """Tests for general functions innaive min-max quantize algorithm.
35
+
36
+ See naive_min_max_quantize_op_tests for op specific tests.
37
+ """
38
+
39
+ def setUp(self):
40
+ super().setUp()
41
+ np.random.seed(666)
42
+ self._test_model_path = os.path.join(
43
+ _TEST_DATA_PREFIX_PATH, "conv_fc_mnist.tflite"
44
+ )
45
+ self._test_model = tfl_flatbuffer_utils.read_model(self._test_model_path)
46
+ # The test model has one subgraph for now.
47
+ self._graph_info = qtyping.GraphInfo(
48
+ subgraph_tensors=self._test_model.subgraphs[0].tensors,
49
+ buffers=self._test_model.buffers,
50
+ )
51
+ self._tensor_name_to_qsv = {}
52
+
53
+ @parameterized.parameters(
54
+ (qtyping.QuantGranularity.TENSORWISE),
55
+ (qtyping.QuantGranularity.CHANNELWISE),
56
+ )
57
+ def test_init_qsvs(self, granularity):
58
+ # Read from Model Explorer.
59
+ subgraph0 = self._test_model.subgraphs[0]
60
+ subgraph_op_index = 3
61
+ fc_op = subgraph0.operators[subgraph_op_index]
62
+ op_info = qtyping.OpInfo(
63
+ op=fc_op,
64
+ op_name=_TFLOpName.FULLY_CONNECTED,
65
+ subgraph_op_index=subgraph_op_index,
66
+ op_quant_config=qtyping.OpQuantizationConfig(
67
+ weight_tensor_config=_TensorQuantConfig(
68
+ 8,
69
+ symmetric=True,
70
+ granularity=granularity,
71
+ ),
72
+ ),
73
+ )
74
+
75
+ initial_qsvs = naive_min_max_quantize.init_qsvs(
76
+ op_info,
77
+ self._graph_info,
78
+ )
79
+ self.assertIn("sequential/flatten/Reshape", initial_qsvs)
80
+ input_tensor_qsv = initial_qsvs["sequential/flatten/Reshape"]
81
+ self.assertEmpty(input_tensor_qsv)
82
+ self.assertIn(
83
+ "sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd",
84
+ initial_qsvs,
85
+ )
86
+ output_tensor_qsv = initial_qsvs[
87
+ "sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd"
88
+ ]
89
+ self.assertEmpty(output_tensor_qsv)
90
+
91
+ self.assertIn("arith.constant1", initial_qsvs)
92
+ weight_tensor_qsv = initial_qsvs["arith.constant1"]
93
+ if granularity is qtyping.QuantGranularity.CHANNELWISE:
94
+ mins_maxs_shape = (32, 1)
95
+ else:
96
+ mins_maxs_shape = (1, 1)
97
+ self.assertTupleEqual(weight_tensor_qsv["min"].shape, mins_maxs_shape)
98
+ self.assertTupleEqual(weight_tensor_qsv["max"].shape, mins_maxs_shape)
99
+
100
+ self.assertIn("arith.constant2", initial_qsvs)
101
+ bias_tensor_qsv = initial_qsvs["arith.constant2"]
102
+ if granularity is qtyping.QuantGranularity.CHANNELWISE:
103
+ mins_maxs_shape = (32,)
104
+ else:
105
+ mins_maxs_shape = (1,)
106
+ self.assertTupleEqual(bias_tensor_qsv["min"].shape, mins_maxs_shape)
107
+ self.assertTupleEqual(bias_tensor_qsv["max"].shape, mins_maxs_shape)
108
+
109
+ initial_qsvs = naive_min_max_quantize.init_qsvs(
110
+ op_info,
111
+ self._graph_info,
112
+ inputs_to_ignore=[0],
113
+ outputs_to_ignore=[0],
114
+ )
115
+ self.assertNotIn("sequential/flatten/Reshape", initial_qsvs)
116
+ self.assertNotIn(
117
+ "sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd",
118
+ initial_qsvs,
119
+ )
120
+
121
+ def test_min_max_calibrate(self):
122
+ # Sample input/output data for the fc op.
123
+ tensor_content_map = {
124
+ "sequential/flatten/Reshape": np.array([[1, 2, 3, 4, 5]]),
125
+ "sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd": np.array(
126
+ [[6, 7, 8, 9, 10]]
127
+ ),
128
+ }
129
+ # Read from Model Explorer.
130
+ subgraph0 = self._test_model.subgraphs[0]
131
+ fc_op = subgraph0.operators[3]
132
+ # ignore 1(weight), and 2(bias) in inputs.
133
+ op_qsvs = naive_min_max_quantize.min_max_calibrate(
134
+ fc_op,
135
+ self._graph_info,
136
+ tensor_content_map,
137
+ [1, 2],
138
+ [],
139
+ )
140
+ self.assertIn("sequential/flatten/Reshape", op_qsvs)
141
+ input_tensor_qsv = op_qsvs["sequential/flatten/Reshape"]
142
+ self.assertTupleEqual(input_tensor_qsv["min"].shape, (1, 1))
143
+ self.assertEqual(input_tensor_qsv["min"], np.array([[1]]))
144
+ self.assertTupleEqual(input_tensor_qsv["max"].shape, (1, 1))
145
+ self.assertEqual(input_tensor_qsv["max"], np.array([[5]]))
146
+ self.assertIn(
147
+ "sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd",
148
+ op_qsvs,
149
+ )
150
+ output_tensor_qsv = op_qsvs[
151
+ "sequential/dense/MatMul;sequential/dense/Relu;sequential/dense/BiasAdd"
152
+ ]
153
+ self.assertTupleEqual(output_tensor_qsv["min"].shape, (1, 1))
154
+ self.assertEqual(output_tensor_qsv["min"], np.array([[6]]))
155
+ self.assertTupleEqual(output_tensor_qsv["max"].shape, (1, 1))
156
+ self.assertEqual(output_tensor_qsv["max"], np.array([[10]]))
157
+ # weight and bias are excluded.
158
+ self.assertNotIn("arith.constant1", op_qsvs)
159
+ self.assertNotIn("arith.constant2", op_qsvs)
160
+
161
+ def test_check_op_quantization_config_with_negative_min_weight_elements_raises_error(
162
+ self,
163
+ ):
164
+ op_quant_config = qtyping.OpQuantizationConfig(
165
+ weight_tensor_config=_TensorQuantConfig(
166
+ num_bits=8,
167
+ granularity=qtyping.QuantGranularity.CHANNELWISE,
168
+ ),
169
+ compute_precision=qtyping.ComputePrecision.INTEGER, # DRQ.
170
+ min_weight_elements=-1,
171
+ )
172
+ with self.assertRaisesWithPredicateMatch(
173
+ ValueError,
174
+ lambda err: "min_weight_elements must be non-negative" in str(err),
175
+ ):
176
+ naive_min_max_quantize.check_op_quantization_config(
177
+ _TFLOpName.FULLY_CONNECTED,
178
+ op_quant_config,
179
+ default_policy.DEFAULT_CONFIG_CHECK_POLICY,
180
+ )
181
+
182
+
183
+ if __name__ == "__main__":
184
+ googletest.main()
@@ -0,0 +1,371 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Uniform quantize in tensor level."""
17
+
18
+ import dataclasses
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+
22
+
23
+ @dataclasses.dataclass(frozen=True)
24
+ class IntType:
25
+ num_bits: int
26
+ signed: bool
27
+
28
+
29
+ def get_quantized_range(qtype: IntType) -> tuple[float, float]:
30
+ """Calculates range of the quantized type."""
31
+ if qtype.signed:
32
+ qmax = 2 ** (qtype.num_bits - 1) - 1
33
+ qmin = -(2 ** (qtype.num_bits - 1))
34
+ else:
35
+ qmax = (2**qtype.num_bits) - 1
36
+ qmin = 0
37
+ return float(qmin), float(qmax)
38
+
39
+
40
+ def _round_and_clip(
41
+ tensor: np.ndarray, qtype: IntType, narrow: bool
42
+ ) -> np.ndarray:
43
+ """Round and clip the tensor to the given type, but don't cast it."""
44
+ qmin, qmax = get_quantized_range(qtype)
45
+ if narrow:
46
+ if qtype.signed:
47
+ return np.clip(
48
+ np.rint(tensor),
49
+ qmin + 1,
50
+ qmax,
51
+ )
52
+ else:
53
+ raise ValueError("Unsigned data type should not have narrow range.")
54
+ else:
55
+ return np.clip(np.rint(tensor), qmin, qmax)
56
+
57
+
58
+ def assign_quantized_type(tensor: np.ndarray, qtype: IntType) -> np.ndarray:
59
+ """Cast the tensor to the quantized type."""
60
+ if qtype.num_bits <= 8:
61
+ qtype = np.int8 if qtype.signed else np.uint8
62
+ elif qtype.num_bits <= 16:
63
+ qtype = np.int16 if qtype.signed else np.uint16
64
+ elif qtype.num_bits <= 32:
65
+ qtype = np.int32 if qtype.signed else np.uint32
66
+ else:
67
+ qtype = np.int64 if qtype.signed else np.uint64
68
+ return tensor.astype(qtype)
69
+
70
+
71
+ def fix_quantization_params_rank(
72
+ tensor_data: np.ndarray,
73
+ quantization_params: qtyping.UniformQuantParams,
74
+ ) -> qtyping.UniformQuantParams:
75
+ """Fix the rank of quantization parameters (scale/zero points).
76
+
77
+ Scale and zero points need to be the same rank as tensor_data to avoid
78
+ ambiguous broadcasting.
79
+
80
+ Args:
81
+ tensor_data: The tensor to be quantized.
82
+ quantization_params: The quantization parameters.
83
+
84
+ Returns:
85
+ quantization_params with broadcasted scales and zero_points.
86
+ """
87
+ scales, zero_points = (
88
+ quantization_params.scale,
89
+ quantization_params.zero_point,
90
+ )
91
+ if tensor_data.ndim == scales.ndim:
92
+ return quantization_params
93
+
94
+ if tensor_data.ndim == 0:
95
+ # Scalar tensor requires scalar scale and zero_point.
96
+ if scales.size != 1 or zero_points.size != 1:
97
+ raise ValueError(
98
+ "Scale and zero_point must contain single element for scalar tensor."
99
+ f" Got scale: {scales}, zero_point: {zero_points}"
100
+ )
101
+ scales = np.array(scales.item())
102
+ zero_points = np.array(zero_points.item())
103
+ else:
104
+ dims = [
105
+ dim
106
+ for dim in range(tensor_data.ndim)
107
+ if dim != quantization_params.quantized_dimension
108
+ ]
109
+ scales = np.expand_dims(scales, axis=dims)
110
+ zero_points = np.expand_dims(zero_points, axis=dims)
111
+
112
+ return qtyping.UniformQuantParams(
113
+ scale=scales,
114
+ zero_point=zero_points,
115
+ num_bits=quantization_params.num_bits,
116
+ symmetric=quantization_params.symmetric,
117
+ quantized_dimension=quantization_params.quantized_dimension,
118
+ quantized_data=quantization_params.quantized_data,
119
+ )
120
+
121
+
122
+ def uniform_quantize_for_emulated_subchannel(
123
+ tensor_data: np.ndarray,
124
+ quantization_params: qtyping.UniformQuantParams,
125
+ block_size: int,
126
+ ) -> np.ndarray:
127
+ """Uniform quantize a tensor for emulated subchannel.
128
+
129
+ emulation involves reshaping the tensor and quantizing value on a different
130
+ axes. Hence, we use a different quantization function.
131
+
132
+ Args:
133
+ tensor_data: The tensor to be quantized.
134
+ quantization_params: The quantization parameters.
135
+ block_size: The block size of the emulated subchannel.
136
+
137
+ Returns:
138
+ The quantized tensor.
139
+ """
140
+ scales, zero_points = (
141
+ quantization_params.scale,
142
+ quantization_params.zero_point,
143
+ )
144
+ transposed_and_reshaped_tensor = np.reshape(
145
+ np.transpose(tensor_data, (1, 0)),
146
+ (
147
+ 1,
148
+ int(tensor_data.shape[1] / block_size),
149
+ block_size,
150
+ tensor_data.shape[0],
151
+ ),
152
+ )
153
+ inverse_scales = 1.0 / scales
154
+ qtype = IntType(quantization_params.num_bits, signed=True)
155
+ # Symmetric means narrow range (e.g., -127 to 127)
156
+ narrow_range = quantization_params.symmetric
157
+ required_dtype = np.signedinteger if qtype.signed else np.unsignedinteger
158
+ if not np.issubdtype(zero_points.dtype, required_dtype):
159
+ raise ValueError(
160
+ f"zero_points need to be {required_dtype}."
161
+ f" But the actual type is {zero_points.dtype}."
162
+ )
163
+ ret = (
164
+ np.multiply(transposed_and_reshaped_tensor, inverse_scales) + zero_points
165
+ )
166
+ ret = _round_and_clip(ret, qtype, narrow_range)
167
+ ret = assign_quantized_type(ret, qtype)
168
+ return ret
169
+
170
+
171
+ def uniform_quantize(
172
+ tensor_data: np.ndarray,
173
+ quantization_params: qtyping.UniformQuantParams,
174
+ ):
175
+ """Uniform quantize a tensor.
176
+
177
+ Args:
178
+ tensor_data: The tensor to be quantized.
179
+ quantization_params: The quantization parameters.
180
+
181
+ Returns:
182
+ The quantized tensor.
183
+ """
184
+ # quant params in flatbuffer is flattened, expand the rank to be the same
185
+ # as the tensor rank to avoid ambiguous broadcasting.
186
+ quantization_params = fix_quantization_params_rank(
187
+ tensor_data, quantization_params
188
+ )
189
+ _is_valid_quantization_params(tensor_data, quantization_params)
190
+ scales, zero_points = (
191
+ quantization_params.scale,
192
+ quantization_params.zero_point,
193
+ )
194
+ inverse_scales = 1.0 / scales
195
+ # TODO: b/332574603 - support unsigned data type.
196
+ qtype = IntType(quantization_params.num_bits, signed=True)
197
+ # Symmetric means narrow range (e.g., -127 to 127)
198
+ narrow_range = quantization_params.symmetric
199
+ required_dtype = np.signedinteger if qtype.signed else np.unsignedinteger
200
+ if not np.issubdtype(zero_points.dtype, required_dtype):
201
+ raise ValueError(
202
+ f"zero_points need to be {required_dtype}."
203
+ f" But the actual type is {zero_points.dtype}."
204
+ )
205
+ ret = np.multiply(tensor_data, inverse_scales) + zero_points
206
+ ret = _round_and_clip(ret, qtype, narrow_range)
207
+ ret = assign_quantized_type(ret, qtype)
208
+ return ret
209
+
210
+
211
+ def uniform_dequantize(
212
+ tensor_data: np.ndarray,
213
+ quantization_params: qtyping.UniformQuantParams,
214
+ ):
215
+ """Uniform dequantize a tensor.
216
+
217
+ Args:
218
+ tensor_data: The tensor to be dequantized.
219
+ quantization_params: The quantization parameters.
220
+
221
+ Returns:
222
+ The dequantized tensor.
223
+ """
224
+ # quant params in flatbuffer is flattened, expand the rank to be the same
225
+ # as the tensor rank to avoid ambiguous broadcasting.
226
+ quantization_params = fix_quantization_params_rank(
227
+ tensor_data, quantization_params
228
+ )
229
+ _is_valid_quantization_params(tensor_data, quantization_params)
230
+ return np.multiply(
231
+ tensor_data - quantization_params.zero_point, quantization_params.scale
232
+ )
233
+
234
+
235
+ def symmetric_quantize_bias_tensor(
236
+ bias_content: np.ndarray,
237
+ input_tensor_quant_params: qtyping.UniformQuantParams,
238
+ weight_tensor_quant_params: qtyping.UniformQuantParams,
239
+ ) -> qtyping.UniformQuantParams:
240
+ """Quantize bias tensor (symmetrically, i.e., zero_point = 0).
241
+
242
+ We quantize bias to a much higher bit width, e.g., int32 for int8 weights. We
243
+ can afford the cost of being symmetric all the time. This configuration fits
244
+ TFL kernel designs.
245
+
246
+ Args:
247
+ bias_content: The bias content.
248
+ input_tensor_quant_params: The quantization parameters of input tensor.
249
+ weight_tensor_quant_params: The quantization parameters of weight tensor.
250
+
251
+ Returns:
252
+ The quantized bias tensor.
253
+ """
254
+ input_tensor_scale = input_tensor_quant_params.scale
255
+ weight_tensor_scale = weight_tensor_quant_params.scale
256
+ # Bias is always 1D, make sure the scale has 1D shape as well.
257
+ effective_output_scale = np.squeeze(input_tensor_scale * weight_tensor_scale)
258
+ # Squeeze can produce scalar, but we want 1D tensor.
259
+ if not effective_output_scale.shape:
260
+ effective_output_scale = np.expand_dims(effective_output_scale, axis=0)
261
+
262
+ # symmetric
263
+ bias_zp = np.zeros_like(effective_output_scale, dtype=np.int32)
264
+ bias_number_bits = 64 if input_tensor_quant_params.num_bits == 16 else 32
265
+ symmetric = True
266
+ quantized_dimension = None if len(effective_output_scale) == 1 else 0
267
+ bias_quant_params = qtyping.UniformQuantParams(
268
+ scale=effective_output_scale,
269
+ zero_point=bias_zp,
270
+ num_bits=bias_number_bits,
271
+ symmetric=symmetric,
272
+ quantized_dimension=quantized_dimension,
273
+ )
274
+
275
+ quantized_vars = uniform_quantize(bias_content, bias_quant_params)
276
+
277
+ # UniformQuantParams is frozen dataclass, need to recreate.
278
+ return qtyping.UniformQuantParams(
279
+ scale=effective_output_scale,
280
+ zero_point=bias_zp,
281
+ num_bits=bias_number_bits,
282
+ quantized_dimension=quantized_dimension,
283
+ symmetric=symmetric,
284
+ quantized_data=quantized_vars,
285
+ )
286
+
287
+
288
+ def tensor_zp_scale_from_min_max(
289
+ min_value, max_value, num_bits: int, symmetric: bool
290
+ ):
291
+ """Get zero point and scale from min and max value.
292
+
293
+ Args:
294
+ min_value: The minimum value of the tensor (channel-wise supported).
295
+ max_value: The maximum value of the tensor (channel-wise supported).
296
+ num_bits: The number of bits of the tensor.
297
+ symmetric: Whether the tensor is symmetric.
298
+
299
+ Returns:
300
+ The zero point and scale of the tensor.
301
+ """
302
+ # TODO: b/332574603 - support unsigned data type.
303
+ qtype = IntType(
304
+ num_bits,
305
+ signed=True,
306
+ )
307
+ qmin, qmax = get_quantized_range(qtype)
308
+ min_bound = 1e-4 # 1e-6 precision for int8 and 1e-8 for int16.
309
+
310
+ if symmetric:
311
+ bound = np.maximum(np.abs(min_value), np.abs(max_value))
312
+ bound = np.maximum(bound, min_bound)
313
+ if not qtype.signed:
314
+ half_q = (qmax - 1) / 2
315
+ scale = bound / half_q
316
+ zp = np.ones_like(scale) * (half_q + 1)
317
+ else:
318
+ scale = bound / qmax
319
+ zp = np.zeros_like(scale, dtype=np.int32)
320
+
321
+ else:
322
+ # Include 0 to the range to support zero-padding.
323
+ # See: https://arxiv.org/pdf/1712.05877.pdf
324
+ # This ensures bound_min <= 0 <= bound_max.
325
+ bound_max = np.maximum(max_value, np.zeros_like(max_value))
326
+ bound_min = np.minimum(min_value, np.zeros_like(min_value))
327
+ bound = np.maximum(bound_max - bound_min, min_bound)
328
+ scale = bound / (qmax - qmin)
329
+ zp = qmin - bound_min / scale
330
+ zp = np.rint(zp)
331
+
332
+ # It's safe to cast zp to qtype without clipping because we can infer
333
+ # qmin <= zp <= qmax from bound_min <= 0 <= bound_max.
334
+ zp = assign_quantized_type(zp, qtype)
335
+ return zp, scale
336
+
337
+
338
+ def _is_valid_quantization_params(
339
+ tensor_data: np.ndarray,
340
+ quantization_params: qtyping.UniformQuantParams,
341
+ ) -> None:
342
+ """Checks if the quantization parameters are valid.
343
+
344
+ A valid quantization params requires:
345
+ 1. scale and zero point have the same shape (TFL Runtime requirement).
346
+ 2. scale and zero point have the same rank as the tensor content (avoid
347
+ ambiguous broadcasting).
348
+
349
+ Args:
350
+ tensor_data: The tensor to be quantized.
351
+ quantization_params: The quantization parameters.
352
+
353
+ Returns:
354
+ True if the quantization parameters are valid.
355
+ """
356
+ if quantization_params.scale.shape != quantization_params.zero_point.shape:
357
+ raise ValueError(
358
+ "scale and zero_point must have the same shape. Got"
359
+ f" {quantization_params.scale.shape} and"
360
+ f" {quantization_params.zero_point.shape}"
361
+ )
362
+
363
+ tensor_rank = tensor_data.ndim
364
+ scale_rank = quantization_params.scale.ndim
365
+ zero_point_rank = quantization_params.zero_point.ndim
366
+ if (tensor_rank != scale_rank) or (tensor_rank != zero_point_rank):
367
+ raise ValueError(
368
+ f"Ranks of scales ({scale_rank}) and zps"
369
+ f" ({zero_point_rank}) must be the same as the tensor rank"
370
+ f" ({tensor_rank})."
371
+ )