ai-edge-quantizer-nightly 0.0.1.dev20241210__py3-none-any.whl → 0.0.1.dev20241218__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -41,6 +41,7 @@ class Calibrator:
41
41
  def __init__(
42
42
  self,
43
43
  float_tflite: Union[str, bytes],
44
+ num_threads: int = 16,
44
45
  ):
45
46
  self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
46
47
 
@@ -50,7 +51,7 @@ class Calibrator:
50
51
  " the model (e.g., if it is already quantized)."
51
52
  )
52
53
  self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
53
- float_tflite
54
+ float_tflite, use_xnnpack=True, num_threads=num_threads
54
55
  )
55
56
  # Tensor name to tensor content.
56
57
  self._tensor_content_map: dict[str, Any] = {}
@@ -207,7 +207,8 @@ def _setup_validation_interpreter(
207
207
  model: bytes,
208
208
  signature_input: dict[str, Any],
209
209
  signature_key: Optional[str],
210
- use_reference_kernel: bool,
210
+ use_xnnpack: bool,
211
+ num_threads: int,
211
212
  ) -> tuple[Any, int, dict[str, Any]]:
212
213
  """Setup the interpreter for validation given a signature key.
213
214
 
@@ -216,15 +217,15 @@ def _setup_validation_interpreter(
216
217
  signature_input: A dictionary of input tensor name and its value.
217
218
  signature_key: The signature key to be used for invoking the models. If the
218
219
  model only has one signature, this can be set to None.
219
- use_reference_kernel: Whether to use the reference kernel for the
220
- interpreter.
220
+ use_xnnpack: Whether to use xnnpack for the interpreter.
221
+ num_threads: The number of threads to use for the interpreter.
221
222
 
222
223
  Returns:
223
224
  A tuple of interpreter, subgraph_index and tensor_name_to_details.
224
225
  """
225
226
 
226
227
  interpreter = utils.create_tfl_interpreter(
227
- tflite_model=model, use_reference_kernel=use_reference_kernel
228
+ tflite_model=model, use_xnnpack=use_xnnpack, num_threads=num_threads
228
229
  )
229
230
  utils.invoke_interpreter_signature(
230
231
  interpreter, signature_input, signature_key
@@ -247,7 +248,8 @@ def compare_model(
247
248
  test_data: dict[str, Iterable[dict[str, Any]]],
248
249
  error_metric: str,
249
250
  compare_fn: Callable[[Any, Any], float],
250
- use_reference_kernel: bool = False,
251
+ use_xnnpack: bool = True,
252
+ num_threads: int = 16,
251
253
  ) -> ComparisonResult:
252
254
  """Compares model tensors over a model signature using the compare_fn.
253
255
 
@@ -266,8 +268,8 @@ def compare_model(
266
268
  compare_fn: a comparison function to be used for calculating the statistics,
267
269
  this function must be taking in two ArrayLike strcuture and output a
268
270
  single float value.
269
- use_reference_kernel: Whether to use the reference kernel for the
270
- interpreter.
271
+ use_xnnpack: Whether to use xnnpack for the interpreter.
272
+ num_threads: The number of threads to use for the interpreter.
271
273
 
272
274
  Returns:
273
275
  A ComparisonResult object.
@@ -282,12 +284,17 @@ def compare_model(
282
284
  reference_model,
283
285
  signature_input,
284
286
  signature_key,
285
- use_reference_kernel,
287
+ use_xnnpack,
288
+ num_threads,
286
289
  )
287
290
  )
288
291
  targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
289
292
  _setup_validation_interpreter(
290
- target_model, signature_input, signature_key, use_reference_kernel
293
+ target_model,
294
+ signature_input,
295
+ signature_key,
296
+ use_xnnpack,
297
+ num_threads,
291
298
  )
292
299
  )
293
300
  # Compare the cached tensor values.
@@ -216,6 +216,7 @@ class Quantizer:
216
216
  self,
217
217
  calibration_data: dict[str, Iterable[_SignatureInput]],
218
218
  previous_calibration_result: Optional[_CalibrationResult] = None,
219
+ num_threads: int = 16,
219
220
  ) -> _CalibrationResult:
220
221
  """Calibrates the float model (required by static range quantization).
221
222
 
@@ -223,6 +224,7 @@ class Quantizer:
223
224
  calibration_data: Calibration data for a model signature.
224
225
  previous_calibration_result: Previous calibration result to be loaded. The
225
226
  calibration process will be resumed from the previous result.
227
+ num_threads: Number of threads to use for calibration.
226
228
 
227
229
  Returns:
228
230
  Calibration result ({tensor_name: tensor QSVs (e.g.,min/max)}).
@@ -233,7 +235,7 @@ class Quantizer:
233
235
  if not self.need_calibration:
234
236
  return {}
235
237
 
236
- calib = calibrator.Calibrator(self.float_model)
238
+ calib = calibrator.Calibrator(self.float_model, num_threads=num_threads)
237
239
  if previous_calibration_result is not None:
238
240
  calib.load_model_qsvs(previous_calibration_result)
239
241
  calib.calibrate(calibration_data, self._recipe_manager)
@@ -297,7 +299,8 @@ class Quantizer:
297
299
  self,
298
300
  test_data: Optional[dict[str, Iterable[_SignatureInput]]] = None,
299
301
  error_metrics: str = 'mse',
300
- use_reference_kernel: bool = False,
302
+ use_xnnpack: bool = True,
303
+ num_threads: int = 16,
301
304
  ) -> model_validator.ComparisonResult:
302
305
  """Numerical validation of the quantized model for a model signature.
303
306
 
@@ -314,7 +317,8 @@ class Quantizer:
314
317
  data that will be used for validation. If set to None, random normal
315
318
  distributed data will be used for all signatures in the model.
316
319
  error_metrics: Error metrics to be used for comparison.
317
- use_reference_kernel: Whether to use the reference kernel for validation.
320
+ use_xnnpack: Whether to use the xnnpack library for validation.
321
+ num_threads: Number of threads to use for validation.
318
322
 
319
323
  Returns:
320
324
  The comparison result.
@@ -330,7 +334,8 @@ class Quantizer:
330
334
  test_data,
331
335
  error_metrics,
332
336
  validation_utils.get_validation_func(error_metrics),
333
- use_reference_kernel=use_reference_kernel,
337
+ use_xnnpack=use_xnnpack,
338
+ num_threads=num_threads,
334
339
  )
335
340
 
336
341
  def _get_quantization_params(
@@ -30,15 +30,16 @@ DEFAULT_SIGNATURE_KEY = "serving_default"
30
30
  def create_tfl_interpreter(
31
31
  tflite_model: Union[str, bytes],
32
32
  allocate_tensors: bool = True,
33
- use_reference_kernel: bool = False,
33
+ use_xnnpack: bool = True,
34
+ num_threads: int = 16,
34
35
  ) -> tfl.Interpreter:
35
36
  """Creates a TFLite interpreter from a model file.
36
37
 
37
38
  Args:
38
39
  tflite_model: Model file path or bytes.
39
40
  allocate_tensors: Whether to allocate tensors.
40
- use_reference_kernel: Whether to use the reference kernel for the
41
- interpreter.
41
+ use_xnnpack: Whether to use the XNNPACK delegate for the interpreter.
42
+ num_threads: The number of threads to use for the interpreter.
42
43
 
43
44
  Returns:
44
45
  A TFLite interpreter.
