ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
 - ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
 - ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
 - ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
 - ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
 - ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
 - ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
 - ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
 - ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
 - ai_edge_torch/generative/layers/attention.py +60 -63
 - ai_edge_torch/generative/layers/kv_cache.py +160 -51
 - ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
 - ai_edge_torch/generative/test/test_model_conversion.py +71 -33
 - ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
 - ai_edge_torch/generative/test/utils.py +54 -0
 - ai_edge_torch/version.py +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +22 -32
 - ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
 - ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
 - ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
 - ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
 - ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
 - ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
 - ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
 - ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
 - ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
 - ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
 - ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
 - /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
 
| 
         @@ -1,205 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            # Example of building a TinyLlama model from the Edge Generative API layers.
         
     | 
| 
       16 
     | 
    
         
            -
            #
         
     | 
| 
       17 
     | 
    
         
            -
            # Note: This is an experimental version of TinyLlama with external KV cache.
         
     | 
| 
       18 
     | 
    
         
            -
            # Please use with caution.
         
     | 
| 
       19 
     | 
    
         
            -
             
     | 
| 
       20 
     | 
    
         
            -
            import os
         
     | 
| 
       21 
     | 
    
         
            -
            from pathlib import Path
         
     | 
| 
       22 
     | 
    
         
            -
            from typing import Tuple
         
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
            from ai_edge_torch.generative.layers import builder
         
     | 
| 
       25 
     | 
    
         
            -
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         
     | 
| 
       26 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental import attention
         
     | 
| 
       27 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
         
     | 
| 
       28 
     | 
    
         
            -
            import ai_edge_torch.generative.layers.model_config as cfg
         
     | 
| 
       29 
     | 
    
         
            -
            import ai_edge_torch.generative.utilities.loader as loading_utils
         
     | 
| 
       30 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       31 
     | 
    
         
            -
            import torch
         
     | 
| 
       32 
     | 
    
         
            -
            from torch import nn
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
             
     | 
| 
       35 
     | 
    
         
            -
            TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
         
     | 
| 
       36 
     | 
    
         
            -
                ff_up_proj="model.layers.{}.mlp.up_proj",
         
     | 
| 
       37 
     | 
    
         
            -
                ff_down_proj="model.layers.{}.mlp.down_proj",
         
     | 
| 
       38 
     | 
    
         
            -
                ff_gate_proj="model.layers.{}.mlp.gate_proj",
         
     | 
| 
       39 
     | 
    
         
            -
                attn_query_proj="model.layers.{}.self_attn.q_proj",
         
     | 
| 
       40 
     | 
    
         
            -
                attn_key_proj="model.layers.{}.self_attn.k_proj",
         
     | 
| 
       41 
     | 
    
         
            -
                attn_value_proj="model.layers.{}.self_attn.v_proj",
         
     | 
| 
       42 
     | 
    
         
            -
                attn_output_proj="model.layers.{}.self_attn.o_proj",
         
     | 
| 
       43 
     | 
    
         
            -
                pre_attn_norm="model.layers.{}.input_layernorm",
         
     | 
| 
       44 
     | 
    
         
            -
                post_attn_norm="model.layers.{}.post_attention_layernorm",
         
     | 
| 
       45 
     | 
    
         
            -
                embedding="model.embed_tokens",
         
     | 
| 
       46 
     | 
    
         
            -
                final_norm="model.norm",
         
     | 
| 
       47 
     | 
    
         
            -
                lm_head="lm_head",
         
     | 
| 
       48 
     | 
    
         
            -
            )
         
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
            class TinyLLamma(nn.Module):
         
     | 
| 
       52 
     | 
    
         
            -
              """A TinyLlama model built from the Edge Generative API layers."""
         
     | 
| 
       53 
     | 
    
         
            -
             
     | 
| 
       54 
     | 
    
         
            -
              def __init__(self, config: cfg.ModelConfig):
         
     | 
| 
       55 
     | 
    
         
            -
                super().__init__()
         
     | 
| 
       56 
     | 
    
         
            -
             
     | 
| 
       57 
     | 
    
         
            -
                self.config = config
         
     | 
| 
       58 
     | 
    
         
            -
                # Construct model layers.
         
     | 
| 
       59 
     | 
    
         
            -
                self.lm_head = nn.Linear(
         
     | 
| 
       60 
     | 
    
         
            -
                    config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
         
     | 
| 
       61 
     | 
    
         
            -
                )
         
     | 
| 
       62 
     | 
    
         
            -
                self.tok_embedding = nn.Embedding(
         
     | 
| 
       63 
     | 
    
         
            -
                    config.vocab_size, config.embedding_dim, padding_idx=0
         
     | 
| 
       64 
     | 
    
         
            -
                )
         
     | 
| 
       65 
     | 
    
         
            -
                self.transformer_blocks = nn.ModuleList(
         
     | 
| 
       66 
     | 
    
         
            -
                    attention.TransformerBlock(config) for _ in range(config.num_layers)
         
     | 
| 
       67 
     | 
    
         
            -
                )
         
     | 
| 
       68 
     | 
    
         
            -
                self.final_norm = builder.build_norm(
         
     | 
| 
       69 
     | 
    
         
            -
                    config.embedding_dim,
         
     | 
| 
       70 
     | 
    
         
            -
                    config.final_norm_config,
         
     | 
| 
       71 
     | 
    
         
            -
                )
         
     | 
| 
       72 
     | 
    
         
            -
                self.rope_cache = attn_utils.build_rope_cache(
         
     | 
| 
       73 
     | 
    
         
            -
                    size=config.kv_cache_max,
         
     | 
| 
       74 
     | 
    
         
            -
                    dim=int(
         
     | 
| 
       75 
     | 
    
         
            -
                        config.attn_config.rotary_percentage * config.attn_config.head_dim
         
     | 
| 
       76 
     | 
    
         
            -
                    ),
         
     | 
| 
       77 
     | 
    
         
            -
                    base=10_000,
         
     | 
| 
       78 
     | 
    
         
            -
                    condense_ratio=1,
         
     | 
| 
       79 
     | 
    
         
            -
                    dtype=torch.float32,
         
     | 
| 
       80 
     | 
    
         
            -
                    device=torch.device("cpu"),
         
     | 
| 
       81 
     | 
    
         
            -
                )
         
     | 
