ai-edge-quantizer-nightly 0.0.1.dev20250302__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. ai_edge_quantizer/algorithm_manager.py +224 -0
  2. ai_edge_quantizer/algorithm_manager_api_test.py +7 -0
  3. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +643 -20
  5. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  6. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +29 -35
  7. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py +35 -12
  8. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  11. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  12. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +54 -168
  13. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +54 -17
  14. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +188 -0
  15. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +240 -0
  16. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +260 -13
  17. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +152 -5
  18. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -54
  19. ai_edge_quantizer/calibrator.py +58 -94
  20. ai_edge_quantizer/calibrator_test.py +5 -74
  21. ai_edge_quantizer/default_policy.py +108 -16
  22. ai_edge_quantizer/model_modifier.py +132 -8
  23. ai_edge_quantizer/model_modifier_test.py +81 -1
  24. ai_edge_quantizer/model_validator.py +38 -10
  25. ai_edge_quantizer/model_validator_test.py +2 -1
  26. ai_edge_quantizer/params_generator.py +230 -47
  27. ai_edge_quantizer/params_generator_test.py +366 -261
  28. ai_edge_quantizer/qtyping.py +92 -6
  29. ai_edge_quantizer/quantizer.py +167 -23
  30. ai_edge_quantizer/quantizer_test.py +288 -26
  31. ai_edge_quantizer/recipe.py +156 -21
  32. ai_edge_quantizer/recipe_manager.py +158 -1
  33. ai_edge_quantizer/recipe_manager_test.py +146 -32
  34. ai_edge_quantizer/recipe_test.py +93 -17
  35. ai_edge_quantizer/transformation_instruction_generator.py +313 -46
  36. ai_edge_quantizer/transformation_instruction_generator_test.py +449 -27
  37. ai_edge_quantizer/transformation_performer.py +112 -58
  38. ai_edge_quantizer/transformation_performer_test.py +176 -4
  39. ai_edge_quantizer/transformations/duplicate_buffer.py +46 -0
  40. ai_edge_quantizer/transformations/duplicate_buffer_test.py +106 -0
  41. ai_edge_quantizer/transformations/duplicate_tensor.py +62 -0
  42. ai_edge_quantizer/transformations/duplicate_tensor_test.py +131 -0
  43. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  44. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  45. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  46. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  47. ai_edge_quantizer/transformations/quantize_tensor.py +24 -44
  48. ai_edge_quantizer/transformations/quantize_tensor_test.py +3 -2
  49. ai_edge_quantizer/transformations/transformation_utils.py +157 -11
  50. ai_edge_quantizer/transformations/transformation_utils_test.py +96 -2
  51. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  52. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  53. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  54. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  55. ai_edge_quantizer/utils/test_utils.py +191 -58
  56. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +96 -50
  57. ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py +20 -0
  58. ai_edge_quantizer/utils/tfl_interpreter_utils.py +138 -5
  59. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  60. ai_edge_quantizer/utils/validation_utils.py +114 -4
  61. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  62. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +13 -3
  63. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  64. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  65. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  66. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  67. ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info/RECORD +0 -67
  68. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  69. {ai_edge_quantizer_nightly-0.0.1.dev20250302.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -31,8 +31,7 @@ _TensorQuantConfig = qtyping.TensorQuantizationConfig
31
31
 
32
32
 
33
33
  class CommonQuantizeTest(parameterized.TestCase):
34
- """Tests for general quantize functions.
35
- """
34
+ """Tests for general quantize functions."""
36
35
 
37
36
  def setUp(self):
38
37
  super().setUp()
@@ -69,6 +68,34 @@ class CommonQuantizeTest(parameterized.TestCase):
69
68
  default_policy.DEFAULT_CONFIG_CHECK_POLICY,
70
69
  )
71
70
 
