ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/phi/phi2.py +2 -2
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
- ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +8 -8
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +7 -0
- ai_edge_torch/generative/layers/builder.py +33 -11
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +4 -4
- ai_edge_torch/generative/layers/model_config.py +24 -15
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion.py +28 -51
- ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +13 -0
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +48 -46
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
| @@ -30,6 +30,7 @@ class ActivationType(enum.Enum): | |
| 30 30 | 
             
              GELU_QUICK = enum.auto()
         | 
| 31 31 | 
             
              GE_GLU = enum.auto()
         | 
| 32 32 | 
             
              RELU = enum.auto()
         | 
| 33 | 
            +
              SILU_GLU = enum.auto()
         | 
| 33 34 |  | 
| 34 35 |  | 
| 35 36 | 
             
            @enum.unique
         | 
| @@ -58,6 +59,18 @@ class AttentionType(enum.Enum): | |
| 58 59 | 
             
              LOCAL_SLIDING = enum.auto()
         | 
| 59 60 |  | 
| 60 61 |  | 
| 62 | 
            +
            @dataclass
         | 
| 63 | 
            +
            class NormalizationConfig:
         | 
| 64 | 
            +
              """Normalizater parameters."""
         | 
| 65 | 
            +
             | 
| 66 | 
            +
              type: NormalizationType = NormalizationType.NONE
         | 
| 67 | 
            +
              enable_hlfb: bool = False
         | 
| 68 | 
            +
              epsilon: float = 1e-5
         | 
| 69 | 
            +
              zero_centered: bool = False
         | 
| 70 | 
            +
              # Number of groups used in group normalization.
         | 
| 71 | 
            +
              group_num: Optional[float] = None
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 61 74 | 
             
            @dataclass
         | 
| 62 75 | 
             
            class AttentionConfig:
         | 
| 63 76 | 
             
              """Attention model's parameters."""
         | 
| @@ -81,6 +94,14 @@ class AttentionConfig: | |
| 81 94 | 
             
              # Whether to use bias with attention output projection.
         | 
| 82 95 | 
             
              output_proj_use_bias: bool = False
         | 
| 83 96 | 
             
              enable_kv_cache: bool = True
         | 
| 97 | 
            +
              # The normalization applied to query projection's output.
         | 
| 98 | 
            +
              query_norm_config: NormalizationConfig = field(
         | 
| 99 | 
            +
                  default_factory=NormalizationConfig
         | 
| 100 | 
            +
              )
         | 
| 101 | 
            +
              # The normalization applied to key projection's output.
         | 
| 102 | 
            +
              key_norm_config: NormalizationConfig = field(
         | 
| 103 | 
            +
                  default_factory=NormalizationConfig
         | 
| 104 | 
            +
              )
         | 
| 84 105 | 
             
              relative_attention_num_buckets: int = 0
         | 
| 85 106 | 
             
              relative_attention_max_distance: int = 0
         | 
| 86 107 | 
             
              # Softcap on the output logits.
         | 
| @@ -94,21 +115,9 @@ class AttentionConfig: | |
| 94 115 | 
             
            @dataclass
         | 
| 95 116 | 
             
            class ActivationConfig:
         | 
| 96 117 | 
             
              type: ActivationType = ActivationType.LINEAR
         | 
| 97 | 
            -
              #  | 
| 98 | 
            -
               | 
| 99 | 
            -
               | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
            @dataclass
         | 
| 103 | 
            -
            class NormalizationConfig:
         | 
| 104 | 
            -
              """Normalizater parameters."""
         | 
| 105 | 
            -
             | 
| 106 | 
            -
              type: NormalizationType = NormalizationType.NONE
         | 
| 107 | 
            -
              enable_hlfb: bool = False
         | 
| 108 | 
            -
              epsilon: float = 1e-5
         | 
| 109 | 
            -
              zero_centered: bool = False
         | 
| 110 | 
            -
              # Number of groups used in group normalization.
         | 
| 111 | 
            -
              group_num: Optional[float] = None
         | 
| 118 | 
            +
              # Whether to GLU gate is the front part instead of the back part of input
         | 
| 119 | 
            +
              # when ActivationType is `GE_GLU` or `SILU_GLU`.
         | 
| 120 | 
            +
              gate_is_front: bool = False
         | 