| 
       82 
     | 
    
         
            -
                self.mask_cache = attn_utils.build_causal_mask_cache(
         
     | 
| 
       83 
     | 
    
         
            -
                    size=config.kv_cache_max,
         
     | 
| 
       84 
     | 
    
         
            -
                    dtype=torch.float32,
         
     | 
| 
       85 
     | 
    
         
            -
                    device=torch.device("cpu"),
         
     | 
| 
       86 
     | 
    
         
            -
                )
         
     | 
| 
       87 
     | 
    
         
            -
                self.config = config
         
     | 
| 
       88 
     | 
    
         
            -
             
     | 
| 
       89 
     | 
    
         
            -
              @torch.inference_mode
         
     | 
| 
       90 
     | 
    
         
            -
              def forward(
         
     | 
| 
       91 
     | 
    
         
            -
                  self,
         
     | 
| 
       92 
     | 
    
         
            -
                  tokens: torch.Tensor,
         
     | 
| 
       93 
     | 
    
         
            -
                  input_pos: torch.Tensor,
         
     | 
| 
       94 
     | 
    
         
            -
                  kv_cache: kv_utils.EKVCache,
         
     | 
| 
       95 
     | 
    
         
            -
              ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
         
     | 
| 
       96 
     | 
    
         
            -
                _, seq_len = tokens.size()
         
     | 
| 
       97 
     | 
    
         
            -
                assert self.config.max_seq_len >= seq_len, (
         
     | 
| 
       98 
     | 
    
         
            -
                    f"Cannot forward sequence of length {seq_len}, max seq length is only"
         
     | 
| 
       99 
     | 
    
         
            -
                    f" {self.config.max_seq_len}"
         
     | 
| 
       100 
     | 
    
         
            -
                )
         
     | 
| 
       101 
     | 
    
         
            -
             
     | 
| 
       102 
     | 
    
         
            -
                cos, sin = self.rope_cache
         
     | 
| 
       103 
     | 
    
         
            -
                cos = cos.index_select(0, input_pos)
         
     | 
| 
       104 
     | 
    
         
            -
                sin = sin.index_select(0, input_pos)
         
     | 
| 
       105 
     | 
    
         
            -
                mask = self.mask_cache.index_select(2, input_pos)
         
     | 
| 
       106 
     | 
    
         
            -
                mask = mask[:, :, :, : self.config.kv_cache_max]
         
     | 
| 
       107 
     | 
    
         
            -
             
     | 
| 
       108 
     | 
    
         
            -
                # token embeddings of shape (b, t, n_embd)
         
     | 
| 
       109 
     | 
    
         
            -
                x = self.tok_embedding(tokens)
         
     | 
| 
       110 
     | 
    
         
            -
             
     | 
| 
       111 
     | 
    
         
            -
                updated_kv_entires = []
         
     | 
| 
       112 
     | 
    
         
            -
                for i, block in enumerate(self.transformer_blocks):
         
     | 
| 
       113 
     | 
    
         
            -
                  kv_entry = kv_cache.caches[i] if kv_cache else None
         
     | 
| 
       114 
     | 
    
         
            -
                  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
         
     | 
| 
       115 
     | 
    
         
            -
                  if kv_entry:
         
     | 
| 
       116 
     | 
    
         
            -
                    updated_kv_entires.append(kv_entry)
         
     | 
| 
       117 
     | 
    
         
            -
                updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
         
     | 
| 
       118 
     | 
    
         
            -
             
     | 
| 
       119 
     | 
    
         
            -
                x = self.final_norm(x)
         
     | 
| 
       120 
     | 
    
         
            -
                res = self.lm_head(x)  # (b, t, vocab_size)
         
     | 
| 
       121 
     | 
    
         
            -
                return res, updated_kv_cache
         
     | 
| 
       122 
     | 
    
         
            -
             
     | 
| 
       123 
     | 
    
         
            -
             
     | 
| 
       124 
     | 
    
         
            -
            def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
         
     | 
| 
       125 
     | 
    
         
            -
              """Returns the model config for a TinyLlama model.
         
     | 
| 
       126 
     | 
    
         
            -
             
     | 
| 
       127 
     | 
    
         
            -
              Args:
         
     | 
| 
       128 
     | 
    
         
            -
                kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
         
     | 
| 
       129 
     | 
    
         
            -
                  is 1024.
         
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
              Returns:
         
     | 
| 
       132 
     | 
    
         
            -
                The model config for a TinyLlama model.
         
     | 
| 
       133 
     | 
    
         
            -
              """
         
     | 
| 
       134 
     | 
    
         
            -
              attn_config = cfg.AttentionConfig(
         
     | 
| 
       135 
     | 
    
         
            -
                  num_heads=32,
         
     | 
| 
       136 
     | 
    
         
            -
                  head_dim=64,
         
     | 
| 
       137 
     | 
    
         
            -
                  num_query_groups=4,
         
     | 
| 
       138 
     | 
    
         
            -
                  rotary_percentage=1.0,
         
     | 
| 
       139 
     | 
    
         
            -
              )
         
     | 
| 
       140 
     | 
    
         
            -
              ff_config = cfg.FeedForwardConfig(
         
     | 
| 
       141 
     | 
    
         
            -
                  type=cfg.FeedForwardType.GATED,
         
     | 
| 
       142 
     | 
    
         
            -
                  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
         
     | 
| 
       143 
     | 
    
         
            -
                  intermediate_size=5632,
         
     | 
| 
       144 
     | 
    
         
            -
              )
         
     | 
| 
       145 
     | 
    
         
            -
              norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
         
     | 
| 
       146 
     | 
    
         
            -
              config = cfg.ModelConfig(
         
     | 
| 
       147 
     | 
    
         
            -
                  vocab_size=32000,
         
     | 
| 
       148 
     | 
    
         
            -
                  num_layers=22,
         
     | 
| 
       149 
     | 
    
         
            -
                  max_seq_len=2048,
         
     | 
| 
       150 
     | 
    
         
            -
                  embedding_dim=2048,
         
     | 
| 
       151 
     | 
    
         
            -
                  kv_cache_max_len=kv_cache_max_len,
         
     | 
| 
       152 
     | 
    
         
            -
                  attn_config=attn_config,
         
     | 
| 
       153 
     | 
    
         
            -
                  ff_config=ff_config,
         
     | 
| 
       154 
     | 
    
         
            -
                  pre_attention_norm_config=norm_config,
         
     | 
| 
       155 
     | 
    
         
            -
                  post_attention_norm_config=norm_config,
         
     | 
| 
       156 
     | 
    
         
            -
                  final_norm_config=norm_config,
         
     | 
| 
       157 
     | 
    
         
            -
                  enable_hlfb=True,
         
     | 
| 
       158 
     | 
    
         
            -
              )
         
     | 
| 
       159 
     | 
    
         
            -
              return config
         
     | 
| 
       160 
     | 
    
         
            -
             
     | 
| 
       161 
     | 
    
         
            -
             
     | 
| 
       162 
     | 
    
         
            -
            def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
         
     | 
| 
       163 
     | 
    
         
            -
              config = get_model_config(**kwargs)
         
     | 
| 
       164 
     | 
    
         
            -
              config.vocab_size = 128
         
     | 
| 
       165 
     | 
    
         
            -
              config.num_layers = 2
         
     | 
| 
       166 
     | 
    
         
            -
              config.ff_config.intermediate_size = 256
         
     | 
| 
       167 
     | 
    
         
            -
              return config
         
     | 
| 
       168 
     | 
    
         
            -
             
     | 
| 
       169 
     | 
    
         
            -
             
     | 
| 
       170 
     | 
    
         
            -
            def build_model(
         
     | 
| 
       171 
     | 
    
         
            -
                checkpoint_path: str, test_model: bool = False, **kwargs
         
     | 
| 
       172 
     | 
    
         
            -
            ) -> nn.Module:
         
     | 
| 
       173 
     | 
    
         
            -
              """Instantiates the model instance and load checkpoint if provided."""
         
     | 
| 
       174 
     | 
    
         
            -
              config = (
         
     | 
| 
       175 
     | 
    
         
            -
                  get_fake_model_config(**kwargs)
         
     | 
| 
       176 
     | 
    
         
            -
                  if test_model
         
     | 
| 
       177 
     | 
    
         
            -
                  else get_model_config(**kwargs)
         
     | 
| 
       178 
     | 
    
         
            -
              )
         
     | 
| 
       179 
     | 
    
         
            -
              model = TinyLLamma(config)
         
     | 
| 
       180 
     | 
    
         
            -
              if checkpoint_path is not None:
         
     | 
| 
       181 
     | 
    
         
            -
                loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
         
     | 
| 
       182 
     | 
    
         
            -
                loader.load(model)
         
     | 
| 
       183 
     | 
    
         
            -
              model.eval()
         
     | 
| 
       184 
     | 
    
         
            -
              return model
         
     | 
| 
       185 
     | 
    
         
            -
             
     | 
| 
       186 
     | 
    
         
            -
             
     | 
| 
       187 
     | 
    
         
            -
            def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
         
     | 
| 
       188 
     | 
    
         
            -
              """Instantiates and runs a TinyLlama model."""
         
     | 
| 
       189 
     | 
    
         
            -
             
     | 
| 
       190 
     | 
    
         
            -
              kv_cache_max_len = 1024
         
     | 
| 
       191 
     | 
    
         
            -
              model = build_model(
         
     | 
| 
       192 
     | 
    
         
            -
                  checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
         
     | 
| 
       193 
     | 
    
         
            -
              )
         
     | 
| 
       194 
     | 
    
         
            -
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         
     | 
| 
       195 
     | 
    
         
            -
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
         
     | 
| 
       196 
     | 
    
         
            -
              tokens[0, :4] = idx
         
     | 
| 
       197 
     | 
    
         
            -
              input_pos = torch.arange(0, kv_cache_max_len)
         
     | 
| 
       198 
     | 
    
         
            -
              kv = kv_utils.EKVCache.from_model_config(model.config)
         
     | 
| 
       199 
     | 
    
         
            -
              print("running an inference")
         
     | 
| 
       200 
     | 
    
         
            -
              print(model.forward(tokens, input_pos, kv))
         
     | 
| 
       201 
     | 
    
         
            -
             
     | 
| 
       202 
     | 
    
         
            -
             
     | 
| 
       203 
     | 
    
         
            -
            if __name__ == "__main__":
         
     | 
| 
       204 
     | 
    
         
            -
              input_checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
         
     | 
| 
       205 
     | 
    
         
            -
              define_and_run(input_checkpoint_path)
         
     | 
| 
         @@ -1,14 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
         @@ -1,67 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
     | 
    
         
            -
            import os
         
     | 
| 
       17 
     | 
    
         
            -
            from pathlib import Path
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
            import ai_edge_torch
         
     | 
| 
       20 
     | 
    
         
            -
            from ai_edge_torch.generative.examples.phi2 import phi2
         
     | 
| 
       21 
     | 
    
         
            -
            from ai_edge_torch.generative.quantize import quant_recipes
         
     | 
| 
       22 
     | 
    
         
            -
            import torch
         
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
            def convert_phi2_to_tflite(
         
     | 
| 
       26 
     | 
    
         
            -
                checkpoint_path: str,
         
     | 
| 
       27 
     | 
    
         
            -
                prefill_seq_len: int = 512,
         
     | 
| 
       28 
     | 
    
         
            -
                kv_cache_max_len: int = 1024,
         
     | 
| 
       29 
     | 
    
         
            -
                quantize: bool = True,
         
     | 
| 
       30 
     | 
    
         
            -
            ):
         
     | 
| 
       31 
     | 
    
         
            -
              """Converts a Phi-2 model to multi-signature tflite model.
         
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
              Args:
         
     | 
| 
       34 
     | 
    
         
            -
                  checkpoint_path (str): The filepath to the model checkpoint, or directory
         
     | 
| 
       35 
     | 
    
         
            -
                    holding the checkpoint.
         
     | 
| 
       36 
     | 
    
         
            -
                  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
         
     | 
| 
       37 
     | 
    
         
            -
                    Defaults to 512.
         
     | 
| 
       38 
     | 
    
         
            -
                  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
         
     | 
| 
       39 
     | 
    
         
            -
                    including both prefill and decode. Defaults to 1024.
         
     | 
| 
       40 
     | 
    
         
            -
                  quantize (bool, optional): Whether the model should be quanized. Defaults
         
     | 
| 
       41 
     | 
    
         
            -
                    to True.
         
     | 
| 
       42 
     | 
    
         
            -
              """
         
     | 
| 
       43 
     | 
    
         
            -
              pytorch_model = phi2.build_model(
         
     | 
| 
       44 
     | 
    
         
            -
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         
     | 
| 
       45 
     | 
    
         
            -
              )
         
     | 
| 
       46 
     | 
    
         
            -
              # Tensors used to trace the model graph during conversion.
         
     | 
| 
       47 
     | 
    
         
            -
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
         
     | 
| 
       48 
     | 
    
         
            -
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         
     | 
| 
       49 
     | 
    
         
            -
              decode_token = torch.tensor([[0]], dtype=torch.long)
         
     | 
| 
       50 
     | 
    
         
            -
              decode_input_pos = torch.tensor([0], dtype=torch.int64)
         
     | 
| 
       51 
     | 
    
         
            -
             
     | 
| 
       52 
     | 
    
         
            -
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         
     | 
| 
       53 
     | 
    
         
            -
              edge_model = (
         
     | 
| 
       54 
     | 
    
         
            -
                  ai_edge_torch.signature(
         
     | 
| 
       55 
     | 
    
         
            -
                      'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
         
     | 
| 
       56 
     | 
    
         
            -
                  )
         
     | 
| 
       57 
     | 
    
         
            -
                  .signature('decode', pytorch_model, (decode_token, decode_input_pos))
         
     | 
| 
       58 
     | 
    
         
            -
                  .convert(quant_config=quant_config)
         
     | 
| 
       59 
     | 
    
         
            -
              )
         
     | 
| 
       60 
     | 
    
         
            -
              edge_model.export(
         
     | 
| 
       61 
     | 
    
         
            -
                  f'/tmp/phi2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
         
     | 
| 
       62 
     | 
    
         
            -
              )
         
     | 
| 
       63 
     | 
    
         
            -
             
     | 
| 
       64 
     | 
    
         
            -
             
     | 
| 
       65 
     | 
    
         
            -
            if __name__ == '__main__':
         
     | 
| 
       66 
     | 
    
         
            -
              checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
         
     | 
| 
       67 
     | 
    
         
            -
              convert_phi2_to_tflite(checkpoint_path)
         
     | 
| 
         @@ -1,189 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            # Example of building phi-2 model from the Edge Generative API layers.
         
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
     | 
    
         
            -
            import os
         
     | 
| 
       19 
     | 
    
         
            -
            from pathlib import Path
         
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
       21 
     | 
    
         
            -
            from ai_edge_torch.generative.layers import attention
         
     | 
| 
       22 
     | 
    
         
            -
            from ai_edge_torch.generative.layers import builder
         
     | 
| 
       23 
     | 
    
         
            -
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         
     | 
| 
       24 
     | 
    
         
            -
            import ai_edge_torch.generative.layers.model_config as cfg
         
     | 
| 
       25 
     | 
    
         
            -
            import ai_edge_torch.generative.utilities.loader as loading_utils
         
     | 
| 
       26 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       27 
     | 
    
         
            -
            import torch
         
     | 
| 
       28 
     | 
    
         
            -
            from torch import nn
         
     | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
       30 
     | 
    
         
            -
            TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
         
     | 
| 
       31 
     | 
    
         
            -
                ff_up_proj="model.layers.{}.mlp.fc1",
         
     | 
| 
       32 
     | 
    
         
            -
                ff_down_proj="model.layers.{}.mlp.fc2",
         
     | 
| 
       33 
     | 
    
         
            -
                attn_query_proj="model.layers.{}.self_attn.q_proj",
         
     | 
| 
       34 
     | 
    
         
            -
                attn_key_proj="model.layers.{}.self_attn.k_proj",
         
     | 
| 
       35 
     | 
    
         
            -
                attn_value_proj="model.layers.{}.self_attn.v_proj",
         
     | 
| 
       36 
     | 
    
         
            -
                attn_output_proj="model.layers.{}.self_attn.dense",
         
     | 
| 
       37 
     | 
    
         
            -
                pre_attn_norm="model.layers.{}.input_layernorm",
         
     | 
| 
       38 
     | 
    
         
            -
                embedding="model.embed_tokens",
         
     | 
| 
       39 
     | 
    
         
            -
                final_norm="model.final_layernorm",
         
     | 
| 
       40 
     | 
    
         
            -
                lm_head="lm_head",
         
     | 
| 
       41 
     | 
    
         
            -
            )
         
     | 
