ai-edge-quantizer-nightly 0.1.0.dev20250415__py3-none-any.whl → 0.5.0.dev20260103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ai_edge_quantizer/algorithm_manager.py +158 -0
  2. ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py +2 -2
  3. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py +489 -53
  4. ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py +29 -2
  5. ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +4 -6
  6. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py +414 -0
  7. ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py +440 -0
  8. ai_edge_quantizer/algorithms/uniform_quantize/mse.py +127 -0
  9. ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py +195 -0
  10. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +48 -42
  11. ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +53 -14
  12. ai_edge_quantizer/algorithms/uniform_quantize/octav.py +32 -18
  13. ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py +92 -38
  14. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py +248 -13
  15. ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py +126 -6
  16. ai_edge_quantizer/algorithms/utils/common_utils.py +142 -53
  17. ai_edge_quantizer/calibrator.py +11 -60
  18. ai_edge_quantizer/calibrator_test.py +4 -73
  19. ai_edge_quantizer/default_policy.py +61 -26
  20. ai_edge_quantizer/model_modifier.py +97 -7
  21. ai_edge_quantizer/model_modifier_test.py +81 -1
  22. ai_edge_quantizer/model_validator.py +31 -8
  23. ai_edge_quantizer/params_generator.py +17 -10
  24. ai_edge_quantizer/params_generator_test.py +2 -7
  25. ai_edge_quantizer/qtyping.py +86 -6
  26. ai_edge_quantizer/quantizer.py +166 -21
  27. ai_edge_quantizer/quantizer_test.py +284 -16
  28. ai_edge_quantizer/recipe.py +154 -42
  29. ai_edge_quantizer/recipe_manager.py +158 -1
  30. ai_edge_quantizer/recipe_manager_test.py +146 -32
  31. ai_edge_quantizer/recipe_test.py +93 -17
  32. ai_edge_quantizer/transformation_instruction_generator.py +118 -13
  33. ai_edge_quantizer/transformation_instruction_generator_test.py +163 -27
  34. ai_edge_quantizer/transformation_performer.py +55 -25
  35. ai_edge_quantizer/transformation_performer_test.py +127 -5
  36. ai_edge_quantizer/transformations/duplicate_buffer.py +2 -1
  37. ai_edge_quantizer/transformations/duplicate_tensor.py +1 -0
  38. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +299 -0
  39. ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py +244 -0
  40. ai_edge_quantizer/transformations/insert_hadamard_rotation.py +186 -0
  41. ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py +200 -0
  42. ai_edge_quantizer/transformations/quantize_tensor.py +17 -32
  43. ai_edge_quantizer/transformations/quantize_tensor_test.py +1 -1
  44. ai_edge_quantizer/transformations/transformation_utils.py +129 -6
  45. ai_edge_quantizer/transformations/transformation_utils_test.py +65 -3
  46. ai_edge_quantizer/utils/calibration_utils.py +263 -1
  47. ai_edge_quantizer/utils/calibration_utils_test.py +173 -3
  48. ai_edge_quantizer/utils/constrained_ops_utils.py +111 -0
  49. ai_edge_quantizer/utils/constrained_ops_utils_test.py +50 -0
  50. ai_edge_quantizer/utils/test_utils.py +75 -2
  51. ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +39 -6
  52. ai_edge_quantizer/utils/tfl_interpreter_utils.py +87 -15
  53. ai_edge_quantizer/utils/tfl_interpreter_utils_test.py +29 -2
  54. ai_edge_quantizer/utils/validation_utils.py +114 -4
  55. ai_edge_quantizer/utils/validation_utils_test.py +80 -0
  56. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/METADATA +14 -4
  57. ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD +81 -0
  58. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/WHEEL +1 -1
  59. ai_edge_quantizer/transformations/emulated_subchannel.py +0 -363
  60. ai_edge_quantizer/transformations/emulated_subchannel_test.py +0 -212
  61. ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info/RECORD +0 -73
  62. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses}/LICENSE +0 -0
  63. {ai_edge_quantizer_nightly-0.1.0.dev20250415.dist-info → ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info}/top_level.txt +0 -0
@@ -17,21 +17,25 @@
17
17
 
18
18
  from typing import Any, Optional, Union
19
19
 
20
+ import ml_dtypes
20
21
  import numpy as np
21
22
 
22
23
  from ai_edge_quantizer import qtyping
23
24
  from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
24
25
  from ai_edge_litert import interpreter as tfl # pylint: disable=g-direct-tensorflow-import
25
- from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import
26
+ import os # tensorflow.python.platform.gfile # pylint: disable=g-direct-tensorflow-import
26
27
 
27
28
  DEFAULT_SIGNATURE_KEY = "serving_default"
28
29
 
30
+ _Numeric = Union[int, float]
31
+
29
32
 
30
33
  def create_tfl_interpreter(
31
34
  tflite_model: Union[str, bytes],
32
35
  allocate_tensors: bool = True,
33
36
  use_xnnpack: bool = True,
34
37
  num_threads: int = 16,
38
+ preserve_all_tensors: bool = True,
35
39
  ) -> tfl.Interpreter:
36
40
  """Creates a TFLite interpreter from a model file.
37
41
 
@@ -40,12 +44,14 @@ def create_tfl_interpreter(
40
44
  allocate_tensors: Whether to allocate tensors.
41
45
  use_xnnpack: Whether to use the XNNPACK delegate for the interpreter.
42
46
  num_threads: The number of threads to use for the interpreter.
47
+ preserve_all_tensors: Whether to preserve all tensors. If False, only input
48
+ and output tensors are preserved.
43
49
 
44
50
  Returns:
45
51
  A TFLite interpreter.
46
52
  """
47
53
  if isinstance(tflite_model, str):
48
- with gfile.GFile(tflite_model, "rb") as f:
54
+ with open(tflite_model, "rb") as f:
49
55
  tflite_model = f.read()
50
56
 
51
57
  if use_xnnpack:
@@ -56,7 +62,7 @@ def create_tfl_interpreter(
56
62
  model_content=bytes(tflite_model),
57
63
  num_threads=num_threads,
58
64
  experimental_op_resolver_type=op_resolver,
59
- experimental_preserve_all_tensors=True,
65
+ experimental_preserve_all_tensors=preserve_all_tensors,
60
66
  )
61
67
  if allocate_tensors:
62
68
  tflite_interpreter.allocate_tensors()
@@ -319,10 +325,51 @@ def get_signature_main_subgraph_index(
319
325
  return signature_runner._subgraph_index # pylint:disable=protected-access
320
326
 
321
327
 
322
- def create_random_normal_dataset(
328
+ def _create_random_normal(
329
+ rng: np.random.Generator,
330
+ shape: tuple[int, ...],
331
+ dtype: np.dtype,
332
+ ) -> dict[str, Any]:
333
+ """Creates a random normal dataset sample for given input details."""
334
+ return rng.normal(size=shape).astype(dtype)
335
+
336
+
337
+ def _create_random_uniform(
338
+ rng: np.random.Generator,
339
+ shape: tuple[int, ...],
340
+ dtype: np.dtype,
341
+ min_value: float = 0.0,
342
+ max_value: float = 1.0,
343
+ ) -> dict[str, Any]:
344
+ """Creates a random uniform dataset sample for given input details."""
345
+ return rng.uniform(min_value, max_value, size=shape).astype(dtype)
346
+
347
+
348
+ def _create_random_integers(
349
+ rng: np.random.Generator,
350
+ shape: tuple[int, ...],
351
+ dtype: np.dtype,
352
+ min_value: int = 0,
353
+ max_value: int = 1024,
354
+ ) -> dict[str, Any]:
355
+ """Creates a random integer dataset sample for given input details."""
356
+ return rng.integers(min_value, max_value, size=shape, dtype=dtype)
357
+
358
+
359
+ def _create_random_bool(
360
+ rng: np.random.Generator,
361
+ shape: tuple[int, ...],
362
+ dtype: np.dtype,
363
+ ) -> dict[str, Any]:
364
+ """Creates a random bool dataset sample for given input details."""
365
+ return rng.choice([True, False], size=shape, replace=True).astype(dtype)
366
+
367
+
368
+ def create_random_dataset(
323
369
  input_details: dict[str, Any],
324
370
  num_samples: int,
325
371
  random_seed: Union[int, np._typing.ArrayLike],
372
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
326
373
  ) -> list[dict[str, Any]]:
327
374
  """Creates a random normal dataset for given input details.
328
375
 
@@ -330,6 +377,7 @@ def create_random_normal_dataset(
330
377
  input_details: A dictionary of input details.
331
378
  num_samples: The number of samples to generate.
332
379
  random_seed: The random seed to use.
380
+ min_max_range: The min and max of the input range.
333
381
 
334
382
  Returns:
335
383
  A list of dictionaries, each containing a sample of input data (for all
@@ -340,9 +388,28 @@ def create_random_normal_dataset(
340
388
  for _ in range(num_samples):
341
389
  input_data = {}
342
390
  for arg_name, input_tensor in input_details.items():
343
- new_data = rng.normal(size=input_tensor["shape"]).astype(
344
- input_tensor["dtype"]
345
- )
391
+ dtype = input_tensor["dtype"]
392
+ shape = input_tensor["shape"]
393
+ if dtype in (np.int32, np.int64):
394
+ if min_max_range is None:
395
+ new_data = _create_random_integers(rng, shape, dtype)
396
+ else:
397
+ min_value, max_value = min_max_range
398
+ new_data = _create_random_integers(
399
+ rng, shape, dtype, min_value, max_value
400
+ )
401
+ elif dtype in (np.float32, ml_dtypes.bfloat16):
402
+ if min_max_range is None:
403
+ new_data = _create_random_normal(rng, shape, dtype)
404
+ else:
405
+ min_value, max_value = min_max_range
406
+ new_data = _create_random_uniform(
407
+ rng, shape, dtype, min_value, max_value
408
+ )
409
+ elif dtype == np.bool:
410
+ new_data = _create_random_bool(rng, shape, dtype)
411
+ else:
412
+ raise ValueError(f"Unsupported dtype: {input_tensor['dtype']}")
346
413
  input_data[arg_name] = new_data
347
414
  dataset.append(input_data)
348
415
  return dataset
@@ -352,18 +419,20 @@ def create_random_normal_input_data(
352
419
  tflite_model: Union[str, bytes],
353
420
  num_samples: int = 4,
354
421
  random_seed: int = 666,
422
+ min_max_range: Optional[tuple[_Numeric, _Numeric]] = None,
355
423
  ) -> dict[str, list[dict[str, Any]]]:
356
- """create random dataset following random distribution for signature runner.
424
+ """Creates a random normal dataset for a signature runner.
357
425
 
358
426
  Args:
359
- tflite_model: TFLite model path or bytearray
360
- num_samples: number of input samples to be generated
361
- random_seed: random seed to be used for function
427
+ tflite_model: TFLite model path or bytearray.
428
+ num_samples: Number of input samples to be generated.
429
+ random_seed: Random seed to be used for function.
430
+ min_max_range: The min and max of the input range.
362
431
 
363
432
  Returns:
364
- a list of inputs to the given interpreter, for a single interpreter we may
433
+ A list of inputs to the given interpreter, for a single interpreter we may
365
434
  have multiple signatures so each set of inputs is also represented as
366
- list
435
+ list.
367
436
  """
368
437
  tfl_interpreter = create_tfl_interpreter(tflite_model)
369
438
  signature_defs = tfl_interpreter.get_signature_list()
@@ -372,7 +441,10 @@ def create_random_normal_input_data(
372
441
  for signature_key in signature_keys:
373
442
  signature_runner = tfl_interpreter.get_signature_runner(signature_key)
374
443
  input_details = signature_runner.get_input_details()
375
- test_data[signature_key] = create_random_normal_dataset(
376
- input_details, num_samples, random_seed
444
+ test_data[signature_key] = create_random_dataset(
445
+ input_details,
446
+ num_samples,
447
+ random_seed,
448
+ min_max_range,
377
449
  )
378
450
  return test_data
@@ -19,7 +19,6 @@ from tensorflow.python.platform import googletest
19
19
  from ai_edge_quantizer.utils import test_utils
20
20
  from ai_edge_quantizer.utils import tfl_interpreter_utils
21
21
 
22
-
23
22
  TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("../tests/models")
24
23
 
25
24
 
@@ -91,6 +90,16 @@ class TflUtilsSingleSignatureModelTest(googletest.TestCase):
91
90
  ]
92
91
  self.assertEqual(tuple(average_pool_res.shape), (1, 14, 14, 8))
93
92
 
93
+ def test_get_tensor_name_to_content_map_fails_no_preserve_all_tensors(self):
94
+ tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
95
+ self._test_model_path, preserve_all_tensors=False
96
+ )
97
+ tfl_interpreter_utils.invoke_interpreter_once(
98
+ tfl_interpreter, [self._input_data]
99
+ )
100
+ with self.assertRaisesRegex(ValueError, "Tensor data is null."):
101
+ tfl_interpreter_utils.get_tensor_name_to_content_map(tfl_interpreter)
102
+
94
103
  def test_is_tensor_quantized(self):
95
104
  tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
96
105
  self._test_model_path
@@ -159,7 +168,6 @@ class TflUtilsQuantizedModelTest(googletest.TestCase):
159
168
  signature_output = tfl_interpreter_utils.invoke_interpreter_signature(
160
169
  tfl_interpreter, self._signature_input_data
161
170
  )
162
- print(signature_output)
163
171
  self.assertEqual(tuple(signature_output["dense_1"].shape), (1, 10))
164
172
 
165
173
  # Assert the input data is not modified in-place b/353340272.
@@ -328,5 +336,24 @@ class TflUtilsMultiSignatureModelTest(googletest.TestCase):
328
336
  self.assertEqual(multiply_output_content, [20.0])
329
337
 
330
338
 
339
+ class TflUtilsIntegerInputModelTest(googletest.TestCase):
340
+
341
+ def setUp(self):
342
+ super().setUp()
343
+ np.random.seed(0)
344
+ self._test_model_path = os.path.join(
345
+ TEST_DATA_PREFIX_PATH, "toy_model_with_kv_cache_multi_signature.tflite"
346
+ )
347
+
348
+ def test_random_integer_input_data(self):
349
+ test_data = tfl_interpreter_utils.create_random_normal_input_data(
350
+ self._test_model_path
351
+ )
352
+ self.assertEqual(test_data["signature_1"][0]["cache_0"].dtype, np.float32)
353
+ self.assertEqual(test_data["signature_1"][0]["cache_1"].dtype, np.float32)
354
+ self.assertEqual(test_data["signature_1"][0]["positions"].dtype, np.int32)
355
+ self.assertEqual(test_data["signature_1"][0]["tokens"].dtype, np.int32)
356
+
357
+
331
358
  if __name__ == "__main__":
332
359
  googletest.main()
@@ -32,12 +32,18 @@ def get_validation_func(
32
32
  a validation function
33
33
 
34
34
  Raises:
35
- Value error if the function name is not supported
35
+ ValueError: if the function name is not supported
36
36
  """
37
37
  if func_name == "mse":
38
38
  return mean_squared_difference
39
39
  elif func_name == "median_diff_ratio":
40
40
  return median_diff_ratio
41
+ elif func_name == "cosine_similarity":
42
+ return cosine_similarity
43
+ elif func_name == "kl_divergence":
44
+ return kl_divergence
45
+ elif func_name == "snr":
46
+ return signal_to_noise_ratio
41
47
  else:
42
48
  raise ValueError(f"Validation function {func_name} not supported")
43
49
 
@@ -58,7 +64,7 @@ def mean_squared_difference(
58
64
  a float value representing the MSD between data1 & 2
59
65
 
60
66
  Raises:
61
- Value error if the two inputs don't have the same number of elements
67
+ ValueError: if the two inputs don't have the same number of elements
62
68
  """
63
69
  data1, data2 = _preprocess_same_size_arrays(data1, data2)
64
70
  # special handling for tensor of size 0
@@ -87,7 +93,7 @@ def median_diff_ratio(
87
93
  a float value representing the median diff ratio between data1 & 2
88
94
 
89
95
  Raises:
90
- Value error if the two inputs don't have the same number of elements
96
+ ValueError: if the two inputs don't have the same number of elements
91
97
  """
92
98
  data1, data2 = _preprocess_same_size_arrays(data1, data2)
93
99
  # special handling for tensor of size 0
@@ -99,6 +105,110 @@ def median_diff_ratio(
99
105
  return median_ratio
100
106
 
101
107
 
108
+ def cosine_similarity(
109
+ data1: np._typing.ArrayLike,
110
+ data2: np._typing.ArrayLike,
111
+ ) -> float:
112
+ """Calculates the cosine similarity between data1 & data2.
113
+
114
+ ref: https://en.wikipedia.org/wiki/Cosine_similarity
115
+
116
+ Args:
117
+ data1: input data to be used for comparison
118
+ data2: input data to be used for comparison, data1 & 2 must be of the same
119
+ shape
120
+
121
+ Returns:
122
+ a float value representing the cosine similarity between data1 & 2
123
+
124
+ Raises:
125
+ ValueError: if the two inputs don't have the same number of elements
126
+ """
127
+ data1, data2 = _preprocess_same_size_arrays(data1, data2)
128
+ # special handling for tensor of size 0
129
+ if data1.size == 0:
130
+ return float(0)
131
+ norm_data1 = np.linalg.norm(data1)
132
+ norm_data2 = np.linalg.norm(data2)
133
+ # special handling for tensor of length 0
134
+ if norm_data1 == 0 and norm_data2 == 0:
135
+ return 1.0
136
+ if norm_data1 == 0 or norm_data2 == 0:
137
+ return 0.0
138
+ return np.dot(data1, data2) / (norm_data1 * norm_data2)
139
+
140
+
141
+ def kl_divergence(
142
+ data1: np._typing.ArrayLike,
143
+ data2: np._typing.ArrayLike,
144
+ epsilon: float = 1e-9,
145
+ ) -> float:
146
+ """Calculates the KL divergence between data1 & data2.
147
+
148
+ KL(data2 || data1) = sum(data2 * log(data2 / data1)).
149
+ data2 is treated as the true distribution P, and data1 as the
150
+ approximated distribution Q.
151
+ Non-positive values in data1 and data2 are clipped to 0 before
152
+ KL divergence calculation. Epsilon is added to avoid log(0) and
153
+ division by zero.
154
+
155
+ Args:
156
+ data1: input data to be used for comparison (distribution Q)
157
+ data2: input data to be used for comparison (distribution P), data1 & 2 must
158
+ be of the same shape
159
+ epsilon: small value to avoid log(0) and division by zero.
160
+
161
+ Returns:
162
+ A float value representing the KL divergence between data1 & 2.
163
+
164
+ Raises:
165
+ ValueError: if the two inputs don't have the same number of elements.
166
+ """
167
+ data1, data2 = _preprocess_same_size_arrays(data1, data2)
168
+ # special handling for tensor of size 0
169
+ if data1.size == 0:
170
+ return float(0)
171
+
172
+ p = np.maximum(0, data2)
173
+ q = np.maximum(0, data1)
174
+
175
+ return float(np.sum(p * np.log((p + epsilon) / (q + epsilon))))
176
+
177
+
178
+ def signal_to_noise_ratio(
179
+ noisy_signal: np._typing.ArrayLike,
180
+ signal: np._typing.ArrayLike,
181
+ epsilon: float = 1e-9,
182
+ ) -> float:
183
+ """Calculates the signal to noise ratio between noisy_signal & signal.
184
+
185
+ SNR = P_signal / P_noise, where signal is treated as the clean signal and
186
+ noisy_signal-signal is treated as the noise samples.
187
+ P_signal = mean(signal^2)
188
+ P_noise = mean((noisy_signal-signal)^2) = mse(noisy_signal, signal)
189
+
190
+ Args:
191
+ noisy_signal: Input data to be used for comparison (e.g. noisy signal).
192
+ signal: Input data to be used for comparison (e.g. clean signal),
193
+ noisy_signal & signal must be of the same shape.
194
+ epsilon: Small value to avoid division by zero.
195
+
196
+ Returns:
197
+ A float value representing the SNR between noisy_signal & signal.
198
+
199
+ Raises:
200
+ ValueError: If the two inputs don't have the same number of elements.
201
+ """
202
+ noisy_signal, signal = _preprocess_same_size_arrays(noisy_signal, signal)
203
+ if signal.size == 0:
204
+ return float(0)
205
+
206
+ mse = mean_squared_difference(noisy_signal, signal)
207
+ signal_power = float(np.square(signal).mean())
208
+ snr = signal_power / (mse + epsilon)
209
+ return snr
210
+
211
+
102
212
  def _preprocess_same_size_arrays(
103
213
  data1: np._typing.ArrayLike, data2: np._typing.ArrayLike
104
214
  ) -> Tuple[np.ndarray, np.ndarray]:
@@ -113,7 +223,7 @@ def _preprocess_same_size_arrays(
113
223
  a tuple of the preprocessed data1 & 2
114
224
 
115
225
  Raises:
116
- Value error if the two inputs don't have the same number of elements
226
+ ValueError: if the two inputs don't have the same number of elements
117
227
  """
118
228
  data1 = np.array(data1, dtype=np.float32).flatten()
119
229
  data2 = np.array(data2, dtype=np.float32).flatten()
@@ -82,6 +82,86 @@ class ValidationUtilTest(googletest.TestCase):
82
82
  result = validation_utils.median_diff_ratio(data1, data2)
83
83
  self.assertEqual(result, 0)
84
84
 
85
+ def test_cosine_similarity(self):
86
+ data1 = [1, 2, 3]
87
+ data2 = [1, 2, 3]
88
+ result = validation_utils.cosine_similarity(data1, data2)
89
+ self.assertAlmostEqual(result, 1.0, 6)
90
+
91
+ def test_cosine_similarity_perpendicular(self):
92
+ data1 = [1, 0, 0]
93
+ data2 = [0, 1, 0]
94
+ result = validation_utils.cosine_similarity(data1, data2)
95
+ self.assertAlmostEqual(result, 0.0, 6)
96
+
97
+ def test_cosine_similarity_multidim(self):
98
+ data1 = [[1, 2], [4, 5]]
99
+ data2 = [[1, 3], [2, 2]]
100
+ result = validation_utils.cosine_similarity(data1, data2)
101
+ self.assertAlmostEqual(result, 0.86881, 6)
102
+
103
+ def test_cosine_similarity_0d(self):
104
+ data1 = []
105
+ data2 = []
106
+ result = validation_utils.cosine_similarity(data1, data2)
107
+ self.assertEqual(result, 0)
108
+
109
+ def test_kl_divergence(self):
110
+ data1 = [0.5, 0.5]
111
+ data2 = [0.1, 0.9]
112
+ result = validation_utils.kl_divergence(data1, data2)
113
+ self.assertAlmostEqual(result, 0.36808, 4)
114
+
115
+ def test_kl_divergence_zero_in_q(self):
116
+ data1 = [0, 1]
117
+ data2 = [1, 0]
118
+ result = validation_utils.kl_divergence(data1, data2)
119
+ self.assertAlmostEqual(result, 20.7232658, 4)
120
+
121
+ def test_kl_divergence_negative_values(self):
122
+ data1 = [-1, 1]
123
+ data2 = [1, -1]
124
+ result = validation_utils.kl_divergence(data1, data2)
125
+ self.assertAlmostEqual(result, 20.7232658, 4)
126
+
127
+ def test_kl_divergence_0d(self):
128
+ data1 = []
129
+ data2 = []
130
+ result = validation_utils.kl_divergence(data1, data2)
131
+ self.assertEqual(result, 0)
132
+
133
+ def test_get_validation_func_kl_divergence(self):
134
+ func = validation_utils.get_validation_func("kl_divergence")
135
+ self.assertEqual(func, validation_utils.kl_divergence)
136
+
137
+ def test_signal_to_noise_ratio_0d(self):
138
+ data1 = []
139
+ data2 = []
140
+ result = validation_utils.signal_to_noise_ratio(data1, data2)
141
+ self.assertEqual(result, 0)
142
+
143
+ def test_signal_to_noise_ratio_identical(self):
144
+ data1 = [1, 2, 3]
145
+ data2 = [1, 2, 3]
146
+ result = validation_utils.signal_to_noise_ratio(data1, data2)
147
+ self.assertGreater(result, 1e8) # mse=0, so snr should be large
148
+
149
+ def test_signal_to_noise_ratio_with_noise(self):
150
+ data1 = [2, 3, 4]
151
+ data2 = [1, 2, 3]
152
+ result = validation_utils.signal_to_noise_ratio(data1, data2)
153
+ self.assertAlmostEqual(result, 14 / 3, places=5)
154
+
155
+ def test_signal_to_noise_ratio_simple(self):
156
+ data1 = [1, 1]
157
+ data2 = [1, 0]
158
+ result = validation_utils.signal_to_noise_ratio(data1, data2)
159
+ self.assertAlmostEqual(result, 1.0, places=5)
160
+
161
+ def test_get_validation_func_snr(self):
162
+ func = validation_utils.get_validation_func("snr")
163
+ self.assertEqual(func, validation_utils.signal_to_noise_ratio)
164
+
85
165
 
86
166
  if __name__ == "__main__":
87
167
  googletest.main()
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.1.0.dev20250415
3
+ Version: 0.5.0.dev20260103
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -24,10 +24,20 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
24
  Requires-Python: >=3.9
25
25
  Description-Content-Type: text/markdown
26
26
  License-File: LICENSE
27
+ Requires-Dist: absl-py
27
28
  Requires-Dist: immutabledict
28
29
  Requires-Dist: numpy
29
- Requires-Dist: tf-nightly>=2.17.0.dev20240509
30
- Requires-Dist: ai-edge-litert>=1.2.0
30
+ Requires-Dist: ml_dtypes
31
+ Requires-Dist: ai-edge-litert-nightly
32
+ Dynamic: classifier
33
+ Dynamic: description
34
+ Dynamic: description-content-type
35
+ Dynamic: home-page
36
+ Dynamic: keywords
37
+ Dynamic: license-file
38
+ Dynamic: requires-dist
39
+ Dynamic: requires-python
40
+ Dynamic: summary
31
41
 
32
42
  It aims to facilitate advanced users to strive for optimal performance on
33
43
  resource demanding models (e.g., GenAI models).
@@ -0,0 +1,81 @@
1
+ ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas,793
2
+ ai_edge_quantizer/algorithm_manager.py,sha256=ZJcmIBREZ7maqxQbMkvwGaQhTxYWFHrdiqNF6c53Jb8,16846
3
+ ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
4
+ ai_edge_quantizer/algorithm_manager_api_test.py,sha256=w6bSONvXkX6bzXAGc0-7b6gNDt9oz9ieq97KP8Sg_JU,7666
5
+ ai_edge_quantizer/calibrator.py,sha256=nkHUmxdWy16Vw3EOD3B_7EkGiX8V-XJRXXFynweGfG8,9744
6
+ ai_edge_quantizer/calibrator_test.py,sha256=c2ZCjl7PQYU9KtAovpDO9JX8sClgaLGO0P7oqoL6rP0,8830
7
+ ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
8
+ ai_edge_quantizer/default_policy.py,sha256=ou__mTzh6hcrO2-_ZHxhOZjbVLAwNfCzckxcyISYRMc,11431
9
+ ai_edge_quantizer/model_modifier.py,sha256=RxzfB1UULxLZlFEtgvFu0WrdTo7SLofc52KZchV_2vQ,10421
10
+ ai_edge_quantizer/model_modifier_test.py,sha256=5vUCodVNk9GPcecjGwovV0677vD0BUZjfq9PGOnMEmM,7227
11
+ ai_edge_quantizer/model_validator.py,sha256=mU6MLMvNQK7fxEJmh11H44OGnkUof0CVP6kYjb_du2A,13931
12
+ ai_edge_quantizer/model_validator_test.py,sha256=EeqOP_mrZsnZ3rug756s0ryDDqd2KgIDld5Lm_gDuWY,13020
13
+ ai_edge_quantizer/params_generator.py,sha256=-tbXB6crutiFhmLFEMe_-sxGylsvgd_cRZQ2fB67bNE,20436
14
+ ai_edge_quantizer/params_generator_test.py,sha256=gJlq_qCPC0dWkbkyCpQiqAsmCYoWYxtxM2xYMEkrr3g,40436
15
+ ai_edge_quantizer/qtyping.py,sha256=RPJTlcculzgx_QxAU6I_TS6JnJYTlqnx2WfxnLKK1dg,18081
16
+ ai_edge_quantizer/quantizer.py,sha256=dgBkHR1VXuXzwKKdv7D39OL2z0ASp30xbN0vwFUX31M,19125
17
+ ai_edge_quantizer/quantizer_test.py,sha256=6gcOLsZO-XW9VoKmcf_9CalG-_2lSUAe_fcmH2zHcoU,30167
18
+ ai_edge_quantizer/recipe.py,sha256=MEkfQ2Sg3KAE9LAORHWcbjYNPg06EUbwc1d-VspQA2U,6461
19
+ ai_edge_quantizer/recipe_manager.py,sha256=OcnrY8Qj_kjDIXx71RX1MHw5qND89N-DKuMRajfGMEg,15205
20
+ ai_edge_quantizer/recipe_manager_test.py,sha256=pLEnLX8zwfZu9LcZoU0a8QpxNr8IFwbGdxp-hlYEwU4,37050
21
+ ai_edge_quantizer/recipe_test.py,sha256=QisyaTol8JRZFcGOGyee7QRCvqj5VbF4guKWdIoMUOE,6213
22
+ ai_edge_quantizer/transformation_instruction_generator.py,sha256=YmjtOFqc4ajGzvHEWTyIUIom0I0uJtxt4Uc9nxzmw2A,31852
23
+ ai_edge_quantizer/transformation_instruction_generator_test.py,sha256=KW5-WoTTo9IqLEVnWxVC8ut8eWLi_91xfKgGqVQ9QDk,54635
24
+ ai_edge_quantizer/transformation_performer.py,sha256=mFsig0E5Isy7cnG1wMO2jzBn3Wql8fElM_PSpaL8okw,13354
25
+ ai_edge_quantizer/transformation_performer_test.py,sha256=xk6A3LStCyPclN51--9uO7XjSxNfZmpdfvrzOL0maNM,20349
26
+ ai_edge_quantizer/algorithms/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
27
+ ai_edge_quantizer/algorithms/nonlinear_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
28
+ ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting.py,sha256=Bs9CK7wZAw6jNaZ8xEtbwO2vM34VYXNZSMVWvxJo9nw,9297
29
+ ai_edge_quantizer/algorithms/nonlinear_quantize/float_casting_test.py,sha256=EqIHGEZ1LgUrTN7zf880RuAzEv3Qy7kgh5ivObJGHSo,22646
30
+ ai_edge_quantizer/algorithms/uniform_quantize/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
31
+ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize.py,sha256=_ItPnd2TXj95QimjiPaJOKcyfW_C5emIougygNTaZxA,42072
32
+ ai_edge_quantizer/algorithms/uniform_quantize/common_quantize_test.py,sha256=GGf_n3wIeg3GB_eGsmyNJ0fTcxgpeMMbugTMRONK6TQ,3553
33
+ ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py,sha256=VjBDxGxjITHJc7xJABqBbZt6_qhobtZAl2gnVQrYJgc,8652
34
+ ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery_test.py,sha256=sT5eX5TLZEHTtPfnSkCPDlS0sQxlTFWbCsbvOuj--yY,8889
35
+ ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation.py,sha256=qxt9CPDcidVWIxp5nSWPN2hKKj1XZcsOOLBd2SYIvW0,14572
36
+ ai_edge_quantizer/algorithms/uniform_quantize/hadamard_rotation_test.py,sha256=1ejj5WS3GZwFk3qpsPiPS8jcmVS1-e7zRmvj2Nj8fKw,15440
37
+ ai_edge_quantizer/algorithms/uniform_quantize/mse.py,sha256=EP5yPw6khAhTo6VNTPXEE2aGKLfNnqz8COeJnTKaGWs,4641
38
+ ai_edge_quantizer/algorithms/uniform_quantize/mse_test.py,sha256=-E1LIlxadckspltdgBWTiUzsiwbawSubndavHhWLt1g,7145
39
+ ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py,sha256=NCLKwM8Teu2yI-Qd36e8KfqZWIqtHeAg_gMD7Z_sqNE,8988
40
+ ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py,sha256=Eqa4OUqoCGywbHz-HxJ9dWRj9BKlVzJPuIhVzvrpdLM,8925
41
+ ai_edge_quantizer/algorithms/uniform_quantize/octav.py,sha256=-n-QZyp9y8WCy5FPSpXZXHfOA-p-RLvfSaCzAfhHiHI,7040
42
+ ai_edge_quantizer/algorithms/uniform_quantize/octav_test.py,sha256=6m2U-9JdNei0XzOORg2gt87TJdD0XHZ-z5h9c4g_TB4,9120
43
+ ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor.py,sha256=qZxTj3B-tqNTLCViwuJj285YncvwjWeay2QKWd8nr6A,20420
44
+ ai_edge_quantizer/algorithms/uniform_quantize/uniform_quantize_tensor_test.py,sha256=eTrrc8AGaSf1Ytp5gsRONAZ94PHFJUTd4dGi5ZnKZjU,16038
45
+ ai_edge_quantizer/algorithms/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
46
+ ai_edge_quantizer/algorithms/utils/common_utils.py,sha256=M3VZsdLC4jCPfSI_aGAY4XjiHvoXtR-UyPZdZdz8GD0,38082
47
+ ai_edge_quantizer/algorithms/utils/common_utils_test.py,sha256=zqapGEfYhjQWe9cNGPLmdbwtEUUYQRhlO_kNe0cXX6E,18104
48
+ ai_edge_quantizer/transformations/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
49
+ ai_edge_quantizer/transformations/dequant_insert.py,sha256=sL1LHFVzBDSd9jgrzlHz38LWU0bwmVX7iBkaNcui0ts,3566
50
+ ai_edge_quantizer/transformations/dequant_insert_test.py,sha256=NJ18PnG71_AvUPz3Cr_TmG6URMeBfa7IiDDyddfTkKQ,10830
51
+ ai_edge_quantizer/transformations/duplicate_buffer.py,sha256=TvTHbm24IiICNkWOlvR2UpJKMU-88puNFycDYc0_ehQ,1774
52
+ ai_edge_quantizer/transformations/duplicate_buffer_test.py,sha256=YYWl3Q5WF60s8T8pLzzA8TCSxz-i7dqc03dJt1LtMw4,3880
53
+ ai_edge_quantizer/transformations/duplicate_tensor.py,sha256=WKhf2LIAL0MnZe88b6942A37lvHXe1cFjUDqE5VNmvU,2490
54
+ ai_edge_quantizer/transformations/duplicate_tensor_test.py,sha256=s-RqSxNBMfVJyCunXz2eb7-KA6UiBmbOmL7phLslENQ,5056
55
+ ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py,sha256=kBQQZ0VKUsV8hId3_a4K-812wNrZHU4LK2c1Rt_D_XA,11084
56
+ ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation_test.py,sha256=Z9Nr5e5aEeEMahhhizFyOkAMEXkEg1EKYZ_bGb5Vbvw,8993
57
+ ai_edge_quantizer/transformations/insert_hadamard_rotation.py,sha256=5D5WwrJCE6hQoANbMwa6YGBbjcG5HcL_rkkoXIAIW9w,6883
58
+ ai_edge_quantizer/transformations/insert_hadamard_rotation_test.py,sha256=iV1p3nZfHUATV2YRoBOYurnu3pLy8n3aFppLWGQOPdA,7268
59
+ ai_edge_quantizer/transformations/quant_insert.py,sha256=jn6HsJaV-sqBiFPY-Aqbd64t8zgcYVkEkZI375x_FWY,3958
60
+ ai_edge_quantizer/transformations/quant_insert_test.py,sha256=X9ptPDvJCFkR5tejKnD1SlHFGPazQTW-wNNMV9MEAuw,10107
61
+ ai_edge_quantizer/transformations/quantize_tensor.py,sha256=VGTVpZWla9R-LPfhTzH1NVAp2soOqDF_duIm8ez_z3Y,7264
62
+ ai_edge_quantizer/transformations/quantize_tensor_test.py,sha256=CD7OboBcIQxQY8OaRd5ISC1JcwQW726P_vneY4LKVpA,9117
63
+ ai_edge_quantizer/transformations/transformation_utils.py,sha256=IKrtXJNH0msiTcI7KXkCYn2EkzmbZKWMMX_r5PMEx2U,8857
64
+ ai_edge_quantizer/transformations/transformation_utils_test.py,sha256=MWgq29t7rvxRQIfi4ny9IoODFCTcbpjnIwoCL40zDKk,8698
65
+ ai_edge_quantizer/utils/__init__.py,sha256=lpq1g2ayg3lCPLy79t2VicYcnGKw64FfYIj1V7J-4m8,676
66
+ ai_edge_quantizer/utils/calibration_utils.py,sha256=dFDsjc3CXaDFNbCMyoPrMVubd3EDtG0ZwIY3Tmbb0sw,11506
67
+ ai_edge_quantizer/utils/calibration_utils_test.py,sha256=jod4iokZkG00y9JrYaFzVvg4JwiA6mX8_whAMkNyoEc,9334
68
+ ai_edge_quantizer/utils/constrained_ops_utils.py,sha256=z0sm1R9anRRVgdgI23XQKwDRcdARdpTo_6UBDB_lHXE,4502
69
+ ai_edge_quantizer/utils/constrained_ops_utils_test.py,sha256=zmMIAS1WIvYK1Z9ZMMxYovIGtxfek-jvfZqrois1ahE,1756
70
+ ai_edge_quantizer/utils/test_utils.py,sha256=a4Nk-wbeB09dFjTDZiA0K67d26j5DD0UDH_GIVmVG_4,8685
71
+ ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=3mngikx_lF-qKBc5KxGX-5kELH_XGKpeGjwUyR5dfZI,12167
72
+ ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=K1SbK8q92qYVtiVj0I0GtugsPTkpIpEKv9zakvFV_Sc,8555
73
+ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=ptdlC3WVUE9aBznT7kZQ0ZOk3EKgOBQdMDAaCdGedIM,15093
74
+ ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=EPOXbmXqbt3tAewo3BQQjh2mjuxrrFit5tkF0wUVYHU,12471
75
+ ai_edge_quantizer/utils/validation_utils.py,sha256=Mr0D6X-pTDLODFAnCX3IlqdV1OL02tlq0ZjHbqx8nzg,7439
76
+ ai_edge_quantizer/utils/validation_utils_test.py,sha256=T8K5mCWeMcihND2KS_dHvCJUU9lEdG2sD95EgPkaX3w,5584
77
+ ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
78
+ ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/METADATA,sha256=gJG1WQM6zOnsWl8Fp3axH3YydN3szFTS6qZ05biDrXw,1729
79
+ ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
80
+ ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
81
+ ai_edge_quantizer_nightly-0.5.0.dev20260103.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.45.1)
2
+ Generator: setuptools (79.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5