| 112 121 |  | 
| 113 122 |  | 
| 114 123 | 
             
            @dataclass
         | 
| @@ -25,9 +25,9 @@ def main(): | |
| 25 25 | 
             
              config = gemma.get_fake_model_config()
         | 
| 26 26 | 
             
              model = gemma.Gemma(config)
         | 
| 27 27 | 
             
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         | 
| 28 | 
            -
              tokens = torch.full((1, 10), 0, dtype=torch. | 
| 28 | 
            +
              tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
         | 
| 29 29 | 
             
              tokens[0, :4] = idx
         | 
| 30 | 
            -
              input_pos = torch.arange(0, 10)
         | 
| 30 | 
            +
              input_pos = torch.arange(0, 10, dtype=torch.int)
         | 
| 31 31 |  | 
| 32 32 | 
             
              # Create a quantization recipe to be applied to the model
         | 
| 33 33 | 
             
              quant_config = quant_recipes.full_int8_dynamic_recipe()
         | 
| @@ -42,15 +42,9 @@ class TestModelConversion(googletest.TestCase): | |
| 42 42 | 
             
                    )
         | 
| 43 43 | 
             
                )
         | 
| 44 44 |  | 
| 45 | 
            -
               | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
              )
         | 
| 49 | 
            -
              def test_toy_model_with_kv_cache(self):
         | 
| 50 | 
            -
                config = toy_model_with_kv_cache.get_model_config()
         | 
| 51 | 
            -
                pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
         | 
| 52 | 
            -
                tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
         | 
| 53 | 
            -
                    [10], dtype=torch.int64
         | 
| 45 | 
            +
              def _test_model_with_kv_cache(self, config, pytorch_model):
         | 
| 46 | 
            +
                tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
         | 
| 47 | 
            +
                    [10], dtype=torch.int
         | 
| 54 48 | 
             
                )
         | 
| 55 49 | 
             
                kv = kv_cache.KVCache.from_model_config(config)
         | 
| 56 50 |  | 
| @@ -83,58 +77,32 @@ class TestModelConversion(googletest.TestCase): | |
| 83 77 | 
             
                  ai_edge_config.Config.use_torch_xla,
         | 
| 84 78 | 
             
                  reason="tests with custom ops are not supported on oss",
         | 
| 85 79 | 
             
              )
         | 
| 86 | 
            -
              def  | 
| 80 | 
            +
              def test_toy_model_with_kv_cache(self):
         | 
| 87 81 | 
             
                config = toy_model_with_kv_cache.get_model_config()
         | 
| 88 | 
            -
                config.enable_hlfb = True
         | 
| 89 82 | 
             
                pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
         | 
| 90 | 
            -
                 | 
| 91 | 
            -
                    [10], dtype=torch.int64
         | 
| 92 | 
            -
                )
         | 
| 93 | 
            -
                kv = kv_cache.KVCache.from_model_config(config)
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                edge_model = ai_edge_torch.convert(
         | 
| 96 | 
            -
                    pytorch_model,
         | 
| 97 | 
            -
                    sample_kwargs={
         | 
| 98 | 
            -
                        "tokens": tokens,
         | 
| 99 | 
            -
                        "input_pos": input_pos,
         | 
| 100 | 
            -
                        "kv_cache": kv,
         | 
| 101 | 
            -
                    },
         | 
| 102 | 
            -
                )
         | 
| 103 | 
            -
                edge_model.set_interpreter_builder(
         | 
| 104 | 
            -
                    self._interpreter_builder(edge_model.tflite_model())
         | 
| 105 | 
            -
                )
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                self.assertTrue(
         | 
| 108 | 
            -
                    test_utils.compare_tflite_torch(
         | 
| 109 | 
            -
                        edge_model,
         | 
| 110 | 
            -
                        pytorch_model,
         | 
| 111 | 
            -
                        tokens,
         | 
| 112 | 
            -
                        input_pos,
         | 
| 113 | 
            -
                        kv,
         | 
| 114 | 
            -
                        signature_name="serving_default",
         | 
| 115 | 
            -
                        atol=1e-5,
         | 
| 116 | 
            -
                        rtol=1e-5,
         | 
| 117 | 
            -
                    )
         | 
| 118 | 
            -
                )
         | 
| 83 | 
            +
                self._test_model_with_kv_cache(config, pytorch_model)
         | 
| 119 84 |  | 
| 120 85 | 
             
              @googletest.skipIf(
         | 
| 121 86 | 
             
                  ai_edge_config.Config.use_torch_xla,
         | 
| 122 87 | 
             
                  reason="tests with custom ops are not supported on oss",
         | 
| 123 88 | 
             
              )
         | 
| 124 | 
            -
              def  | 
| 125 | 
            -
                config =  | 
| 126 | 
            -
                 | 
| 89 | 
            +
              def test_toy_model_with_kv_cache_with_hlfb(self):
         | 
| 90 | 
            +
                config = toy_model_with_kv_cache.get_model_config()
         | 
| 91 | 
            +
                config.enable_hlfb = True
         | 
| 92 | 
            +
                pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
         | 
| 93 | 
            +
                self._test_model_with_kv_cache(config, pytorch_model)
         | 
| 127 94 |  | 
| 95 | 
            +
              def _test_multisig_model(self, config, pytorch_model, atol, rtol):
         | 
| 128 96 | 
             
                # prefill
         | 
| 129 97 | 
             
                seq_len = 10
         | 
| 130 | 
            -
                prefill_tokens = torch.full((1, seq_len), 0, dtype=torch. | 
| 98 | 
            +
                prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.int, device="cpu")
         | 
| 131 99 | 
             
                prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
         | 
| 132 100 | 
             
                prefill_tokens[0, : len(prompt_token)] = prompt_token
         | 
| 133 | 
            -
                prefill_input_pos = torch.arange(0, seq_len)
         | 
| 101 | 
            +
                prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
         | 
| 134 102 |  | 
| 135 103 | 
             
                # decode
         | 
| 136 | 
            -
                decode_token = torch.tensor([[1]], dtype=torch. | 
| 137 | 
            -
                decode_input_pos = torch.tensor([5], dtype=torch. | 
| 104 | 
            +
                decode_token = torch.tensor([[1]], dtype=torch.int)
         | 
| 105 | 
            +
                decode_input_pos = torch.tensor([5], dtype=torch.int)
         | 
| 138 106 |  | 
| 139 107 | 
             
                kv = kv_cache.KVCache.from_model_config(config)
         | 
| 140 108 |  | 
| @@ -171,8 +139,8 @@ class TestModelConversion(googletest.TestCase): | |
| 171 139 | 
             
                        prefill_input_pos,
         | 
| 172 140 | 
             
                        kv,
         | 
| 173 141 | 
             
                        signature_name="prefill",
         | 
| 174 | 
            -
                        atol= | 
| 175 | 
            -
                        rtol= | 
| 142 | 
            +
                        atol=atol,
         | 
| 143 | 
            +
                        rtol=atol,
         | 
| 176 144 | 
             
                    )
         | 
| 177 145 | 
             
                )
         | 
| 178 146 |  | 
| @@ -184,11 +152,20 @@ class TestModelConversion(googletest.TestCase): | |
| 184 152 | 
             
                        decode_input_pos,
         | 
| 185 153 | 
             
                        kv,
         | 
| 186 154 | 
             
                        signature_name="decode",
         | 
| 187 | 
            -
                        atol= | 
| 188 | 
            -
                        rtol= | 
| 155 | 
            +
                        atol=atol,
         | 
| 156 | 
            +
                        rtol=atol,
         | 
| 189 157 | 
             
                    )
         | 
| 190 158 | 
             
                )
         | 
| 191 159 |  | 
| 160 | 
            +
              @googletest.skipIf(
         | 
| 161 | 
            +
                  ai_edge_config.Config.use_torch_xla,
         | 
| 162 | 
            +
                  reason="tests with custom ops are not supported on oss",
         | 
| 163 | 
            +
              )
         | 
| 164 | 
            +
              def test_tiny_llama_multisig(self):
         | 
| 165 | 
            +
                config = tiny_llama.get_fake_model_config()
         | 
| 166 | 
            +
                pytorch_model = tiny_llama.TinyLlama(config).eval()
         | 
| 167 | 
            +
                self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
         | 
| 168 | 
            +
             | 
| 192 169 |  | 
| 193 170 | 
             
            if __name__ == "__main__":
         | 
| 194 171 | 
             
              googletest.main()
         | 
