ai-edge-torch-nightly 0.5.0.dev20250508__py3-none-any.whl → 0.5.0.dev20250509__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,7 +20,9 @@ from ai_edge_torch.generative.examples.gemma import gemma2
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
- flags = converter.define_conversion_flags("gemma2-2b")
23
+ flags = converter.define_conversion_flags(
24
+ "gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
25
+ )
24
26
 
25
27
 
26
28
  def main(_):
@@ -20,7 +20,9 @@ from ai_edge_torch.generative.examples.gemma3 import gemma3
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
- flags = converter.define_conversion_flags('gemma3-1b')
23
+ flags = converter.define_conversion_flags(
24
+ 'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
25
+ )
24
26
 
25
27
  _MODEL_SIZE = flags.DEFINE_string(
26
28
  'model_size',
@@ -261,7 +261,6 @@ class Decoder(nn.Module):
261
261
  pixel_mask = self.build_pixel_mask(image_indices)
262
262
  # RoPE parameters are the same for all blocks. Use the first layer.
263
263
  attn_config = self.config.block_config(0).attn_config
264
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
265
264
  # Different rotary base for global and local attention
266
265
  # based on attention pattern
267
266
  rope = [
@@ -305,7 +304,7 @@ class Decoder(nn.Module):
305
304
  if pixel_mask is None:
306
305
  mask = [
307
306
  self.get_local_global_attention_mask(
308
- mask,
307
+ mask[i] if isinstance(mask, list) else mask,
309
308
  self.config.block_config(i).attn_config.attn_type,
310
309
  input_pos,
311
310
  self.config.block_config(i).attn_config.sliding_window_size,
@@ -316,7 +315,7 @@ class Decoder(nn.Module):
316
315
  pixel_mask = pixel_mask.index_select(2, input_pos)
317
316
  mask = [
318
317
  self.compose_mask(
319
- mask[i],
318
+ mask[i] if isinstance(mask, list) else mask,
320
319
  pixel_mask,
321
320
  self.config.block_config(i).attn_config.attn_type,
322
321
  )
@@ -42,7 +42,11 @@ class ExportableModule(torch.nn.Module):
42
42
  return self.module(*export_args, **full_kwargs)
43
43
 
44
44
 
45
- def define_conversion_flags(model_name: str):
45
+ def define_conversion_flags(
46
+ model_name: str,
47
+ default_mask_as_input: bool = False,
48
+ default_transpose_kv_cache: bool = False,
49
+ ):
46
50
  """Defines common flags used for model conversion."""
47
51
 
48
52
  flags.DEFINE_string(
@@ -83,13 +87,13 @@ def define_conversion_flags(model_name: str):
83
87
  )
84
88
  flags.DEFINE_bool(
85
89
  'mask_as_input',
86
- False,
90
+ default_mask_as_input,
87
91
  'If true, the mask will be passed in as input. Otherwise, mask will be '
88
92
  'built by the model internally.',
89
93
  )
90
94
  flags.DEFINE_bool(
91
95
  'transpose_kv_cache',
92
- False,
96
+ default_transpose_kv_cache,
93
97
  'If true, the model will be converted with transposed KV cache.',
94
98
  )
95
99
  return flags
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250508"
16
+ __version__ = "0.5.0.dev20250509"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250508
3
+ Version: 0.5.0.dev20250509
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -22,6 +22,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
22
22
  Requires-Python: >=3.10
23
23
  Description-Content-Type: text/markdown
24
24
  License-File: LICENSE
25
+ Requires-Dist: absl-py
25
26
  Requires-Dist: numpy
26
27
  Requires-Dist: scipy
27
28
  Requires-Dist: safetensors
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=7lrbHHeWyBpqJdwFYYooOGJss4Rvg3UAdFSo9K0uzek,706
5
+ ai_edge_torch/version.py,sha256=bmd7zA2ryjEwRkVoQtjuMQTzdZ0gresufnMFls2eWNo,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -61,15 +61,15 @@ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt
61
61
  ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
62
62
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
63
63
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=RRilUl2Ui08R9gy1Ua0jnaXNCrIJJb-oztgP62G3mX4,1526
64
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=7IlF-4NEfZAzIfkOUHR-HeCSLSUGEu7wnO52UtERCa4,1527
64
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=9ozSw2-xuf5Wfh1HeLDTP3wJxxUZmrD3An1njJPMpdI,1594
65
65
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=6ImjTzJcq6JoKz2Z-z8pjv5BsRu5nUeEsTK3IPs3xgI,3521
66
66
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=JQLLiHNVBM9jOrZqUF0EmgAwtDD0yTRlmIbLaWM7qTg,11557
67
67
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
68
68
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
69
69
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
70
70
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
71
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
72
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=xGxeNKQvgyrENmUQMu0uKymL3qthvbdoxdMbAzwiLz0,15725
71
+ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=JLXXn2mFEBs4DlHH_O6hpEG9KInJqsCdWy3DrgUjT1c,1827
72
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=soxNVtN2fns9pQWw55ZND7dJ9RQLkFBtAteDSwZO9oY,15729
73
73
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
74
74
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
75
75
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
@@ -192,7 +192,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
192
192
  ai_edge_torch/generative/test/test_quantize.py,sha256=TG6vTF9yOZWe2wW7v8-hmuaQoODwJC1Z-2d5xv3zgfI,7389
193
193
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
194
194
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
195
- ai_edge_torch/generative/utilities/converter.py,sha256=d0JOWN5l2vbvt8RzFFiRoulkWiejyEZ21xKv5LdLIyc,11675
195
+ ai_edge_torch/generative/utilities/converter.py,sha256=u-FViLhHrbcO-GYfcGXTY28Kf-N682j-sao7LkdZbJ0,11806
196
196
  ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
197
197
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
198
198
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
@@ -251,8 +251,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
251
251
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
252
252
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
253
253
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
254
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
255
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/METADATA,sha256=GGDJl2Fya8gLr9RIfSLCmm1K1xA3qzBrrEOy1hwR2dQ,2051
256
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
257
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
258
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/RECORD,,
254
+ ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
255
+ ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/METADATA,sha256=IrEO6k64_TTmFzYdwfbTfwPaNX-Lhcg3pwAV8sGw4Gk,2074
256
+ ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
257
+ ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
258
+ ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/RECORD,,