ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
 - ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
 - ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
 - ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
 - ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
 - ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
 - ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
 - ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
 - ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
 - ai_edge_torch/generative/layers/attention.py +60 -63
 - ai_edge_torch/generative/layers/kv_cache.py +160 -51
 - ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
 - ai_edge_torch/generative/test/test_model_conversion.py +71 -33
 - ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
 - ai_edge_torch/generative/test/utils.py +54 -0
 - ai_edge_torch/version.py +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +22 -32
 - ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
 - ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
 - ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
 - ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
 - ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
 - ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
 - ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
 - ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
 - ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
 - ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
 - ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
 - /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
 
| 
         @@ -13,32 +13,35 @@ 
     | 
|
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
15 
     | 
    
         | 
| 
      
 16 
     | 
    
         
            +
            """Example of converting a Gemma2 model to multi-signature tflite model."""
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
       16 
18 
     | 
    
         
             
            import os
         
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
      
 19 
     | 
    
         
            +
            import pathlib
         
     | 
| 
       18 
20 
     | 
    
         | 
| 
       19 
21 
     | 
    
         
             
            import ai_edge_torch
         
     | 
| 
       20 
22 
     | 
    
         
             
            from ai_edge_torch.generative.examples.gemma import gemma2
         
     | 
| 
      
 23 
     | 
    
         
            +
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         
     | 
| 
       21 
24 
     | 
    
         
             
            from ai_edge_torch.generative.quantize import quant_recipes
         
     | 
| 
       22 
25 
     | 
    
         
             
            import torch
         
     | 
| 
       23 
26 
     | 
    
         | 
| 
       24 
27 
     | 
    
         | 
| 
       25 
     | 
    
         
            -
            def  
     | 
| 
      
 28 
     | 
    
         
            +
            def convert_gemma2_to_tflite(
         
     | 
| 
       26 
29 
     | 
    
         
             
                checkpoint_path: str,
         
     | 
| 
       27 
30 
     | 
    
         
             
                prefill_seq_len: int = 512,
         
     | 
| 
       28 
31 
     | 
    
         
             
                kv_cache_max_len: int = 1024,
         
     | 
| 
       29 
32 
     | 
    
         
             
                quantize: bool = True,
         
     | 
| 
       30 
33 
     | 
    
         
             
            ):
         
     | 
| 
       31 
     | 
    
         
            -
              """ 
     | 
| 
       32 
     | 
    
         
            -
              tflite model.
         
     | 
| 
      
 34 
     | 
    
         
            +
              """Converts a Gemma2 2B model to multi-signature tflite model.
         
     | 
| 
       33 
35 
     | 
    
         | 
| 
       34 
36 
     | 
    
         
             
              Args:
         
     | 
| 
       35 
     | 
    
         
            -
                  checkpoint_path (str): The filepath to the model checkpoint, or directory 
     | 
| 
      
 37 
     | 
    
         
            +
                  checkpoint_path (str): The filepath to the model checkpoint, or directory
         
     | 
| 
      
 38 
     | 
    
         
            +
                    holding the checkpoint.
         
     | 
| 
       36 
39 
     | 
    
         
             
                  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
         
     | 
| 
       37 
40 
     | 
    
         
             
                    Defaults to 512.
         
     | 
| 
       38 
41 
     | 
    
         
             
                  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
         
     | 
| 
       39 
42 
     | 
    
         
             
                    including both prefill and decode. Defaults to 1024.
         
     | 
| 
       40 
     | 
    
         
            -
                  quantize (bool, optional): Whether the model should be quanized.
         
     | 
| 
       41 
     | 
    
         
            -
                     
     | 
| 
      
 43 
     | 
    
         
            +
                  quantize (bool, optional): Whether the model should be quanized. Defaults
         
     | 
| 
      
 44 
     | 
    
         
            +
                    to True.
         
     | 
| 
       42 
45 
     | 
    
         
             
              """
         
     | 
| 
       43 
46 
     | 
    
         
             
              pytorch_model = gemma2.build_2b_model(
         
     | 
| 
       44 
47 
     | 
    
         
             
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         
     | 
| 
         @@ -48,20 +51,36 @@ def convert_gemma_to_tflite( 
     | 
|
| 
       48 
51 
     | 
    
         
             
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         
     | 
| 
       49 
52 
     | 
    
         
             
              decode_token = torch.tensor([[0]], dtype=torch.long)
         
     | 
| 
       50 
53 
     | 
    
         
             
              decode_input_pos = torch.tensor([0], dtype=torch.int64)
         
     | 
| 
      
 54 
     | 
    
         
            +
              kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
         
     | 
| 
       51 
55 
     | 
    
         | 
| 
       52 
56 
     | 
    
         
             
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         
     | 
| 
       53 
57 
     | 
    
         
             
              edge_model = (
         
     | 
| 
       54 
58 
     | 
    
         
             
                  ai_edge_torch.signature(
         
     | 
| 
       55 
     | 
    
         
            -
                      'prefill', 
     | 
| 
      
 59 
     | 
    
         
            +
                      'prefill',
         
     | 
| 
      
 60 
     | 
    
         
            +
                      pytorch_model,
         
     | 
| 
      
 61 
     | 
    
         
            +
                      sample_kwargs={
         
     | 
| 
      
 62 
     | 
    
         
            +
                          'tokens': prefill_tokens,
         
     | 
| 
      
 63 
     | 
    
         
            +
                          'input_pos': prefill_input_pos,
         
     | 
| 
      
 64 
     | 
    
         
            +
                          'kv_cache': kv,
         
     | 
| 
      
 65 
     | 
    
         
            +
                      },
         
     | 
| 
      
 66 
     | 
    
         
            +
                  )
         
     | 
| 
      
 67 
     | 
    
         
            +
                  .signature(
         
     | 
| 
      
 68 
     | 
    
         
            +
                      'decode',
         
     | 
| 
      
 69 
     | 
    
         
            +
                      pytorch_model,
         
     | 
| 
      
 70 
     | 
    
         
            +
                      sample_kwargs={
         
     | 
| 
      
 71 
     | 
    
         
            +
                          'tokens': decode_token,
         
     | 
| 
      
 72 
     | 
    
         
            +
                          'input_pos': decode_input_pos,
         
     | 
| 
      
 73 
     | 
    
         
            +
                          'kv_cache': kv,
         
     | 
| 
      
 74 
     | 
    
         
            +
                      },
         
     | 
| 
       56 
75 
     | 
    
         
             
                  )
         
     | 
| 
       57 
     | 
    
         
            -
                  .signature('decode', pytorch_model, (decode_token, decode_input_pos))
         
     | 
| 
       58 
76 
     | 
    
         
             
                  .convert(quant_config=quant_config)
         
     | 
| 
       59 
77 
     | 
    
         
             
              )
         
     | 
| 
      
 78 
     | 
    
         
            +
              quant_suffix = 'q8' if quantize else 'f32'
         
     | 
| 
       60 
79 
     | 
    
         
             
              edge_model.export(
         
     | 
| 
       61 
     | 
    
         
            -
                  f'/tmp/ 
     | 
| 
      
 80 
     | 
    
         
            +
                  f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
         
     | 
| 
       62 
81 
     | 
    
         
             
              )
         
     | 
| 
       63 
82 
     | 
    
         | 
| 
       64 
83 
     | 
    
         | 
| 
       65 
84 
     | 
    
         
             
            if __name__ == '__main__':
         
     | 
| 
       66 
     | 
    
         
            -
               
     | 
| 
       67 
     | 
    
         
            -
               
     | 
| 
      
 85 
     | 
    
         
            +
              path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
         
     | 
| 
      
 86 
     | 
    
         
            +
              convert_gemma2_to_tflite(path)
         
     | 
| 
         @@ -13,11 +13,14 @@ 
     | 
|
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
15 
     | 
    
         | 
| 
      
 16 
     | 
    
         
            +
            """Example of converting a Gemma model to multi-signature tflite model."""
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
       16 
18 
     | 
    
         
             
            import os
         
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
      
 19 
     | 
    
         
            +
            import pathlib
         
     | 
| 
       18 
20 
     | 
    
         | 
| 
       19 
21 
     | 
    
         
             
            import ai_edge_torch
         
     | 
| 
       20 
22 
     | 
    
         
             
            from ai_edge_torch.generative.examples.gemma import gemma
         
     | 
| 
      
 23 
     | 
    
         
            +
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         
     | 
| 
       21 
24 
     | 
    
         
             
            from ai_edge_torch.generative.quantize import quant_recipes
         
     | 
| 
       22 
25 
     | 
    
         
             
            import torch
         
     | 
| 
       23 
26 
     | 
    
         | 
| 
         @@ -48,20 +51,36 @@ def convert_gemma_to_tflite( 
     | 
|
| 
       48 
51 
     | 
    
         
             
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         
     | 
| 
       49 
52 
     | 
    
         
             
              decode_token = torch.tensor([[0]], dtype=torch.long)
         
     | 
| 
       50 
53 
     | 
    
         
             
              decode_input_pos = torch.tensor([0], dtype=torch.int64)
         
     | 
| 
      
 54 
     | 
    
         
            +
              kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
         
     | 
| 
       51 
55 
     | 
    
         | 
| 
       52 
56 
     | 
    
         
             
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         
     | 
| 
       53 
57 
     | 
    
         
             
              edge_model = (
         
     | 
| 
       54 
58 
     | 
    
         
             
                  ai_edge_torch.signature(
         
     | 
| 
       55 
     | 
    
         
            -
                      'prefill', 
     | 
| 
      
 59 
     | 
    
         
            +
                      'prefill',
         
     | 
| 
      
 60 
     | 
    
         
            +
                      pytorch_model,
         
     | 
| 
      
 61 
     | 
    
         
            +
                      sample_kwargs={
         
     | 
| 
      
 62 
     | 
    
         
            +
                          'tokens': prefill_tokens,
         
     | 
| 
      
 63 
     | 
    
         
            +
                          'input_pos': prefill_input_pos,
         
     | 
| 
      
 64 
     | 
    
         
            +
                          'kv_cache': kv,
         
     | 
| 
      
 65 
     | 
    
         
            +
                      },
         
     | 
| 
      
 66 
     | 
    
         
            +
                  )
         
     | 
| 
      
 67 
     | 
    
         
            +
                  .signature(
         
     | 
| 
      
 68 
     | 
    
         
            +
                      'decode',
         
     | 
| 
      
 69 
     | 
    
         
            +
                      pytorch_model,
         
     | 
| 
      
 70 
     | 
    
         
            +
                      sample_kwargs={
         
     | 
| 
      
 71 
     | 
    
         
            +
                          'tokens': decode_token,
         
     | 
| 
      
 72 
     | 
    
         
            +
                          'input_pos': decode_input_pos,
         
     | 
| 
      
 73 
     | 
    
         
            +
                          'kv_cache': kv,
         
     | 
| 
      
 74 
     | 
    
         
            +
                      },
         
     | 
| 
       56 
75 
     | 
    
         
             
                  )
         
     | 
| 
       57 
     | 
    
         
            -
                  .signature('decode', pytorch_model, (decode_token, decode_input_pos))
         
     | 
| 
       58 
76 
     | 
    
         
             
                  .convert(quant_config=quant_config)
         
     | 
| 
       59 
77 
     | 
    
         
             
              )
         
     | 
| 
      
 78 
     | 
    
         
            +
              quant_suffix = 'q8' if quantize else 'f32'
         
     | 
| 
       60 
79 
     | 
    
         
             
              edge_model.export(
         
     | 
| 
       61 
     | 
    
         
            -
                  f'/tmp/ 
     | 
| 
      
 80 
     | 
    
         
            +
                  f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
         
     | 
| 
       62 
81 
     | 
    
         
             
              )
         
     | 
| 
       63 
82 
     | 
    
         | 
| 
       64 
83 
     | 
    
         | 
| 
       65 
84 
     | 
    
         
             
            if __name__ == '__main__':
         
     | 
| 
       66 
     | 
    
         
            -
               
     | 
| 
       67 
     | 
    
         
            -
              convert_gemma_to_tflite( 
     | 
| 
      
 85 
     | 
    
         
            +
              path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
         
     | 
| 
      
 86 
     | 
    
         
            +
              convert_gemma_to_tflite(path)
         
     | 
| 
         @@ -12,13 +12,15 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            """Example of building a Gemma model."""
         
     | 
| 
       16 
17 
     | 
    
         | 
| 
       17 
18 
     | 
    
         
             
            import os
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
      
 19 
     | 
    
         
            +
            import pathlib
         
     | 
| 
       19 
20 
     | 
    
         | 
| 
       20 
21 
     | 
    
         
             
            from ai_edge_torch.generative.layers import attention
         
     | 
| 
       21 
22 
     | 
    
         
             
            from ai_edge_torch.generative.layers import builder
         
     | 
| 
      
 23 
     | 
    
         
            +
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         
     | 
| 
       22 
24 
     | 
    
         
             
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         
     | 
| 
       23 
25 
     | 
    
         
             
            import ai_edge_torch.generative.layers.model_config as cfg
         
     | 
| 
       24 
26 
     | 
    
         
             
            import ai_edge_torch.generative.utilities.loader as loading_utils
         
     | 
| 
         @@ -84,16 +86,22 @@ class Gemma(nn.Module): 
     | 
|
| 
       84 
86 
     | 
    
         
             
                )
         
     | 
| 
       85 
87 
     | 
    
         
             
                self.config = config
         
     | 
| 
       86 
88 
     | 
    
         | 
| 
       87 
     | 
    
         
            -
              # The model's forward function takes in additional k/v cache tensors
         
     | 
| 
       88 
     | 
    
         
            -
              # and returns the updated k/v cache tensors to the caller.
         
     | 
| 
       89 
     | 
    
         
            -
              # This can be eliminated if we handle k/v cache updates inside the model itself.
         
     | 
| 
       90 
89 
     | 
    
         
             
              @torch.inference_mode
         
     | 
| 
       91 
     | 
    
         
            -
              def forward( 
     | 
| 
       92 
     | 
    
         
            -
             
     | 
| 
      
 90 
     | 
    
         
            +
              def forward(
         
     | 
| 
      
 91 
     | 
    
         
            +
                  self,
         
     | 
| 
      
 92 
     | 
    
         
            +
                  tokens: torch.Tensor,
         
     | 
| 
      
 93 
     | 
    
         
            +
                  input_pos: torch.Tensor,
         
     | 
| 
      
 94 
     | 
    
         
            +
                  kv_cache: kv_utils.KVCache,
         
     | 
| 
      
 95 
     | 
    
         
            +
              ) -> dict[torch.Tensor, kv_utils.KVCache]:
         
     | 
| 
      
 96 
     | 
    
         
            +
                _, seq_len = tokens.size()
         
     | 
| 
       93 
97 
     | 
    
         
             
                assert self.config.max_seq_len >= seq_len, (
         
     | 
| 
       94 
98 
     | 
    
         
             
                    f"Cannot forward sequence of length {seq_len}, max seq length is only"
         
     | 
| 
       95 
99 
     | 
    
         
             
                    f" {self.config.max_seq_len}"
         
     | 
| 
       96 
100 
     | 
    
         
             
                )
         
     | 
| 
      
 101 
     | 
    
         
            +
                assert len(self.transformer_blocks) == len(kv_cache.caches), (
         
     | 
| 
      
 102 
     | 
    
         
            +
                    "The number of transformer blocks and the number of KV cache entries"
         
     | 
| 
      
 103 
     | 
    
         
            +
                    " must be the same."
         
     | 
| 
      
 104 
     | 
    
         
            +
                )
         
     | 
| 
       97 
105 
     | 
    
         | 
| 
       98 
106 
     | 
    
         
             
                cos, sin = self.rope_cache
         
     | 
| 
       99 
107 
     | 
    
         
             
                cos = cos.index_select(0, input_pos)
         
     | 
| 
         @@ -102,15 +110,20 @@ class Gemma(nn.Module): 
     | 
|
| 
       102 
110 
     | 
    
         
             
                mask = mask[:, :, :, : self.config.kv_cache_max]
         
     | 
| 
       103 
111 
     | 
    
         | 
| 
       104 
112 
     | 
    
         
             
                # token embeddings of shape (b, t, n_embd)
         
     | 
| 
       105 
     | 
    
         
            -
                x = self.tok_embedding( 
     | 
| 
      
 113 
     | 
    
         
            +
                x = self.tok_embedding(tokens)
         
     | 
| 
       106 
114 
     | 
    
         
             
                x = x * (self.config.embedding_dim**0.5)
         
     | 
| 
       107 
115 
     | 
    
         | 
| 
       108 
     | 
    
         
            -
                 
     | 
| 
       109 
     | 
    
         
            -
             
     | 
| 
      
 116 
     | 
    
         
            +
                updated_kv_entires = []
         
     | 
| 
      
 117 
     | 
    
         
            +
                for i, block in enumerate(self.transformer_blocks):
         
     | 
| 
      
 118 
     | 
    
         
            +
                  kv_entry = kv_cache.caches[i] if kv_cache else None
         
     | 
| 
      
 119 
     | 
    
         
            +
                  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
         
     | 
| 
      
 120 
     | 
    
         
            +
                  if kv_entry:
         
     | 
| 
      
 121 
     | 
    
         
            +
                    updated_kv_entires.append(kv_entry)
         
     | 
| 
      
 122 
     | 
    
         
            +
                updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
         
     | 
| 
       110 
123 
     | 
    
         | 
| 
       111 
124 
     | 
    
         
             
                x = self.final_norm(x)
         
     | 
| 
       112 
     | 
    
         
            -
                 
     | 
| 
       113 
     | 
    
         
            -
                return  
     | 
| 
      
 125 
     | 
    
         
            +
                logits = self.lm_head(x)  # (b, t, vocab_size)
         
     | 
| 
      
 126 
     | 
    
         
            +
                return {"logits": logits, "kv_cache": updated_kv_cache}
         
     | 
| 
       114 
127 
     | 
    
         | 
| 
       115 
128 
     | 
    
         | 
| 
       116 
129 
     | 
    
         
             
            def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
         
     | 
| 
         @@ -177,25 +190,28 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module: 
     | 
|
| 
       177 
190 
     | 
    
         
             
              return model
         
     | 
| 
       178 
191 
     | 
    
         | 
| 
       179 
192 
     | 
    
         | 
| 
       180 
     | 
    
         
            -
            def define_and_run_2b() -> None:
         
     | 
| 
      
 193 
     | 
    
         
            +
            def define_and_run_2b(checkpoint_path: str) -> None:
         
     | 
| 
       181 
194 
     | 
    
         
             
              """Instantiates and runs a Gemma 2B model."""
         
     | 
| 
       182 
195 
     | 
    
         | 
| 
       183 
     | 
    
         
            -
              current_dir = Path(__file__).parent.resolve()
         
     | 
| 
      
 196 
     | 
    
         
            +
              current_dir = pathlib.Path(__file__).parent.resolve()
         
     | 
| 
       184 
197 
     | 
    
         
             
              gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
         
     | 
| 
       185 
198 
     | 
    
         | 
| 
       186 
199 
     | 
    
         
             
              kv_cache_max_len = 1024
         
     | 
| 
       187 
     | 
    
         
            -
              checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
         
     | 
| 
       188 
200 
     | 
    
         
             
              model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
         
     | 
| 
       189 
201 
     | 
    
         
             
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         
     | 
| 
       190 
202 
     | 
    
         
             
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
         
     | 
| 
       191 
203 
     | 
    
         
             
              tokens[0, :4] = idx
         
     | 
| 
       192 
204 
     | 
    
         
             
              input_pos = torch.arange(0, kv_cache_max_len)
         
     | 
| 
       193 
     | 
    
         
            -
               
     | 
| 
      
 205 
     | 
    
         
            +
              kv = kv_utils.KVCache.from_model_config(model.config)
         
     | 
| 
      
 206 
     | 
    
         
            +
              output = model.forward(tokens, input_pos, kv)
         
     | 
| 
       194 
207 
     | 
    
         
             
              print("comparing with goldens..")
         
     | 
| 
       195 
208 
     | 
    
         
             
              assert torch.allclose(
         
     | 
| 
       196 
     | 
    
         
            -
                  gemma_goldens,  
     | 
| 
      
 209 
     | 
    
         
            +
                  gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
         
     | 
| 
       197 
210 
     | 
    
         
             
              )
         
     | 
| 
       198 
211 
     | 
    
         | 
| 
       199 
212 
     | 
    
         | 
| 
       200 
213 
     | 
    
         
             
            if __name__ == "__main__":
         
     | 
| 
       201 
     | 
    
         
            -
               
     | 
| 
      
 214 
     | 
    
         
            +
              input_checkpoint_path = os.path.join(
         
     | 
| 
      
 215 
     | 
    
         
            +
                  pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
         
     | 
| 
      
 216 
     | 
    
         
            +
              )
         
     | 
| 
      
 217 
     | 
    
         
            +
              define_and_run_2b(input_checkpoint_path)
         
     | 
| 
         @@ -12,14 +12,16 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            """Example of building a Gemma2 model."""
         
     | 
| 
       16 
17 
     | 
    
         | 
| 
       17 
18 
     | 
    
         
             
            import os
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
      
 19 
     | 
    
         
            +
            import pathlib
         
     | 
| 
       19 
20 
     | 
    
         
             
            from typing import Optional, Tuple
         
     | 
| 
       20 
21 
     | 
    
         | 
| 
       21 
22 
     | 
    
         
             
            from ai_edge_torch.generative.layers import attention
         
     | 
| 
       22 
23 
     | 
    
         
             
            from ai_edge_torch.generative.layers import builder
         
     | 
| 
      
 24 
     | 
    
         
            +
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         
     | 
| 
       23 
25 
     | 
    
         
             
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         
     | 
| 
       24 
26 
     | 
    
         
             
            import ai_edge_torch.generative.layers.model_config as cfg
         
     | 
| 
       25 
27 
     | 
    
         
             
            import ai_edge_torch.generative.utilities.loader as loading_utils
         
     | 
| 
         @@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock): 
     | 
|
| 
       51 
53 
     | 
    
         
             
                  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         
     | 
| 
       52 
54 
     | 
    
         
             
                  mask: Optional[torch.Tensor] = None,
         
     | 
| 
       53 
55 
     | 
    
         
             
                  input_pos: Optional[torch.Tensor] = None,
         
     | 
| 
       54 
     | 
    
         
            -
             
     | 
| 
      
 56 
     | 
    
         
            +
                  kv_cache: kv_utils.KVCacheEntry = None,
         
     | 
| 
      
 57 
     | 
    
         
            +
              ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
         
     | 
| 
       55 
58 
     | 
    
         
             
                """Forward function of the Gemma2Block.
         
     | 
