ai-edge-torch-nightly 0.7.0.dev20251007__py3-none-any.whl → 0.7.0.dev20251008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/utilities/converter.py +86 -16
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.7.0.dev20251008.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.7.0.dev20251008.dist-info}/RECORD +7 -7
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.7.0.dev20251008.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.7.0.dev20251008.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.7.0.dev20251008.dist-info}/top_level.txt +0 -0
@@ -143,9 +143,23 @@ def define_conversion_flags(
|
|
143
143
|
'`prefill_seq_lens` as the maximum of kv_cache size and prefill lengths '
|
144
144
|
'in the graph.',
|
145
145
|
)
|
146
|
+
flags.DEFINE_bool(
|
147
|
+
'export_gpu_dynamic_shape_verifications',
|
148
|
+
False,
|
149
|
+
'If true, the conversion script will export signatures used only for '
|
150
|
+
'verification of GPU dynamic shapes.',
|
151
|
+
)
|
146
152
|
return flags
|
147
153
|
|
148
154
|
|
155
|
+
# Context length for verifying GPU dynamic shapes.
|
156
|
+
_CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1280
|
157
|
+
# Long prefill length for verifying GPU dynamic shapes.
|
158
|
+
_LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1024
|
159
|
+
# Short prefill length for verifying GPU dynamic shapes.
|
160
|
+
_SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 64
|
161
|
+
|
162
|
+
|
149
163
|
def is_magic_number_(num: int) -> bool:
|
150
164
|
"""Returns true if the number is a magic number, i.e. prime number > 10."""
|
151
165
|
if num < 10:
|
@@ -263,6 +277,10 @@ def convert_to_tflite(
|
|
263
277
|
config: cfg.ModelConfig = None,
|
264
278
|
lora_ranks: Optional[list[int]] = None,
|
265
279
|
export_config: ExportConfig = None,
|
280
|
+
extra_model: torch.nn.Module = None,
|
281
|
+
extra_prefill_seq_lens: list[int] = None,
|
282
|
+
extra_kv_cache_max_len: int = 0,
|
283
|
+
extra_signature_prefix: str = '',
|
266
284
|
):
|
267
285
|
"""Converts a nn.Module model to multi-signature tflite model.
|
268
286
|
|
@@ -315,6 +333,15 @@ def convert_to_tflite(
|
|
315
333
|
no LoRA signatures will be added.
|
316
334
|
export_config (ExportConfig, optional): The export configuration. If None,
|
317
335
|
it uses the default export configuration.
|
336
|
+
extra_model (torch.nn.Module, optional): PyTorch model to export in
|
337
|
+
addition to the pytorch_model. This model can have different
|
338
|
+
prefill_seq_lens and kv_cache_max_len.
|
339
|
+
extra_prefill_seq_lens (list[int], optional): The prefill sequence
|
340
|
+
lengths for extra_model. Meaningful only when extra_model is not None.
|
341
|
+
extra_kv_cache_max_len (int, optional): The maximum size of KV cache
|
342
|
+
buffer for extra_model. Meaningful only when extra_model is not None.
|
343
|
+
extra_signature_prefix (str, optional): The prefix of the extra model
|
344
|
+
signatures. Meaningful only when extra_model is not None.
|
318
345
|
"""
|
319
346
|
# pylint: disable=protected-access
|
320
347
|
torch._dynamo.config.cache_size_limit = 64
|
@@ -353,32 +380,51 @@ def convert_to_tflite(
|
|
353
380
|
)
|
354
381
|
output_file = os.path.join(output_path, output_filename)
|
355
382
|
|
356
|
-
|
383
|
+
converter = converter_utils.Converter()
|
384
|
+
_add_signatures(
|
385
|
+
converter,
|
357
386
|
pytorch_model,
|
358
|
-
output_file,
|
359
387
|
prefill_seq_lens,
|
360
388
|
kv_cache_max_len,
|
361
389
|
pixel_values_size,
|
362
390
|
pixel_seq_len,
|
363
|
-
quantize,
|
364
391
|
config,
|
365
392
|
loras,
|
366
393
|
export_config,
|
367
394
|
)
|
395
|
+
|
396
|
+
if extra_model is not None and extra_prefill_seq_lens:
|
397
|
+
_add_signatures(
|
398
|
+
converter,
|
399
|
+
extra_model,
|
400
|
+
extra_prefill_seq_lens,
|
401
|
+
extra_kv_cache_max_len,
|
402
|
+
pixel_values_size,
|
403
|
+
pixel_seq_len,
|
404
|
+
config,
|
405
|
+
loras,
|
406
|
+
export_config,
|
407
|
+
signature_prefix=extra_signature_prefix,
|
408
|
+
)
|
409
|
+
|
410
|
+
edge_model = converter.convert(
|
411
|
+
quant_config=get_quant_recipe_from_flag(quantize, config),
|
412
|
+
)
|
413
|
+
edge_model.export(output_file)
|
368
414
|
return output_file
|
369
415
|
|
370
416
|
|
371
|
-
def
|
417
|
+
def _add_signatures(
|
418
|
+
converter: converter_utils.Converter,
|
372
419
|
pytorch_model: torch.nn.Module,
|
373
|
-
output_file: str,
|
374
420
|
prefill_seq_lens: list[int],
|
375
421
|
kv_cache_max_len: int,
|
376
422
|
pixel_values_size: torch.Size,
|
377
423
|
pixel_seq_len: int,
|
378
|
-
quantize: str,
|
379
424
|
config: cfg.ModelConfig,
|
380
425
|
loras: list[None | lora_utils.LoRA],
|
381
426
|
export_config: ExportConfig,
|
427
|
+
signature_prefix: str = '',
|
382
428
|
):
|
383
429
|
"""Helper function to export a model to tflite."""
|
384
430
|
prefill_tokens_list = []
|
@@ -423,17 +469,14 @@ def _export_helper(
|
|
423
469
|
kv_layout=export_config.kvcache_layout,
|
424
470
|
)
|
425
471
|
|
426
|
-
quant_config = get_quant_recipe_from_flag(quantize, config)
|
427
|
-
|
428
472
|
# For export, we create a module that captures any non-exportable,
|
429
473
|
# arugments, e.g. the generation config object.
|
430
474
|
mod = ExportableModule(pytorch_model, export_config=export_config).eval()
|
431
475
|
|
432
|
-
converter = converter_utils.Converter()
|
433
476
|
for lora in loras:
|
434
477
|
for i in range(len(prefill_seq_lens)):
|
435
478
|
prefill_seq_len = prefill_seq_lens[i]
|
436
|
-
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
479
|
+
prefill_signature_name = f'{signature_prefix}prefill_{prefill_seq_len}'
|
437
480
|
|
438
481
|
sample_kwargs = {
|
439
482
|
'tokens': prefill_tokens_list[i],
|
@@ -488,17 +531,15 @@ def _export_helper(
|
|
488
531
|
if lora is not None:
|
489
532
|
sample_kwargs['lora'] = lora
|
490
533
|
|
534
|
+
decode_signature_name = f'{signature_prefix}decode'
|
535
|
+
if lora is not None:
|
536
|
+
decode_signature_name += f'_lora_r{lora.get_rank()}'
|
491
537
|
converter.add_signature(
|
492
|
-
|
538
|
+
decode_signature_name,
|
493
539
|
mod,
|
494
540
|
sample_kwargs=sample_kwargs,
|
495
541
|
)
|
496
542
|
|
497
|
-
edge_model = converter.convert(
|
498
|
-
quant_config=quant_config,
|
499
|
-
)
|
500
|
-
edge_model.export(output_file)
|
501
|
-
|
502
543
|
|
503
544
|
def build_and_convert_to_tflite_from_flags(
|
504
545
|
model_builder: Callable[
|
@@ -521,11 +562,36 @@ def build_and_convert_to_tflite_from_flags(
|
|
521
562
|
get_mask_cache_size_from_flags(),
|
522
563
|
)
|
523
564
|
|
565
|
+
# Extra model for GPU dynamic shape verification if needed.
|
566
|
+
extra_model = None
|
567
|
+
extra_prefill_seq_lens = None
|
568
|
+
extra_kv_cache_max_len = 0
|
524
569
|
if flags.FLAGS.gpu_dynamic_shapes:
|
525
570
|
prefill_seq_lens = [
|
526
571
|
get_magic_number_for(l) for l in flags.FLAGS.prefill_seq_lens
|
527
572
|
]
|
528
573
|
kv_cache_max_len = get_magic_number_for(flags.FLAGS.kv_cache_max_len)
|
574
|
+
|
575
|
+
if flags.FLAGS.export_gpu_dynamic_shape_verifications:
|
576
|
+
extra_kv_cache_max_len = _CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS
|
577
|
+
if extra_kv_cache_max_len > flags.FLAGS.kv_cache_max_len:
|
578
|
+
extra_kv_cache_max_len = flags.FLAGS.kv_cache_max_len
|
579
|
+
extra_model = model_builder(
|
580
|
+
checkpoint_path,
|
581
|
+
loader.maybe_get_custom_loader(
|
582
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
583
|
+
),
|
584
|
+
extra_kv_cache_max_len,
|
585
|
+
)
|
586
|
+
extra_prefill_seq_lens = []
|
587
|
+
if extra_kv_cache_max_len > _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
|
588
|
+
extra_prefill_seq_lens.append(
|
589
|
+
_SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
|
590
|
+
)
|
591
|
+
if extra_kv_cache_max_len > _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
|
592
|
+
extra_prefill_seq_lens.append(
|
593
|
+
_LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
|
594
|
+
)
|
529
595
|
else:
|
530
596
|
prefill_seq_lens = flags.FLAGS.prefill_seq_lens
|
531
597
|
kv_cache_max_len = flags.FLAGS.kv_cache_max_len
|
@@ -539,6 +605,10 @@ def build_and_convert_to_tflite_from_flags(
|
|
539
605
|
quantize=flags.FLAGS.quantize,
|
540
606
|
lora_ranks=flags.FLAGS.lora_ranks,
|
541
607
|
export_config=export_config_lib.get_from_flags(),
|
608
|
+
extra_model=extra_model,
|
609
|
+
extra_prefill_seq_lens=extra_prefill_seq_lens,
|
610
|
+
extra_kv_cache_max_len=extra_kv_cache_max_len,
|
611
|
+
extra_signature_prefix='test_' if extra_model is not None else '',
|
542
612
|
)
|
543
613
|
|
544
614
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.7.0.
|
3
|
+
Version: 0.7.0.dev20251008
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=Jd2ZmbryaZTSc314Yj8KLDdZImrRPNAWsBVxJ18z8dk,806
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -208,7 +208,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NkEwrjO8vIcd
|
|
208
208
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
|
209
209
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
210
210
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
211
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
211
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=d8pehTq6EzEdVR8ioL2b1ECGTR4G1K1fczc9amu_Oyk,23106
|
212
212
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5B15nYyqf96kjjYlHfPctUfsIdsBsh1f8rxKitJpwKQ,2384
|
213
213
|
ai_edge_torch/generative/utilities/litertlm_builder.py,sha256=0cNuaqhc7cQcAa4NRalUXyoPQUQC9O3-aHAJEDV1Mps,4265
|
214
214
|
ai_edge_torch/generative/utilities/loader.py,sha256=drgKBmNibuc3PCdc0kU0pVcp2Nt1_mjLYh67RyXOn7U,15952
|
@@ -270,8 +270,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
270
270
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
271
271
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
272
272
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
273
|
-
ai_edge_torch_nightly-0.7.0.
|
274
|
-
ai_edge_torch_nightly-0.7.0.
|
275
|
-
ai_edge_torch_nightly-0.7.0.
|
276
|
-
ai_edge_torch_nightly-0.7.0.
|
277
|
-
ai_edge_torch_nightly-0.7.0.
|
273
|
+
ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
274
|
+
ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/METADATA,sha256=Xg4GCLMHL1FhKweyTtY2OAUdPGHFwpGlpZw-bUvs3FY,2074
|
275
|
+
ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
276
|
+
ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
277
|
+
ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/RECORD,,
|
File without changes
|
File without changes
|