ai-edge-torch-nightly 0.5.0.dev20250416__py3-none-any.whl → 0.5.0.dev20250418__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -1
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +1 -0
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +7 -0
- ai_edge_torch/generative/utilities/converter.py +28 -14
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250416.dist-info → ai_edge_torch_nightly-0.5.0.dev20250418.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250416.dist-info → ai_edge_torch_nightly-0.5.0.dev20250418.dist-info}/RECORD +10 -10
- {ai_edge_torch_nightly-0.5.0.dev20250416.dist-info → ai_edge_torch_nightly-0.5.0.dev20250418.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250416.dist-info → ai_edge_torch_nightly-0.5.0.dev20250418.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250416.dist-info → ai_edge_torch_nightly-0.5.0.dev20250418.dist-info}/top_level.txt +0 -0
@@ -154,7 +154,7 @@ class Gemma3MM(nn.Module):
|
|
154
154
|
def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
|
155
155
|
return Gemma3MMConfig(
|
156
156
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
157
|
-
decoder_config=decoder.
|
157
|
+
decoder_config=decoder.get_fake_decoder_config_1b(**kwargs),
|
158
158
|
image_token_id=127,
|
159
159
|
image_projection_scale=128**0.5,
|
160
160
|
image_projection_use_bias=False,
|
@@ -48,6 +48,7 @@ def main(_):
|
|
48
48
|
pixel_values_size=torch.Size(
|
49
49
|
[1, config.channels, config.image_size, config.image_size]
|
50
50
|
),
|
51
|
+
pixel_seq_len=(config.image_size // config.patch_size) ** 2,
|
51
52
|
quantize=flags.FLAGS.quantize,
|
52
53
|
config=pytorch_model.config.decoder_config,
|
53
54
|
export_config=ExportConfig(),
|
@@ -43,6 +43,9 @@ def main(_):
|
|
43
43
|
)
|
44
44
|
|
45
45
|
grid_thw = pytorch_model.image_encoder.get_grid_thw()
|
46
|
+
spatial_merge_size = (
|
47
|
+
pytorch_model.config.image_encoder_config.spatial_merge_size
|
48
|
+
)
|
46
49
|
converter.convert_to_tflite(
|
47
50
|
pytorch_model,
|
48
51
|
output_path=flags.FLAGS.output_path,
|
@@ -51,6 +54,10 @@ def main(_):
|
|
51
54
|
pixel_values_size=(
|
52
55
|
pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
|
53
56
|
),
|
57
|
+
pixel_seq_len=(
|
58
|
+
(grid_thw[0][1] // spatial_merge_size)
|
59
|
+
* (grid_thw[0][2] // spatial_merge_size)
|
60
|
+
),
|
54
61
|
quantize=flags.FLAGS.quantize,
|
55
62
|
config=pytorch_model.config.decoder_config,
|
56
63
|
export_config=ExportConfig(),
|
@@ -57,7 +57,7 @@ def define_conversion_flags(model_name: str):
|
|
57
57
|
)
|
58
58
|
flags.DEFINE_string(
|
59
59
|
'output_name_prefix',
|
60
|
-
|
60
|
+
model_name,
|
61
61
|
'The prefix of the output tflite model name.',
|
62
62
|
)
|
63
63
|
flags.DEFINE_multi_integer(
|
@@ -91,6 +91,7 @@ def convert_to_tflite(
|
|
91
91
|
output_name_prefix: str,
|
92
92
|
prefill_seq_len: Union[int, list[int]],
|
93
93
|
pixel_values_size: torch.Size = None,
|
94
|
+
pixel_seq_len: int = 0,
|
94
95
|
quantize: bool = True,
|
95
96
|
config: cfg.ModelConfig = None,
|
96
97
|
lora_ranks: Optional[list[int]] = None,
|
@@ -133,12 +134,18 @@ def convert_to_tflite(
|
|
133
134
|
use. If a list, the model will have multiple prefill signatures.
|
134
135
|
pixel_values_size (torch.Size, optional): The size of pixel values to pass
|
135
136
|
to the model. If None, the model is not expected to take pixel values.
|
137
|
+
pixel_seq_len (int, optional): The length of pixel tokens, or pixel
|
138
|
+
embeddings generated by the image encoder with pixel values. The actual
|
139
|
+
length of prefill_seq_len will be added by pixel_seq_len when pixel
|
140
|
+
values are passed.
|
136
141
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
137
142
|
to True.
|
138
143
|
config (cfg.ModelConfig, optional): The model config used to configure KV
|
139
144
|
cache. If None, it uses the config of the pytorch_model.
|
140
145
|
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
|
141
146
|
no LoRA signatures will be added.
|
147
|
+
export_config (ExportConfig, optional): The export configuration. If None,
|
148
|
+
it uses the default export configuration.
|
142
149
|
"""
|
143
150
|
# pylint: disable=protected-access
|
144
151
|
torch._dynamo.config.cache_size_limit = 64
|
@@ -173,6 +180,7 @@ def convert_to_tflite(
|
|
173
180
|
output_file,
|
174
181
|
prefill_seq_lens,
|
175
182
|
pixel_values_size,
|
183
|
+
pixel_seq_len,
|
176
184
|
quantize,
|
177
185
|
config,
|
178
186
|
loras,
|
@@ -185,6 +193,7 @@ def _export_helper(
|
|
185
193
|
output_file: str,
|
186
194
|
prefill_seq_lens: list[int],
|
187
195
|
pixel_values_size: torch.Size,
|
196
|
+
pixel_seq_len: int,
|
188
197
|
quantize: bool,
|
189
198
|
config: cfg.ModelConfig,
|
190
199
|
loras: list[None | lora_utils.LoRA],
|
@@ -197,11 +206,18 @@ def _export_helper(
|
|
197
206
|
prefill_tokens_list.append(torch.full((1, seq_len), 0, dtype=torch.int))
|
198
207
|
prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
|
199
208
|
|
200
|
-
prefill_pixel_values =
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
209
|
+
prefill_pixel_values = None
|
210
|
+
prefill_tokens_list_with_pixel = []
|
211
|
+
prefill_input_pos_list_with_pixel = []
|
212
|
+
if pixel_values_size is not None:
|
213
|
+
prefill_pixel_values = torch.full(pixel_values_size, 0, dtype=torch.float32)
|
214
|
+
for seq_len in prefill_seq_lens:
|
215
|
+
prefill_tokens_list_with_pixel.append(
|
216
|
+
torch.full((1, seq_len + pixel_seq_len), 0, dtype=torch.int)
|
217
|
+
)
|
218
|
+
prefill_input_pos_list_with_pixel.append(
|
219
|
+
torch.arange(0, seq_len + pixel_seq_len, dtype=torch.int)
|
220
|
+
)
|
205
221
|
|
206
222
|
if export_config.prefill_mask is None:
|
207
223
|
prefill_masks = None
|
@@ -238,13 +254,11 @@ def _export_helper(
|
|
238
254
|
for lora in loras:
|
239
255
|
for i in range(len(prefill_seq_lens)):
|
240
256
|
prefill_seq_len = prefill_seq_lens[i]
|
241
|
-
prefill_tokens = prefill_tokens_list[i]
|
242
|
-
prefill_input_pos = prefill_input_pos_list[i]
|
243
257
|
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
244
258
|
|
245
259
|
sample_kwargs = {
|
246
|
-
'tokens':
|
247
|
-
'input_pos':
|
260
|
+
'tokens': prefill_tokens_list[i],
|
261
|
+
'input_pos': prefill_input_pos_list[i],
|
248
262
|
'kv_cache': prefill_kv,
|
249
263
|
}
|
250
264
|
if prefill_masks is not None:
|
@@ -261,13 +275,13 @@ def _export_helper(
|
|
261
275
|
)
|
262
276
|
|
263
277
|
if prefill_pixel_values is not None:
|
278
|
+
sample_kwargs['tokens'] = prefill_tokens_list_with_pixel[i]
|
279
|
+
sample_kwargs['input_pos'] = prefill_input_pos_list_with_pixel[i]
|
280
|
+
sample_kwargs['pixel_values'] = prefill_pixel_values
|
264
281
|
converter.add_signature(
|
265
282
|
prefill_signature_name + '_pixel',
|
266
283
|
mod,
|
267
|
-
sample_kwargs=
|
268
|
-
**sample_kwargs,
|
269
|
-
'pixel_values': prefill_pixel_values,
|
270
|
-
},
|
284
|
+
sample_kwargs=sample_kwargs,
|
271
285
|
)
|
272
286
|
|
273
287
|
sample_kwargs = {
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250418
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=RijgVzFc62dFf_A0Pr5jwDo3wtofAoxUcfBieTnHvIw,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=GPDsXhfECjDzOut4vh_d9qWcyfpxobFMBTsC7MyJbM0,5557
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -68,7 +68,7 @@ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw
|
|
68
68
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
69
69
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
|
70
70
|
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=n6ZQfqNEHuOhY7Pu21bb8Eax8yn2Sx5osTKJKmhonXY,15659
|
71
|
-
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=
|
71
|
+
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
|
72
72
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
73
73
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
74
74
|
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=nEv0qQ0l6gSXKxP5mNwkd2lRGxpFfD4e7FNV3V76zhw,8915
|
@@ -84,7 +84,7 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=Hgp31zIQdJ
|
|
84
84
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
|
85
85
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
|
86
86
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
87
|
-
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=
|
87
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=7HHXkC-IIu7ieBvBI4RlXs_oITz7R8a6YVYQskAs_Uk,2023
|
88
88
|
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=G1dwtWp_v77AI3uyIY-8g6qRP2tRH3CIKjJTeYNqFPU,5511
|
89
89
|
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=Z-SKdb0dd8uWT1d-FRwFx5-tJEqpdrQwiIZnFRhOtVo,6060
|
90
90
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
|
@@ -108,7 +108,7 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=om3lXL1RnA87P
|
|
108
108
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
|
109
109
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
110
110
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
111
|
-
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=
|
111
|
+
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
|
112
112
|
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=7RFM25tDj_b0FkpSv8RUWir8K8v9p2jMtwZmP4VAUhw,4474
|
113
113
|
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
|
114
114
|
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=mfLFrT8NPEPh9CqlJYHwh-I2y6ST7hH_vEmbZYartHQ,7764
|
@@ -186,7 +186,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3Gy
|
|
186
186
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
187
187
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
188
188
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
189
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
189
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=LtBHjnslhL-uf4sDRoC8JIbbUD73g0QW3FiWsHUdV1g,10631
|
190
190
|
ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
|
191
191
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
192
192
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
@@ -245,8 +245,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
245
245
|
ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
|
246
246
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
247
247
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
-
ai_edge_torch_nightly-0.5.0.
|
252
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
+
ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/METADATA,sha256=cmJPI-Zen5YAjRWQPoVwvcAwVZJGz3Jz0OuNAHs5498,2051
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
252
|
+
ai_edge_torch_nightly-0.5.0.dev20250418.dist-info/RECORD,,
|
File without changes
|
File without changes
|