| @@ -19,7 +19,9 @@ import ai_edge_torch | |
| 19 19 | 
             
            from ai_edge_torch import config as ai_edge_config
         | 
| 20 20 | 
             
            from ai_edge_torch.generative.examples.gemma import gemma
         | 
| 21 21 | 
             
            from ai_edge_torch.generative.examples.gemma import gemma2
         | 
| 22 | 
            +
            from ai_edge_torch.generative.examples.openelm import openelm
         | 
| 22 23 | 
             
            from ai_edge_torch.generative.examples.phi import phi2
         | 
| 24 | 
            +
            from ai_edge_torch.generative.examples.smollm import smollm
         | 
| 23 25 | 
             
            from ai_edge_torch.generative.layers import kv_cache
         | 
| 24 26 | 
             
            from ai_edge_torch.generative.test import utils as test_utils
         | 
| 25 27 | 
             
            import numpy as np
         | 
| @@ -43,28 +45,22 @@ class TestModelConversion(googletest.TestCase): | |
| 43 45 | 
             
                    )
         | 
| 44 46 | 
             
                )
         | 
| 45 47 |  | 
| 46 | 
            -
               | 
| 47 | 
            -
                  ai_edge_config.Config.use_torch_xla,
         | 
| 48 | 
            -
                  reason="tests with custom ops are not supported on oss",
         | 
| 49 | 
            -
              )
         | 
| 50 | 
            -
              def test_gemma(self):
         | 
| 51 | 
            -
                config = gemma.get_fake_model_config()
         | 
| 52 | 
            -
                model = gemma.Gemma(config)
         | 
| 53 | 
            -
             | 
| 48 | 
            +
              def _test_model(self, config, model, signature_name, atol, rtol):
         | 
| 54 49 | 
             
                idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         | 
| 55 | 
            -
                tokens = torch.full((1, 10), 0, dtype=torch. | 
| 50 | 
            +
                tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
         | 
| 56 51 | 
             
                tokens[0, :4] = idx
         | 
| 57 | 
            -
                input_pos = torch.arange(0, 10)
         | 
| 52 | 
            +
                input_pos = torch.arange(0, 10, dtype=torch.int)
         | 
| 58 53 | 
             
                kv = kv_cache.KVCache.from_model_config(config)
         | 
| 59 54 |  | 
| 60 | 
            -
                edge_model = ai_edge_torch. | 
| 55 | 
            +
                edge_model = ai_edge_torch.signature(
         | 
| 56 | 
            +
                    signature_name,
         | 
| 61 57 | 
             
                    model,
         | 
| 62 58 | 
             
                    sample_kwargs={
         | 
| 63 59 | 
             
                        "tokens": tokens,
         | 
| 64 60 | 
             
                        "input_pos": input_pos,
         | 
| 65 61 | 
             
                        "kv_cache": kv,
         | 
| 66 62 | 
             
                    },
         | 
| 67 | 
            -
                )
         | 
| 63 | 
            +
                ).convert()
         | 
| 68 64 | 
             
                edge_model.set_interpreter_builder(
         | 
| 69 65 | 
             
                    self._interpreter_builder(edge_model.tflite_model())
         | 
| 70 66 | 
             
                )
         | 
| @@ -76,9 +72,9 @@ class TestModelConversion(googletest.TestCase): | |
| 76 72 | 
             
                        tokens,
         | 
| 77 73 | 
             
                        input_pos,
         | 
| 78 74 | 
             
                        kv,
         | 
| 79 | 
            -
                        signature_name= | 
| 80 | 
            -
                        atol= | 
| 81 | 
            -
                        rtol= | 
| 75 | 
            +
                        signature_name=signature_name,
         | 
| 76 | 
            +
                        atol=atol,
         | 
| 77 | 
            +
                        rtol=rtol,
         | 
| 82 78 | 
             
                    )
         | 
| 83 79 | 
             
                )
         | 
| 84 80 |  | 
| @@ -86,42 +82,21 @@ class TestModelConversion(googletest.TestCase): | |
| 86 82 | 
             
                  ai_edge_config.Config.use_torch_xla,
         | 
| 87 83 | 
             
                  reason="tests with custom ops are not supported on oss",
         | 
| 88 84 | 
             
              )
         | 
| 89 | 
            -
              def  | 
| 90 | 
            -
                config =  | 
