ai-edge-torch-nightly 0.3.0.dev20241222__py3-none-any.whl → 0.3.0.dev20241223__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
- ai_edge_torch/generative/examples/paligemma/decoder.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder2.py +2 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +2 -2
- ai_edge_torch/generative/examples/paligemma/verify.py +1 -1
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/RECORD +13 -13
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/top_level.txt +0 -0
@@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter
|
|
29
29
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
30
30
|
import torch
|
31
31
|
|
32
|
+
_VERSION = flags.DEFINE_enum(
|
33
|
+
'version',
|
34
|
+
'2',
|
35
|
+
['1', '2'],
|
36
|
+
'The version of PaliGemma model to verify.',
|
37
|
+
)
|
32
38
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
33
39
|
'checkpoint_path',
|
34
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
40
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
35
41
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
36
42
|
)
|
37
43
|
_TFLITE_PATH = flags.DEFINE_string(
|
@@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
63
69
|
|
64
70
|
def main(_):
|
65
71
|
pytorch_model = paligemma.build_model(
|
66
|
-
_CHECKPOINT_PATH.value,
|
72
|
+
_CHECKPOINT_PATH.value,
|
73
|
+
version=int(_VERSION.value),
|
74
|
+
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
67
75
|
)
|
68
76
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
69
|
-
output_filename = f'
|
77
|
+
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
70
78
|
converter.convert_to_tflite(
|
71
79
|
pytorch_model,
|
72
80
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
@@ -137,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
137
137
|
config.vocab_size = 128
|
138
138
|
config.num_layers = 2
|
139
139
|
config.max_seq_len = 2 * kv_cache_max_len
|
140
|
+
config.embedding_dim = 128
|
141
|
+
config.embedding_scale = 128**0.5
|
140
142
|
return config
|
141
143
|
|
142
144
|
|
@@ -160,6 +160,8 @@ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
160
160
|
config.vocab_size = 128
|
161
161
|
config.num_layers = 2
|
162
162
|
config.max_seq_len = 2 * kv_cache_max_len
|
163
|
+
config.embedding_dim = 128
|
164
|
+
config.embedding_scale = 128**0.5
|
163
165
|
return config
|
164
166
|
|
165
167
|
|
@@ -136,8 +136,8 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
|
136
136
|
return PaliGemmaConfig(
|
137
137
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
138
138
|
decoder_config=get_decoder_config(**kwargs),
|
139
|
-
image_token_id=
|
140
|
-
image_projection_scale=
|
139
|
+
image_token_id=127,
|
140
|
+
image_projection_scale=128**0.5,
|
141
141
|
image_projection_use_bias=True,
|
142
142
|
)
|
143
143
|
|
@@ -21,6 +21,8 @@ from ai_edge_torch.generative.examples.gemma import gemma1
|
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
22
|
from ai_edge_torch.generative.examples.llama import llama
|
23
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
24
|
+
from ai_edge_torch.generative.examples.paligemma import decoder
|
25
|
+
from ai_edge_torch.generative.examples.paligemma import decoder2
|
24
26
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
25
27
|
from ai_edge_torch.generative.examples.phi import phi2
|
26
28
|
from ai_edge_torch.generative.examples.phi import phi3
|
@@ -171,13 +173,9 @@ class TestModelConversion(googletest.TestCase):
|
|
171
173
|
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
172
174
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
173
175
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
)
|
178
|
-
def disabled_test_paligemma(self):
|
179
|
-
config = paligemma.get_fake_model_config()
|
180
|
-
pytorch_model = paligemma.PaliGemma(config).eval()
|
176
|
+
def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
|
177
|
+
config = paligemma.get_fake_model_config(decoder_config)
|
178
|
+
pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
|
181
179
|
|
182
180
|
image_embedding_config = config.image_encoder_config.image_embedding
|
183
181
|
num_patches = (
|
@@ -215,11 +213,32 @@ class TestModelConversion(googletest.TestCase):
|
|
215
213
|
kv,
|
216
214
|
pixel_values=pixel_values,
|
217
215
|
signature_name="prefill_pixel",
|
218
|
-
atol=
|
219
|
-
rtol=
|
216
|
+
atol=atol,
|
217
|
+
rtol=rtol,
|
220
218
|
)
|
221
219
|
)
|
222
220
|
|
221
|
+
@googletest.skipIf(
|
222
|
+
ai_edge_torch.config.in_oss,
|
223
|
+
reason="tests with custom ops are not supported in oss",
|
224
|
+
)
|
225
|
+
def disabled_test_paligemma1(self):
|
226
|
+
self._test_paligemma_model(
|
227
|
+
decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
|
228
|
+
)
|
229
|
+
|
230
|
+
@googletest.skipIf(
|
231
|
+
ai_edge_torch.config.in_oss,
|
232
|
+
reason="tests with custom ops are not supported in oss",
|
233
|
+
)
|
234
|
+
def disabled_test_paligemma2(self):
|
235
|
+
self._test_paligemma_model(
|
236
|
+
decoder2.Decoder2,
|
237
|
+
decoder2.get_fake_decoder2_config,
|
238
|
+
atol=1e-3,
|
239
|
+
rtol=1e-5,
|
240
|
+
)
|
241
|
+
|
223
242
|
@googletest.skipIf(
|
224
243
|
ai_edge_torch.config.in_oss,
|
225
244
|
reason="tests with custom ops are not supported in oss",
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241223
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=lt932yLxGWAwtUdRXBC2DJS7c4fJ4v36-tuSXCugwsc,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -63,15 +63,15 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=nji1oDgf6x
|
|
63
63
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
|
64
64
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
65
65
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
66
|
-
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=
|
67
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
68
|
-
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=
|
66
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=LFCcnkmOksySDa_5bLBzoGMijYdFVjXIMidUlyzAbNk,2996
|
67
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=amN96oBMTPolOFvGa47vG92AZ-BNLm8j0bBYd-IrMvI,5407
|
68
|
+
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=0V_CX0Pn5Fj_-koOGjc_Av2KMSAaVjAlD-G8P6FBGyY,6385
|
69
69
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
|
70
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
71
|
-
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=
|
70
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=FwGlFHl9zktGDxnoOpEtbS6NYN5RyzcOXH7lvNUCwEU,6257
|
71
|
+
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
|
72
72
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
73
73
|
ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
|
74
|
-
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
|
75
75
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
76
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=cD8rtwgYeGrXB9sYVV_D1AB8Up1AWNS-1XtrRlyzE5o,2296
|
77
77
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=G1i_ybDCTBaOD1OOCTk6jqOf__xYYZvhXcxY8MXhPHw,2294
|
@@ -142,7 +142,7 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1e
|
|
142
142
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
|
143
143
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
144
144
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
145
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
145
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NctnggTSFh0XEQbTu55diZ35rFD2QIARO-8PzLktRWg,12165
|
146
146
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
147
147
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
148
148
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
203
203
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
204
204
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
205
205
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/METADATA,sha256=EfUQ_LF_l-OEQVb9-qgcqC67LwOgGvW2p1DDD7QFqp0,1966
|
208
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/RECORD,,
|
File without changes
|
File without changes
|