ai-edge-torch-nightly 0.7.0.dev20251006__py3-none-any.whl → 0.7.0.dev20251008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -143,9 +143,23 @@ def define_conversion_flags(
143
143
  '`prefill_seq_lens` as the maximum of kv_cache size and prefill lengths '
144
144
  'in the graph.',
145
145
  )
146
+ flags.DEFINE_bool(
147
+ 'export_gpu_dynamic_shape_verifications',
148
+ False,
149
+ 'If true, the conversion script will export signatures used only for '
150
+ 'verification of GPU dynamic shapes.',
151
+ )
146
152
  return flags
147
153
 
148
154
 
155
+ # Context length for verifying GPU dynamic shapes.
156
+ _CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1280
157
+ # Long prefill length for verifying GPU dynamic shapes.
158
+ _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1024
159
+ # Short prefill length for verifying GPU dynamic shapes.
160
+ _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 64
161
+
162
+
149
163
  def is_magic_number_(num: int) -> bool:
150
164
  """Returns true if the number is a magic number, i.e. prime number > 10."""
151
165
  if num < 10:
@@ -263,6 +277,10 @@ def convert_to_tflite(
263
277
  config: cfg.ModelConfig = None,
264
278
  lora_ranks: Optional[list[int]] = None,
265
279
  export_config: ExportConfig = None,
280
+ extra_model: torch.nn.Module = None,
281
+ extra_prefill_seq_lens: list[int] = None,
282
+ extra_kv_cache_max_len: int = 0,
283
+ extra_signature_prefix: str = '',
266
284
  ):
267
285
  """Converts a nn.Module model to multi-signature tflite model.
268
286
 
@@ -315,6 +333,15 @@ def convert_to_tflite(
315
333
  no LoRA signatures will be added.
316
334
  export_config (ExportConfig, optional): The export configuration. If None,
317
335
  it uses the default export configuration.
336
+ extra_model (torch.nn.Module, optional): PyTorch model to export in
337
+ addition to the pytorch_model. This model can have different
338
+ prefill_seq_lens and kv_cache_max_len.
339
+ extra_prefill_seq_lens (list[int], optional): The prefill sequence
340
+ lengths for extra_model. Meaningful only when extra_model is not None.
341
+ extra_kv_cache_max_len (int, optional): The maximum size of KV cache
342
+ buffer for extra_model. Meaningful only when extra_model is not None.
343
+ extra_signature_prefix (str, optional): The prefix of the extra model
344
+ signatures. Meaningful only when extra_model is not None.
318
345
  """
319
346
  # pylint: disable=protected-access
320
347
  torch._dynamo.config.cache_size_limit = 64
@@ -353,32 +380,51 @@ def convert_to_tflite(
353
380
  )
354
381
  output_file = os.path.join(output_path, output_filename)
355
382
 