| 91 | 
            -
                 | 
| 92 | 
            -
                 | 
| 93 | 
            -
             | 
| 94 | 
            -
                idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         | 
| 95 | 
            -
                prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
         | 
| 96 | 
            -
                prefill_tokens[0, :4] = idx
         | 
| 97 | 
            -
                prefill_input_pos = torch.arange(0, 10)
         | 
| 98 | 
            -
                kv = kv_cache.KVCache.from_model_config(config)
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                edge_model = ai_edge_torch.signature(
         | 
| 101 | 
            -
                    "prefill",
         | 
| 102 | 
            -
                    model,
         | 
| 103 | 
            -
                    sample_kwargs={
         | 
| 104 | 
            -
                        "tokens": prefill_tokens,
         | 
| 105 | 
            -
                        "input_pos": prefill_input_pos,
         | 
| 106 | 
            -
                        "kv_cache": kv,
         | 
| 107 | 
            -
                    },
         | 
| 108 | 
            -
                ).convert()
         | 
| 109 | 
            -
                edge_model.set_interpreter_builder(
         | 
| 110 | 
            -
                    self._interpreter_builder(edge_model.tflite_model())
         | 
| 85 | 
            +
              def test_gemma(self):
         | 
| 86 | 
            +
                config = gemma.get_fake_model_config()
         | 
| 87 | 
            +
                pytorch_model = gemma.Gemma(config).eval()
         | 
| 88 | 
            +
                self._test_model(
         | 
| 89 | 
            +
                    config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
         | 
| 111 90 | 
             
                )
         | 
| 112 91 |  | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
                        atol=1e-1,
         | 
| 122 | 
            -
                        rtol=1e-3,
         | 
| 123 | 
            -
                    )
         | 
| 124 | 
            -
                )
         | 
| 92 | 
            +
              @googletest.skipIf(
         | 
| 93 | 
            +
                  ai_edge_config.Config.use_torch_xla,
         | 
| 94 | 
            +
                  reason="tests with custom ops are not supported on oss",
         | 
| 95 | 
            +
              )
         | 
| 96 | 
            +
              def test_gemma2(self):
         | 
| 97 | 
            +
                config = gemma2.get_fake_model_config()
         | 
| 98 | 
            +
                pytorch_model = gemma2.Gemma2(config).eval()
         | 
| 99 | 
            +
                self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
         | 
