ai-edge-torch-nightly 0.3.0.dev20241222__py3-none-any.whl → 0.3.0.dev20241223__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
- ai_edge_torch/generative/examples/paligemma/decoder.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder2.py +2 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +2 -2
- ai_edge_torch/generative/examples/paligemma/verify.py +1 -1
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/RECORD +13 -13
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/top_level.txt +0 -0
@@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter
|
|
29
29
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
30
30
|
import torch
|
31
31
|
|
32
|
+
_VERSION = flags.DEFINE_enum(
|
33
|
+
'version',
|
34
|
+
'2',
|
35
|
+
['1', '2'],
|
36
|
+
'The version of PaliGemma model to verify.',
|
37
|
+
)
|
32
38
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
33
39
|
'checkpoint_path',
|
34
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
40
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
35
41
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
36
42
|
)
|
37
43
|
_TFLITE_PATH = flags.DEFINE_string(
|
@@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
63
69
|
|
64
70
|
def main(_):
|
65
71
|
pytorch_model = paligemma.build_model(
|
66
|
-
_CHECKPOINT_PATH.value,
|
72
|
+
_CHECKPOINT_PATH.value,
|
73
|
+
version=int(_VERSION.value),
|
74
|
+
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
67
75
|
)
|
68
76
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
69
|
-
output_filename = f'
|
77
|
+
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
70
78
|
converter.convert_to_tflite(
|
71
79
|
pytorch_model,
|
72
80
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
@@ -137,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
137
137
|
config.vocab_size = 128
|
138
138
|
config.num_layers = 2
|
139
139
|
config.max_seq_len = 2 * kv_cache_max_len
|
140
|
+
config.embedding_dim = 128
|
141
|
+
config.embedding_scale = 128**0.5
|
140
142
|
return config
|
141
143
|
|
142
144
|
|
@@ -160,6 +160,8 @@ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
160
160
|
config.vocab_size = 128
|
161
161
|
config.num_layers = 2
|
162
162
|
config.max_seq_len = 2 * kv_cache_max_len
|
163
|
+
config.embedding_dim = 128
|
164
|
+
config.embedding_scale = 128**0.5
|
163
165
|
return config
|
164
166
|
|
165
167
|
|
@@ -136,8 +136,8 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
|
136
136
|
return PaliGemmaConfig(
|
137
137
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
138
138
|
decoder_config=get_decoder_config(**kwargs),
|
139
|
-
image_token_id=
|
140
|
-
image_projection_scale=
|
139
|
+
image_token_id=127,
|
140
|
+
image_projection_scale=128**0.5,
|
141
141
|
image_projection_use_bias=True,
|
142
142
|
)
|
143
143
|
|
@@ -21,6 +21,8 @@ from ai_edge_torch.generative.examples.gemma import gemma1
|
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
22
|
from ai_edge_torch.generative.examples.llama import llama
|
23
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
24
|
+
from ai_edge_torch.generative.examples.paligemma import decoder
|
25
|
+
from ai_edge_torch.generative.examples.paligemma import decoder2
|
24
26
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
25
27
|
from ai_edge_torch.generative.examples.phi import phi2
|
26
28
|
from ai_edge_torch.generative.examples.phi import phi3
|
@@ -171,13 +173,9 @@ class TestModelConversion(googletest.TestCase):
|
|
171
173
|
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
172
174
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
173
175
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
)
|
178
|
-
def disabled_test_paligemma(self):
|
179
|
-
config = paligemma.get_fake_model_config()
|
180
|
-
pytorch_model = paligemma.PaliGemma(config).eval()
|
176
|
+
def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
|
177
|
+
config = paligemma.get_fake_model_config(decoder_config)
|
178
|
+
pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
|
181
179
|
|
182
180
|
image_embedding_config = config.image_encoder_config.image_embedding
|
183
181
|
num_patches = (
|
@@ -215,11 +213,32 @@ class TestModelConversion(googletest.TestCase):
|
|
215
213
|
kv,
|
216
214
|
pixel_values=pixel_values,
|
217
215
|
signature_name="prefill_pixel",
|
218
|
-
atol=
|
219
|
-
rtol=
|
216
|
+
atol=atol,
|
217
|
+
rtol=rtol,
|
220
218
|
)
|
221
219
|
)
|
222
220
|
|
221
|
+
@googletest.skipIf(
|
222
|
+
ai_edge_torch.config.in_oss,
|
223
|
+
reason="tests with custom ops are not supported in oss",
|
224
|
+
)
|
225
|
+
def disabled_test_paligemma1(self):
|
226
|
+
self._test_paligemma_model(
|
227
|
+
decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
|
228
|
+
)
|
229
|
+
|
230
|
+
@googletest.skipIf(
|
231
|
+
ai_edge_torch.config.in_oss,
|
232
|
+
reason="tests with custom ops are not supported in oss",
|
233
|
+
)
|
234
|
+
def disabled_test_paligemma2(self):
|
235
|
+
self._test_paligemma_model(
|
236
|
+
decoder2.Decoder2,
|
237
|
+
decoder2.get_fake_decoder2_config,
|
238
|
+
atol=1e-3,
|
239
|
+
rtol=1e-5,
|
240
|
+
)
|
241
|
+
|
223
242
|
@googletest.skipIf(
|
224
243
|
ai_edge_torch.config.in_oss,
|
225
244
|
reason="tests with custom ops are not supported in oss",
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241223
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=lt932yLxGWAwtUdRXBC2DJS7c4fJ4v36-tuSXCugwsc,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -63,15 +63,15 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=nji1oDgf6x
|
|
63
63
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
|
64
64
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
65
65
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
66
|
-
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=
|
67
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
68
|
-
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=
|
66
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=LFCcnkmOksySDa_5bLBzoGMijYdFVjXIMidUlyzAbNk,2996
|
67
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=amN96oBMTPolOFvGa47vG92AZ-BNLm8j0bBYd-IrMvI,5407
|
68
|
+
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=0V_CX0Pn5Fj_-koOGjc_Av2KMSAaVjAlD-G8P6FBGyY,6385
|
69
69
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
|
70
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
71
|
-
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=
|
70
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=FwGlFHl9zktGDxnoOpEtbS6NYN5RyzcOXH7lvNUCwEU,6257
|
71
|
+
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
|
72
72
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
73
73
|
ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
|
74
|
-
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
|
75
75
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
76
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=cD8rtwgYeGrXB9sYVV_D1AB8Up1AWNS-1XtrRlyzE5o,2296
|
77
77
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=G1i_ybDCTBaOD1OOCTk6jqOf__xYYZvhXcxY8MXhPHw,2294
|
@@ -142,7 +142,7 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1e
|
|
142
142
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
|
143
143
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
144
144
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
145
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
145
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NctnggTSFh0XEQbTu55diZ35rFD2QIARO-8PzLktRWg,12165
|
146
146
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
147
147
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
148
148
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
203
203
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
204
204
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
205
205
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/METADATA,sha256=EfUQ_LF_l-OEQVb9-qgcqC67LwOgGvW2p1DDD7QFqp0,1966
|
208
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/RECORD,,
|
File without changes
|
File without changes
|