ai-edge-torch-nightly 0.3.0.dev20250218__py3-none-any.whl → 0.4.0.dev20250220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,39 +29,48 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/amd-llama-135m'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
34
- '/tmp/',
35
- 'The tflite file path to export.',
36
- )
37
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
38
- 'prefill_seq_len',
39
- 1024,
40
- 'The maximum size of prefill input tensor.',
41
- )
42
32
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
43
33
  'kv_cache_max_len',
44
34
  1280,
45
35
  'The maximum size of KV cache buffer, including both prefill and decode.',
46
36
  )
37
+ _OUTPUT_PATH = flags.DEFINE_string(
38
+ 'output_path',
39
+ '/tmp/',
40
+ 'The path to export the tflite model.',
41
+ )
42
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
43
+ 'output_name_prefix',
44
+ 'deepseek',
45
+ 'The prefix of the output tflite model name.',
46
+ )
47
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
48
+ 'prefill_seq_lens',
49
+ (8, 64, 128, 256, 512, 1024),
50
+ 'List of the maximum sizes of prefill input tensors.',
51
+ )
47
52
  _QUANTIZE = flags.DEFINE_bool(
48
53
  'quantize',
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
52
-
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
53
62
 
54
63
  def main(_):
55
64
  pytorch_model = amd_llama_135m.build_model(
56
65
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
66
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'amd-llama-135m_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
67
  converter.convert_to_tflite(
61
68
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
63
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
69
+ output_path=_OUTPUT_PATH.value,
70
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
71
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
72
  quantize=_QUANTIZE.value,
73
+ lora_ranks=_LORA_RANKS.value,
65
74
  export_config=ExportConfig(),
66
75
  )
67
76
 
@@ -51,7 +51,7 @@ def main(_):
51
51
  )
52
52
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
53
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54
- reauthored_model = amd_llama_135m.build_model(reauthored_checkpoint)
54
+ reauthored_model = amd_llama_135m.build_model(str(reauthored_checkpoint))
55
55
 
56
56
  logging.info("Loading the tokenizer from: %s", checkpoint)
57
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
@@ -49,7 +49,7 @@ def main(_):
49
49
  )
50
50
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
51
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = deepseek.build_model(reauthored_checkpoint)
52
+ reauthored_model = deepseek.build_model(str(reauthored_checkpoint))
53
53
 
54
54
  logging.info("Loading the tokenizer from: %s", checkpoint)
55
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
@@ -50,7 +50,7 @@ def main(_):
50
50
  )
51
51
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
52
52
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
53
- reauthored_model = openelm.build_model(reauthored_checkpoint)
53
+ reauthored_model = openelm.build_model(str(reauthored_checkpoint))
54
54
 
55
55
  tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
56
56
  logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
@@ -70,7 +70,7 @@ def main(_):
70
70
  cached_config_file = transformers.utils.cached_file(
71
71
  checkpoint, transformers.utils.CONFIG_NAME
72
72
  )
73
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
73
+ reauthored_checkpoint = str(pathlib.Path(cached_config_file).parent)
74
74
  else:
75
75
  checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
76
76
  reauthored_checkpoint = checkpoint
@@ -67,7 +67,7 @@ def main(_):
67
67
  )
68
68
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
69
69
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
70
- reauthored_model = qwen_vl.build_model(reauthored_checkpoint)
70
+ reauthored_model = qwen_vl.build_model(str(reauthored_checkpoint))
71
71
 
72
72
  logging.info("Loading the processor from: %s", checkpoint)
73
73
  processor = transformers.AutoProcessor.from_pretrained(checkpoint)
@@ -51,7 +51,7 @@ def main(_):
51
51
  )
52
52
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
53
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54
- reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
54
+ reauthored_model = tiny_llama.build_model(str(reauthored_checkpoint))
55
55
 
56
56
  logging.info("Loading the tokenizer from: %s", checkpoint)
57
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250218"
16
+ __version__ = "0.4.0.dev20250220"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250218
3
+ Version: 0.4.0.dev20250220
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=88WhbJBrSg611BztpwT9yaLRJGmwzdTMLnhbe5bfiqs,706
5
+ ai_edge_torch/version.py,sha256=ilOEeUufy7PWKexgXSKLjrAG_xeF7RhNEx-8y-4eSYQ,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -47,12 +47,12 @@ ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf
47
47
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
48
48
  ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
49
49
  ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
50
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=Oqlg5ZoUuG2aU3067QaPpmEXWOdB8GEq7u_NWoBpoB4,2337
51
- ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
50
+ ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=K1cRwdxJWhZ4g97GwI_HwAwU5m5TTIEsjMuGtAlAen8,2563
51
+ ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
52
52
  ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
53
53
  ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=I5eA-XfFdHjYwDsLIjn23T2e-IgnSCQ129-5DOU8j44,2532
54
54
  ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=AOAJ7ltXwY5IbmcCP2nVHW9FmRwexzfNxnoDlR-sW9c,2885
55
- ai_edge_torch/generative/examples/deepseek/verify.py,sha256=sDYBhmE_CeZw5iLIQ7rJNGLjhcTyKUQGdg7_QQBh9WM,2398
55
+ ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
56
56
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
57
57
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
58
58
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
@@ -71,14 +71,14 @@ ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdC
71
71
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
72
72
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrnzBvYHkW0F9mgAabM19b-FDT6PT6j_-n2U,2528
73
73
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
74
- ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
74
+ ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
75
75
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
76
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=APQymtr3n2k6-e8wvn3kVrli0qiElduYIkHeahcoSA0,2743
77
77
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
78
78
  ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
79
79
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
80
80
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
81
- ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
81
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=zrCNz_QSQU6BbaFtx-J-MqxXWcNlsAlquaHpKodsyW4,5350
82
82
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
83
83
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
84
84
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
@@ -98,7 +98,7 @@ ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=MXK75-Upoq
98
98
  ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=0x4iDg2cBe3PFnjVce3nj7g2rjagGHcKqRCfbASNxA8,4402
99
99
  ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
100
100
  ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=rcYHkpO-NbF4F1Da7q2xNiTng9NHiLx59HyuOgQX5W0,7753
101
- ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=cKinMEDXauR5yKxtNTQk1RvwIHUG8-FOkmAie18sukY,5039
101
+ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=JUwHoC_zvcC3RC3wZ3e3e6fGmrhbgdoztjK8HGSUG8I,5044
102
102
  ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
103
103
  ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
104
104
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -132,7 +132,7 @@ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
132
132
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
133
133
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
134
134
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
135
- ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
135
+ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1tO2i0nUCqe-VkRgboA10VZ7KNg,2431
136
136
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
137
137
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
138
138
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
230
230
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
231
231
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
232
232
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
233
- ai_edge_torch_nightly-0.3.0.dev20250218.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
- ai_edge_torch_nightly-0.3.0.dev20250218.dist-info/METADATA,sha256=1ya0dobgUlh9c8RakbcYUWWzM3enRRN9GXiQzE7XM1A,1966
235
- ai_edge_torch_nightly-0.3.0.dev20250218.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
- ai_edge_torch_nightly-0.3.0.dev20250218.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
- ai_edge_torch_nightly-0.3.0.dev20250218.dist-info/RECORD,,
233
+ ai_edge_torch_nightly-0.4.0.dev20250220.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
+ ai_edge_torch_nightly-0.4.0.dev20250220.dist-info/METADATA,sha256=HZAorXR_re4XSg6gaflqg3YRkdfZcuq7Yxh9K1f1nc4,1966
235
+ ai_edge_torch_nightly-0.4.0.dev20250220.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
+ ai_edge_torch_nightly-0.4.0.dev20250220.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
+ ai_edge_torch_nightly-0.4.0.dev20250220.dist-info/RECORD,,