| 125 100 |  | 
| 126 101 | 
             
              @googletest.skipIf(
         | 
| 127 102 | 
             
                  ai_edge_config.Config.use_torch_xla,
         | 
| @@ -130,37 +105,27 @@ class TestModelConversion(googletest.TestCase): | |
| 130 105 | 
             
              def test_phi2(self):
         | 
| 131 106 | 
             
                config = phi2.get_fake_model_config()
         | 
| 132 107 | 
             
                pytorch_model = phi2.Phi2(config).eval()
         | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 135 | 
            -
                tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
         | 
| 136 | 
            -
                tokens[0, :4] = idx
         | 
| 137 | 
            -
                input_pos = torch.arange(0, 10)
         | 
| 138 | 
            -
                kv = kv_cache.KVCache.from_model_config(config)
         | 
| 139 | 
            -
             | 
| 140 | 
            -
                edge_model = ai_edge_torch.convert(
         | 
| 141 | 
            -
                    pytorch_model,
         | 
| 142 | 
            -
                    sample_kwargs={
         | 
| 143 | 
            -
                        "tokens": tokens,
         | 
| 144 | 
            -
                        "input_pos": input_pos,
         | 
| 145 | 
            -
                        "kv_cache": kv,
         | 
| 146 | 
            -
                    },
         | 
| 147 | 
            -
                )
         | 
| 148 | 
            -
                edge_model.set_interpreter_builder(
         | 
| 149 | 
            -
                    self._interpreter_builder(edge_model.tflite_model())
         | 
| 108 | 
            +
                self._test_model(
         | 
| 109 | 
            +
                    config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
         | 
| 150 110 | 
             
                )
         | 
| 151 111 |  | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
             | 
| 156 | 
            -
             | 
| 157 | 
            -
             | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 112 | 
            +
              @googletest.skipIf(
         | 
| 113 | 
            +
                  ai_edge_config.Config.use_torch_xla,
         | 
| 114 | 
            +
                  reason="tests with custom ops are not supported on oss",
         | 
| 115 | 
            +
              )
         | 
| 116 | 
            +
              def test_smollm(self):
         | 
| 117 | 
            +
                config = smollm.get_fake_model_config()
         | 
| 118 | 
            +
                pytorch_model = smollm.SmolLM(config).eval()
         | 
| 119 | 
            +
                self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
              @googletest.skipIf(
         | 
| 122 | 
            +
                  ai_edge_config.Config.use_torch_xla,
         | 
| 123 | 
            +
                  reason="tests with custom ops are not supported on oss",
         | 
| 124 | 
            +
              )
         | 
| 125 | 
            +
              def test_openelm(self):
         | 
| 126 | 
            +
                config = openelm.get_fake_model_config()
         | 
| 127 | 
            +
                pytorch_model = openelm.OpenELM(config).eval()
         | 
| 128 | 
            +
                self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
         | 
| 164 129 |  | 
| 165 130 |  | 
| 166 131 | 
             
            if __name__ == "__main__":
         | 
| @@ -115,8 +115,8 @@ class TestQuantizeConvert(parameterized.TestCase): | |
| 115 115 | 
             
              def test_quantize_convert_toy_sizes(self, quant_config):
         | 
| 116 116 | 
             
                config = toy_model.get_model_config()
         | 
| 117 117 | 
             
                pytorch_model = toy_model.ToySingleLayerModel(config)
         | 
| 118 | 
            -
                idx = torch.unsqueeze(torch.arange(0, 100), 0)
         | 
| 119 | 
            -
                input_pos = torch.arange(0, 100)
         | 
| 118 | 
            +
                idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
         | 
| 119 | 
            +
                input_pos = torch.arange(0, 100, dtype=torch.int)
         | 
| 120 120 |  | 
| 121 121 | 
             
                quantized_model = ai_edge_torch.convert(
         | 
| 122 122 | 
             
                    pytorch_model, (idx, input_pos), quant_config=quant_config
         | 
| @@ -131,8 +131,8 @@ class TestQuantizeConvert(parameterized.TestCase): | |
| 131 131 | 
             
              def test_quantize_convert_toy_weight_sharing(self):
         | 
| 132 132 | 
             
                config = toy_model.get_model_config()
         | 
| 133 133 | 
             
                pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
         | 
| 134 | 
            -
                idx = torch.unsqueeze(torch.arange(0, 100), 0)
         | 
| 135 | 
            -
                input_pos = torch.arange(0, 100)
         | 
| 134 | 
            +
                idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
         | 
| 135 | 
            +
                input_pos = torch.arange(0, 100, dtype=torch.int)
         | 
| 136 136 |  | 
| 137 137 | 
             
                quant_config = quant_recipes.full_int8_dynamic_recipe()
         | 
| 138 138 | 
             
                quantized_model = ai_edge_torch.convert(
         | 
| @@ -149,7 +149,7 @@ class TestQuantizeConvert(parameterized.TestCase): | |
| 149 149 | 
             
                self.skipTest("b/338288901")
         | 
| 150 150 | 
             
                config = toy_model_with_kv_cache.get_model_config()
         | 
| 151 151 | 
             
                pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
         | 
| 152 | 
            -
                idx, input_pos = torch.tensor([[1]], dtype=torch. | 
| 152 | 
            +
                idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
         | 
| 153 153 | 
             
                    [10], dtype=torch.int64
         | 
| 154 154 | 
             
                )
         | 
| 155 155 |  | 
| @@ -101,6 +101,8 @@ class ModelLoader: | |
| 101 101 | 
             
                attn_value_proj: str = None
         | 
| 102 102 | 
             
                attn_fused_qkv_proj: str = None
         | 
| 103 103 | 
             
                attn_output_proj: str = None
         | 
| 104 | 
            +
                attn_query_norm: str = None
         | 
| 105 | 
            +
                attn_key_norm: str = None
         | 
| 104 106 |  | 
| 105 107 | 
             
                ff_up_proj: str = None
         | 
| 106 108 | 
             
                ff_down_proj: str = None
         | 
| @@ -323,6 +325,17 @@ class ModelLoader: | |
| 323 325 | 
             
                        )
         | 
| 324 326 | 
             
                    )
         | 
| 325 327 |  | 
| 328 | 
            +
                if self._names.attn_query_norm is not None:
         | 
| 329 | 
            +
                  attn_query_norm_name = self._names.attn_query_norm.format(idx)
         | 
| 330 | 
            +
                  converted_state[f"{prefix}.atten_func.query_norm.weight"] = state.pop(
         | 
| 331 | 
            +
                      f"{attn_query_norm_name}.weight"
         | 
| 332 | 
            +
                  )
         | 
| 333 | 
            +
                if self._names.attn_key_norm is not None:
         | 
| 334 | 
            +
                  attn_key_norm_name = self._names.attn_key_norm.format(idx)
         | 
| 335 | 
            +
                  converted_state[f"{prefix}.atten_func.key_norm.weight"] = state.pop(
         | 
| 336 | 
            +
                      f"{attn_key_norm_name}.weight"
         | 
| 337 | 
            +
                  )
         | 
| 338 | 
            +
             | 
| 326 339 | 
             
                o_name = self._names.attn_output_proj.format(idx)
         | 
| 327 340 | 
             
                converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
         | 
| 328 341 | 
             
                    state.pop(f"{o_name}.weight")
         | 
| @@ -223,6 +223,41 @@ class MlirLowered: | |
| 223 223 | 
             
                return tf_integration.mlir_to_flatbuffer(self)
         | 
| 224 224 |  | 
| 225 225 |  | 
| 226 | 
            +
            # TODO(b/331481564) Make this a ai_edge_torch FX pass.
         | 
| 227 | 
            +
            def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
         | 
| 228 | 
            +
              """Convert internal constant aten ops' output from int64 to int32.
         | 
| 229 | 
            +
             | 
| 230 | 
            +
              Int32 generally has better performance and compatibility than int64 in
         | 
| 231 | 
            +
              runtime. This pass converts aten op where the output(s) are int64 constant
         | 
| 232 | 
            +
              tensors to return int32 constant tensors.
         | 
| 233 | 
            +
             | 
| 234 | 
            +
              Args:
         | 
| 235 | 
            +
                exported_program: The exported program to apply the pass.
         | 
| 236 | 
            +
              """
         | 
| 237 | 
            +
             | 
| 238 | 
            +
              def in_i32(x: int):
         | 
| 239 | 
            +
                return -2147483648 <= x <= 2147483647
         | 
| 240 | 
            +
             | 
| 241 | 
            +
              def rewrite_arange(node: torch.fx.Node):
         | 
| 242 | 
            +
                tensor_meta = node.meta.get("tensor_meta", None)
         | 
| 243 | 
            +
                if not tensor_meta:
         | 
| 244 | 
            +
                  return
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                start, end = node.args[:2]
         | 
| 247 | 
            +
                if tensor_meta.dtype != torch.int64:
         | 
| 248 | 
            +
                  return
         | 
| 249 | 
            +
                if not (in_i32(start) and in_i32(end)):
         | 
| 250 | 
            +
                  return
         | 
| 251 | 
            +
                op = node.target
         | 
| 252 | 
            +
                node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
              graph_module = exported_program.graph_module
         | 
| 255 | 
            +
              for node in graph_module.graph.nodes:
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                if node.target == torch.ops.aten.arange.start_step:
         | 
| 258 | 
            +
                  rewrite_arange(node)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
             | 
| 226 261 | 
             
            def exported_program_to_mlir(
         | 
| 227 262 | 
             
                exported_program: torch.export.ExportedProgram,
         | 
| 228 263 | 
             
            ) -> MlirLowered:
         | 
| @@ -231,6 +266,11 @@ def exported_program_to_mlir( | |
| 231 266 | 
             
                  lowerings.decompositions()
         | 
| 232 267 | 
             
              )
         | 
| 233 268 |  | 
| 269 | 
            +
              _convert_i64_to_i32(exported_program)
         | 
| 270 | 
            +
              exported_program = exported_program.run_decompositions(
         | 
| 271 | 
            +
                  lowerings.decompositions()
         | 
| 272 | 
            +
              )
         | 
| 273 | 
            +
             | 
| 234 274 | 
             
              with export_utils.create_ir_context() as context, ir.Location.unknown():
         | 
| 235 275 |  | 
| 236 276 | 
             
                module = ir.Module.create()
         | 
| @@ -202,3 +202,47 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value: | |
| 202 202 | 
             
              x, y = utils.broadcast_args_if_needed(x, y)
         | 
| 203 203 |  | 
| 204 204 | 
             
              return stablehlo.divide(x, y)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            # Schema:
         | 
| 208 | 
            +
            #   - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
         | 
| 209 | 
            +
            #       start=None, SymInt? end=None, SymInt step=1) -> Tensor
         | 
| 210 | 
            +
            # Torch Reference:
         | 
| 211 | 
            +
            #   - https://pytorch.org/docs/stable/generated/torch.slice_scatter.html
         | 
| 212 | 
            +
            #   - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
         | 
| 213 | 
            +
            @lower(torch.ops.aten.slice_scatter)
         | 
| 214 | 
            +
            def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
         | 
| 215 | 
            +
              start = start or 0
         | 
| 216 | 
            +
              end = end or self.type.shape[dim]
         | 
| 217 | 
            +
              if start < 0:
         | 
| 218 | 
            +
                start = self.type.shape[dim] + start
         | 
| 219 | 
            +
              if end < 0:
         | 
| 220 | 
            +
                end = self.type.shape[dim] + end
         | 
| 221 | 
            +
             | 
| 222 | 
            +
              end = start + step * math.ceil((end - start) / step) - (step - 1)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
              padding_low = start
         | 
| 225 | 
            +
              padding_high = self.type.shape[dim] - end
         | 
| 226 | 
            +
             | 
| 227 | 
            +
              rank = len(self.type.shape)
         | 
| 228 | 
            +
              src = stablehlo.pad(
         | 
| 229 | 
            +
                  src,
         | 
| 230 | 
            +
                  utils.splat(0, src.type.element_type, []),
         | 
| 231 | 
            +
                  edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
         | 
| 232 | 
            +
                  edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
         | 
| 233 | 
            +
                  interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
         | 
| 234 | 
            +
              )
         | 
| 235 | 
            +
              pred = np.ones(self.type.shape, dtype=np.bool_)
         | 
| 236 | 
            +
              pred[*[
         | 
| 237 | 
            +
                  slice(start, end, step) if i == dim else slice(None, None, None)
         | 
| 238 | 
            +
                  for i in range(rank)
         | 
| 239 | 
            +
              ]] = False
         | 
| 240 | 
            +
              pred = stablehlo.constant(
         | 
| 241 | 
            +
                  ir.DenseElementsAttr.get(
         | 
| 242 | 
            +
                      np.packbits(pred, bitorder="little"),
         | 
| 243 | 
            +
                      type=ir.IntegerType.get_signless(1),
         | 
| 244 | 
            +
                      shape=pred.shape,
         | 
| 245 | 
            +
                  )
         | 
| 246 | 
            +
              )
         | 
| 247 | 
            +
              out = stablehlo.select(pred, self, src)
         | 
| 248 | 
            +
              return out
         | 
| @@ -203,7 +203,6 @@ lower_by_torch_xla2(torch.ops.aten.sin) | |
| 203 203 | 
             
            lower_by_torch_xla2(torch.ops.aten.sinh)
         | 
| 204 204 | 
             
            lower_by_torch_xla2(torch.ops.aten.slice)
         | 
| 205 205 | 
             
            lower_by_torch_xla2(torch.ops.aten.slice_copy)
         | 
| 206 | 
            -
            lower_by_torch_xla2(torch.ops.aten.slice_scatter)
         | 
| 207 206 | 
             
            lower_by_torch_xla2(torch.ops.aten.sort)
         | 
| 208 207 | 
             
            lower_by_torch_xla2(torch.ops.aten.split)
         | 
| 209 208 | 
             
            lower_by_torch_xla2(torch.ops.aten.split_copy)
         | 
    
        ai_edge_torch/version.py
    CHANGED
    
    
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.1
         | 
| 2 2 | 
             
            Name: ai-edge-torch-nightly
         | 
| 3 | 
            -
            Version: 0.3.0. | 
| 3 | 
            +
            Version: 0.3.0.dev20240914
         | 
| 4 4 | 
             
            Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
         | 
| 5 5 | 
             
            Home-page: https://github.com/google-ai-edge/ai-edge-torch
         | 
| 6 6 | 
             
            Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
         |