ai-edge-torch-nightly 0.3.0.dev20241222__py3-none-any.whl → 0.3.0.dev20241223__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter
29
29
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
30
30
  import torch
31
31
 
32
+ _VERSION = flags.DEFINE_enum(
33
+ 'version',
34
+ '2',
35
+ ['1', '2'],
36
+ 'The version of PaliGemma model to verify.',
37
+ )
32
38
  _CHECKPOINT_PATH = flags.DEFINE_string(
33
39
  'checkpoint_path',
34
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
40
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
35
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
36
42
  )
37
43
  _TFLITE_PATH = flags.DEFINE_string(
@@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool(
63
69
 
64
70
  def main(_):
65
71
  pytorch_model = paligemma.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
72
+ _CHECKPOINT_PATH.value,
73
+ version=int(_VERSION.value),
74
+ kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
67
75
  )
68
76
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
69
- output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
77
+ output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
70
78
  converter.convert_to_tflite(
71
79
  pytorch_model,
72
80
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
@@ -137,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
137
137
  config.vocab_size = 128
138
138
  config.num_layers = 2
139
139
  config.max_seq_len = 2 * kv_cache_max_len
140
+ config.embedding_dim = 128
141
+ config.embedding_scale = 128**0.5
140
142
  return config
141
143
 
142
144
 
@@ -160,6 +160,8 @@ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
160
160
  config.vocab_size = 128
161
161
  config.num_layers = 2
162
162
  config.max_seq_len = 2 * kv_cache_max_len
163
+ config.embedding_dim = 128
164
+ config.embedding_scale = 128**0.5
163
165
  return config
164
166
 
165
167
 
@@ -136,8 +136,8 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
136
136
  return PaliGemmaConfig(
137
137
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
138
138
  decoder_config=get_decoder_config(**kwargs),
139
- image_token_id=257152,
140
- image_projection_scale=2048**0.5,
139
+ image_token_id=127,
140
+ image_projection_scale=128**0.5,
141
141
  image_projection_use_bias=True,
142
142
  )
143
143
 
@@ -30,7 +30,7 @@ import transformers
30
30
 
31
31
  _VERSION = flags.DEFINE_enum(
32
32
  "version",
33
- "1",
33
+ "2",
34
34
  ["1", "2"],
35
35
  "The version of PaliGemma model to verify.",
36
36
  )
@@ -28,7 +28,7 @@ import transformers
28
28
 
29
29
  _VERSION = flags.DEFINE_enum(
30
30
  "version",
31
- "1",
31
+ "2",
32
32
  ["1", "2"],
33
33
  "The version of PaliGemma vision model to verify.",
34
34
  )
@@ -21,6 +21,8 @@ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.llama import llama
23
23
  from ai_edge_torch.generative.examples.openelm import openelm
24
+ from ai_edge_torch.generative.examples.paligemma import decoder
25
+ from ai_edge_torch.generative.examples.paligemma import decoder2
24
26
  from ai_edge_torch.generative.examples.paligemma import paligemma
25
27
  from ai_edge_torch.generative.examples.phi import phi2
26
28
  from ai_edge_torch.generative.examples.phi import phi3
@@ -171,13 +173,9 @@ class TestModelConversion(googletest.TestCase):
171
173
  pytorch_model = amd_llama_135m.AmdLlama(config).eval()
172
174
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
173
175
 
174
- @googletest.skipIf(
175
- ai_edge_torch.config.in_oss,
176
- reason="tests with custom ops are not supported in oss",
177
- )
178
- def disabled_test_paligemma(self):
179
- config = paligemma.get_fake_model_config()
180
- pytorch_model = paligemma.PaliGemma(config).eval()
176
+ def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
177
+ config = paligemma.get_fake_model_config(decoder_config)
178
+ pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
181
179
 
182
180
  image_embedding_config = config.image_encoder_config.image_embedding
183
181
  num_patches = (
@@ -215,11 +213,32 @@ class TestModelConversion(googletest.TestCase):
215
213
  kv,
216
214
  pixel_values=pixel_values,
217
215
  signature_name="prefill_pixel",
218
- atol=1e-3,
219
- rtol=1e-5,
216
+ atol=atol,
217
+ rtol=rtol,
220
218
  )
221
219
  )
222
220
 
221
+ @googletest.skipIf(
222
+ ai_edge_torch.config.in_oss,
223
+ reason="tests with custom ops are not supported in oss",
224
+ )
225
+ def disabled_test_paligemma1(self):
226
+ self._test_paligemma_model(
227
+ decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
228
+ )
229
+
230
+ @googletest.skipIf(
231
+ ai_edge_torch.config.in_oss,
232
+ reason="tests with custom ops are not supported in oss",
233
+ )
234
+ def disabled_test_paligemma2(self):
235
+ self._test_paligemma_model(
236
+ decoder2.Decoder2,
237
+ decoder2.get_fake_decoder2_config,
238
+ atol=1e-3,
239
+ rtol=1e-5,
240
+ )
241
+
223
242
  @googletest.skipIf(
224
243
  ai_edge_torch.config.in_oss,
225
244
  reason="tests with custom ops are not supported in oss",
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241222"
16
+ __version__ = "0.3.0.dev20241223"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241222
3
+ Version: 0.3.0.dev20241223
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=PKEPravHVUIDugudfMDzqU57wXbpvrsY94puBM6FS-c,706
6
+ ai_edge_torch/version.py,sha256=lt932yLxGWAwtUdRXBC2DJS7c4fJ4v36-tuSXCugwsc,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -63,15 +63,15 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=nji1oDgf6x
63
63
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=DDVFHGqRbJgnLT4XJRYJ-MAp2-xPnI4fAUGSYVNMprc,5342
68
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=EjNbZwXM_T_0FXgHUAtLupihPsNlhPWeOop3IJ10Wzg,6320
66
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=LFCcnkmOksySDa_5bLBzoGMijYdFVjXIMidUlyzAbNk,2996
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=amN96oBMTPolOFvGa47vG92AZ-BNLm8j0bBYd-IrMvI,5407
68
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=0V_CX0Pn5Fj_-koOGjc_Av2KMSAaVjAlD-G8P6FBGyY,6385
69
69
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
70
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDBFu_uVzdARH06BU6xRerVdjahSCm39nQcYigJVoHE,6261
71
- ai_edge_torch/generative/examples/paligemma/verify.py,sha256=__RUyh0L5Td2jbm1xGnSldbfpKHtxyXAh2h06KVGxLA,5622
70
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=FwGlFHl9zktGDxnoOpEtbS6NYN5RyzcOXH7lvNUCwEU,6257
71
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
72
72
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
73
73
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
74
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=qaROQSjgs0DtVOX4KS5kPmlDrBFn0yJr83_kWIN8NzM,3540
74
+ ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
75
75
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
76
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=cD8rtwgYeGrXB9sYVV_D1AB8Up1AWNS-1XtrRlyzE5o,2296
77
77
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=G1i_ybDCTBaOD1OOCTk6jqOf__xYYZvhXcxY8MXhPHw,2294
@@ -142,7 +142,7 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1e
142
142
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
143
143
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
144
144
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
145
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mVuax3MPRmuNjnDRKXqtc9YmswCy7MnhD1CHADK-3nk,11501
145
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NctnggTSFh0XEQbTu55diZ35rFD2QIARO-8PzLktRWg,12165
146
146
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
147
147
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
148
148
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
203
203
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
204
204
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
205
205
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
206
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/METADATA,sha256=0-7zfD8burp8x7iTlCrOe2JO8BZuV1zcRmMrcsGFjVk,1966
208
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/RECORD,,
206
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/METADATA,sha256=EfUQ_LF_l-OEQVb9-qgcqC67LwOgGvW2p1DDD7QFqp0,1966
208
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/RECORD,,