| 
       42 
     | 
    
         
            -
             
     | 
| 
       43 
     | 
    
         
            -
             
     | 
| 
       44 
     | 
    
         
            -
            class Phi2(nn.Module):
         
     | 
| 
       45 
     | 
    
         
            -
              """A Phi-2 model built from the Edge Generative API layers."""
         
     | 
| 
       46 
     | 
    
         
            -
             
     | 
| 
       47 
     | 
    
         
            -
              def __init__(self, config: cfg.ModelConfig):
         
     | 
| 
       48 
     | 
    
         
            -
                super().__init__()
         
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
                self.config = config
         
     | 
| 
       51 
     | 
    
         
            -
                # Construct model layers.
         
     | 
| 
       52 
     | 
    
         
            -
                self.lm_head = nn.Linear(
         
     | 
| 
       53 
     | 
    
         
            -
                    config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
         
     | 
| 
       54 
     | 
    
         
            -
                )
         
     | 
| 
       55 
     | 
    
         
            -
                self.tok_embedding = nn.Embedding(
         
     | 
| 
       56 
     | 
    
         
            -
                    config.vocab_size, config.embedding_dim, padding_idx=0
         
     | 
| 
       57 
     | 
    
         
            -
                )
         
     | 
| 
       58 
     | 
    
         
            -
                self.transformer_blocks = nn.ModuleList(
         
     | 
| 
       59 
     | 
    
         
            -
                    attention.TransformerBlock(config) for _ in range(config.num_layers)
         
     | 
| 
       60 
     | 
    
         
            -
                )
         
     | 
| 
       61 
     | 
    
         
            -
                self.final_norm = builder.build_norm(
         
     | 
| 
       62 
     | 
    
         
            -
                    config.embedding_dim,
         
     | 
| 
       63 
     | 
    
         
            -
                    config.final_norm_config,
         
     | 
| 
       64 
     | 
    
         
            -
                )
         
     | 
| 
       65 
     | 
    
         
            -
                self.rope_cache = attn_utils.build_rope_cache(
         
     | 
| 
       66 
     | 
    
         
            -
                    size=config.kv_cache_max,
         
     | 
| 
       67 
     | 
    
         
            -
                    dim=int(
         
     | 
| 
       68 
     | 
    
         
            -
                        config.attn_config.rotary_percentage * config.attn_config.head_dim
         
     | 
| 
       69 
     | 
    
         
            -
                    ),
         
     | 
| 
       70 
     | 
    
         
            -
                    base=10_000,
         
     | 
| 
       71 
     | 
    
         
            -
                    condense_ratio=1,
         
     | 
| 
       72 
     | 
    
         
            -
                    dtype=torch.float32,
         
     | 
| 
       73 
     | 
    
         
            -
                    device=torch.device("cpu"),
         
     | 
| 
       74 
     | 
    
         
            -
                )
         
     | 
