ai-edge-torch-nightly 0.5.0.dev20250520__py3-none-any.whl → 0.6.0.dev20250521__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -350,10 +350,10 @@ def _export_helper(
350
350
  )
351
351
 
352
352
  prefill_masks = None
353
- if flags.FLAGS.mask_as_input:
353
+ if export_config.mask_as_input:
354
354
  prefill_masks = _build_mask(
355
- flags.FLAGS.prefill_seq_lens,
356
- flags.FLAGS.kv_cache_max_len,
355
+ prefill_seq_lens,
356
+ config.kv_cache_max_len,
357
357
  config.causal_mask_value,
358
358
  )
359
359
  if not isinstance(prefill_masks, list):
@@ -424,7 +424,7 @@ def _export_helper(
424
424
  'input_pos': decode_input_pos,
425
425
  'kv_cache': decode_kv,
426
426
  }
427
- if flags.FLAGS.mask_as_input:
427
+ if export_config.mask_as_input:
428
428
  # Note that the decode mask is not a correct causal mask, but it is okay
429
429
  # for the conversion purpose because only the shape matters in conversion.
430
430
  # A correct causal mask of decode for a given token position of decode, it
@@ -433,7 +433,7 @@ def _export_helper(
433
433
  # torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
434
434
  #
435
435
  sample_kwargs['mask'] = _build_mask(
436
- 1, flags.FLAGS.kv_cache_max_len, config.causal_mask_value
436
+ 1, config.kv_cache_max_len, config.causal_mask_value
437
437
  )
438
438
  if lora is not None:
439
439
  sample_kwargs['lora'] = lora
@@ -43,6 +43,9 @@ class ExportConfig:
43
43
  kvcache_cls: type = kv_utils.KVCache
44
44
  # The batch size of the decode signature.
45
45
  decode_batch_size: int = 1
46
+ # If true, the mask will be passed in as input. Otherwise, mask will be
47
+ # built by the model internally.
48
+ mask_as_input: bool = False
46
49
 
47
50
 
48
51
  def get_from_flags() -> ExportConfig:
@@ -51,5 +54,7 @@ def get_from_flags() -> ExportConfig:
51
54
 
52
55
  if flags.FLAGS.transpose_kv_cache:
53
56
  export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
57
+ if flags.FLAGS.mask_as_input:
58
+ export_config.mask_as_input = flags.FLAGS.mask_as_input
54
59
 
55
60
  return export_config
@@ -19,8 +19,8 @@ import os
19
19
  from typing import Callable, Dict, List, Optional, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.layers import model_config
22
- import safetensors
23
22
  from safetensors import safe_open
23
+ from safetensors.torch import load_file
24
24
  import torch
25
25
 
26
26
 
@@ -47,7 +47,7 @@ def get_custom_loader(
47
47
 
48
48
  if checkpoint_format:
49
49
  if checkpoint_format == "safetensors":
50
- return safetensors.torch.load_file
50
+ return load_file
51
51
  if checkpoint_format == "pt":
52
52
  return lambda path: torch.load(path, weights_only=True)
53
53
  raise ValueError(f"Unsupported checkpoint format: {checkpoint_format}")
@@ -55,7 +55,7 @@ def get_custom_loader(
55
55
  if os.path.splitext(checkpoint_path)[1] in [".bin", ".pt", ".ckpt"]:
56
56
  return lambda path: torch.load(path, weights_only=True)
57
57
  if checkpoint_path.endswith(".safetensors"):
58
- return safetensors.torch.load_file
58
+ return load_file
59
59
  raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
60
60
 
61
61
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250520"
16
+ __version__ = "0.6.0.dev20250521"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250520
3
+ Version: 0.6.0.dev20250521
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=4yV1q9jK9Zr0i0SQM4PpfioywnIovZWfNODUFlxFS-I,706
5
+ ai_edge_torch/version.py,sha256=lmyCstaeVZjTAbBP4s9Z02tpX00ynyLPsymBY2tCe4A,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -205,9 +205,9 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
205
205
  ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
206
206
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
207
207
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
208
- ai_edge_torch/generative/utilities/converter.py,sha256=VRI960xo86g6lGLc_II3vDovFMa2DGIxnAZgE2GfSiM,15530
209
- ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
210
- ai_edge_torch/generative/utilities/loader.py,sha256=ODAdOnwQXscVPiUM6ssFWqDtD-Hl-h814X1EH1c0tuw,15969
208
+ ai_edge_torch/generative/utilities/converter.py,sha256=DuoPb8Uhbxa32uUvr6grV5lssmUJdx298QwYz8cG_1Y,15512
209
+ ai_edge_torch/generative/utilities/export_config.py,sha256=qjkEbjcvi2AgQikZS5qfgR95Z5z9pm07KX-RN5ibfNE,2280
210
+ ai_edge_torch/generative/utilities/loader.py,sha256=oGgEc2tHRsVqSN3mgvcngXQrpV0a7cwTpJ3LmMVnyF0,15954
211
211
  ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSydYN3ZgmkzPc6nW0AJrA81wI,6461
212
212
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
213
213
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
@@ -264,8 +264,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
264
264
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
265
265
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
266
266
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
267
- ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
268
- ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/METADATA,sha256=P4YDKSZOCPj-hx7bnU6EWLvigx-dpgHd_cIARQd4Fss,2074
269
- ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
270
- ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
271
- ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/RECORD,,
267
+ ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
268
+ ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/METADATA,sha256=_UC8q7Xe3xMUCwKKbF4CJ5hewK9PLIJ26ksKCAeWjik,2074
269
+ ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
270
+ ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
271
+ ai_edge_torch_nightly-0.6.0.dev20250521.dist-info/RECORD,,