@@ -47,12 +48,13 @@ def create_tfl_interpreter(
47
48
  with gfile.GFile(tflite_model, "rb") as f:
48
49
  tflite_model = f.read()
49
50
 
50
- if use_reference_kernel:
51
- op_resolver = tfl.OpResolverType.BUILTIN_REF
51
+ if use_xnnpack:
52
+ op_resolver = tfl.OpResolverType.BUILTIN
52
53
  else:
53
54
  op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
54
55
  tflite_interpreter = tfl.Interpreter(
55
56
  model_content=bytes(tflite_model),
57
+ num_threads=num_threads,
56
58
  experimental_op_resolver_type=op_resolver,
57
59
  experimental_preserve_all_tensors=True,
58
60
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-quantizer-nightly
3
- Version: 0.0.1.dev20241210
3
+ Version: 0.0.1.dev20241218
4
4
  Summary: A quantizer for advanced developers to quantize converted AI Edge models.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
6
6
  Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
@@ -2,18 +2,18 @@ ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas
2
2
  ai_edge_quantizer/algorithm_manager.py,sha256=9nd4Txfl2z-14rFHmL7vqSfnkAQeagCRKyCIQ7ru0_Y,5981
3
3
  ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
4
4
  ai_edge_quantizer/algorithm_manager_api_test.py,sha256=tL_ozYFTsOPX8qGcti0KTz37nVsCxf0SSG5C45SyT-g,7319
5
- ai_edge_quantizer/calibrator.py,sha256=BSu0DPzVhAgFFA0JsHZtawPFlr0YPirRxItuA9SerNg,11007
5
+ ai_edge_quantizer/calibrator.py,sha256=0zAWrSpl_08u6BNLVgG_TQeNcT16wJ-oLeQgznziGoo,11079
6
6
  ai_edge_quantizer/calibrator_test.py,sha256=5DGvKWRRjjU3L5wZoN56AyOVljmxOitwhuBUp6GL_bU,11354
7
7
  ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
8
8
  ai_edge_quantizer/default_policy.py,sha256=TQ9yY8jtrSpMsTBsTyKW6TY-voGH_psvwGZoFglAbiA,9079
9
9
  ai_edge_quantizer/model_modifier.py,sha256=Z8EYtrz4zhCFpzd1zVwl2AetVE3BGBf5OvB2DbVQuds,5850
10
10
  ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
11
- ai_edge_quantizer/model_validator.py,sha256=QvlG1TewSBo9FMwzDYPFGqR4mOa_Xhn21wi2OFAvbCI,12593
11
+ ai_edge_quantizer/model_validator.py,sha256=oZk0b1qGczaEm5erJFm4SbwadDnl7DFhC0bXuxwVgps,12787
12
12
  ai_edge_quantizer/model_validator_test.py,sha256=ctvVmMHvnmFbkG4o8Jaa6kXXRrGHzhYpNylgLSmOboA,12951
13
13
  ai_edge_quantizer/params_generator.py,sha256=FvBub5yM2q98k7wNLgEyRerf8sVIETvGbrFcXFPUPdA,13523
14
14
  ai_edge_quantizer/params_generator_test.py,sha256=d9JwR-yxNJgg1SW-m8sFFPkIRdhgsDwMpVKsBQFL0gg,37658
15
15
  ai_edge_quantizer/qtyping.py,sha256=bue_WfK05QTkQcoyVVWeIxh8LRVGhHMWruXk3cgpFpw,14577
16
- ai_edge_quantizer/quantizer.py,sha256=OYfSo06JcoursXbJBRfHQbR2-Pa4sHnZB2n9od9OzEY,13039
16
+ ai_edge_quantizer/quantizer.py,sha256=Gny7WLuRibiIuDtcRn_g8RCD-zAm_fuDG7WmGq5dRx8,13238
17
17
  ai_edge_quantizer/quantizer_test.py,sha256=38oTMJwMmxwPDeqT3eaVbazjtuIUIzMQ3mJNKh_eNQY,20493
18
18
  ai_edge_quantizer/recipe.py,sha256=r5tJiUs-ihZFzeK_jP2sUIUgTqZsL5SWvbUokuIUPDo,2251
19
19
  ai_edge_quantizer/recipe_manager.py,sha256=qcGUD7e7BISKdsY9WH2rdaRR3acmzSA5qMezGNbzlpo,8931
@@ -52,12 +52,12 @@ ai_edge_quantizer/utils/calibration_utils_test.py,sha256=Z-AcdTieesWFKyKBb08ZXm4
52
52
  ai_edge_quantizer/utils/test_utils.py,sha256=95BDAdjE4Zvd6JZ90fG8FE3wKWE-Lu0ZIE3hQ1B6adI,3616
53
53
  ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=F6_AkCSv35FAhJX2qel8VTARhGOVwaeo7_mqRZygrpA,10126
54
54
  ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=AbyDxoM62k4ojD8gPdkWo--xe5hlX3t0kobQSA80kuk,7740
55
- ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=GzrsaL3fkOXN5iPRJv7lqhNISY6lnrBVTotWDHzI5m8,10344
55
+ ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=yta7b_VmhVZmntwHK27vqVnie3XRejN459P0uJHbpb8,10431
56
56
  ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=Op3JxtOqlrjzmYF18jnnstL1k9xiY9kKJ8S2vklKGkc,11327
57
57
  ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
58
58
  ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
59
- ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
60
- ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info/METADATA,sha256=o1xW7CHGdW5K9XN9eUik8kRTwfVhjiKMJijm4Sewl4M,1484
61
- ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
62
- ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
63
- ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info/RECORD,,
59
+ ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
60
+ ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/METADATA,sha256=9YqAIud-Y9td1FerL7kzgo_dyJh7qiBR0f6tEy_aHtg,1484
61
+ ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
62
+ ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
63
+ ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/RECORD,,