356
- _export_helper(
383
+ converter = converter_utils.Converter()
384
+ _add_signatures(
385
+ converter,
357
386
  pytorch_model,
358
- output_file,
359
387
  prefill_seq_lens,
360
388
  kv_cache_max_len,
361
389
  pixel_values_size,
362
390
  pixel_seq_len,
363
- quantize,
364
391
  config,
365
392
  loras,
366
393
  export_config,
367
394
  )
395
+
396
+ if extra_model is not None and extra_prefill_seq_lens:
397
+ _add_signatures(
398
+ converter,
399
+ extra_model,
400
+ extra_prefill_seq_lens,
401
+ extra_kv_cache_max_len,
402
+ pixel_values_size,
403
+ pixel_seq_len,
404
+ config,
405
+ loras,
406
+ export_config,
407
+ signature_prefix=extra_signature_prefix,
408
+ )
409
+
410
+ edge_model = converter.convert(
411
+ quant_config=get_quant_recipe_from_flag(quantize, config),
412
+ )
413
+ edge_model.export(output_file)
368
414
  return output_file
369
415
 
370
416
 
371
- def _export_helper(
417
+ def _add_signatures(
418
+ converter: converter_utils.Converter,
372
419
  pytorch_model: torch.nn.Module,
373
- output_file: str,
374
420
  prefill_seq_lens: list[int],
375
421
  kv_cache_max_len: int,
376
422
  pixel_values_size: torch.Size,
377
423
  pixel_seq_len: int,
378
- quantize: str,
379
424
  config: cfg.ModelConfig,
380
425
  loras: list[None | lora_utils.LoRA],
381
426
  export_config: ExportConfig,
427
+ signature_prefix: str = '',
382
428
  ):
383
429
  """Helper function to export a model to tflite."""
384
430
  prefill_tokens_list = []
@@ -423,17 +469,14 @@ def _export_helper(
423
469
  kv_layout=export_config.kvcache_layout,
424
470
  )
425
471
 
426
- quant_config = get_quant_recipe_from_flag(quantize, config)
427
-
428
472
  # For export, we create a module that captures any non-exportable,
429
473
  # arugments, e.g. the generation config object.
430
474
  mod = ExportableModule(pytorch_model, export_config=export_config).eval()
431
475
 
432
- converter = converter_utils.Converter()
433
476
  for lora in loras:
434
477
  for i in range(len(prefill_seq_lens)):
435
478
  prefill_seq_len = prefill_seq_lens[i]
436
- prefill_signature_name = f'prefill_{prefill_seq_len}'
479
+ prefill_signature_name = f'{signature_prefix}prefill_{prefill_seq_len}'
437
480
 
438
481
  sample_kwargs = {
439
482
  'tokens': prefill_tokens_list[i],
@@ -488,17 +531,15 @@ def _export_helper(
488
531
  if lora is not None:
489
532
  sample_kwargs['lora'] = lora
490
533
 
534
+ decode_signature_name = f'{signature_prefix}decode'
535
+ if lora is not None:
536
+ decode_signature_name += f'_lora_r{lora.get_rank()}'
491
537
  converter.add_signature(
492
- 'decode' if lora is None else f'decode_lora_r{lora.get_rank()}',
538
+ decode_signature_name,
493
539
  mod,
494
540
  sample_kwargs=sample_kwargs,
495
541
  )
496
542
 
497
- edge_model = converter.convert(
498
- quant_config=quant_config,
499
- )
500
- edge_model.export(output_file)
501
-
502
543
 
503
544
  def build_and_convert_to_tflite_from_flags(
504
545
  model_builder: Callable[
@@ -521,11 +562,36 @@ def build_and_convert_to_tflite_from_flags(
521
562
  get_mask_cache_size_from_flags(),
522
563
  )
523
564
 
565
+ # Extra model for GPU dynamic shape verification if needed.
566
+ extra_model = None
567
+ extra_prefill_seq_lens = None
568
+ extra_kv_cache_max_len = 0
524
569
  if flags.FLAGS.gpu_dynamic_shapes:
525
570
  prefill_seq_lens = [
526
571
  get_magic_number_for(l) for l in flags.FLAGS.prefill_seq_lens
527
572
  ]
528
573
  kv_cache_max_len = get_magic_number_for(flags.FLAGS.kv_cache_max_len)
574
+
575
+ if flags.FLAGS.export_gpu_dynamic_shape_verifications:
576
+ extra_kv_cache_max_len = _CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS
577
+ if extra_kv_cache_max_len > flags.FLAGS.kv_cache_max_len:
578
+ extra_kv_cache_max_len = flags.FLAGS.kv_cache_max_len
579
+ extra_model = model_builder(
580
+ checkpoint_path,
581
+ loader.maybe_get_custom_loader(
582
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
583
+ ),
584
+ extra_kv_cache_max_len,
585
+ )
586
+ extra_prefill_seq_lens = []
587
+ if extra_kv_cache_max_len > _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
588
+ extra_prefill_seq_lens.append(
589
+ _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
590
+ )
591
+ if extra_kv_cache_max_len > _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
592
+ extra_prefill_seq_lens.append(
593
+ _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
594
+ )
529
595
  else:
530
596
  prefill_seq_lens = flags.FLAGS.prefill_seq_lens
531
597
  kv_cache_max_len = flags.FLAGS.kv_cache_max_len
@@ -539,6 +605,10 @@ def build_and_convert_to_tflite_from_flags(
539
605
  quantize=flags.FLAGS.quantize,
540
606
  lora_ranks=flags.FLAGS.lora_ranks,
541
607
  export_config=export_config_lib.get_from_flags(),
608
+ extra_model=extra_model,
609
+ extra_prefill_seq_lens=extra_prefill_seq_lens,
610
+ extra_kv_cache_max_len=extra_kv_cache_max_len,
611
+ extra_signature_prefix='test_' if extra_model is not None else '',
542
612
  )
543
613
 
544
614
 
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.7.0.dev20251006"
18
+ __version__ = "0.7.0.dev20251008"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.7.0.dev20251006
3
+ Version: 0.7.0.dev20251008
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=NWXVGUPY6DtGiDOywYxBG91TAP5aYrDt8Ayzv2kwLhs,806
5
+ ai_edge_torch/version.py,sha256=Jd2ZmbryaZTSc314Yj8KLDdZImrRPNAWsBVxJ18z8dk,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -208,7 +208,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NkEwrjO8vIcd
208
208
  ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
209
209
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
210
210
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
211
- ai_edge_torch/generative/utilities/converter.py,sha256=Bt-48O1wf-7YcGVof53eVKI7wJwNvrc1Bv5zE3JuFdk,20093
211
+ ai_edge_torch/generative/utilities/converter.py,sha256=d8pehTq6EzEdVR8ioL2b1ECGTR4G1K1fczc9amu_Oyk,23106
212
212
  ai_edge_torch/generative/utilities/export_config.py,sha256=5B15nYyqf96kjjYlHfPctUfsIdsBsh1f8rxKitJpwKQ,2384
213
213
  ai_edge_torch/generative/utilities/litertlm_builder.py,sha256=0cNuaqhc7cQcAa4NRalUXyoPQUQC9O3-aHAJEDV1Mps,4265
214
214
  ai_edge_torch/generative/utilities/loader.py,sha256=drgKBmNibuc3PCdc0kU0pVcp2Nt1_mjLYh67RyXOn7U,15952
@@ -270,8 +270,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
270
270
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
271
271
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
272
272
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
273
- ai_edge_torch_nightly-0.7.0.dev20251006.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
274
- ai_edge_torch_nightly-0.7.0.dev20251006.dist-info/METADATA,sha256=9-6TbSzTZaYh69AVVx-R794sgQjC01l9C0jLlWmTNOg,2074
275
- ai_edge_torch_nightly-0.7.0.dev20251006.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
276
- ai_edge_torch_nightly-0.7.0.dev20251006.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
277
- ai_edge_torch_nightly-0.7.0.dev20251006.dist-info/RECORD,,
273
+ ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
274
+ ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/METADATA,sha256=Xg4GCLMHL1FhKweyTtY2OAUdPGHFwpGlpZw-bUvs3FY,2074
275
+ ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
276
+ ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
277
+ ai_edge_torch_nightly-0.7.0.dev20251008.dist-info/RECORD,,