ai-edge-quantizer-nightly 0.3.0.dev20250725__py3-none-any.whl → 0.3.0.dev20250726__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -118,6 +118,7 @@ MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT = {
118
118
  _TFLOpName.UNPACK: common_quantize.materialize_unpack,
119
119
  _TFLOpName.DIV: common_quantize.materialize_div,
120
120
  _TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
121
+ _TFLOpName.SQRT: common_quantize.materialize_sqrt,
121
122
  }
122
123
  for op_name, materialize_func in MIN_MAX_OP_NAME_MATERIALIZE_FUNC_DICT.items():
123
124
  register_quantized_op(
@@ -262,6 +263,7 @@ _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT = immutabledict({
262
263
  _TFLOpName.UNPACK: common_quantize.materialize_unpack,
263
264
  _TFLOpName.DIV: common_quantize.materialize_div,
264
265
  _TFLOpName.BROADCAST_TO: common_quantize.materialize_broadcast_to,
266
+ _TFLOpName.SQRT: common_quantize.materialize_sqrt,
265
267
  })
266
268
 
267
269
  for op_name, materialize_func in _OCTAV_OP_NAME_MATERIALIZE_FUNC_DICT.items():
@@ -826,6 +826,21 @@ def materialize_broadcast_to(
826
826
  )
827
827
 
828
828
 
829
+ def materialize_sqrt(
830
+ get_tensor_quant_params_fn: qtyping.GetTensorQuantParamsFuncSignature,
831
+ op_info: qtyping.OpInfo,
832
+ graph_info: qtyping.GraphInfo,
833
+ tensor_name_to_qsv: dict[str, Any],
834
+ ) -> list[qtyping.TensorTransformationParams]:
835
+ """Materialize tensors in tfl.sqrt."""
836
+ return common_utils.materialize_standard_op(
837
+ op_info,
838
+ graph_info,
839
+ tensor_name_to_qsv,
840
+ get_tensor_quant_params_fn,
841
+ )
842
+
843
+
829
844
  def _get_tensor_shape_for_blockwise(
830
845
  tensor_shape: Sequence[int], quantized_dim: int, block_size: int
831
846
  ) -> list[int]:
@@ -191,7 +191,8 @@ DEFAULT_JSON_POLICY = """
191
191
  "PACK",
192
192
  "UNPACK",
193
193
  "DIV",
194
- "BROADCAST_TO"
194
+ "BROADCAST_TO",
195
+ "SQRT"
195
196
  ],
196
197
  "static_wi8_ai8": [
197
198
  "ADD",
@@ -231,7 +232,8 @@ DEFAULT_JSON_POLICY = """
231
232
  "PACK",
232
233
  "UNPACK",
233
234
  "DIV",
234
- "BROADCAST_TO"
235
+ "BROADCAST_TO",
236
+ "SQRT"
235
237
  ],
236
238
  "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
237
239
  "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT", "EMBEDDING_LOOKUP"],
@@ -71,6 +71,7 @@ class TFLOperationName(str, enum.Enum):
71
71
  UNPACK = 'UNPACK'
72
72
  DIV = 'DIV'
73
73
  BROADCAST_TO = 'BROADCAST_TO'
74
+ SQRT = 'SQRT'
74
75
 
75
76
 
76
77
  class QuantizeMode(enum.Enum):
@@ -18,7 +18,7 @@
18
18
  import inspect as _inspect
19
19
  import os.path as _os_path
20
20
  import sys as _sys
21
- from typing import Union
21
+ from typing import Optional, Union
22
22
 
23
23
  from absl.testing import parameterized
24
24
 
@@ -32,6 +32,7 @@ _OpName = qtyping.TFLOperationName
32
32
  _TensorQuantConfig = qtyping.TensorQuantizationConfig
33
33
  _OpQuantConfig = qtyping.OpQuantizationConfig
34
34
  _AlgorithmName = quantizer.AlgorithmName
35
+ _Numeric = Union[int, float]
35
36
 
36
37
 
37
38
  DEFAULT_ACTIVATION_QUANT_SETTING = _TensorQuantConfig(
@@ -98,9 +99,9 @@ class BaseOpTestCase(parameterized.TestCase):
98
99
  op_name: _OpName,
99
100
  op_config: _OpQuantConfig,
100
101
  num_validation_samples: int = 4,
101
- num_calibration_samples: Union[int, None] = None,
102
+ num_calibration_samples: Optional[int] = None,
102
103
  error_metric: str = 'mse',
103
- int_min_max: Union[tuple[int, int], None] = None,
104
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
104
105
  ) -> model_validator.ComparisonResult:
105
106
  """Quantizes and validates the given model with the given configurations.
106
107
 
@@ -113,7 +114,7 @@ class BaseOpTestCase(parameterized.TestCase):
113
114
  num_calibration_samples: The number of samples to use for calibration. If
114
115
  None then it will be set to num_validation_samples * 8.
115
116
  error_metric: The error error_metric to use for validation.
116
- int_min_max: The min and max of the integer input range.
117
+ min_max_range: The min and max of the input range.
117
118
 
118
119
  Returns:
119
120
  The comparison result of the validation.
@@ -131,7 +132,7 @@ class BaseOpTestCase(parameterized.TestCase):
131
132
  calibration_data = tfl_interpreter_utils.create_random_normal_input_data(
132
133
  quantizer_instance.float_model,
133
134
  num_samples=num_calibration_samples,
134
- int_min_max=int_min_max,
135
+ min_max_range=min_max_range,
135
136
  )
136
137
  calibration_result = quantizer_instance.calibrate(calibration_data)
137
138
  quantization_result = quantizer_instance.quantize(calibration_result)
@@ -140,7 +141,7 @@ class BaseOpTestCase(parameterized.TestCase):
140
141
  test_data = tfl_interpreter_utils.create_random_normal_input_data(
141
142
  quantization_result.quantized_model,
142
143
  num_samples=num_validation_samples,
143
- int_min_max=int_min_max,
144
+ min_max_range=min_max_range,
144
145
  )
145
146
  return quantizer_instance.validate(test_data, error_metric)
146
147
 
@@ -190,7 +191,7 @@ class BaseOpTestCase(parameterized.TestCase):
190
191
  expected_model_size_reduction: float,
191
192
  weight_tolerance: float = 1e-4,
192
193
  output_tolerance: float = 1e-4,
193
- int_min_max: Union[tuple[int, int], None] = None,
194
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
194
195
  ):
195
196
  """Check if the quantization is successful and the result is valid."""
196
197
  validation_result = self.quantize_and_validate(
@@ -198,7 +199,7 @@ class BaseOpTestCase(parameterized.TestCase):
198
199
  algorithm_key=algorithm_key,
199
200
  op_name=op_name,
200
201
  op_config=op_config,
201
- int_min_max=int_min_max,
202
+ min_max_range=min_max_range,
202
203
  )
203
204
  with self.subTest(name='ModelSizeReduction'):
204
205
  self.assert_model_size_reduction_above_min_pct(
@@ -220,9 +221,9 @@ class BaseOpTestCase(parameterized.TestCase):
220
221
  op_name: _OpName,
221
222
  op_config: _OpQuantConfig,
222
223
  num_validation_samples: int = 4,
223
- num_calibration_samples: Union[int, None] = None,
224
+ num_calibration_samples: Optional[int] = None,
224
225
  output_tolerance: float = 1e-4,
225
- int_min_max: Union[tuple[int, int], None] = None,
226
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
226
227
  ):
227
228
  """Checks if the output errors after quantization are within the tolerance."""
228
229
  validation_result = self.quantize_and_validate(
@@ -232,7 +233,7 @@ class BaseOpTestCase(parameterized.TestCase):
232
233
  num_calibration_samples=num_calibration_samples,
233
234
  op_name=op_name,
234
235
  op_config=op_config,
235
- int_min_max=int_min_max,
236
+ min_max_range=min_max_range,
236
237
  )
237
238
  self.assert_output_errors_below_tolerance(
238
239
  validation_result, output_tolerance
@@ -65,6 +65,7 @@ TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({
65
65
  _TFLOpName.UNPACK: schema.BuiltinOperator.UNPACK,
66
66
  _TFLOpName.DIV: schema.BuiltinOperator.DIV,
67
67
  _TFLOpName.BROADCAST_TO: schema.BuiltinOperator.BROADCAST_TO,
68
+ _TFLOpName.SQRT: schema.BuiltinOperator.SQRT,
68
69
  })
69
70
 
70
71
  TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
@@ -27,6 +27,8 @@ from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorf
27
27
 
28
28
  DEFAULT_SIGNATURE_KEY = "serving_default"
29
29
 
30
+ _Numeric = Union[int, float]
31
+
30
32
 
31
33
  def create_tfl_interpreter(
32
34
  tflite_model: Union[str, bytes],
@@ -329,6 +331,17 @@ def _create_random_normal(
329
331
  return rng.normal(size=shape).astype(dtype)
330
332
 
331
333
 
334
+ def _create_random_uniform(
335
+ rng: np.random.Generator,
336
+ shape: tuple[int, ...],
337
+ dtype: np.dtype,
338
+ min_value: float = 0.0,
339
+ max_value: float = 1.0,
340
+ ) -> dict[str, Any]:
341
+ """Creates a random uniform dataset sample for given input details."""
342
+ return rng.uniform(min_value, max_value, size=shape).astype(dtype)
343
+
344
+
332
345
  def _create_random_integers(
333
346
  rng: np.random.Generator,
334
347
  shape: tuple[int, ...],
@@ -353,7 +366,7 @@ def create_random_dataset(
353
366
  input_details: dict[str, Any],
354
367
  num_samples: int,
355
368
  random_seed: Union[int, np._typing.ArrayLike],
356
- int_min_max: Union[tuple[int, int], None] = None,
369
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
357
370
  ) -> list[dict[str, Any]]:
358
371
  """Creates a random normal dataset for given input details.
