ai-edge-torch-nightly 0.5.0.dev20250508__py3-none-any.whl → 0.5.0.dev20250509__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +3 -1
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +3 -1
- ai_edge_torch/generative/examples/gemma3/decoder.py +2 -3
- ai_edge_torch/generative/utilities/converter.py +7 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250509.dist-info}/METADATA +2 -1
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250509.dist-info}/RECORD +10 -10
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250509.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250509.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250509.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,9 @@ from ai_edge_torch.generative.examples.gemma import gemma2
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
|
-
flags = converter.define_conversion_flags(
|
23
|
+
flags = converter.define_conversion_flags(
|
24
|
+
"gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
|
25
|
+
)
|
24
26
|
|
25
27
|
|
26
28
|
def main(_):
|
@@ -20,7 +20,9 @@ from ai_edge_torch.generative.examples.gemma3 import gemma3
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
|
-
flags = converter.define_conversion_flags(
|
23
|
+
flags = converter.define_conversion_flags(
|
24
|
+
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
|
25
|
+
)
|
24
26
|
|
25
27
|
_MODEL_SIZE = flags.DEFINE_string(
|
26
28
|
'model_size',
|
@@ -261,7 +261,6 @@ class Decoder(nn.Module):
|
|
261
261
|
pixel_mask = self.build_pixel_mask(image_indices)
|
262
262
|
# RoPE parameters are the same for all blocks. Use the first layer.
|
263
263
|
attn_config = self.config.block_config(0).attn_config
|
264
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
265
264
|
# Different rotary base for global and local attention
|
266
265
|
# based on attention pattern
|
267
266
|
rope = [
|
@@ -305,7 +304,7 @@ class Decoder(nn.Module):
|
|
305
304
|
if pixel_mask is None:
|
306
305
|
mask = [
|
307
306
|
self.get_local_global_attention_mask(
|
308
|
-
mask,
|
307
|
+
mask[i] if isinstance(mask, list) else mask,
|
309
308
|
self.config.block_config(i).attn_config.attn_type,
|
310
309
|
input_pos,
|
311
310
|
self.config.block_config(i).attn_config.sliding_window_size,
|
@@ -316,7 +315,7 @@ class Decoder(nn.Module):
|
|
316
315
|
pixel_mask = pixel_mask.index_select(2, input_pos)
|
317
316
|
mask = [
|
318
317
|
self.compose_mask(
|
319
|
-
mask[i],
|
318
|
+
mask[i] if isinstance(mask, list) else mask,
|
320
319
|
pixel_mask,
|
321
320
|
self.config.block_config(i).attn_config.attn_type,
|
322
321
|
)
|
@@ -42,7 +42,11 @@ class ExportableModule(torch.nn.Module):
|
|
42
42
|
return self.module(*export_args, **full_kwargs)
|
43
43
|
|
44
44
|
|
45
|
-
def define_conversion_flags(
|
45
|
+
def define_conversion_flags(
|
46
|
+
model_name: str,
|
47
|
+
default_mask_as_input: bool = False,
|
48
|
+
default_transpose_kv_cache: bool = False,
|
49
|
+
):
|
46
50
|
"""Defines common flags used for model conversion."""
|
47
51
|
|
48
52
|
flags.DEFINE_string(
|
@@ -83,13 +87,13 @@ def define_conversion_flags(model_name: str):
|
|
83
87
|
)
|
84
88
|
flags.DEFINE_bool(
|
85
89
|
'mask_as_input',
|
86
|
-
|
90
|
+
default_mask_as_input,
|
87
91
|
'If true, the mask will be passed in as input. Otherwise, mask will be '
|
88
92
|
'built by the model internally.',
|
89
93
|
)
|
90
94
|
flags.DEFINE_bool(
|
91
95
|
'transpose_kv_cache',
|
92
|
-
|
96
|
+
default_transpose_kv_cache,
|
93
97
|
'If true, the model will be converted with transposed KV cache.',
|
94
98
|
)
|
95
99
|
return flags
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250509
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -22,6 +22,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
22
22
|
Requires-Python: >=3.10
|
23
23
|
Description-Content-Type: text/markdown
|
24
24
|
License-File: LICENSE
|
25
|
+
Requires-Dist: absl-py
|
25
26
|
Requires-Dist: numpy
|
26
27
|
Requires-Dist: scipy
|
27
28
|
Requires-Dist: safetensors
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=bmd7zA2ryjEwRkVoQtjuMQTzdZ0gresufnMFls2eWNo,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -61,15 +61,15 @@ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt
|
|
61
61
|
ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
|
62
62
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
63
63
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=RRilUl2Ui08R9gy1Ua0jnaXNCrIJJb-oztgP62G3mX4,1526
|
64
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
64
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=9ozSw2-xuf5Wfh1HeLDTP3wJxxUZmrD3An1njJPMpdI,1594
|
65
65
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=6ImjTzJcq6JoKz2Z-z8pjv5BsRu5nUeEsTK3IPs3xgI,3521
|
66
66
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=JQLLiHNVBM9jOrZqUF0EmgAwtDD0yTRlmIbLaWM7qTg,11557
|
67
67
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
68
68
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
|
69
69
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
|
70
70
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
71
|
-
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=
|
72
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
71
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=JLXXn2mFEBs4DlHH_O6hpEG9KInJqsCdWy3DrgUjT1c,1827
|
72
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=soxNVtN2fns9pQWw55ZND7dJ9RQLkFBtAteDSwZO9oY,15729
|
73
73
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
|
74
74
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
75
75
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
@@ -192,7 +192,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
|
|
192
192
|
ai_edge_torch/generative/test/test_quantize.py,sha256=TG6vTF9yOZWe2wW7v8-hmuaQoODwJC1Z-2d5xv3zgfI,7389
|
193
193
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
194
194
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
195
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
195
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=u-FViLhHrbcO-GYfcGXTY28Kf-N682j-sao7LkdZbJ0,11806
|
196
196
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
|
197
197
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
198
198
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
@@ -251,8 +251,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
251
251
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
252
252
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
253
253
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
254
|
-
ai_edge_torch_nightly-0.5.0.
|
255
|
-
ai_edge_torch_nightly-0.5.0.
|
256
|
-
ai_edge_torch_nightly-0.5.0.
|
257
|
-
ai_edge_torch_nightly-0.5.0.
|
258
|
-
ai_edge_torch_nightly-0.5.0.
|
254
|
+
ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
255
|
+
ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/METADATA,sha256=IrEO6k64_TTmFzYdwfbTfwPaNX-Lhcg3pwAV8sGw4Gk,2074
|
256
|
+
ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
257
|
+
ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
258
|
+
ai_edge_torch_nightly-0.5.0.dev20250509.dist-info/RECORD,,
|
File without changes
|
File without changes
|