| 
       56 
59 
     | 
    
         | 
| 
       57 
60 
     | 
    
         
             
                Exactly the same as TransformerBlock but we call the post-attention norm
         
     | 
| 
         @@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock): 
     | 
|
| 
       62 
65 
     | 
    
         
             
                  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
         
     | 
| 
       63 
66 
     | 
    
         
             
                  mask (torch.Tensor): the optional mask tensor.
         
     | 
| 
       64 
67 
     | 
    
         
             
                  input_pos (torch.Tensor): the optional input position tensor.
         
     | 
| 
      
 68 
     | 
    
         
            +
                  kv_cache (KVCacheEntry): the optional kv cache entry.
         
     | 
| 
       65 
69 
     | 
    
         | 
| 
       66 
70 
     | 
    
         
             
                Returns:
         
     | 
| 
       67 
     | 
    
         
            -
                  output activation from this transformer block 
     | 
| 
      
 71 
     | 
    
         
            +
                  output activation from this transformer block, and updated kv cache (if
         
     | 
| 
      
 72 
     | 
    
         
            +
                  passed in).
         
     | 
| 
       68 
73 
     | 
    
         
             
                """
         
     | 
| 
       69 
74 
     | 
    
         | 
| 
       70 
75 
     | 
    
         
             
                x_norm = self.pre_atten_norm(x)
         
     | 
| 
       71 
     | 
    
         
            -
                attn_out = self.atten_func(x_norm, rope, mask, input_pos)
         
     | 
| 
      
 76 
     | 
    
         
            +
                attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
         
     | 
| 
       72 
77 
     | 
    
         
             
                attn_out_norm = self.post_atten_norm(attn_out)
         
     | 
| 
       73 
78 
     | 
    
         
             
                x = x + attn_out_norm
         
     | 
| 
       74 
79 
     | 
    
         
             
                output = x + self.ff(x)
         
     | 
| 
       75 
     | 
    
         
            -
                return output
         
     | 
| 
      
 80 
     | 
    
         
            +
                return output, kv
         
     | 
| 
       76 
81 
     | 
    
         | 
| 
       77 
82 
     | 
    
         | 
| 
       78 
83 
     | 
    
         
             
            class Gemma2(nn.Module):
         
     | 
| 
         @@ -138,24 +143,38 @@ class Gemma2(nn.Module): 
     | 
|
| 
       138 
143 
     | 
    
         
             
                return self.mask_cache.index_select(2, input_pos)
         
     | 
| 
       139 
144 
     | 
    
         | 
| 
       140 
145 
     | 
    
         
             
              @torch.inference_mode
         
     | 
| 
       141 
     | 
    
         
            -
              def forward( 
     | 
| 
       142 
     | 
    
         
            -
             
     | 
| 
      
 146 
     | 
    
         
            +
              def forward(
         
     | 
| 
      
 147 
     | 
    
         
            +
                  self,
         
     | 
| 
      
 148 
     | 
    
         
            +
                  tokens: torch.Tensor,
         
     | 
| 
      
 149 
     | 
    
         
            +
                  input_pos: torch.Tensor,
         
     | 
| 
      
 150 
     | 
    
         
            +
                  kv_cache: kv_utils.KVCache,
         
     | 
| 
      
 151 
     | 
    
         
            +
              ) -> dict[torch.Tensor, kv_utils.KVCache]:
         
     | 
| 
      
 152 
     | 
    
         
            +
                _, seq_len = tokens.size()
         
     | 
| 
       143 
153 
     | 
    
         
             
                assert self.config.max_seq_len >= seq_len, (
         
     | 
| 
       144 
154 
     | 
    
         
             
                    f"Cannot forward sequence of length {seq_len}, max seq length is only"
         
     | 
| 
       145 
155 
     | 
    
         
             
                    f" {self.config.max_seq_len}"
         
     | 
| 
       146 
156 
     | 
    
         
             
                )
         
     | 
| 
      
 157 
     | 
    
         
            +
                assert len(self.transformer_blocks) == len(kv_cache.caches), (
         
     | 
| 
      
 158 
     | 
    
         
            +
                    "The number of transformer blocks and the number of KV cache entries"
         
     | 
| 
      
 159 
     | 
    
         
            +
                    " must be the same."
         
     | 
| 
      
 160 
     | 
    
         
            +
                )
         
     | 
| 
       147 
161 
     | 
    
         | 
| 
       148 
162 
     | 
    
         
             
                cos, sin = self.rope_cache
         
     | 
| 
       149 
163 
     | 
    
         
             
                cos = cos.index_select(0, input_pos)
         
     | 
| 
       150 
164 
     | 
    
         
             
                sin = sin.index_select(0, input_pos)
         
     | 
| 
       151 
165 
     | 
    
         | 
| 
       152 
166 
     | 
    
         
             
                # token embeddings of shape (b, t, n_embd)
         
     | 
| 
       153 
     | 
    
         
            -
                x = self.tok_embedding( 
     | 
| 
      
 167 
     | 
    
         
            +
                x = self.tok_embedding(tokens)
         
     | 
| 
       154 
168 
     | 
    
         
             
                x = x * (self.config.embedding_dim**0.5)
         
     | 
| 
       155 
169 
     | 
    
         | 
| 
      
 170 
     | 
    
         
            +
                updated_kv_entires = []
         
     | 
| 
       156 
171 
     | 
    
         
             
                for i, block in enumerate(self.transformer_blocks):
         
     | 
| 
       157 
172 
     | 
    
         
             
                  mask = self.get_attention_mask(i, input_pos)
         
     | 
| 
       158 
     | 
    
         
            -
                   
     | 
| 
      
 173 
     | 
    
         
            +
                  kv_entry = kv_cache.caches[i] if kv_cache else None
         
     | 
| 
      
 174 
     | 
    
         
            +
                  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
         
     | 
| 
      
 175 
     | 
    
         
            +
                  if kv_entry:
         
     | 
| 
      
 176 
     | 
    
         
            +
                    updated_kv_entires.append(kv_entry)
         
     | 
| 
      
 177 
     | 
    
         
            +
                updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
         
     | 
| 
       159 
178 
     | 
    
         | 
| 
       160 
179 
     | 
    
         
             
                x = self.final_norm(x)
         
     | 
| 
       161 
180 
     | 
    
         
             
                res = self.lm_head(x)  # (b, t, vocab_size)
         
     | 
| 
         @@ -163,7 +182,8 @@ class Gemma2(nn.Module): 
     | 
|
| 
       163 
182 
     | 
    
         
             
                  res = res / self.config.final_logit_softcap
         
     | 
| 
       164 
183 
     | 
    
         
             
                  res = torch.tanh(res)
         
     | 
| 
       165 
184 
     | 
    
         
             
                  res = res * self.config.final_logit_softcap
         
     | 
| 
       166 
     | 
    
         
            -
             
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
      
 186 
     | 
    
         
            +
                return {"logits": res, "kv_cache": updated_kv_cache}
         
     | 
| 
       167 
187 
     | 
    
         | 
| 
       168 
188 
     | 
    
         | 
| 
       169 
189 
     | 
    
         
             
            def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
         
     | 
| 
         @@ -243,14 +263,13 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module: 
     | 
|
| 
       243 
263 
     | 
    
         
             
              return model
         
     | 
| 
       244 
264 
     | 
    
         | 
| 
       245 
265 
     | 
    
         | 
| 
       246 
     | 
    
         
            -
            def define_and_run_2b() -> None:
         
     | 
| 
      
 266 
     | 
    
         
            +
            def define_and_run_2b(checkpoint_path: str) -> None:
         
     | 
| 
       247 
267 
     | 
    
         
             
              """Instantiates and runs a Gemma2 2B model."""
         
     | 
| 
       248 
268 
     | 
    
         | 
| 
       249 
     | 
    
         
            -
              current_dir = Path(__file__).parent.resolve()
         
     | 
| 
      
 269 
     | 
    
         
            +
              current_dir = pathlib.Path(__file__).parent.resolve()
         
     | 
| 
       250 
270 
     | 
    
         
             
              gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
         
     | 
| 
       251 
271 
     | 
    
         
             
              print("Running GEMMA 2")
         
     | 
| 
       252 
272 
     | 
    
         
             
              kv_cache_max_len = 1024
         
     | 
| 
       253 
     | 
    
         
            -
              checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
         
     | 
| 
       254 
273 
     | 
    
         
             
              model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
         
     | 
| 
       255 
274 
     | 
    
         
             
              toks = torch.from_numpy(
         
     | 
| 
       256 
275 
     | 
    
         
             
                  np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
         
     | 
| 
         @@ -258,11 +277,13 @@ def define_and_run_2b() -> None: 
     | 
|
| 
       258 
277 
     | 
    
         
             
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
         
     | 
| 
       259 
278 
     | 
    
         
             
              tokens[0, :9] = toks
         
     | 
| 
       260 
279 
     | 
    
         
             
              input_pos = torch.arange(0, kv_cache_max_len)
         
     | 
| 
       261 
     | 
    
         
            -
               
     | 
| 
       262 
     | 
    
         
            -
               
     | 
| 
      
 280 
     | 
    
         
            +
              kv = kv_utils.KVCache.from_model_config(model.config)
         
     | 
| 
      
 281 
     | 
    
         
            +
              out = model.forward(tokens, input_pos, kv)
         
     | 
| 
      
 282 
     | 
    
         
            +
              out_final = out["logits"][0, 8, :]
         
     | 
| 
       263 
283 
     | 
    
         
             
              assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
         
     | 
| 
       264 
284 
     | 
    
         | 
| 
       265 
285 
     | 
    
         | 
| 
       266 
286 
     | 
    
         
             
            if __name__ == "__main__":
         
     | 
| 
       267 
287 
     | 
    
         
             
              torch.set_printoptions(sci_mode=True)
         
     | 
| 
       268 
     | 
    
         
            -
               
     | 
| 
      
 288 
     | 
    
         
            +
              path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
         
     | 
| 
      
 289 
     | 
    
         
            +
              define_and_run_2b(path)
         
     | 
| 
         @@ -12,16 +12,15 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
            # Please use with caution.
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            """Example of converting a Phi-2 model to multi-signature tflite model."""
         
     | 
| 
       18 
17 
     | 
    
         | 
| 
       19 
18 
     | 
    
         
             
            import os
         
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
      
 19 
     | 
    
         
            +
            import pathlib
         
     | 
| 
       21 
20 
     | 
    
         | 
| 
       22 
21 
     | 
    
         
             
            import ai_edge_torch
         
     | 
| 
       23 
     | 
    
         
            -
            from ai_edge_torch.generative.examples. 
     | 
| 
       24 
     | 
    
         
            -
            from ai_edge_torch.generative.layers 
     | 
| 
      
 22 
     | 
    
         
            +
            from ai_edge_torch.generative.examples.phi import phi2
         
     | 
| 
      
 23 
     | 
    
         
            +
            from ai_edge_torch.generative.layers import kv_cache
         
     | 
| 
       25 
24 
     | 
    
         
             
            from ai_edge_torch.generative.quantize import quant_recipes
         
     | 
| 
       26 
25 
     | 
    
         
             
            import torch
         
     | 
| 
       27 
26 
     | 
    
         | 
| 
         @@ -32,9 +31,8 @@ def convert_phi2_to_tflite( 
     | 
|
| 
       32 
31 
     | 
    
         
             
                kv_cache_max_len: int = 1024,
         
     | 
| 
       33 
32 
     | 
    
         
             
                quantize: bool = True,
         
     | 
| 
       34 
33 
     | 
    
         
             
            ):
         
     | 
| 
       35 
     | 
    
         
            -
              """ 
     | 
| 
      
 34 
     | 
    
         
            +
              """Converts a Phi-2 model to multi-signature tflite model.
         
     | 
| 
       36 
35 
     | 
    
         | 
| 
       37 
     | 
    
         
            -
              tflite model.
         
     | 
| 
       38 
36 
     | 
    
         
             
              Args:
         
     | 
| 
       39 
37 
     | 
    
         
             
                  checkpoint_path (str): The filepath to the model checkpoint, or directory
         
     | 
| 
       40 
38 
     | 
    
         
             
                    holding the checkpoint.
         
     | 
| 
         @@ -53,7 +51,7 @@ def convert_phi2_to_tflite( 
     | 
|
| 
       53 
51 
     | 
    
         
             
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         
     | 
| 
       54 
52 
     | 
    
         
             
              decode_token = torch.tensor([[0]], dtype=torch.long)
         
     | 
| 
       55 
53 
     | 
    
         
             
              decode_input_pos = torch.tensor([0], dtype=torch.int64)
         
     | 
| 
       56 
     | 
    
         
            -
              kv =  
     | 
| 
      
 54 
     | 
    
         
            +
              kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
         
     | 
| 
       57 
55 
     | 
    
         | 
| 
       58 
56 
     | 
    
         
             
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         
     | 
| 
       59 
57 
     | 
    
         
             
              edge_model = (
         
     | 
| 
         @@ -77,11 +75,12 @@ def convert_phi2_to_tflite( 
     | 
|
| 
       77 
75 
     | 
    
         
             
                  )
         
     | 
| 
       78 
76 
     | 
    
         
             
                  .convert(quant_config=quant_config)
         
     | 
| 
       79 
77 
     | 
    
         
             
              )
         
     | 
| 
      
 78 
     | 
    
         
            +
              quant_suffix = 'q8' if quantize else 'f32'
         
     | 
| 
       80 
79 
     | 
    
         
             
              edge_model.export(
         
     | 
| 
       81 
     | 
    
         
            -
                  f'/tmp/ 
     | 
| 
      
 80 
     | 
    
         
            +
                  f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
         
     | 
| 
       82 
81 
     | 
    
         
             
              )
         
     | 
| 
       83 
82 
     | 
    
         | 
| 
       84 
83 
     | 
    
         | 
| 
       85 
84 
     | 
    
         
             
            if __name__ == '__main__':
         
     | 
| 
       86 
     | 
    
         
            -
               
     | 
| 
       87 
     | 
    
         
            -
              convert_phi2_to_tflite( 
     | 
| 
      
 85 
     | 
    
         
            +
              path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
         
     | 
| 
      
 86 
     | 
    
         
            +
              convert_phi2_to_tflite(path)
         
     | 
| 
         @@ -12,26 +12,22 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
            # Note: This is an experimental version of phi2 with external KV cache.
         
     | 
| 
       18 
     | 
    
         
            -
            # Please use with caution.
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            """Example of building a Phi-2 model."""
         
     | 
| 
       19 
17 
     | 
    
         | 
| 
       20 
18 
     | 
    
         
             
            import os
         
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
            from typing import Tuple
         
     | 
| 
      
 19 
     | 
    
         
            +
            import pathlib
         
     | 
| 
       23 
20 
     | 
    
         | 
| 
      
 21 
     | 
    
         
            +
            from ai_edge_torch.generative.layers import attention
         
     | 
| 
       24 
22 
     | 
    
         
             
            from ai_edge_torch.generative.layers import builder
         
     | 
| 
      
 23 
     | 
    
         
            +
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         
     | 
| 
       25 
24 
     | 
    
         
             
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         
     | 
| 
       26 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental import attention
         
     | 
| 
       27 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
         
     | 
| 
       28 
25 
     | 
    
         
             
            import ai_edge_torch.generative.layers.model_config as cfg
         
     | 
| 
       29 
26 
     | 
    
         
             
            import ai_edge_torch.generative.utilities.loader as loading_utils
         
     | 
| 
       30 
27 
     | 
    
         
             
            import numpy as np
         
     | 
| 
       31 
28 
     | 
    
         
             
            import torch
         
     | 
| 
       32 
29 
     | 
    
         
             
            from torch import nn
         
     | 
| 
       33 
30 
     | 
    
         | 
| 
       34 
     | 
    
         
            -
             
     | 
| 
       35 
31 
     | 
    
         
             
            TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
         
     | 
| 
       36 
32 
     | 
    
         
             
                ff_up_proj="model.layers.{}.mlp.fc1",
         
     | 
| 
       37 
33 
     | 
    
         
             
                ff_down_proj="model.layers.{}.mlp.fc2",
         
     | 
| 
         @@ -89,13 +85,17 @@ class Phi2(nn.Module): 
     | 
|
| 
       89 
85 
     | 
    
         
             
                  self,
         
     | 
| 
       90 
86 
     | 
    
         
             
                  tokens: torch.Tensor,
         
     | 
| 
       91 
87 
     | 
    
         
             
                  input_pos: torch.Tensor,
         
     | 
| 
       92 
     | 
    
         
            -
                  kv_cache: kv_utils. 
     | 
| 
       93 
     | 
    
         
            -
              ) ->  
     | 
| 
      
 88 
     | 
    
         
            +
                  kv_cache: kv_utils.KVCache,
         
     | 
| 
      
 89 
     | 
    
         
            +
              ) -> dict[torch.Tensor, kv_utils.KVCache]:
         
     | 
| 
       94 
90 
     | 
    
         
             
                _, seq_len = tokens.size()
         
     | 
| 
       95 
91 
     | 
    
         
             
                assert self.config.max_seq_len >= seq_len, (
         
     | 
| 
       96 
92 
     | 
    
         
             
                    f"Cannot forward sequence of length {seq_len}, max seq length is only"
         
     | 
| 
       97 
93 
     | 
    
         
             
                    f" {self.config.max_seq_len}"
         
     | 
| 
       98 
94 
     | 
    
         
             
                )
         
     | 
| 
      
 95 
     | 
    
         
            +
                assert len(self.transformer_blocks) == len(kv_cache.caches), (
         
     | 
| 
      
 96 
     | 
    
         
            +
                    "The number of transformer blocks and the number of KV cache entries"
         
     | 
| 
      
 97 
     | 
    
         
            +
                    " must be the same."
         
     | 
| 
      
 98 
     | 
    
         
            +
                )
         
     | 
| 
       99 
99 
     | 
    
         | 
| 
       100 
100 
     | 
    
         
             
                cos, sin = self.rope_cache
         
     | 
| 
       101 
101 
     | 
    
         
             
                cos = cos.index_select(0, input_pos)
         
     | 
| 
         @@ -111,11 +111,11 @@ class Phi2(nn.Module): 
     | 
|
| 
       111 
111 
     | 
    
         
             
                  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
         
     | 
| 
       112 
112 
     | 
    
         
             
                  if kv_entry:
         
     | 
| 
       113 
113 
     | 
    
         
             
                    updated_kv_entires.append(kv_entry)
         
     | 
| 
       114 
     | 
    
         
            -
                updated_kv_cache = kv_utils. 
     | 
| 
      
 114 
     | 
    
         
            +
                updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
         
     | 
| 
       115 
115 
     | 
    
         | 
| 
       116 
116 
     | 
    
         
             
                x = self.final_norm(x)
         
     | 
| 
       117 
     | 
    
         
            -
                 
     | 
| 
       118 
     | 
    
         
            -
                return  
     | 
| 
      
 117 
     | 
    
         
            +
                logits = self.lm_head(x)  # (b, t, vocab_size)
         
     | 
| 
      
 118 
     | 
    
         
            +
                return {"logits": logits, "kv_cache": updated_kv_cache}
         
     | 
| 
       119 
119 
     | 
    
         | 
| 
       120 
120 
     | 
    
         | 
| 
       121 
121 
     | 
    
         
             
            def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
         
     | 
| 
         @@ -169,39 +169,37 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: 
     | 
|
| 
       169 
169 
     | 
    
         
             
              return config
         
     | 
| 
       170 
170 
     | 
    
         | 
| 
       171 
171 
     | 
    
         | 
| 
       172 
     | 
    
         
            -
            def build_model(
         
     | 
| 
       173 
     | 
    
         
            -
                checkpoint_path: str, test_model: bool = False, **kwargs
         
     | 
| 
       174 
     | 
    
         
            -
            ) -> nn.Module:
         
     | 
| 
      
 172 
     | 
    
         
            +
            def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
         
     | 
| 
       175 
173 
     | 
    
         
             
              """Instantiates the model instance and load checkpoint if provided."""
         
     | 
| 
       176 
     | 
    
         
            -
              config = (
         
     | 
| 
       177 
     | 
    
         
            -
                  get_fake_model_config(**kwargs)
         
     | 
| 
       178 
     | 
    
         
            -
                  if test_model
         
     | 
| 
       179 
     | 
    
         
            -
                  else get_model_config(**kwargs)
         
     | 
| 
       180 
     | 
    
         
            -
              )
         
     | 
| 
      
 174 
     | 
    
         
            +
              config = get_model_config(**kwargs)
         
     | 
| 
       181 
175 
     | 
    
         
             
              model = Phi2(config)
         
     | 
| 
       182 
     | 
    
         
            -
               
     | 
| 
       183 
     | 
    
         
            -
             
     | 
| 
       184 
     | 
    
         
            -
                loader.load(model)
         
     | 
| 
      
 176 
     | 
    
         
            +
              loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
         
     | 
| 
      
 177 
     | 
    
         
            +
              loader.load(model)
         
     | 
| 
       185 
178 
     | 
    
         
             
              model.eval()
         
     | 
| 
       186 
179 
     | 
    
         
             
              return model
         
     | 
| 
       187 
180 
     | 
    
         | 
| 
       188 
181 
     | 
    
         | 
| 
       189 
     | 
    
         
            -
            def define_and_run(checkpoint_path: str 
     | 
| 
      
 182 
     | 
    
         
            +
            def define_and_run(checkpoint_path: str) -> None:
         
     | 
| 
       190 
183 
     | 
    
         
             
              """Instantiates and runs a Phi-2 model."""
         
     | 
| 
       191 
184 
     | 
    
         | 
| 
      
 185 
     | 
    
         
            +
              current_dir = pathlib.Path(__file__).parent.resolve()
         
     | 
| 
      
 186 
     | 
    
         
            +
              phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
         
     | 
| 
       192 
187 
     | 
    
         
             
              kv_cache_max_len = 1024
         
     | 
| 
       193 
     | 
    
         
            -
              model = build_model(
         
     | 
| 
       194 
     | 
    
         
            -
                  checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
         
     | 
| 
       195 
     | 
    
         
            -
              )
         
     | 
| 
      
 188 
     | 
    
         
            +
              model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
         
     | 
| 
       196 
189 
     | 
    
         
             
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         
     | 
| 
       197 
190 
     | 
    
         
             
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
         
     | 
| 
       198 
191 
     | 
    
         
             
              tokens[0, :4] = idx
         
     | 
| 
       199 
192 
     | 
    
         
             
              input_pos = torch.arange(0, kv_cache_max_len)
         
     | 
| 
       200 
     | 
    
         
            -
              kv = kv_utils. 
     | 
| 
       201 
     | 
    
         
            -
               
     | 
| 
       202 
     | 
    
         
            -
              print( 
     | 
| 
      
 193 
     | 
    
         
            +
              kv = kv_utils.KVCache.from_model_config(model.config)
         
     | 
| 
      
 194 
     | 
    
         
            +
              output = model.forward(tokens, input_pos, kv)
         
     | 
| 
      
 195 
     | 
    
         
            +
              print("comparing with goldens..")
         
     | 
| 
      
 196 
     | 
    
         
            +
              assert torch.allclose(
         
     | 
| 
      
 197 
     | 
    
         
            +
                  phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
         
     | 
| 
      
 198 
     | 
    
         
            +
              )
         
     | 
| 
       203 
199 
     | 
    
         | 
| 
       204 
200 
     | 
    
         | 
| 
       205 
201 
     | 
    
         
             
            if __name__ == "__main__":
         
     | 
| 
       206 
     | 
    
         
            -
              input_checkpoint_path = os.path.join( 
     | 
| 
      
 202 
     | 
    
         
            +
              input_checkpoint_path = os.path.join(
         
     | 
| 
      
 203 
     | 
    
         
            +
                  pathlib.Path.home(), "Downloads/llm_data/phi2"
         
     | 
| 
      
 204 
     | 
    
         
            +
              )
         
     | 
| 
       207 
205 
     | 
    
         
             
              define_and_run(input_checkpoint_path)
         
     |