71
+ def test_reshape_data_for_blockwise_raises_error_when_quantized_dim_not_divisible_by_block_size(
72
+ self,
73
+ ):
74
+ tensor_data = np.ones((24, 128), dtype=np.float32)
75
+ block_size = 256
76
+ quantized_dim = 1
77
+ with self.assertRaisesWithPredicateMatch(
78
+ ValueError,
79
+ lambda err: (
80
+ "Tensor quantization dimension must be divisible by block"
81
+ " size for blockwise quantization."
82
+ )
83
+ in str(err),
84
+ ):
85
+ common_quantize._reshape_data_for_blockwise(
86
+ tensor_data, quantized_dim, block_size
87
+ )
88
+
89
+ def test_reshape_data_for_blockwise_returns_correct_values(self):
90
+ tensor_data = np.ones((24, 128), dtype=np.float32)
91
+ block_size = 32
92
+ quantized_dim = 1
93
+ new_tensor_data, reduce_dim = common_quantize._reshape_data_for_blockwise(
94
+ tensor_data, quantized_dim, block_size
95
+ )
96
+ self.assertEqual(new_tensor_data.shape, (24, 4, 32))
97
+ self.assertEqual(reduce_dim, 2)
98
+
72
99
 
73
100
  if __name__ == "__main__":
74
101
  googletest.main()
@@ -70,17 +70,17 @@ def _get_scale(arr: np.ndarray, min_scale: float) -> float:
70
70
  return min_scale
71
71
 
72
72
 
73
- def get_zp_scale_from_2d_dequantized_symmetric_weights(
73
+ def get_zp_scale_from_dequantized_symmetric_weights(
74
74
  dequant_vals: np.ndarray,
75
75
  quantized_dimension: Optional[int] = None,
76
76
  min_scale: float = 1e-9,
77
77
  ) -> tuple[np.ndarray, np.ndarray]:
78
- """Calculates scale and zero point from 2D dequantized, symmetric weights.
78
+ """Calculates scale and zero point from dequantized and symmetric weights.
79
79
 
80
80
  Handles both per-tensor and per-channel (axis) quantization.
81
81
 
82
82
  Args:
83
- dequant_vals: The 2D dequantized weight values (numpy array).
83
+ dequant_vals: The dequantized weight values (numpy array).
84
84
  quantized_dimension: The dimension along which quantization was performed
85
85
  (0 or 1), or None for per-tensor quantization.
86
86
  min_scale: The minimum allowed scale value.
@@ -91,15 +91,9 @@ def get_zp_scale_from_2d_dequantized_symmetric_weights(
91
91
  - scales: Scales (scalar for per-tensor, array for per-channel).
92
92
 
93
93
  Raises:
94
- ValueError: If `dequant_vals` is not 2D, or if
95
- `quantized_dimension` is not 0, 1, or None.
94
+ ValueError: If `quantized_dimension` is not 0, 1, or None.
96
95
  """
97
96
 
98
- if dequant_vals.ndim != 2:
99
- raise ValueError(
100
- f"Only 2D weights are supported. Got {dequant_vals.ndim} dimensions."
101
- )
102
-
103
97
  if quantized_dimension not in (0, 1, None):
104
98
  raise ValueError(
105
99
  f"quantized_dimension must be 0, 1, or None. Got {quantized_dimension}"
@@ -112,23 +106,26 @@ def get_zp_scale_from_2d_dequantized_symmetric_weights(
112
106
  # Per-tensor quantization: One scale for the entire tensor.
113
107
  scales = _get_scale(dequant_vals.flatten(), min_scale)
114
108
  scales = np.array([[scales]])
115
-
116
109
  else:
117
110
  # Per-channel quantization: A scale for each slice along the dimension.
118
- scales = []
119
- for i in range(dequant_vals.shape[quantized_dimension]):
120
- if quantized_dimension == 0:
121
- vec = dequant_vals[i, :]
122
- else: # quantized_dimension == 1
123
- vec = dequant_vals[:, i]
124
- scales.append(_get_scale(vec, min_scale))
125
-
126
- # Reshape for correct broadcasting.
127
- scales = (
128
- np.array(scales).reshape(-1, 1)
129
- if quantized_dimension == 0
130
- else np.array(scales).reshape(1, -1)
111
+ # Create a broadcasted array for per-channel scales. It should have the same
112
+ # number of dimensions as the input, with 1 in all dimensions except for the
113
+ # quantized dimension, which retains its original size.
114
+ scales = np.empty(
115
+ tuple(
116
+ [
117
+ 1
118
+ if i != quantized_dimension
119
+ else dequant_vals.shape[quantized_dimension]
120
+ for i in range(dequant_vals.ndim)
121
+ ]
122
+ )
131
123
  )
124
+ for i in range(dequant_vals.shape[quantized_dimension]):
125
+ slices = [slice(None)] * dequant_vals.ndim
126
+ slices[quantized_dimension] = i
127
+ vec = dequant_vals[tuple(slices)]
128
+ scales[tuple(slices)] = _get_scale(vec, min_scale)
132
129
 
133
130
  zero_points = np.zeros_like(scales, dtype=np.int32)
134
131
  return zero_points, scales
@@ -153,7 +150,7 @@ def get_tensor_quant_params(
153
150
 
154
151
  Raises:
155
152
  ValueError: If the quantization granularity is blockwise, or if the tensor
156
- is not a 2D symmetric weight tensor.
153
+ is not a symmetric weight tensor.
157
154
  """
158
155
  # Fallback to naive_min_max_quantize.py for non-weight tensors.
159
156
  if tensor_content is None:
@@ -161,24 +158,21 @@ def get_tensor_quant_params(
161
158
  op_info, tensor_quant_config, tensor_content, tensor_qsv
162
159
  )
163
160
 
164
- if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
161
+ if uniform_quantize_tensor.is_blockwise(tensor_quant_config.granularity):
165
162
  raise ValueError(
166
163
  "Blockwise quantization is not supported for dequantized weight"
167
164
  " recovery."
168
165
  )
169
- if tensor_content.ndim != 2 or not tensor_quant_config.symmetric:
166
+ if not tensor_quant_config.symmetric:
170
167
  raise ValueError(
171
- "Only 2D symmetric weights are supported for dequantized weight"
172
- " recovery."
168
+ "Only symmetric weights are supported for dequantized weight recovery."
173
169
  )
174
170
 
175
- quantized_dim = None
176
- if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
177
- quantized_dim = common_utils.get_weight_quantized_dim(
178
- op_info, tensor_content
179
- )
171
+ quantized_dim = common_utils.get_weight_quantized_dim(
172
+ op_info, tensor_content, tensor_quant_config.granularity
173
+ )
180
174
 
181
- zp, scale = get_zp_scale_from_2d_dequantized_symmetric_weights(
175
+ zp, scale = get_zp_scale_from_dequantized_symmetric_weights(
182
176
  dequant_vals=tensor_content,
183
177
  quantized_dimension=quantized_dim,
184
178
  )
@@ -62,7 +62,7 @@ class DequantizedWeightRecoveryTest(parameterized.TestCase):
62
62
  ):
63
63
  dequant_vals = scale * self._dummy_quantized_weights
64
64
  zp, recovered_scale = (
65
- dequantized_weight_recovery.get_zp_scale_from_2d_dequantized_symmetric_weights(
65
+ dequantized_weight_recovery.get_zp_scale_from_dequantized_symmetric_weights(
66
66
  dequant_vals, quantized_dimension
67
67
  )
68
68
  )
@@ -72,17 +72,40 @@ class DequantizedWeightRecoveryTest(parameterized.TestCase):
72
72
  self.assertEqual(np.sum(zp), 0)
73
73
  self.assertEqual(zp.shape, scale.shape)
74
74
 
75
- def test_tensor_zp_scale_from_2d_dequantized_symmetric_weights_raises_error_for_non_2d_weights(
76
- self,
75
+ @parameterized.named_parameters(
76
+ dict(
77
+ testcase_name="per-tensor-recovery",
78
+ quantized_dimension=None,
79
+ scale=np.array([0.1875]).reshape(1, 1),
80
+ ),
81
+ dict(
82
+ testcase_name="channel0-recovery",
83
+ quantized_dimension=0,
84
+ scale=np.array([0.1875, 1e-4, 12.3]).reshape(3, 1, 1),
85
+ ),
86
+ dict(
87
+ testcase_name="channel1-recovery",
88
+ quantized_dimension=1,
89
+ scale=np.array([0.003, 1.234]).reshape(1, 2, 1),
90
+ ),
91
+ )
92
+ def test_tensor_zp_scale_from_3d_dequantized_symmetric_weights_success(
93
+ self, quantized_dimension, scale
77
94
  ):
78
- weights_3d = self._dummy_quantized_weights.reshape(1, 3, 4)
79
- weights_3d = weights_3d * 1.02
80
- with self.assertRaisesRegex(
81
- ValueError, "Only 2D weights are supported. Got 3 dimensions."
82
- ):
83
- dequantized_weight_recovery.get_zp_scale_from_2d_dequantized_symmetric_weights(
84
- weights_3d, quantized_dimension=None
85
- )
95
+ dequant_vals = scale * self._dummy_quantized_weights.reshape(3, 2, 2)
96
+ zp, recovered_scale = (
97
+ dequantized_weight_recovery.get_zp_scale_from_dequantized_symmetric_weights(
98
+ dequant_vals, quantized_dimension
99
+ )
100
+ )
101
+ with self.subTest("shapes_match"):
102
+ self.assertEqual(recovered_scale.shape, scale.shape)
103
+ self.assertEqual(zp.shape, scale.shape)
104
+ with self.subTest("scale_value_match"):
105
+ self.assertSequenceAlmostEqual(recovered_scale.flatten(), scale.flatten())
106
+ with self.subTest("zp_is_zero"):
107
+ # Zero point should be zero for symmetric quantization.
108
+ self.assertEqual(np.sum(zp), 0)
86
109
 
87
110
  @parameterized.named_parameters(
88
111
  dict(testcase_name="negative_dimension", quantized_dimension=-1),
@@ -95,7 +118,7 @@ class DequantizedWeightRecoveryTest(parameterized.TestCase):
95
118
  with self.assertRaisesRegex(
96
119
  ValueError, "quantized_dimension must be 0, 1, or None. Got"
97
120
  ):
98
- dequantized_weight_recovery.get_zp_scale_from_2d_dequantized_symmetric_weights(
121
+ dequantized_weight_recovery.get_zp_scale_from_dequantized_symmetric_weights(
99
122
  dequant_vals, quantized_dimension
100
123
  )
101
124
 
@@ -0,0 +1,414 @@
1
+ # Copyright 2024 The AI Edge Quantizer Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements the Hadamard Rotation quantization."""
17
+
18
+ from typing import Any, Optional
19
+ import numpy as np
20
+ from ai_edge_quantizer import qtyping
21
+ from ai_edge_quantizer.algorithms.uniform_quantize import octav
22
+ from ai_edge_quantizer.algorithms.utils import common_utils
23
+ from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
+
25
+
26
+ CUSTOM_OP_ALGORITHM_KEY = "HADAMARD_ROTATION"
27
+ DECOMPOSED_ALGORITHM_KEY = "DECOMPOSED_HADAMARD_ROTATION"
28
+
29
+
30
+ def _make_hadamard_matrix(size: int) -> np.ndarray:
31
+ """Generates a Hadamard matrix of the given size.
32
+
33
+ Args:
34
+ size: The size of the Hadamard matrix. Must be a power of 2. This represents
35
+ a single dimension. E.g. if size is 4, then the Hadamard matrix is a 4x4
36
+ matrix.
37
+
38
+ Returns:
39
+ The Hadamard matrix.
40
+
41
+ Raises:
42
+ ValueError: If the size is not a power of 2.
43
+ """
44
+ if size <= 0 or (size & (size - 1)) != 0:
45
+ raise ValueError("Hadamard matrix size must be a power of 2. ")
46
+ h = h2 = np.array([[1, 1], [1, -1]])
47
+ current_size = 2
48
+ while current_size < size:
49
+ h = np.kron(h, h2)
50
+ current_size *= 2
51
+ return h / np.sqrt(size)
52
+
53
+
54
+ def _rotate_with_diagonal_hadamard(
55
+ tensor_content: np.ndarray,
56
+ axis: int,
57
+ ):
58
+ """Quantizes the given float array using the diagonal Hadamard algorithm.
59
+
60
+ Args:
61
+ tensor_content: The float array to quantize.
62
+ axis: The axis of the tensor to rotate.
63
+
64
+ Returns:
65
+ A tuple containing the quantized array and the recovered array.
66
+
67
+ Raises:
68
+ ValueError: If the axis is not the last axis of tensor_content. To support
69
+ other axes, please add support to the matrix multiplication.
70
+ """
71
+ if axis != tensor_content.ndim - 1:
72
+ raise ValueError(
73
+ "Hadamard rotation is only supported for tensors with quantized"
74
+ " dimension 0 (rotate last dimension)."
75
+ )
76
+
77
+ # Use the largest power of 2 that is a factor of the dimension and then
78
+ # tile this Hadamard matrix along the diagonal. 2**30 is just a large power
79
+ # of 2 to calculate this factor.
80
+ hadamard_size = np.gcd(tensor_content.shape[axis], 2 ** 30)
81
+ diagonal_size = tensor_content.shape[axis] // hadamard_size
82
+ # Output size is the product of all dimensions except the one being rotated.
83
+ output_size = np.prod(np.delete(tensor_content.shape, axis))
84
+ random_vector = np.ones(hadamard_size, dtype=np.int8)
85
+
86
+ # Use a canonical Hadamard matrix.
87
+ hadamard = _make_hadamard_matrix(hadamard_size)
88
+ reshaped_tensor = tensor_content.reshape(
89
+ diagonal_size * output_size, hadamard_size)
90
+ w_rotated = np.matmul(hadamard, reshaped_tensor.mT).mT
91
+ return w_rotated.reshape(tensor_content.shape), hadamard_size, random_vector
92
+
93
+
94
+ def get_tensor_quant_params(
95
+ op_info: qtyping.OpInfo,
96
+ tensor_quant_config: qtyping.TensorQuantizationConfig,
97
+ tensor_content: Optional[np.ndarray] = None,
98
+ tensor_qsv: Optional[dict[str, Any]] = None,
99
+ ) -> qtyping.UniformQuantParams:
100
+ """Returns the quantization parameters for a tensor.
101
+
102
+ This function will rotate the tensor with a Hadamard matrix and then
103
+ quantize it with OCTAV.
104
+
105
+ Args:
106
+ op_info: Aggregated information about the op (e.g., quantization config).
107
+ tensor_quant_config: The quantization config for the tensor.
108
+ tensor_content: The content of the tensor. When None, it means the tensor is
109
+ not a weight tensor (e.g. static quantization).
110
+ tensor_qsv: A dictionary containing the min/max of the tensor.
111
+
112
+ Raises:
113
+ ValueError: If the blockwise quantization is requested.
114
+ ValueError: If the asymmetric quantization is requested.
115
+ ValueError: `tensor_qsv` must contain min/max values, or `tensor_content`
116
+ must be provided so that they can be inferred.
117
+ """
118
+ if tensor_content is None:
119
+ raise ValueError("Hadamard rotation is only supported for weight tensors.")
120
+
121
+ if tensor_qsv is not None:
122
+ raise ValueError(
123
+ "Hadamard rotation is not supported for static quantization."
124
+ )
125
+
126
+ if tensor_content.ndim < 2:
127
+ raise ValueError(
128
+ "Hadamard rotation is only supported for tensors with rank >= 2."
129
+ )
130
+
131
+ # Reduction axis is the last non-quantized dimension. Since we only support
132
+ # quantized_dim of 0 (or 1 for blockwise), the reduction axis is the last
133
+ # axis.
134
+ reduce_axis = tensor_content.ndim - 1
135
+
136
+ # Rotate the tensor with a Hadamard matrix.
137
+ w_rotated, hadamard_size, random_vector = _rotate_with_diagonal_hadamard(
138
+ tensor_content, axis=reduce_axis
139
+ )
140
+
141
+ # Get the quantized values of the rotated tensor.
142
+ qparams = octav.get_tensor_quant_params(
143
+ op_info, tensor_quant_config, w_rotated, tensor_qsv
144
+ )
145
+
146
+ return qtyping.UniformQuantParams(
147
+ quantized_dimension=qparams.quantized_dimension,
148
+ num_bits=qparams.num_bits,
149
+ scale=qparams.scale,
150
+ zero_point=qparams.zero_point,
151
+ symmetric=qparams.symmetric,
152
+ quantized_data=qparams.quantized_data,
153
+ block_size=qparams.block_size,
154
+ hadamard=qtyping.UniformQuantParams.HadamardRotationParams(
155
+ random_binary_vector=random_vector,
156
+ hadamard_size=hadamard_size,
157
+ ),
158
+ )
159
+
160
+
161
+ def _materialize_fully_connected(
162
+ op_info: qtyping.OpInfo,
163
+ graph_info: qtyping.GraphInfo,
164
+ is_decomposed: bool = False,
165
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
166
+ ) -> list[qtyping.TensorTransformationParams]:
167
+ """Materialize the fully_connected op.
168
+
169
+ Args:
170
+ op_info: Aggregated information about the op (e.g., quantization config).
171
+ graph_info: Graph information needed to perform quantization for the op.
172
+ is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
173
+ op.
174
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
175
+
176
+ Returns:
177
+ Quantization configuration for the tensors associated with the op (e.g.,
178
+ weights, bias).
179
+ """
180
+ if op_info.op_quant_config.weight_tensor_config is None:
181
+ raise ValueError(
182
+ "Weight tensor quantization config is not provided for Hadamard"
183
+ " Rotation quantization."
184
+ )
185
+
186
+ op_tensor_params = []
187
+
188
+ # Materialize weight.
189
+ weight_tensor_index = 1
190
+ weight_tensor = graph_info.subgraph_tensors[
191
+ op_info.op.inputs[weight_tensor_index]
192
+ ]
193
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
194
+ weight_tensor, graph_info.buffers
195
+ )
196
+ # quant_params contains the rotated and quantized weights done by
197
+ # get_tensor_quant_params().
198
+ quant_params = get_tensor_quant_params(
199
+ op_info,
200
+ op_info.op_quant_config.weight_tensor_config,
201
+ tensor_data,
202
+ None,
203
+ )
204
+ transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
205
+ op2tensor_params = qtyping.OpToTensorParams(
206
+ subgraph_op_id=op_info.subgraph_op_index,
207
+ parameters=quant_params,
208
+ transformations=transformations,
209
+ )
210
+ weight_transformation_params = qtyping.TensorTransformationParams(
211
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(weight_tensor),
212
+ consumers=[op2tensor_params],
213
+ )
214
+
215
+ # Materialize input. A hadamard rotation op should be inserted on the input
216
+ # tensor to do the inverse of the weight's transformation.
217
+ input_tensor_index = 0
218
+ input_tensor = graph_info.subgraph_tensors[
219
+ op_info.op.inputs[input_tensor_index]
220
+ ]
221
+ transformations = [
222
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
223
+ if is_decomposed
224
+ else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
225
+ ]
226
+ op2tensor_params = qtyping.OpToTensorParams(
227
+ subgraph_op_id=op_info.subgraph_op_index,
228
+ parameters=quant_params,
229
+ transformations=transformations,
230
+ )
231
+ input_transformation_params = qtyping.TensorTransformationParams(
232
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(input_tensor),
233
+ consumers=[op2tensor_params],
234
+ )
235
+ op_tensor_params.append(input_transformation_params)
236
+ op_tensor_params.append(weight_transformation_params)
237
+
238
+ # Materialize bias. Since static quantization is not supported, we do not
239
+ # quantize the bias tensor.
240
+ bias_tensor_index = 2
241
+ bias_tensor = graph_info.subgraph_tensors[
242
+ op_info.op.inputs[bias_tensor_index]
243
+ ]
244
+ no_quant_tensor_params = qtyping.OpToTensorParams(
245
+ subgraph_op_id=op_info.subgraph_op_index,
246
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
247
+ )
248
+ bias_transformation_params = qtyping.TensorTransformationParams(
249
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
250
+ consumers=[no_quant_tensor_params],
251
+ )
252
+ op_tensor_params.append(bias_transformation_params)
253
+
254
+ # Materialize output. Since static quantization is not supported, we do not
255
+ # quantize the output tensor.
256
+ output_tensor_index = 0
257
+ output_tensor = graph_info.subgraph_tensors[
258
+ op_info.op.outputs[output_tensor_index]
259
+ ]
260
+ no_quant_tensor_params = qtyping.OpToTensorParams(
261
+ subgraph_op_id=op_info.subgraph_op_index,
262
+ transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
263
+ )
264
+ output_transformation_params = qtyping.TensorTransformationParams(
265
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
266
+ producer=no_quant_tensor_params,
267
+ )
268
+ op_tensor_params.append(output_transformation_params)
269
+
270
+ return op_tensor_params
271
+
272
+
273
+ def materialize_fully_connected_custom_op(
274
+ op_info: qtyping.OpInfo,
275
+ graph_info: qtyping.GraphInfo,
276
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
277
+ ) -> list[qtyping.TensorTransformationParams]:
278
+ return _materialize_fully_connected(
279
+ op_info,
280
+ graph_info,
281
+ is_decomposed=False,
282
+ tensor_name_to_qsv=tensor_name_to_qsv,
283
+ )
284
+
285
+
286
+ def materialize_fully_connected_decomposed(
287
+ op_info: qtyping.OpInfo,
288
+ graph_info: qtyping.GraphInfo,
289
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
290
+ ) -> list[qtyping.TensorTransformationParams]:
291
+ return _materialize_fully_connected(
292
+ op_info,
293
+ graph_info,
294
+ is_decomposed=True,
295
+ tensor_name_to_qsv=tensor_name_to_qsv,
296
+ )
297
+
298
+
299
+ def _materialize_embedding_lookup(
300
+ op_info: qtyping.OpInfo,
301
+ graph_info: qtyping.GraphInfo,
302
+ is_decomposed: bool = False,
303
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
304
+ ) -> list[qtyping.TensorTransformationParams]:
305
+ """Materialize the embedding_lookup op.
306
+
307
+ Args:
308
+ op_info: Aggregated information about the op (e.g., quantization config).
309
+ graph_info: Graph information needed to perform quantization for the op.
310
+ is_decomposed: Whether to use decomposed Hadamard rotation ops or a custom
311
+ op.
312
+ tensor_name_to_qsv: A map of tensor name to quantization parameters.
313
+
314
+ Returns:
315
+ Quantization configuration for the tensors associated with the op (e.g.,
316
+ weights, bias).
317
+ """
318
+ op_tensor_params = []
319
+
320
+ # Materialize lookup.
321
+ lookup_tensor_index = 0
322
+ lookup_tensor = graph_info.subgraph_tensors[
323
+ op_info.op.inputs[lookup_tensor_index]
324
+ ]
325
+ transformations = [
326
+ qtyping.QuantTransformation.NO_QUANTIZE,
327
+ ]
328
+ op2tensor_params = qtyping.OpToTensorParams(
329
+ subgraph_op_id=op_info.subgraph_op_index,
330
+ parameters=None,
331
+ transformations=transformations,
332
+ )
333
+ lookup_transformation_params = qtyping.TensorTransformationParams(
334
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(lookup_tensor),
335
+ consumers=[op2tensor_params],
336
+ )
337
+ op_tensor_params.append(lookup_transformation_params)
338
+
339
+ # Materialize embedding. The embedding table should be rotated and then
340
+ # quantized.
341
+ embedding_tensor_index = 1
342
+ embedding_tensor = graph_info.subgraph_tensors[
343
+ op_info.op.inputs[embedding_tensor_index]
344
+ ]
345
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
346
+ embedding_tensor, graph_info.buffers
347
+ )
348
+ quant_params = get_tensor_quant_params(
349
+ op_info,
350
+ op_info.op_quant_config.weight_tensor_config,
351
+ tensor_data,
352
+ None,
353
+ )
354
+ transformations = [qtyping.QuantTransformation.QUANTIZE_TENSOR]
355
+ op2tensor_params = qtyping.OpToTensorParams(
356
+ subgraph_op_id=op_info.subgraph_op_index,
357
+ parameters=quant_params,
358
+ transformations=transformations,
359
+ )
360
+ weight_transformation_params = qtyping.TensorTransformationParams(
361
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(embedding_tensor),
362
+ consumers=[op2tensor_params],
363
+ )
364
+ op_tensor_params.append(weight_transformation_params)
365
+
366
+ # Materialize output. A hadamard rotation op should be inserted on the output
367
+ # tensor to do the inverse of the embedding's transformation.
368
+ output_tensor_index = 0
369
+ output_tensor = graph_info.subgraph_tensors[
370
+ op_info.op.outputs[output_tensor_index]
371
+ ]
372
+ transformations = [
373
+ qtyping.QuantTransformation.INSERT_DECOMPOSED_HADAMARD_ROTATION
374
+ if is_decomposed
375
+ else qtyping.QuantTransformation.INSERT_HADAMARD_ROTATION,
376
+ ]
377
+ op2tensor_params = qtyping.OpToTensorParams(
378
+ subgraph_op_id=op_info.subgraph_op_index,
379
+ parameters=quant_params,
380
+ transformations=transformations,
381
+ )
382
+ output_transformation_params = qtyping.TensorTransformationParams(
383
+ tensor_name=tfl_flatbuffer_utils.get_tensor_name(output_tensor),
384
+ producer=op2tensor_params,
385
+ )
386
+ op_tensor_params.append(output_transformation_params)
387
+
388
+ return op_tensor_params
389
+
390
+
391
+ def materialize_embedding_lookup_custom_op(
392
+ op_info: qtyping.OpInfo,
393
+ graph_info: qtyping.GraphInfo,
394
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
395
+ ) -> list[qtyping.TensorTransformationParams]:
396
+ return _materialize_embedding_lookup(
397
+ op_info,
398
+ graph_info,
399
+ is_decomposed=False,
400
+ tensor_name_to_qsv=tensor_name_to_qsv,
401
+ )
402
+
403
+
404
+ def materialize_embedding_lookup_decomposed(
405
+ op_info: qtyping.OpInfo,
406
+ graph_info: qtyping.GraphInfo,
407
+ tensor_name_to_qsv: Optional[dict[str, Any]] = None, # pylint: disable=unused-argument
408
+ ) -> list[qtyping.TensorTransformationParams]:
409
+ return _materialize_embedding_lookup(
410
+ op_info,
411
+ graph_info,
412
+ is_decomposed=True,
413
+ tensor_name_to_qsv=tensor_name_to_qsv,
414
+ )