ai-edge-quantizer-nightly 0.0.1.dev20241210__py3-none-any.whl → 0.0.1.dev20241218__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_quantizer/calibrator.py +2 -1
- ai_edge_quantizer/model_validator.py +16 -9
- ai_edge_quantizer/quantizer.py +9 -4
- ai_edge_quantizer/utils/tfl_interpreter_utils.py +7 -5
- {ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info}/METADATA +1 -1
- {ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info}/RECORD +9 -9
- {ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info}/LICENSE +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info}/WHEEL +0 -0
- {ai_edge_quantizer_nightly-0.0.1.dev20241210.dist-info → ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info}/top_level.txt +0 -0
ai_edge_quantizer/calibrator.py
CHANGED
@@ -41,6 +41,7 @@ class Calibrator:
|
|
41
41
|
def __init__(
|
42
42
|
self,
|
43
43
|
float_tflite: Union[str, bytes],
|
44
|
+
num_threads: int = 16,
|
44
45
|
):
|
45
46
|
self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
|
46
47
|
|
@@ -50,7 +51,7 @@ class Calibrator:
|
|
50
51
|
" the model (e.g., if it is already quantized)."
|
51
52
|
)
|
52
53
|
self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
|
53
|
-
float_tflite
|
54
|
+
float_tflite, use_xnnpack=True, num_threads=num_threads
|
54
55
|
)
|
55
56
|
# Tensor name to tensor content.
|
56
57
|
self._tensor_content_map: dict[str, Any] = {}
|
@@ -207,7 +207,8 @@ def _setup_validation_interpreter(
|
|
207
207
|
model: bytes,
|
208
208
|
signature_input: dict[str, Any],
|
209
209
|
signature_key: Optional[str],
|
210
|
-
|
210
|
+
use_xnnpack: bool,
|
211
|
+
num_threads: int,
|
211
212
|
) -> tuple[Any, int, dict[str, Any]]:
|
212
213
|
"""Setup the interpreter for validation given a signature key.
|
213
214
|
|
@@ -216,15 +217,15 @@ def _setup_validation_interpreter(
|
|
216
217
|
signature_input: A dictionary of input tensor name and its value.
|
217
218
|
signature_key: The signature key to be used for invoking the models. If the
|
218
219
|
model only has one signature, this can be set to None.
|
219
|
-
|
220
|
-
|
220
|
+
use_xnnpack: Whether to use xnnpack for the interpreter.
|
221
|
+
num_threads: The number of threads to use for the interpreter.
|
221
222
|
|
222
223
|
Returns:
|
223
224
|
A tuple of interpreter, subgraph_index and tensor_name_to_details.
|
224
225
|
"""
|
225
226
|
|
226
227
|
interpreter = utils.create_tfl_interpreter(
|
227
|
-
tflite_model=model,
|
228
|
+
tflite_model=model, use_xnnpack=use_xnnpack, num_threads=num_threads
|
228
229
|
)
|
229
230
|
utils.invoke_interpreter_signature(
|
230
231
|
interpreter, signature_input, signature_key
|
@@ -247,7 +248,8 @@ def compare_model(
|
|
247
248
|
test_data: dict[str, Iterable[dict[str, Any]]],
|
248
249
|
error_metric: str,
|
249
250
|
compare_fn: Callable[[Any, Any], float],
|
250
|
-
|
251
|
+
use_xnnpack: bool = True,
|
252
|
+
num_threads: int = 16,
|
251
253
|
) -> ComparisonResult:
|
252
254
|
"""Compares model tensors over a model signature using the compare_fn.
|
253
255
|
|
@@ -266,8 +268,8 @@ def compare_model(
|
|
266
268
|
compare_fn: a comparison function to be used for calculating the statistics,
|
267
269
|
this function must be taking in two ArrayLike strcuture and output a
|
268
270
|
single float value.
|
269
|
-
|
270
|
-
|
271
|
+
use_xnnpack: Whether to use xnnpack for the interpreter.
|
272
|
+
num_threads: The number of threads to use for the interpreter.
|
271
273
|
|
272
274
|
Returns:
|
273
275
|
A ComparisonResult object.
|
@@ -282,12 +284,17 @@ def compare_model(
|
|
282
284
|
reference_model,
|
283
285
|
signature_input,
|
284
286
|
signature_key,
|
285
|
-
|
287
|
+
use_xnnpack,
|
288
|
+
num_threads,
|
286
289
|
)
|
287
290
|
)
|
288
291
|
targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
|
289
292
|
_setup_validation_interpreter(
|
290
|
-
target_model,
|
293
|
+
target_model,
|
294
|
+
signature_input,
|
295
|
+
signature_key,
|
296
|
+
use_xnnpack,
|
297
|
+
num_threads,
|
291
298
|
)
|
292
299
|
)
|
293
300
|
# Compare the cached tensor values.
|
ai_edge_quantizer/quantizer.py
CHANGED
@@ -216,6 +216,7 @@ class Quantizer:
|
|
216
216
|
self,
|
217
217
|
calibration_data: dict[str, Iterable[_SignatureInput]],
|
218
218
|
previous_calibration_result: Optional[_CalibrationResult] = None,
|
219
|
+
num_threads: int = 16,
|
219
220
|
) -> _CalibrationResult:
|
220
221
|
"""Calibrates the float model (required by static range quantization).
|
221
222
|
|
@@ -223,6 +224,7 @@ class Quantizer:
|
|
223
224
|
calibration_data: Calibration data for a model signature.
|
224
225
|
previous_calibration_result: Previous calibration result to be loaded. The
|
225
226
|
calibration process will be resumed from the previous result.
|
227
|
+
num_threads: Number of threads to use for calibration.
|
226
228
|
|
227
229
|
Returns:
|
228
230
|
Calibration result ({tensor_name: tensor QSVs (e.g.,min/max)}).
|
@@ -233,7 +235,7 @@ class Quantizer:
|
|
233
235
|
if not self.need_calibration:
|
234
236
|
return {}
|
235
237
|
|
236
|
-
calib = calibrator.Calibrator(self.float_model)
|
238
|
+
calib = calibrator.Calibrator(self.float_model, num_threads=num_threads)
|
237
239
|
if previous_calibration_result is not None:
|
238
240
|
calib.load_model_qsvs(previous_calibration_result)
|
239
241
|
calib.calibrate(calibration_data, self._recipe_manager)
|
@@ -297,7 +299,8 @@ class Quantizer:
|
|
297
299
|
self,
|
298
300
|
test_data: Optional[dict[str, Iterable[_SignatureInput]]] = None,
|
299
301
|
error_metrics: str = 'mse',
|
300
|
-
|
302
|
+
use_xnnpack: bool = True,
|
303
|
+
num_threads: int = 16,
|
301
304
|
) -> model_validator.ComparisonResult:
|
302
305
|
"""Numerical validation of the quantized model for a model signature.
|
303
306
|
|
@@ -314,7 +317,8 @@ class Quantizer:
|
|
314
317
|
data that will be used for validation. If set to None, random normal
|
315
318
|
distributed data will be used for all signatures in the model.
|
316
319
|
error_metrics: Error metrics to be used for comparison.
|
317
|
-
|
320
|
+
use_xnnpack: Whether to use the xnnpack library for validation.
|
321
|
+
num_threads: Number of threads to use for validation.
|
318
322
|
|
319
323
|
Returns:
|
320
324
|
The comparison result.
|
@@ -330,7 +334,8 @@ class Quantizer:
|
|
330
334
|
test_data,
|
331
335
|
error_metrics,
|
332
336
|
validation_utils.get_validation_func(error_metrics),
|
333
|
-
|
337
|
+
use_xnnpack=use_xnnpack,
|
338
|
+
num_threads=num_threads,
|
334
339
|
)
|
335
340
|
|
336
341
|
def _get_quantization_params(
|
@@ -30,15 +30,16 @@ DEFAULT_SIGNATURE_KEY = "serving_default"
|
|
30
30
|
def create_tfl_interpreter(
|
31
31
|
tflite_model: Union[str, bytes],
|
32
32
|
allocate_tensors: bool = True,
|
33
|
-
|
33
|
+
use_xnnpack: bool = True,
|
34
|
+
num_threads: int = 16,
|
34
35
|
) -> tfl.Interpreter:
|
35
36
|
"""Creates a TFLite interpreter from a model file.
|
36
37
|
|
37
38
|
Args:
|
38
39
|
tflite_model: Model file path or bytes.
|
39
40
|
allocate_tensors: Whether to allocate tensors.
|
40
|
-
|
41
|
-
|
41
|
+
use_xnnpack: Whether to use the XNNPACK delegate for the interpreter.
|
42
|
+
num_threads: The number of threads to use for the interpreter.
|
42
43
|
|
43
44
|
Returns:
|
44
45
|
A TFLite interpreter.
|
@@ -47,12 +48,13 @@ def create_tfl_interpreter(
|
|
47
48
|
with gfile.GFile(tflite_model, "rb") as f:
|
48
49
|
tflite_model = f.read()
|
49
50
|
|
50
|
-
if
|
51
|
-
op_resolver = tfl.OpResolverType.
|
51
|
+
if use_xnnpack:
|
52
|
+
op_resolver = tfl.OpResolverType.BUILTIN
|
52
53
|
else:
|
53
54
|
op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
|
54
55
|
tflite_interpreter = tfl.Interpreter(
|
55
56
|
model_content=bytes(tflite_model),
|
57
|
+
num_threads=num_threads,
|
56
58
|
experimental_op_resolver_type=op_resolver,
|
57
59
|
experimental_preserve_all_tensors=True,
|
58
60
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-quantizer-nightly
|
3
|
-
Version: 0.0.1.
|
3
|
+
Version: 0.0.1.dev20241218
|
4
4
|
Summary: A quantizer for advanced developers to quantize converted AI Edge models.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,Quantization,LLMs,GenAI
|
@@ -2,18 +2,18 @@ ai_edge_quantizer/__init__.py,sha256=4pFSkukSwahYyzwqia0yPRyz8TnFQfGRthVJhYpMWas
|
|
2
2
|
ai_edge_quantizer/algorithm_manager.py,sha256=9nd4Txfl2z-14rFHmL7vqSfnkAQeagCRKyCIQ7ru0_Y,5981
|
3
3
|
ai_edge_quantizer/algorithm_manager_api.py,sha256=u903TG0s1uIDhJqfeJne3CFl8A93phZrwgV2-hwdcXU,9247
|
4
4
|
ai_edge_quantizer/algorithm_manager_api_test.py,sha256=tL_ozYFTsOPX8qGcti0KTz37nVsCxf0SSG5C45SyT-g,7319
|
5
|
-
ai_edge_quantizer/calibrator.py,sha256=
|
5
|
+
ai_edge_quantizer/calibrator.py,sha256=0zAWrSpl_08u6BNLVgG_TQeNcT16wJ-oLeQgznziGoo,11079
|
6
6
|
ai_edge_quantizer/calibrator_test.py,sha256=5DGvKWRRjjU3L5wZoN56AyOVljmxOitwhuBUp6GL_bU,11354
|
7
7
|
ai_edge_quantizer/conftest.py,sha256=SxCz-5LlRD_lQm4hQc4c6IGG7DS8d7IyEWY9gnscPN0,794
|
8
8
|
ai_edge_quantizer/default_policy.py,sha256=TQ9yY8jtrSpMsTBsTyKW6TY-voGH_psvwGZoFglAbiA,9079
|
9
9
|
ai_edge_quantizer/model_modifier.py,sha256=Z8EYtrz4zhCFpzd1zVwl2AetVE3BGBf5OvB2DbVQuds,5850
|
10
10
|
ai_edge_quantizer/model_modifier_test.py,sha256=cJd04SLOG-fQZZNZPcisoBLx3cLtWEwGqUBbLb-pif4,4751
|
11
|
-
ai_edge_quantizer/model_validator.py,sha256=
|
11
|
+
ai_edge_quantizer/model_validator.py,sha256=oZk0b1qGczaEm5erJFm4SbwadDnl7DFhC0bXuxwVgps,12787
|
12
12
|
ai_edge_quantizer/model_validator_test.py,sha256=ctvVmMHvnmFbkG4o8Jaa6kXXRrGHzhYpNylgLSmOboA,12951
|
13
13
|
ai_edge_quantizer/params_generator.py,sha256=FvBub5yM2q98k7wNLgEyRerf8sVIETvGbrFcXFPUPdA,13523
|
14
14
|
ai_edge_quantizer/params_generator_test.py,sha256=d9JwR-yxNJgg1SW-m8sFFPkIRdhgsDwMpVKsBQFL0gg,37658
|
15
15
|
ai_edge_quantizer/qtyping.py,sha256=bue_WfK05QTkQcoyVVWeIxh8LRVGhHMWruXk3cgpFpw,14577
|
16
|
-
ai_edge_quantizer/quantizer.py,sha256=
|
16
|
+
ai_edge_quantizer/quantizer.py,sha256=Gny7WLuRibiIuDtcRn_g8RCD-zAm_fuDG7WmGq5dRx8,13238
|
17
17
|
ai_edge_quantizer/quantizer_test.py,sha256=38oTMJwMmxwPDeqT3eaVbazjtuIUIzMQ3mJNKh_eNQY,20493
|
18
18
|
ai_edge_quantizer/recipe.py,sha256=r5tJiUs-ihZFzeK_jP2sUIUgTqZsL5SWvbUokuIUPDo,2251
|
19
19
|
ai_edge_quantizer/recipe_manager.py,sha256=qcGUD7e7BISKdsY9WH2rdaRR3acmzSA5qMezGNbzlpo,8931
|
@@ -52,12 +52,12 @@ ai_edge_quantizer/utils/calibration_utils_test.py,sha256=Z-AcdTieesWFKyKBb08ZXm4
|
|
52
52
|
ai_edge_quantizer/utils/test_utils.py,sha256=95BDAdjE4Zvd6JZ90fG8FE3wKWE-Lu0ZIE3hQ1B6adI,3616
|
53
53
|
ai_edge_quantizer/utils/tfl_flatbuffer_utils.py,sha256=F6_AkCSv35FAhJX2qel8VTARhGOVwaeo7_mqRZygrpA,10126
|
54
54
|
ai_edge_quantizer/utils/tfl_flatbuffer_utils_test.py,sha256=AbyDxoM62k4ojD8gPdkWo--xe5hlX3t0kobQSA80kuk,7740
|
55
|
-
ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=
|
55
|
+
ai_edge_quantizer/utils/tfl_interpreter_utils.py,sha256=yta7b_VmhVZmntwHK27vqVnie3XRejN459P0uJHbpb8,10431
|
56
56
|
ai_edge_quantizer/utils/tfl_interpreter_utils_test.py,sha256=Op3JxtOqlrjzmYF18jnnstL1k9xiY9kKJ8S2vklKGkc,11327
|
57
57
|
ai_edge_quantizer/utils/validation_utils.py,sha256=oYw33Sg547AqtGw-choPUJmp9SAKkV46J_ddqSsum2Q,3950
|
58
58
|
ai_edge_quantizer/utils/validation_utils_test.py,sha256=V_qNDikPD4OPB-siOLQCWNVWTAu87h2IgNYt7teFd-o,2934
|
59
|
-
ai_edge_quantizer_nightly-0.0.1.
|
60
|
-
ai_edge_quantizer_nightly-0.0.1.
|
61
|
-
ai_edge_quantizer_nightly-0.0.1.
|
62
|
-
ai_edge_quantizer_nightly-0.0.1.
|
63
|
-
ai_edge_quantizer_nightly-0.0.1.
|
59
|
+
ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
60
|
+
ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/METADATA,sha256=9YqAIud-Y9td1FerL7kzgo_dyJh7qiBR0f6tEy_aHtg,1484
|
61
|
+
ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
62
|
+
ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/top_level.txt,sha256=8QTfPnFXNVUhScFLaa-NWZMFWMn72M50DVPubpwWB1g,18
|
63
|
+
ai_edge_quantizer_nightly-0.0.1.dev20241218.dist-info/RECORD,,
|
File without changes
|
File without changes
|