359
372
 
@@ -361,7 +374,7 @@ def create_random_dataset(
361
374
  input_details: A dictionary of input details.
362
375
  num_samples: The number of samples to generate.
363
376
  random_seed: The random seed to use.
364
- int_min_max: The min and max of the integer input range.
377
+ min_max_range: The min and max of the input range.
365
378
 
366
379
  Returns:
367
380
  A list of dictionaries, each containing a sample of input data (for all
@@ -375,15 +388,21 @@ def create_random_dataset(
375
388
  dtype = input_tensor["dtype"]
376
389
  shape = input_tensor["shape"]
377
390
  if dtype in (np.int32, np.int64):
378
- if int_min_max is None:
391
+ if min_max_range is None:
379
392
  new_data = _create_random_integers(rng, shape, dtype)
380
393
  else:
381
- min_value, max_value = int_min_max
394
+ min_value, max_value = min_max_range
382
395
  new_data = _create_random_integers(
383
396
  rng, shape, dtype, min_value, max_value
384
397
  )
385
398
  elif dtype in (np.float32, ml_dtypes.bfloat16):
386
- new_data = _create_random_normal(rng, shape, dtype)
399
+ if min_max_range is None:
400
+ new_data = _create_random_normal(rng, shape, dtype)
401
+ else:
402
+ min_value, max_value = min_max_range
403
+ new_data = _create_random_uniform(
404
+ rng, shape, dtype, min_value, max_value
405
+ )
387
406
  elif dtype == np.bool:
388
407
  new_data = _create_random_bool(rng, shape, dtype)
389
408
  else:
@@ -397,7 +416,7 @@ def create_random_normal_input_data(
397
416
  tflite_model: Union[str, bytes],
398
417
  num_samples: int = 4,
399
418
  random_seed: int = 666,
400
- int_min_max: Union[tuple[int, int], None] = None,
419
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
401
420
  ) -> dict[str, list[dict[str, Any]]]:
402
421
  """Creates a random normal dataset for a signature runner.
403
422
 
@@ -405,7 +424,7 @@ def create_random_normal_input_data(
405
424
  tflite_model: TFLite model path or bytearray.
406
425
  num_samples: Number of input samples to be generated.
407
426
  random_seed: Random seed to be used for function.
408
- int_min_max: The min and max of the integer input range.
427
+ min_max_range: The min and max of the input range.
409
428
 
410
429
  Returns:
411
430
  A list of inputs to the given interpreter, for a single interpreter we may
@@ -420,6 +439,9 @@ def create_random_normal_input_data(
420
439
  signature_runner = tfl_interpreter.get_signature_runner(signature_key)
421
440
  input_details = signature_runner.get_input_details()
422
441
  test_data[signature_key] = create_random_dataset(
423
- input_details, num_samples, random_seed, int_min_max
442
+ input_details,
443
+ num_samples,
444
+ random_seed,
445
+ min_max_range,
424
446
  )
425
447
  return test_data
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.3.0.dev20250725
3
+ Version: 0.3.0.dev20250726
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -1,18 +1,18 @@
1
1
  ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas,793
2
- ai_edge_quantizer/algorithm_manager.py,sha256=_m0LZvZkJ1Yntsoc-G--KUIWIE3ojSdpgRuoB5zTNfw,12822
2
+ ai_edge_quantizer/algorithm_manager.py,sha256=wgC3g7hHvEM1fXARQsT3UgR5YLJqdQ4BLIPicn_bTvM,12932
3
3
  ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
4
4
  ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
5
5
  ai_edge_quantizer/calibrator.py,sha256=Sms7_AIHPH9G5xFaz5Ef3a5gPhxuIWQI8d2LUM8C96I,12071
6
6
  ai_edge_quantizer/calibrator_test.py,sha256=ejKc5YC7id8J1Ll9HAYCzMnKzxd0FUENSD06zkSSV0c,11900
7
7
  ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
8
- ai_edge_quantizer/default_policy.py,sha256=ntINf9s_CMVsrJRxpi9boP8lKK6omqO6cQaLKoOpOvo,11410
8
+ ai_edge_quantizer/default_policy.py,sha256=djOEFPStjcDLoqNwK4RH_lfWJmdCLLixhCLwa3mN8pQ,11438
9
9
  ai_edge_quantizer/model_modifier.py,sha256=teGa8I6kGvn6TQY6Xv53YFIc_pQEhNvM9Zb4bvhezyw,7110
10
10
  ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
11
11
  ai_edge_quantizer/model_validator.py,sha256=Hj0_5o-Oa3dSlJ3ryVjRhvsyelHNyek1GrtG9buMczg,13153
12
12
  ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
13
13
  ai_edge_quantizer/params_generator.py,sha256=hcgMHJlERZERUyIAEi6AHJcLJ8gsKIBAEojzFFz-tqk,20098
14
14
  ai_edge_quantizer/params_generator_test.py,sha256=RDYoRZDJfEZRtjlTAU2kZ_4t3JHOqEHxfJX9V4ETAhg,40597
15
- ai_edge_quantizer/qtyping.py,sha256=1XCcdbTzNutOc8CoImk3DPIikmS93K-5E1AA9IE_i2g,16686
15
+ ai_edge_quantizer/qtyping.py,sha256=8sCBPI3IuIHaT4NrMZrLH6Hp_fkrVP8NEQC5zvHG-UU,16702
16
16
  ai_edge_quantizer/quantizer.py,sha256=g3DMqFMrMpt9jQttCE0WcdNbMtk0JZnmN5MmCHrNdyM,13202
17
17
  ai_edge_quantizer/quantizer_test.py,sha256=K_HBA56JkFI3HL8VLWCqGEfC0ISh5ldMKoNyBdGRAJg,20368
18
18
  ai_edge_quantizer/recipe.py,sha256=FR0uJceumZrnle2VRSOQZ1uXup4S1cTYKRH-N53mWRo,2919
@@ -28,7 +28,7 @@ ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCP
28
28
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
29
29
  ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
30
30
  ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
31
- ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=onqXt1Tng0bVTSdKod7fLci9bdXiiZwh8vQIg3ipm9c,32804
31
+ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=BUrGoC1TU6hD0QzqjblIs56Il7PKRfaz6s1G6nTXoio,33239
32
32
  ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
33
33
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=BDdn_uBZakfHyzdMJPKadsOqxqyC-s6W2ZzFH99L4fE,8652
34
34
  ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
@@ -61,15 +61,15 @@ ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rv
61
61
  ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
62
62
  ai_edge_quantizer/utils/calibration_utils.py,sha256=e3dG7Nm94Ix0hkTWTWPUhEG6a8QR_cAM3PSwblfJV5g,15106
63
63
  ai_edge_quantizer/utils/calibration_utils_test.py,sha256=4BlksXl7b4yptL8xPR67hmJCnjhN9V10a2PunzfHrUE,9372
64
- ai_edge_quantizer/utils/test_utils.py,sha256=spqUmSNciOKPQHCBkHE7Zo34eMFq_BfBCAnMT3jAulU,8615
65
- ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=EVbj8wtZNywuFNxLvXBqxDVwFS_QX3V_q8TuZCVJMUI,11108
64
+ ai_edge_quantizer/utils/test_utils.py,sha256=a4Nk-wbeB09dFjTDZiA0K67d26j5DD0UDH_GIVmVG_4,8685
65
+ ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=LPk8yWBjLt_saKobjAvtBR9q_Ets6-3HrfMxPt064Ig,11158
66
66
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
67
- ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=SKSu1nqhGGzVei_DxmzXK-bbOE7G1vKnPDc5skce-yY,14322
67
+ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=EoVjI_hplX_Rml3hfRsGmQOihexmizeJqt4SQcET9aA,14925
68
68
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=6fjkM-rycZ95L4yfvlr0TN6RlrhfPzxNUYrZaYO_F0A,12013
69
69
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
70
70
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
71
- ai_edge_quantizer_nightly-0.3.0.dev20250725.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
72
- ai_edge_quantizer_nightly-0.3.0.dev20250725.dist-info/METADATA,sha256=twgxIe5unlBMZeDiKWldv1SHXyASfppOy7aXVwmf0VM,1528
73
- ai_edge_quantizer_nightly-0.3.0.dev20250725.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
74
- ai_edge_quantizer_nightly-0.3.0.dev20250725.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
75
- ai_edge_quantizer_nightly-0.3.0.dev20250725.dist-info/RECORD,,
71
+ ai_edge_quantizer_nightly-0.3.0.dev20250726.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
72
+ ai_edge_quantizer_nightly-0.3.0.dev20250726.dist-info/METADATA,sha256=e6TL7vVEzY0RsL4EadFawcW8PMwPQPQO6_16mmkqqjw,1528
73
+ ai_edge_quantizer_nightly-0.3.0.dev20250726.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
74
+ ai_edge_quantizer_nightly-0.3.0.dev20250726.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
75
+ ai_edge_quantizer_nightly-0.3.0.dev20250726.dist-info/RECORD,,