ai-edge-torch-nightly 0.5.0.dev20250413__py3-none-any.whl → 0.5.0.dev20250415__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/utilities/converter.py +1 -1
- ai_edge_torch/generative/utilities/verifier.py +33 -27
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250413.dist-info → ai_edge_torch_nightly-0.5.0.dev20250415.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250413.dist-info → ai_edge_torch_nightly-0.5.0.dev20250415.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.5.0.dev20250413.dist-info → ai_edge_torch_nightly-0.5.0.dev20250415.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250413.dist-info → ai_edge_torch_nightly-0.5.0.dev20250415.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250413.dist-info → ai_edge_torch_nightly-0.5.0.dev20250415.dist-info}/top_level.txt +0 -0
@@ -181,7 +181,7 @@ def verify_with_input_ids(
|
|
181
181
|
original_model: ModelWrapper,
|
182
182
|
reauthored_model: ReauthoredModelWrapper,
|
183
183
|
input_ids: List[int],
|
184
|
-
kv_cache_max_len: int =
|
184
|
+
kv_cache_max_len: int = 128,
|
185
185
|
rtol: float = 1e-05,
|
186
186
|
atol: float = 1e-05,
|
187
187
|
):
|
@@ -273,6 +273,8 @@ def verify_reauthored_model(
|
|
273
273
|
rtol: float = 1e-05,
|
274
274
|
atol: float = 1e-05,
|
275
275
|
continue_on_failure: bool = False,
|
276
|
+
verify_inputs: bool = True,
|
277
|
+
verify_prompts: bool = True,
|
276
278
|
) -> bool:
|
277
279
|
"""Verifies the reauthored model against the original model.
|
278
280
|
|
@@ -301,33 +303,37 @@ def verify_reauthored_model(
|
|
301
303
|
"""
|
302
304
|
failure_count = 0
|
303
305
|
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
|
306
|
+
if verify_inputs:
|
307
|
+
for input_ids in forward_input_ids:
|
308
|
+
logging.info(
|
309
|
+
"Verifying the reauthored model with input IDs: %s", input_ids
|
309
310
|
)
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
311
|
+
try:
|
312
|
+
verify_with_input_ids(
|
313
|
+
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
|
314
|
+
)
|
315
|
+
except AssertionError as e:
|
316
|
+
logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
|
317
|
+
failure_count += 1
|
318
|
+
if not continue_on_failure:
|
319
|
+
return False
|
320
|
+
else:
|
321
|
+
logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
|
322
|
+
|
323
|
+
if verify_prompts:
|
324
|
+
for prompts in generate_prompts:
|
325
|
+
logging.info("Verifying the reauthored model with prompts: %s", prompts)
|
326
|
+
try:
|
327
|
+
verify_model_with_prompts(
|
328
|
+
original_model, reauthored_model, tokenizer, prompts, max_new_tokens
|
329
|
+
)
|
330
|
+
except AssertionError as e:
|
331
|
+
logging.error("*** FAILED *** verify with prompts: %s", prompts)
|
332
|
+
failure_count += 1
|
333
|
+
if not continue_on_failure:
|
334
|
+
return False
|
335
|
+
else:
|
336
|
+
logging.info("*** PASSED *** verify with prompts: %s", prompts)
|
331
337
|
|
332
338
|
if failure_count == 0:
|
333
339
|
logging.info("*** PASSED *** verify_reauthored_model")
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250415
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=F_xrSGW8cTj5_LRehy1awbLrOMYd-sOQPyVRXV5mNNI,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=GPDsXhfECjDzOut4vh_d9qWcyfpxobFMBTsC7MyJbM0,5557
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -186,7 +186,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3Gy
|
|
186
186
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
187
187
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
188
188
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
189
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
189
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=swtz69oyMOxSaCEYST_Gzd5sjGZ1qOBAfd_0xl207Nk,9766
|
190
190
|
ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
|
191
191
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
192
192
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
@@ -195,7 +195,7 @@ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWt
|
|
195
195
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
196
196
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
197
197
|
ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
|
198
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
198
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=RSMQ8eda63VHM-5KmquKfogmTPyhGvGnqkoz9i4bppY,12270
|
199
199
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
200
200
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
|
201
201
|
ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
|
@@ -245,8 +245,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
245
245
|
ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
|
246
246
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
247
247
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
-
ai_edge_torch_nightly-0.5.0.
|
252
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
+
ai_edge_torch_nightly-0.5.0.dev20250415.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250415.dist-info/METADATA,sha256=hwDAXpg_N8uBvl134iezzps8jSfwgjs1Q3zrX0As4IM,2051
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250415.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250415.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
252
|
+
ai_edge_torch_nightly-0.5.0.dev20250415.dist-info/RECORD,,
|
File without changes
|
File without changes
|