ai-edge-torch-nightly 0.3.0.dev20241222__py3-none-any.whl → 0.3.0.dev20241223__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter
29
29
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
30
30
  import torch
31
31
 
32
+ _VERSION = flags.DEFINE_enum(
33
+ 'version',
34
+ '2',
35
+ ['1', '2'],
36
+ 'The version of PaliGemma model to verify.',
37
+ )
32
38
  _CHECKPOINT_PATH = flags.DEFINE_string(
33
39
  'checkpoint_path',
34
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
40
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
35
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
36
42
  )
37
43
  _TFLITE_PATH = flags.DEFINE_string(
@@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool(
63
69
 
64
70
  def main(_):
65
71
  pytorch_model = paligemma.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
72
+ _CHECKPOINT_PATH.value,
73
+ version=int(_VERSION.value),
74
+ kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
67
75
  )
68
76
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
69
- output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
77
+ output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
70
78
  converter.convert_to_tflite(
71
79
  pytorch_model,
72
80
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
@@ -137,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
137
137
  config.vocab_size = 128
138
138
  config.num_layers = 2
139
139
  config.max_seq_len = 2 * kv_cache_max_len
140
+ config.embedding_dim = 128
141
+ config.embedding_scale = 128**0.5
140
142
  return config
141
143
 
142
144
 
@@ -160,6 +160,8 @@ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
160
160
  config.vocab_size = 128
161
161
  config.num_layers = 2
162
162
  config.max_seq_len = 2 * kv_cache_max_len
163
+ config.embedding_dim = 128
164
+ config.embedding_scale = 128**0.5
163
165
  return config
164
166
 
165
167
 
@@ -136,8 +136,8 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
136
136
  return PaliGemmaConfig(
137
137
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
138
138
  decoder_config=get_decoder_config(**kwargs),
139
- image_token_id=257152,
140
- image_projection_scale=2048**0.5,
139
+ image_token_id=127,
140
+ image_projection_scale=128**0.5,
141
141
  image_projection_use_bias=True,
142
142
  )
143
143
 
@@ -30,7 +30,7 @@ import transformers
30
30
 
31
31
  _VERSION = flags.DEFINE_enum(
32
32
  "version",
33
- "1",
33
+ "2",
34
34
  ["1", "2"],
35
35
  "The version of PaliGemma model to verify.",
36
36
  )
@@ -28,7 +28,7 @@ import transformers
28
28
 
29
29
  _VERSION = flags.DEFINE_enum(
30
30
  "version",
31
- "1",
31
+ "2",
32
32
  ["1", "2"],
33
33
  "The version of PaliGemma vision model to verify.",
34
34
  )
@@ -21,6 +21,8 @@ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.llama import llama
23
23
  from ai_edge_torch.generative.examples.openelm import openelm
24
+ from ai_edge_torch.generative.examples.paligemma import decoder
25
+ from ai_edge_torch.generative.examples.paligemma import decoder2
24
26
  from ai_edge_torch.generative.examples.paligemma import paligemma
25
27
  from ai_edge_torch.generative.examples.phi import phi2
26
28
  from ai_edge_torch.generative.examples.phi import phi3
@@ -171,13 +173,9 @@ class TestModelConversion(googletest.TestCase):
171
173
  pytorch_model = amd_llama_135m.AmdLlama(config).eval()
172
174
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
173
175
 
174
- @googletest.skipIf(
175
- ai_edge_torch.config.in_oss,
176
- reason="tests with custom ops are not supported in oss",
177
- )
178
- def disabled_test_paligemma(self):
179
- config = paligemma.get_fake_model_config()
180
- pytorch_model = paligemma.PaliGemma(config).eval()
176
+ def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
177
+ config = paligemma.get_fake_model_config(decoder_config)
178
+ pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
181
179
 
182
180
  image_embedding_config = config.image_encoder_config.image_embedding
183
181
  num_patches = (
@@ -215,11 +213,32 @@ class TestModelConversion(googletest.TestCase):
215
213
  kv,
216
214
  pixel_values=pixel_values,
217
215
  signature_name="prefill_pixel",
218
- atol=1e-3,
219
- rtol=1e-5,
216
+ atol=atol,
217
+ rtol=rtol,
220
218
  )
221
219
  )
222
220
 
221
+ @googletest.skipIf(
222
+ ai_edge_torch.config.in_oss,
223
+ reason="tests with custom ops are not supported in oss",
224
+ )
225
+ def disabled_test_paligemma1(self):
226
+ self._test_paligemma_model(
227
+ decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
228
+ )
229
+
230
+ @googletest.skipIf(
231
+ ai_edge_torch.config.in_oss,
232
+ reason="tests with custom ops are not supported in oss",
233
+ )
234
+ def disabled_test_paligemma2(self):
235
+ self._test_paligemma_model(
236
+ decoder2.Decoder2,
237
+ decoder2.get_fake_decoder2_config,
238
+ atol=1e-3,
239
+ rtol=1e-5,
240
+ )
241
+
223
242
  @googletest.skipIf(
224
243
  ai_edge_torch.config.in_oss,
225
244
  reason="tests with custom ops are not supported in oss",
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241222"
16
+ __version__ = "0.3.0.dev20241223"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241222
3
+ Version: 0.3.0.dev20241223
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=PKEPravHVUIDugudfMDzqU57wXbpvrsY94puBM6FS-c,706
6
+ ai_edge_torch/version.py,sha256=lt932yLxGWAwtUdRXBC2DJS7c4fJ4v36-tuSXCugwsc,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -63,15 +63,15 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=nji1oDgf6x
63
63
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=DDVFHGqRbJgnLT4XJRYJ-MAp2-xPnI4fAUGSYVNMprc,5342
68
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=EjNbZwXM_T_0FXgHUAtLupihPsNlhPWeOop3IJ10Wzg,6320
66
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=LFCcnkmOksySDa_5bLBzoGMijYdFVjXIMidUlyzAbNk,2996
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=amN96oBMTPolOFvGa47vG92AZ-BNLm8j0bBYd-IrMvI,5407
68
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=0V_CX0Pn5Fj_-koOGjc_Av2KMSAaVjAlD-G8P6FBGyY,6385
69
69
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
70
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDBFu_uVzdARH06BU6xRerVdjahSCm39nQcYigJVoHE,6261
71
- ai_edge_torch/generative/examples/paligemma/verify.py,sha256=__RUyh0L5Td2jbm1xGnSldbfpKHtxyXAh2h06KVGxLA,5622
70
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=FwGlFHl9zktGDxnoOpEtbS6NYN5RyzcOXH7lvNUCwEU,6257
71
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
72
72
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
73
73
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
74
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=qaROQSjgs0DtVOX4KS5kPmlDrBFn0yJr83_kWIN8NzM,3540
74
+ ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
75
75
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
76
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=cD8rtwgYeGrXB9sYVV_D1AB8Up1AWNS-1XtrRlyzE5o,2296
77
77
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=G1i_ybDCTBaOD1OOCTk6jqOf__xYYZvhXcxY8MXhPHw,2294
@@ -142,7 +142,7 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1e
142
142
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
143
143
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
144
144
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
145
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mVuax3MPRmuNjnDRKXqtc9YmswCy7MnhD1CHADK-3nk,11501
145
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NctnggTSFh0XEQbTu55diZ35rFD2QIARO-8PzLktRWg,12165
146
146
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
147
147
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
148
148
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
203
203
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
204
204
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
205
205
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
206
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/METADATA,sha256=0-7zfD8burp8x7iTlCrOe2JO8BZuV1zcRmMrcsGFjVk,1966
208
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
- ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/RECORD,,
206
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/METADATA,sha256=EfUQ_LF_l-OEQVb9-qgcqC67LwOgGvW2p1DDD7QFqp0,1966
208
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
+ ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/RECORD,,