| 
       75 
     | 
    
         
            -
                self.mask_cache = attn_utils.build_causal_mask_cache(
         
     | 
| 
       76 
     | 
    
         
            -
                    size=config.kv_cache_max,
         
     | 
| 
       77 
     | 
    
         
            -
                    dtype=torch.float32,
         
     | 
| 
       78 
     | 
    
         
            -
                    device=torch.device("cpu"),
         
     | 
| 
       79 
     | 
    
         
            -
                )
         
     | 
| 
       80 
     | 
    
         
            -
                self.config = config
         
     | 
| 
       81 
     | 
    
         
            -
             
     | 
| 
       82 
     | 
    
         
            -
              # The model's forward function takes in additional k/v cache tensors
         
     | 
| 
       83 
     | 
    
         
            -
              # and returns the updated k/v cache tensors to the caller.
         
     | 
| 
       84 
     | 
    
         
            -
              # This can be eliminated if we handle k/v cache updates inside the model itself.
         
     | 
| 
       85 
     | 
    
         
            -
              @torch.inference_mode
         
     | 
| 
       86 
     | 
    
         
            -
              def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
         
     | 
| 
       87 
     | 
    
         
            -
                _, seq_len = idx.size()
         
     | 
| 
       88 
     | 
    
         
            -
                assert self.config.max_seq_len >= seq_len, (
         
     | 
| 
       89 
     | 
    
         
            -
                    f"Cannot forward sequence of length {seq_len}, max seq length is only"
         
     | 
| 
       90 
     | 
    
         
            -
                    f" {self.config.max_seq_len}"
         
     | 
| 
       91 
     | 
    
         
            -
                )
         
     | 
| 
       92 
     | 
    
         
            -
             
     | 
| 
       93 
     | 
    
         
            -
                cos, sin = self.rope_cache
         
     | 
| 
       94 
     | 
    
         
            -
                cos = cos.index_select(0, input_pos)
         
     | 
| 
       95 
     | 
    
         
            -
                sin = sin.index_select(0, input_pos)
         
     | 
| 
       96 
     | 
    
         
            -
                mask = self.mask_cache.index_select(2, input_pos)
         
     | 
| 
       97 
     | 
    
         
            -
                mask = mask[:, :, :, : self.config.kv_cache_max]
         
     | 
| 
       98 
     | 
    
         
            -
             
     | 
| 
       99 
     | 
    
         
            -
                # forward the model itself
         
     | 
| 
       100 
     | 
    
         
            -
                x = self.tok_embedding(idx)  # token embeddings of shape (b, t, n_embd)
         
     | 
| 
       101 
     | 
    
         
            -
             
     | 
| 
       102 
     | 
    
         
            -
                for _, block in enumerate(self.transformer_blocks):
         
     | 
| 
       103 
     | 
    
         
            -
                  x = block(x, (cos, sin), mask, input_pos)
         
     | 
| 
       104 
     | 
    
         
            -
             
     | 
| 
       105 
     | 
    
         
            -
                x = self.final_norm(x)
         
     | 
| 
       106 
     | 
    
         
            -
                res = self.lm_head(x)  # (b, t, vocab_size)
         
     | 
| 
       107 
     | 
    
         
            -
                return res
         
     | 
| 
       108 
     | 
    
         
            -
             
     | 
| 
       109 
     | 
    
         
            -
             
     | 
| 
       110 
     | 
    
         
            -
            def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
         
     | 
| 
       111 
     | 
    
         
            -
              """Returns the model config for a Phi-2 model.
         
     | 
| 
       112 
     | 
    
         
            -
             
     | 
| 
       113 
     | 
    
         
            -
              Args:
         
     | 
| 
       114 
     | 
    
         
            -
                kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
         
     | 
| 
       115 
     | 
    
         
            -
                  is 1024.
         
     | 
| 
       116 
     | 
    
         
            -
             
     | 
| 
       117 
     | 
    
         
            -
              Returns:
         
     | 
| 
       118 
     | 
    
         
            -
                The model config for a Phi-2 model.
         
     | 
| 
       119 
     | 
    
         
            -
              """
         
     | 
| 
       120 
     | 
    
         
            -
              attn_config = cfg.AttentionConfig(
         
     | 
| 
       121 
     | 
    
         
            -
                  num_heads=32,
         
     | 
| 
       122 
     | 
    
         
            -
                  head_dim=80,
         
     | 
| 
       123 
     | 
    
         
            -
                  num_query_groups=32,
         
     | 
| 
       124 
     | 
    
         
            -
                  rotary_percentage=0.4,
         
     | 
| 
       125 
     | 
    
         
            -
                  qkv_use_bias=True,
         
     | 
| 
       126 
     | 
    
         
            -
                  output_proj_use_bias=True,
         
     | 
| 
       127 
     | 
    
         
            -
              )
         
     | 
| 
       128 
     | 
    
         
            -
              ff_config = cfg.FeedForwardConfig(
         
     | 
| 
       129 
     | 
    
         
            -
                  type=cfg.FeedForwardType.SEQUENTIAL,
         
     | 
| 
       130 
     | 
    
         
            -
                  activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
         
     | 
| 
       131 
     | 
    
         
            -
                  intermediate_size=10240,
         
     | 
| 
       132 
     | 
    
         
            -
                  use_bias=True,
         
     | 
| 
       133 
     | 
    
         
            -
              )
         
     | 
| 
       134 
     | 
    
         
            -
              norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
         
     | 
| 
       135 
     | 
    
         
            -
              config = cfg.ModelConfig(
         
     | 
| 
       136 
     | 
    
         
            -
                  vocab_size=51200,
         
     | 
| 
       137 
     | 
    
         
            -
                  num_layers=32,
         
     | 
| 
       138 
     | 
    
         
            -
                  max_seq_len=2048,
         
     | 
| 
       139 
     | 
    
         
            -
                  kv_cache_max_len=kv_cache_max_len,
         
     | 
| 
       140 
     | 
    
         
            -
                  embedding_dim=2560,
         
     | 
| 
       141 
     | 
    
         
            -
                  attn_config=attn_config,
         
     | 
| 
       142 
     | 
    
         
            -
                  ff_config=ff_config,
         
     | 
| 
       143 
     | 
    
         
            -
                  pre_attention_norm_config=norm_config,
         
     | 
| 
       144 
     | 
    
         
            -
                  final_norm_config=norm_config,
         
     | 
| 
       145 
     | 
    
         
            -
                  parallel_residual=True,
         
     | 
| 
       146 
     | 
    
         
            -
                  lm_head_use_bias=True,
         
     | 
| 
       147 
     | 
    
         
            -
                  enable_hlfb=True,
         
     | 
| 
       148 
     | 
    
         
            -
              )
         
     | 
| 
       149 
     | 
    
         
            -
              return config
         
     | 
| 
       150 
     | 
    
         
            -
             
     | 
| 
       151 
     | 
    
         
            -
             
     | 
| 
       152 
     | 
    
         
            -
            def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
         
     | 
| 
       153 
     | 
    
         
            -
              config = get_model_config(kv_cache_max_len)
         
     | 
| 
       154 
     | 
    
         
            -
              config.vocab_size = 128
         
     | 
| 
       155 
     | 
    
         
            -
              config.num_layers = 2
         
     | 
| 
       156 
     | 
    
         
            -
              config.max_seq_len = 2 * kv_cache_max_len
         
     | 
| 
       157 
     | 
    
         
            -
              config.ff_config.intermediate_size = 128
         
     | 
| 
       158 
     | 
    
         
            -
              return config
         
     | 
| 
       159 
     | 
    
         
            -
             
     | 
| 
       160 
     | 
    
         
            -
             
     | 
| 
       161 
     | 
    
         
            -
            def build_model(checkpoint_path, **kwargs) -> nn.Module:
         
     | 
| 
       162 
     | 
    
         
            -
              config = get_model_config(**kwargs)
         
     | 
| 
       163 
     | 
    
         
            -
              model = Phi2(config)
         
     | 
| 
       164 
     | 
    
         
            -
              loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
         
     | 
| 
       165 
     | 
    
         
            -
              loader.load(model)
         
     | 
| 
       166 
     | 
    
         
            -
              return model
         
     | 
| 
       167 
     | 
    
         
            -
             
     | 
| 
       168 
     | 
    
         
            -
             
     | 
| 
       169 
     | 
    
         
            -
            def define_and_run() -> None:
         
     | 
| 
       170 
     | 
    
         
            -
              """Instantiates and runs a Phi-2 model."""
         
     | 
| 
       171 
     | 
    
         
            -
             
     | 
| 
       172 
     | 
    
         
            -
              current_dir = Path(__file__).parent.resolve()
         
     | 
| 
       173 
     | 
    
         
            -
              phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
         
     | 
| 
       174 
     | 
    
         
            -
              kv_cache_max_len = 1024
         
     | 
| 
       175 
     | 
    
         
            -
              checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
         
     | 
| 
       176 
     | 
    
         
            -
              model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
         
     | 
| 
       177 
     | 
    
         
            -
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         
     | 
| 
       178 
     | 
    
         
            -
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
         
     | 
| 
       179 
     | 
    
         
            -
              tokens[0, :4] = idx
         
     | 
| 
       180 
     | 
    
         
            -
              input_pos = torch.arange(0, kv_cache_max_len)
         
     | 
| 
       181 
     | 
    
         
            -
              lm_logits = model.forward(tokens, input_pos)
         
     | 
| 
       182 
     | 
    
         
            -
              print("comparing with goldens..")
         
     | 
| 
       183 
     | 
    
         
            -
              assert torch.allclose(
         
     | 
| 
       184 
     | 
    
         
            -
                  phi2_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
         
     | 
| 
       185 
     | 
    
         
            -
              )
         
     | 
| 
       186 
     | 
    
         
            -
             
     | 
| 
       187 
     | 
    
         
            -
             
     | 
| 
       188 
     | 
    
         
            -
            if __name__ == "__main__":
         
     | 
| 
       189 
     | 
    
         
            -
              define_and_run()
         
     | 
| 
         @@ -1,176 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            # A toy example which has basic transformer block (w/ externalized KV-Cache).
         
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
            from typing import Tuple
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
            import ai_edge_torch
         
     | 
| 
       20 
     | 
    
         
            -
            from ai_edge_torch import lowertools
         
     | 
| 
       21 
     | 
    
         
            -
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         
     | 
| 
       22 
     | 
    
         
            -
            import ai_edge_torch.generative.layers.builder as builder
         
     | 
| 
       23 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
         
     | 
| 
       24 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock  # NOQA
         
     | 
| 
       25 
     | 
    
         
            -
            import ai_edge_torch.generative.layers.model_config as cfg
         
     | 
| 
       26 
     | 
    
         
            -
            import torch
         
     | 
| 
       27 
     | 
    
         
            -
            import torch.nn as nn
         
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
            RoPECache = Tuple[torch.Tensor, torch.Tensor]
         
     | 
| 
       30 
     | 
    
         
            -
             
     | 
| 
       31 
     | 
    
         
            -
             
     | 
