ai-edge-torch-nightly 0.3.0.dev20241222__py3-none-any.whl → 0.3.0.dev20241223__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
- ai_edge_torch/generative/examples/paligemma/decoder.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder2.py +2 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +2 -2
- ai_edge_torch/generative/examples/paligemma/verify.py +1 -1
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/RECORD +13 -13
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241222.dist-info → ai_edge_torch_nightly-0.3.0.dev20241223.dist-info}/top_level.txt +0 -0
| @@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter | |
| 29 29 | 
             
            from ai_edge_torch.generative.utilities.model_builder import ExportConfig
         | 
| 30 30 | 
             
            import torch
         | 
| 31 31 |  | 
| 32 | 
            +
            _VERSION = flags.DEFINE_enum(
         | 
| 33 | 
            +
                'version',
         | 
| 34 | 
            +
                '2',
         | 
| 35 | 
            +
                ['1', '2'],
         | 
| 36 | 
            +
                'The version of PaliGemma model to verify.',
         | 
| 37 | 
            +
            )
         | 
| 32 38 | 
             
            _CHECKPOINT_PATH = flags.DEFINE_string(
         | 
| 33 39 | 
             
                'checkpoint_path',
         | 
| 34 | 
            -
                os.path.join(pathlib.Path.home(), 'Downloads/llm_data/ | 
| 40 | 
            +
                os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
         | 
| 35 41 | 
             
                'The path to the model checkpoint, or directory holding the checkpoint.',
         | 
| 36 42 | 
             
            )
         | 
| 37 43 | 
             
            _TFLITE_PATH = flags.DEFINE_string(
         | 
| @@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool( | |
| 63 69 |  | 
| 64 70 | 
             
            def main(_):
         | 
| 65 71 | 
             
              pytorch_model = paligemma.build_model(
         | 
| 66 | 
            -
                  _CHECKPOINT_PATH.value, | 
| 72 | 
            +
                  _CHECKPOINT_PATH.value,
         | 
| 73 | 
            +
                  version=int(_VERSION.value),
         | 
| 74 | 
            +
                  kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
         | 
| 67 75 | 
             
              )
         | 
| 68 76 | 
             
              quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
         | 
| 69 | 
            -
              output_filename = f' | 
| 77 | 
            +
              output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
         | 
| 70 78 | 
             
              converter.convert_to_tflite(
         | 
| 71 79 | 
             
                  pytorch_model,
         | 
| 72 80 | 
             
                  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
         | 
| @@ -137,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: | |
| 137 137 | 
             
              config.vocab_size = 128
         | 
| 138 138 | 
             
              config.num_layers = 2
         | 
| 139 139 | 
             
              config.max_seq_len = 2 * kv_cache_max_len
         | 
| 140 | 
            +
              config.embedding_dim = 128
         | 
| 141 | 
            +
              config.embedding_scale = 128**0.5
         | 
| 140 142 | 
             
              return config
         | 
| 141 143 |  | 
| 142 144 |  | 
| @@ -160,6 +160,8 @@ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: | |
| 160 160 | 
             
              config.vocab_size = 128
         | 
| 161 161 | 
             
              config.num_layers = 2
         | 
| 162 162 | 
             
              config.max_seq_len = 2 * kv_cache_max_len
         | 
| 163 | 
            +
              config.embedding_dim = 128
         | 
| 164 | 
            +
              config.embedding_scale = 128**0.5
         | 
| 163 165 | 
             
              return config
         | 
| 164 166 |  | 
| 165 167 |  | 
| @@ -136,8 +136,8 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig: | |
| 136 136 | 
             
              return PaliGemmaConfig(
         | 
| 137 137 | 
             
                  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
         | 
| 138 138 | 
             
                  decoder_config=get_decoder_config(**kwargs),
         | 
| 139 | 
            -
                  image_token_id= | 
| 140 | 
            -
                  image_projection_scale= | 
| 139 | 
            +
                  image_token_id=127,
         | 
| 140 | 
            +
                  image_projection_scale=128**0.5,
         | 
| 141 141 | 
             
                  image_projection_use_bias=True,
         | 
| 142 142 | 
             
              )
         | 
| 143 143 |  | 
| @@ -21,6 +21,8 @@ from ai_edge_torch.generative.examples.gemma import gemma1 | |
| 21 21 | 
             
            from ai_edge_torch.generative.examples.gemma import gemma2
         | 
| 22 22 | 
             
            from ai_edge_torch.generative.examples.llama import llama
         | 
| 23 23 | 
             
            from ai_edge_torch.generative.examples.openelm import openelm
         | 
| 24 | 
            +
            from ai_edge_torch.generative.examples.paligemma import decoder
         | 
| 25 | 
            +
            from ai_edge_torch.generative.examples.paligemma import decoder2
         | 
| 24 26 | 
             
            from ai_edge_torch.generative.examples.paligemma import paligemma
         | 
| 25 27 | 
             
            from ai_edge_torch.generative.examples.phi import phi2
         | 
| 26 28 | 
             
            from ai_edge_torch.generative.examples.phi import phi3
         | 
| @@ -171,13 +173,9 @@ class TestModelConversion(googletest.TestCase): | |
| 171 173 | 
             
                pytorch_model = amd_llama_135m.AmdLlama(config).eval()
         | 
| 172 174 | 
             
                self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
         | 
| 173 175 |  | 
| 174 | 
            -
               | 
| 175 | 
            -
             | 
| 176 | 
            -
             | 
| 177 | 
            -
              )
         | 
| 178 | 
            -
              def disabled_test_paligemma(self):
         | 
| 179 | 
            -
                config = paligemma.get_fake_model_config()
         | 
| 180 | 
            -
                pytorch_model = paligemma.PaliGemma(config).eval()
         | 
| 176 | 
            +
              def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
         | 
| 177 | 
            +
                config = paligemma.get_fake_model_config(decoder_config)
         | 
| 178 | 
            +
                pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
         | 
| 181 179 |  | 
| 182 180 | 
             
                image_embedding_config = config.image_encoder_config.image_embedding
         | 
| 183 181 | 
             
                num_patches = (
         | 
| @@ -215,11 +213,32 @@ class TestModelConversion(googletest.TestCase): | |
| 215 213 | 
             
                        kv,
         | 
| 216 214 | 
             
                        pixel_values=pixel_values,
         | 
| 217 215 | 
             
                        signature_name="prefill_pixel",
         | 
| 218 | 
            -
                        atol= | 
| 219 | 
            -
                        rtol= | 
| 216 | 
            +
                        atol=atol,
         | 
| 217 | 
            +
                        rtol=rtol,
         | 
| 220 218 | 
             
                    )
         | 
| 221 219 | 
             
                )
         | 
| 222 220 |  | 
| 221 | 
            +
              @googletest.skipIf(
         | 
| 222 | 
            +
                  ai_edge_torch.config.in_oss,
         | 
| 223 | 
            +
                  reason="tests with custom ops are not supported in oss",
         | 
| 224 | 
            +
              )
         | 
| 225 | 
            +
              def disabled_test_paligemma1(self):
         | 
| 226 | 
            +
                self._test_paligemma_model(
         | 
| 227 | 
            +
                    decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
         | 
| 228 | 
            +
                )
         | 
| 229 | 
            +
             | 
| 230 | 
            +
              @googletest.skipIf(
         | 
| 231 | 
            +
                  ai_edge_torch.config.in_oss,
         | 
| 232 | 
            +
                  reason="tests with custom ops are not supported in oss",
         | 
| 233 | 
            +
              )
         | 
| 234 | 
            +
              def disabled_test_paligemma2(self):
         | 
| 235 | 
            +
                self._test_paligemma_model(
         | 
| 236 | 
            +
                    decoder2.Decoder2,
         | 
| 237 | 
            +
                    decoder2.get_fake_decoder2_config,
         | 
| 238 | 
            +
                    atol=1e-3,
         | 
| 239 | 
            +
                    rtol=1e-5,
         | 
| 240 | 
            +
                )
         | 
| 241 | 
            +
             | 
| 223 242 | 
             
              @googletest.skipIf(
         | 
| 224 243 | 
             
                  ai_edge_torch.config.in_oss,
         | 
| 225 244 | 
             
                  reason="tests with custom ops are not supported in oss",
         | 
    
        ai_edge_torch/version.py
    CHANGED
    
    
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.1
         | 
| 2 2 | 
             
            Name: ai-edge-torch-nightly
         | 
| 3 | 
            -
            Version: 0.3.0. | 
| 3 | 
            +
            Version: 0.3.0.dev20241223
         | 
| 4 4 | 
             
            Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
         | 
| 5 5 | 
             
            Home-page: https://github.com/google-ai-edge/ai-edge-torch
         | 
| 6 6 | 
             
            Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
         | 
| @@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614 | |
| 3 3 | 
             
            ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
         | 
| 4 4 | 
             
            ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
         | 
| 5 5 | 
             
            ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
         | 
| 6 | 
            -
            ai_edge_torch/version.py,sha256= | 
| 6 | 
            +
            ai_edge_torch/version.py,sha256=lt932yLxGWAwtUdRXBC2DJS7c4fJ4v36-tuSXCugwsc,706
         | 
| 7 7 | 
             
            ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 8 8 | 
             
            ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
         | 
| 9 9 | 
             
            ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
         | 
| @@ -63,15 +63,15 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=nji1oDgf6x | |
| 63 63 | 
             
            ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
         | 
| 64 64 | 
             
            ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
         | 
| 65 65 | 
             
            ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 66 | 
            -
            ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256= | 
| 67 | 
            -
            ai_edge_torch/generative/examples/paligemma/decoder.py,sha256= | 
| 68 | 
            -
            ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256= | 
| 66 | 
            +
            ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=LFCcnkmOksySDa_5bLBzoGMijYdFVjXIMidUlyzAbNk,2996
         | 
| 67 | 
            +
            ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=amN96oBMTPolOFvGa47vG92AZ-BNLm8j0bBYd-IrMvI,5407
         | 
| 68 | 
            +
            ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=0V_CX0Pn5Fj_-koOGjc_Av2KMSAaVjAlD-G8P6FBGyY,6385
         | 
| 69 69 | 
             
            ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
         | 
| 70 | 
            -
            ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256= | 
| 71 | 
            -
            ai_edge_torch/generative/examples/paligemma/verify.py,sha256= | 
| 70 | 
            +
            ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=FwGlFHl9zktGDxnoOpEtbS6NYN5RyzcOXH7lvNUCwEU,6257
         | 
| 71 | 
            +
            ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
         | 
| 72 72 | 
             
            ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
         | 
| 73 73 | 
             
            ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
         | 
| 74 | 
            -
            ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256= | 
| 74 | 
            +
            ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
         | 
| 75 75 | 
             
            ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 76 76 | 
             
            ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=cD8rtwgYeGrXB9sYVV_D1AB8Up1AWNS-1XtrRlyzE5o,2296
         | 
| 77 77 | 
             
            ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=G1i_ybDCTBaOD1OOCTk6jqOf__xYYZvhXcxY8MXhPHw,2294
         | 
| @@ -142,7 +142,7 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1e | |
| 142 142 | 
             
            ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
         | 
| 143 143 | 
             
            ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
         | 
| 144 144 | 
             
            ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
         | 
| 145 | 
            -
            ai_edge_torch/generative/test/test_model_conversion_large.py,sha256= | 
| 145 | 
            +
            ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NctnggTSFh0XEQbTu55diZ35rFD2QIARO-8PzLktRWg,12165
         | 
| 146 146 | 
             
            ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
         | 
| 147 147 | 
             
            ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
         | 
| 148 148 | 
             
            ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
         | 
| @@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9 | |
| 203 203 | 
             
            ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 204 204 | 
             
            ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
         | 
| 205 205 | 
             
            ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
         | 
| 206 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 207 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 208 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 209 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 210 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 206 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
         | 
| 207 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/METADATA,sha256=EfUQ_LF_l-OEQVb9-qgcqC67LwOgGvW2p1DDD7QFqp0,1966
         | 
| 208 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
         | 
| 209 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
         | 
| 210 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241223.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |