ai-edge-quantizer-nightly 0.3.0.dev20250609__py3-none-any.whl → 0.3.0.dev20250611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -108,6 +108,10 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
108
108
  ),
109
109
  _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
110
110
  _TFLOpName.PAD: common_quantize.materialize_pad,
111
+ _TFLOpName.SQUARED_DIFFERENCE: (
112
+ common_quantize.materialize_squared_difference
113
+ ),
114
+ _TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
111
115
  }
112
116
  for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
113
117
  register_quantized_op(
@@ -242,6 +246,10 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
242
246
  ),
243
247
  _TFLOpName.STABLEHLO_COMPOSITE: common_quantize.materialize_composite,
244
248
  _TFLOpName.PAD: common_quantize.materialize_pad,
249
+ _TFLOpName.SQUARED_DIFFERENCE: (
250
+ common_quantize.materialize_squared_difference
251
+ ),
252
+ _TFLOpName.MAX_POOL_2D: common_quantize.materialize_max_pool_2d,
245
253
  })
246
254
 
247
255
  for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
@@ -697,6 +697,37 @@ def materialize_pad(
697
697
  )
698
698
 
699
699
 
700
+ def materialize_squared_difference(
701
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
702
+ op_info: qtyping.OpInfo,
703
+ graph_info: qtyping.GraphInfo,
704
+ tensor_name_to_qsv: dict[str, Any],
705
+ ) -> list[qtyping.TensorTransformationParams]:
706
+ """Materialize tensors in tfl.squared_difference."""
707
+ return common_utils.materialize_standard_op(
708
+ op_info,
709
+ graph_info,
710
+ tensor_name_to_qsv,
711
+ get_tensor_quant_params_fn,
712
+ )
713
+
714
+
715
+ def materialize_max_pool_2d(
716
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
717
+ op_info: qtyping.OpInfo,
718
+ graph_info: qtyping.GraphInfo,
719
+ tensor_name_to_qsv: dict[str, Any],
720
+ ) -> list[qtyping.TensorTransformationParams]:
721
+ """Materialize tensors in tfl.max_pool_2d."""
722
+ return common_utils.materialize_standard_op(
723
+ op_info,
724
+ graph_info,
725
+ tensor_name_to_qsv,
726
+ get_tensor_quant_params_fn,
727
+ constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
728
+ )
729
+
730
+
700
731
  def _get_tensor_shape_for_blockwise(
701
732
  tensor_shape: Sequence[int], quantized_dim: int, block_size: int
702
733
  ) -> list[int]:
@@ -184,7 +184,8 @@ DEFAULT_JSON_POLICY = """
184
184
  "DYNAMIC_UPDATE_SLICE",
185
185
  "SELECT_V2",
186
186
  "STABLEHLO_COMPOSITE",
187
- "PAD"
187
+ "PAD",
188
+ "MAX_POOL_2D"
188
189
  ],
189
190
  "static_wi8_ai8": [
190
191
  "ADD",
@@ -216,7 +217,9 @@ DEFAULT_JSON_POLICY = """
216
217
  "DYNAMIC_UPDATE_SLICE",
217
218
  "SELECT_V2",
218
219
  "STABLEHLO_COMPOSITE",
219
- "PAD"
220
+ "PAD",
221
+ "SQUARED_DIFFERENCE",
222
+ "MAX_POOL_2D"
220
223
  ],
221
224
  "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
222
225
  "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
@@ -63,6 +63,8 @@ class TFLOperationName(str, enum.Enum):
63
63
  DYNAMIC_UPDATE_SLICE = 'DYNAMIC_UPDATE_SLICE'
64
64
  STABLEHLO_COMPOSITE = 'STABLEHLO_COMPOSITE'
65
65
  PAD = 'PAD'
66
+ SQUARED_DIFFERENCE = 'SQUARED_DIFFERENCE'
67
+ MAX_POOL_2D = 'MAX_POOL_2D'
66
68
 
67
69
 
68
70
  class QuantizeMode(enum.Enum):
@@ -18,6 +18,7 @@
18
18
  import inspect as _inspect
19
19
  import os.path as _os_path
20
20
  import sys as _sys
21
+ from typing import Union
21
22
 
22
23
  from absl.testing import parameterized
23
24
 
@@ -97,6 +98,7 @@ class BaseOpTestCase(parameterized.TestCase):
97
98
  op_name: _OpName,
98
99
  op_config: _OpQuantConfig,
99
100
  num_validation_samples: int = 4,
101
+ num_calibration_samples: Union[int, None] = None,
100
102
  error_metric: str = 'mse',
101
103
  ) -> model_validator.ComparisonResult:
102
104
  """Quantizes and validates the given model with the given configurations.
@@ -107,6 +109,8 @@ class BaseOpTestCase(parameterized.TestCase):
107
109
  op_name: The name of the operation to be quantized.
108
110
  op_config: The configuration for the operation to be quantized.
109
111
  num_validation_samples: The number of samples to use for validation.
112
+ num_calibration_samples: The number of samples to use for calibration. If
113
+ None then it will be set to num_validation_samples * 8.
110
114
  error_metric: The error error_metric to use for validation.
111
115
 
112
116
  Returns:
@@ -120,8 +124,11 @@ class BaseOpTestCase(parameterized.TestCase):
120
124
  op_config=op_config,
121
125
  )
122
126
  if quantizer_instance.need_calibration:
127
+ if num_calibration_samples is None:
128
+ num_calibration_samples = num_validation_samples * 8
123
129
  calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
124
- quantizer_instance.float_model, num_samples=num_validation_samples * 8
130
+ quantizer_instance.float_model,
131
+ num_samples=num_calibration_samples,
125
132
  )
126
133
  calibration_result = quantizer_instance.calibrate(calibration_data)
127
134
  quantization_result = quantizer_instance.quantize(calibration_result)
@@ -198,3 +205,26 @@ class BaseOpTestCase(parameterized.TestCase):
198
205
  self.assert_output_errors_below_tolerance(
199
206
  validation_result, output_tolerance
200
207
  )
208
+
209
+ def assert_quantization_accuracy(
210
+ self,
211
+ algorithm_key: _AlgorithmName,
212
+ model_path: str,
213
+ op_name: _OpName,
214
+ op_config: _OpQuantConfig,
215
+ num_validation_samples: int = 4,
216
+ num_calibration_samples: Union[int, None] = None,
217
+ output_tolerance: float = 1e-4,
218
+ ):
219
+ """Check if the output errors after quantization are within the tolerance."""
220
+ validation_result = self.quantize_and_validate(
221
+ model_path=model_path,
222
+ algorithm_key=algorithm_key,
223
+ num_validation_samples=num_validation_samples,
224
+ num_calibration_samples=num_calibration_samples,
225
+ op_name=op_name,
226
+ op_config=op_config,
227
+ )
228
+ self.assert_output_errors_below_tolerance(
229
+ validation_result, output_tolerance
230
+ )
@@ -57,6 +57,8 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
57
57
  schema.BuiltinOperator.DYNAMIC_UPDATE_SLICE
58
58
  ),
59
59
  _TFLOpName.PAD: schema.BuiltinOperator.PAD,
60
+ _TFLOpName.SQUARED_DIFFERENCE: schema.BuiltinOperator.SQUARED_DIFFERENCE,
61
+ _TFLOpName.MAX_POOL_2D: schema.BuiltinOperator.MAX_POOL_2D,
60
62
  })
61
63
 
62
64
  TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.3.0.dev20250609
3
+ Version: 0.3.0.dev20250611
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -1,18 +1,18 @@
1
1
  ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas,793
2
- ai_edge_quantizer/algorithm_manager.py,sha256=p-wX2ksIV1hbWEQz-uUnbNMVgDJrsIiIOU2ZYX2ZrTM,11726
2
+ ai_edge_quantizer/algorithm_manager.py,sha256=lfCazb2b0Q4L3of0cTWkF5lMr3AD6LWW1ekmFoEGB_4,12062
3
3
  ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
4
4
  ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
5
5
  ai_edge_quantizer/calibrator.py,sha256=-_jX_KkfIepkQAwxxDrZjvPO1JsoSjHXVy1DPc1iFjM,12068
6
6
  ai_edge_quantizer/calibrator_test.py,sha256=C_oWOaRugPKYX74jF-eRFH-k6nGOdA8I9_uPiocaOuE,11900
7
7
  ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
8
- ai_edge_quantizer/default_policy.py,sha256=zNTeiI_eP5-dLL3P_VWIQB3RzXBrb06peJKngLnSSFY,11125
8
+ ai_edge_quantizer/default_policy.py,sha256=nKtghUjTQ8QS9CgLRwQb3iB2eZOyQv0FqyISlcgzSH4,11195
9
9
  ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
10
10
  ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
11
11
  ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
12
12
  ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
13
13
  ai_edge_quantizer/params_generator.py,sha256=j1BV2cGFLlQmUY6aoW5uglYqf77b9ytN8oZ1gh6o0mM,20096
14
14
  ai_edge_quantizer/params_generator_test.py,sha256=RDYoRZDJfEZRtjlTAU2kZ_4t3JHOqEHxfJX9V4ETAhg,40597
15
- ai_edge_quantizer/qtyping.py,sha256=LKn9w53wmw3gPO0E4DKOhj8gkx9efjXMoipGnsJyGiU,16453
15
+ ai_edge_quantizer/qtyping.py,sha256=0Dwz6LHQG8LhZMhVAo_h6ieZ_gcfkJl2yJcsGf17YYs,16527
16
16
  ai_edge_quantizer/quantizer.py,sha256=g3DMqFMrMpt9jQttCE0WcdNbMtk0JZnmN5MmCHrNdyM,13202
17
17
  ai_edge_quantizer/quantizer_test.py,sha256=K_HBA56JkFI3HL8VLWCqGEfC0ISh5ldMKoNyBdGRAJg,20368
18
18
  ai_edge_quantizer/recipe.py,sha256=FR0uJceumZrnle2VRSOQZ1uXup4S1cTYKRH-N53mWRo,2919
@@ -28,7 +28,7 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCP
28
28
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
29
29
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
30
30
  ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
31
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=NpZ-JvZt2OhpTqH7Z81YYVjzOX_pHoDCt8rr3VIXJUY,28665
31
+ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=ofDoiZhOKjF7Tm-v0a4xsLSvytjfvMALXLDcuwcKNK0,29634
32
32
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
33
33
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=BDdn_uBZakfHyzdMJPKadsOqxqyC-s6W2ZzFH99L4fE,8652
34
34
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
@@ -63,15 +63,15 @@ ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rv
63
63
  ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
64
64
  ai_edge_quantizer/utils/calibration_utils.py,sha256=1Fj9MIO6aLZIRgyd4axvZN4S_O64nB_-Miu1WP664js,2536
65
65
  ai_edge_quantizer/utils/calibration_utils_test.py,sha256=Z-AcdTieesWFKyKBb08ZXm4Mgu6cvJ4bg2-MJ7hLD10,2856
66
- ai_edge_quantizer/utils/test_utils.py,sha256=fXwQ353P7tSy7W4Hs6YskIbCLLaBYGA724hMMbcqCUk,7129
67
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=zNlR_SJAkDi-EX63O3pNpFLVqSktysScZKgKk1XT3c8,10616
66
+ ai_edge_quantizer/utils/test_utils.py,sha256=Y2pdMvn1k4gmqDo3noJfzx3fJcDHX_1hcsP6oiIz65Y,8240
67
+ ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=Yy1u53FzRBFx-fr1TqoycWMZwAlAl0b2IB4MmGV1xJA,10758
68
68
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
69
69
  ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EtOv6cpKM_F0uv2bWuSXylYmTeXT6zUc182pw4sdYSI,13889
70
70
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
71
71
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
72
72
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
73
- ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
- ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info/METADATA,sha256=Gri7IYe99Je5aFDyvlEUJZxYhVdVfeIh2QBwZvu2f_0,1528
75
- ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
- ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
- ai_edge_quantizer_nightly-0.3.0.dev20250609.dist-info/RECORD,,
73
+ ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
74
+ ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/METADATA,sha256=FPK-WqVTMEz-w5yycBejT4oRBxMY4fiYH-AAL6Pf4-w,1528
75
+ ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
76
+ ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
77
+ ai_edge_quantizer_nightly-0.3.0.dev20250611.dist-info/RECORD,,