| 
       32 
     | 
    
         
            -
            class ToyModelWithExternalKV(torch.nn.Module):
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
              def __init__(self, config: cfg.ModelConfig) -> None:
         
     | 
| 
       35 
     | 
    
         
            -
                super().__init__()
         
     | 
| 
       36 
     | 
    
         
            -
                self.lm_head = nn.Linear(
         
     | 
| 
       37 
     | 
    
         
            -
                    config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
         
     | 
| 
       38 
     | 
    
         
            -
                )
         
     | 
| 
       39 
     | 
    
         
            -
                self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
         
     | 
| 
       40 
     | 
    
         
            -
                self.transformer_blocks = nn.ModuleList(
         
     | 
| 
       41 
     | 
    
         
            -
                    TransformerBlock(config) for _ in range(config.num_layers)
         
     | 
| 
       42 
     | 
    
         
            -
                )
         
     | 
| 
       43 
     | 
    
         
            -
                self.final_norm = builder.build_norm(
         
     | 
| 
       44 
     | 
    
         
            -
                    config.embedding_dim,
         
     | 
| 
       45 
     | 
    
         
            -
                    config.final_norm_config,
         
     | 
| 
       46 
     | 
    
         
            -
                )
         
     | 
| 
       47 
     | 
    
         
            -
                self.rope_cache = attn_utils.build_rope_cache(
         
     | 
| 
       48 
     | 
    
         
            -
                    size=config.max_seq_len,
         
     | 
| 
       49 
     | 
    
         
            -
                    dim=int(
         
     | 
| 
       50 
     | 
    
         
            -
                        config.attn_config.rotary_percentage * config.attn_config.head_dim
         
     | 
| 
       51 
     | 
    
         
            -
                    ),
         
     | 
| 
       52 
     | 
    
         
            -
                    base=10_000,
         
     | 
| 
       53 
     | 
    
         
            -
                    condense_ratio=1,
         
     | 
| 
       54 
     | 
    
         
            -
                    dtype=torch.float32,
         
     | 
| 
       55 
     | 
    
         
            -
                    device=torch.device('cpu'),
         
     | 
| 
       56 
     | 
    
         
            -
                )
         
     | 
| 
       57 
     | 
    
         
            -
                self.mask_cache = attn_utils.build_causal_mask_cache(
         
     | 
| 
       58 
     | 
    
         
            -
                    size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
         
     | 
| 
       59 
     | 
    
         
            -
                )
         
     | 
| 
       60 
     | 
    
         
            -
                self.config = config
         
     | 
| 
       61 
     | 
    
         
            -
             
     | 
| 
       62 
     | 
    
         
            -
              def forward(
         
     | 
| 
       63 
     | 
    
         
            -
                  self,
         
     | 
| 
       64 
     | 
    
         
            -
                  tokens: torch.Tensor,
         
     | 
| 
       65 
     | 
    
         
            -
                  input_pos: torch.Tensor,
         
     | 
| 
       66 
     | 
    
         
            -
                  kv_cache: kv_utils.EKVCache,
         
     | 
| 
       67 
     | 
    
         
            -
              ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
         
     | 
| 
       68 
     | 
    
         
            -
                x = self.tok_embedding(tokens)
         
     | 
| 
       69 
     | 
    
         
            -
                cos, sin = self.rope_cache
         
     | 
| 
       70 
     | 
    
         
            -
                cos = cos.index_select(0, input_pos)
         
     | 
| 
       71 
     | 
    
         
            -
                sin = sin.index_select(0, input_pos)
         
     | 
| 
       72 
     | 
    
         
            -
                mask = self.mask_cache.index_select(2, input_pos)
         
     | 
| 
       73 
     | 
    
         
            -
                mask = mask[:, :, :, : self.config.max_seq_len]
         
     | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
       75 
     | 
    
         
            -
                updated_kv_entires = []
         
     | 
| 
       76 
     | 
    
         
            -
                for i, block in enumerate(self.transformer_blocks):
         
     | 
| 
       77 
     | 
    
         
            -
                  kv_entry = kv_cache.caches[i] if kv_cache else None
         
     | 
| 
       78 
     | 
    
         
            -
                  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
         
     | 
| 
       79 
     | 
    
         
            -
                  if kv_entry:
         
     | 
| 
       80 
     | 
    
         
            -
                    updated_kv_entires.append(kv_entry)
         
     | 
| 
       81 
     | 
    
         
            -
             
     | 
| 
       82 
     | 
    
         
            -
                x = self.final_norm(x)
         
     | 
| 
       83 
     | 
    
         
            -
                updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
         
     | 
| 
       84 
     | 
    
         
            -
                return self.lm_head(x), updated_kv_cache
         
     | 
| 
       85 
     | 
    
         
            -
             
     | 
| 
       86 
     | 
    
         
            -
             
     | 
| 
       87 
     | 
    
         
            -
            def _export_stablehlo_mlir(model, args):
         
     | 
| 
       88 
     | 
    
         
            -
              ep = torch.export.export(model, args)
         
     | 
| 
       89 
     | 
    
         
            -
              return lowertools.exported_program_to_mlir_text(ep)
         
     | 
| 
       90 
     | 
    
         
            -
             
     | 
| 
       91 
     | 
    
         
            -
             
     | 
| 
       92 
     | 
    
         
            -
            def get_model_config() -> cfg.ModelConfig:
         
     | 
| 
       93 
     | 
    
         
            -
              attn_config = cfg.AttentionConfig(
         
     | 
| 
       94 
     | 
    
         
            -
                  num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
         
     | 
| 
       95 
     | 
    
         
            -
              )
         
     | 
| 
       96 
     | 
    
         
            -
              ff_config = cfg.FeedForwardConfig(
         
     | 
| 
       97 
     | 
    
         
            -
                  type=cfg.FeedForwardType.GATED,
         
     | 
| 
       98 
     | 
    
         
            -
                  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
         
     | 
| 
       99 
     | 
    
         
            -
                  intermediate_size=256,
         
     | 
| 
       100 
     | 
    
         
            -
              )
         
     | 
| 
       101 
     | 
    
         
            -
              norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
         
     | 
| 
       102 
     | 
    
         
            -
              config = cfg.ModelConfig(
         
     | 
| 
       103 
     | 
    
         
            -
                  vocab_size=150,
         
     | 
| 
       104 
     | 
    
         
            -
                  num_layers=2,
         
     | 
| 
       105 
     | 
    
         
            -
                  max_seq_len=100,
         
     | 
| 
       106 
     | 
    
         
            -
                  embedding_dim=128,
         
     | 
| 
       107 
     | 
    
         
            -
                  attn_config=attn_config,
         
     | 
| 
       108 
     | 
    
         
            -
                  ff_config=ff_config,
         
     | 
| 
       109 
     | 
    
         
            -
                  pre_attention_norm_config=norm_config,
         
     | 
| 
       110 
     | 
    
         
            -
                  post_attention_norm_config=norm_config,
         
     | 
| 
       111 
     | 
    
         
            -
                  final_norm_config=norm_config,
         
     | 
| 
       112 
     | 
    
         
            -
                  enable_hlfb=True,
         
     | 
| 
       113 
     | 
    
         
            -
              )
         
     | 
| 
       114 
     | 
    
         
            -
              return config
         
     | 
| 
       115 
     | 
    
         
            -
             
     | 
| 
       116 
     | 
    
         
            -
             
     | 
| 
       117 
     | 
    
         
            -
            def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
       118 
     | 
    
         
            -
              tokens = torch.unsqueeze(torch.arange(0, 100), 0)
         
     | 
| 
       119 
     | 
    
         
            -
              input_pos = torch.arange(0, 100)
         
     | 
| 
       120 
     | 
    
         
            -
              return tokens, input_pos
         
     | 
| 
       121 
     | 
    
         
            -
             
     | 
| 
       122 
     | 
    
         
            -
             
     | 
| 
       123 
     | 
    
         
            -
            def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
       124 
     | 
    
         
            -
              tokens = torch.tensor([[1]], dtype=torch.long)
         
     | 
| 
       125 
     | 
    
         
            -
              input_pos = torch.tensor([10])
         
     | 
| 
       126 
     | 
    
         
            -
              return tokens, input_pos
         
     | 
| 
       127 
     | 
    
         
            -
             
     | 
| 
       128 
     | 
    
         
            -
             
     | 
| 
       129 
     | 
    
         
            -
            def define_and_run() -> None:
         
     | 
| 
       130 
     | 
    
         
            -
              dump_mlir = False
         
     | 
| 
       131 
     | 
    
         
            -
             
     | 
| 
       132 
     | 
    
         
            -
              config = get_model_config()
         
     | 
| 
       133 
     | 
    
         
            -
              model = ToyModelWithExternalKV(config)
         
     | 
| 
       134 
     | 
    
         
            -
              model.eval()
         
     | 
| 
       135 
     | 
    
         
            -
              print('running an inference')
         
     | 
| 
       136 
     | 
    
         
            -
              kv = kv_utils.EKVCache.from_model_config(config)
         
     | 
| 
       137 
     | 
    
         
            -
             
     | 
| 
       138 
     | 
    
         
            -
              tokens, input_pos = get_sample_prefill_inputs()
         
     | 
| 
       139 
     | 
    
         
            -
              decode_token, decode_input_pos = get_sample_decode_inputs()
         
     | 
| 
       140 
     | 
    
         
            -
              print(model.forward(tokens, input_pos, kv))
         
     | 
| 
       141 
     | 
    
         
            -
             
     | 
| 
       142 
     | 
    
         
            -
              if dump_mlir:
         
     | 
| 
       143 
     | 
    
         
            -
                mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
         
     | 
| 
       144 
     | 
    
         
            -
                with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
         
     | 
| 
       145 
     | 
    
         
            -
                  f.write(mlir_text)
         
     | 
| 
       146 
     | 
    
         
            -
             
     | 
| 
       147 
     | 
    
         
            -
              # Convert model to tflite with 2 signatures (prefill + decode).
         
     | 
| 
       148 
     | 
    
         
            -
              # TODO(b/344014416): currently conversion will fail, because we generate int64 index
         
     | 
| 
       149 
     | 
    
         
            -
              # in dynamic update slice op.
         
     | 
| 
       150 
     | 
    
         
            -
              print('converting toy model to tflite with 2 signatures (prefill + decode)')
         
     | 
| 
       151 
     | 
    
         
            -
              edge_model = (
         
     | 
| 
       152 
     | 
    
         
            -
                  ai_edge_torch.signature(
         
     | 
| 
       153 
     | 
    
         
            -
                      'prefill',
         
     | 
| 
       154 
     | 
    
         
            -
                      model,
         
     | 
| 
       155 
     | 
    
         
            -
                      sample_kwargs={
         
     | 
| 
       156 
     | 
    
         
            -
                          'tokens': tokens,
         
     | 
| 
       157 
     | 
    
         
            -
                          'input_pos': input_pos,
         
     | 
| 
       158 
     | 
    
         
            -
                          'kv_cache': kv,
         
     | 
| 
       159 
     | 
    
         
            -
                      },
         
     | 
| 
       160 
     | 
    
         
            -
                  )
         
     | 
| 
       161 
     | 
    
         
            -
                  .signature(
         
     | 
| 
       162 
     | 
    
         
            -
                      'decode',
         
     | 
| 
       163 
     | 
    
         
            -
                      model,
         
     | 
| 
       164 
     | 
    
         
            -
                      sample_kwargs={
         
     | 
| 
       165 
     | 
    
         
            -
                          'tokens': decode_token,
         
     | 
| 
       166 
     | 
    
         
            -
                          'input_pos': decode_input_pos,
         
     | 
| 
       167 
     | 
    
         
            -
                          'kv_cache': kv,
         
     | 
| 
       168 
     | 
    
         
            -
                      },
         
     | 
| 
       169 
     | 
    
         
            -
                  )
         
     | 
| 
       170 
     | 
    
         
            -
                  .convert()
         
     | 
| 
       171 
     | 
    
         
            -
              )
         
     | 
| 
       172 
     | 
    
         
            -
              edge_model.export('/tmp/toy_external_kv_cache.tflite')
         
     | 
| 
       173 
     | 
    
         
            -
             
     | 
| 
       174 
     | 
    
         
            -
             
     | 
| 
       175 
     | 
    
         
            -
            if __name__ == '__main__':
         
     | 
| 
       176 
     | 
    
         
            -
              define_and_run()
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |