ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241003__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
  4. ai_edge_torch/generative/examples/llama/llama.py +19 -24
  5. ai_edge_torch/generative/examples/llama/verify.py +18 -3
  6. ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
  7. ai_edge_torch/generative/examples/phi/phi2.py +10 -86
  8. ai_edge_torch/generative/examples/phi/phi3.py +9 -69
  9. ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
  10. ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
  12. ai_edge_torch/generative/layers/model_config.py +6 -0
  13. ai_edge_torch/generative/test/test_loader.py +2 -1
  14. ai_edge_torch/generative/test/test_model_conversion.py +2 -1
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
  16. ai_edge_torch/generative/utilities/model_builder.py +141 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/RECORD +22 -23
  20. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
  21. ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
  22. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,9 @@
15
15
 
16
16
  """Example of building a Gemma1 model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
23
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
26
21
 
27
22
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
23
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -38,84 +33,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
38
33
  )
39
34
 
40
35
 
41
- class Gemma(nn.Module):
42
- """A Gemma model built from the Edge Generative API layers."""
43
-
44
- def __init__(self, config: cfg.ModelConfig):
45
- super().__init__()
46
-
47
- # Construct model layers.
48
- self.tok_embedding = nn.Embedding(
49
- config.vocab_size, config.embedding_dim, padding_idx=0
50
- )
51
- self.lm_head = nn.Linear(
52
- config.embedding_dim,
53
- config.vocab_size,
54
- bias=config.lm_head_use_bias,
55
- )
56
- # Gemma re-uses the embedding as the head projection layer.
57
- self.lm_head.weight.data = self.tok_embedding.weight.data
58
- # Gemma has only one block config.
59
- block_config = config.block_config(0)
60
- self.transformer_blocks = nn.ModuleList(
61
- attention.TransformerBlock(block_config, config)
62
- for _ in range(config.num_layers)
63
- )
64
- self.final_norm = builder.build_norm(
65
- config.embedding_dim,
66
- config.final_norm_config,
67
- )
68
- attn_config = block_config.attn_config
69
- self.rope_cache = attn_utils.build_rope_cache(
70
- size=config.kv_cache_max,
71
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
72
- base=attn_config.rotary_base,
73
- )
74
- self.mask_cache = attn_utils.build_causal_mask_cache(
75
- size=config.kv_cache_max,
76
- )
77
- self.config = config
78
-
79
- @torch.inference_mode
80
- def forward(
81
- self,
82
- tokens: torch.Tensor,
83
- input_pos: torch.Tensor,
84
- kv_cache: kv_utils.KVCache,
85
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
86
- _, seq_len = tokens.size()
87
- assert self.config.max_seq_len >= seq_len, (
88
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
89
- f" {self.config.max_seq_len}"
90
- )
91
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
92
- "The number of transformer blocks and the number of KV cache entries"
93
- " must be the same."
94
- )
95
-
96
- cos, sin = self.rope_cache
97
- cos = cos.index_select(0, input_pos)
98
- sin = sin.index_select(0, input_pos)
99
- mask = self.mask_cache.index_select(2, input_pos)
100
- mask = mask[:, :, :, : self.config.kv_cache_max]
101
-
102
- # token embeddings of shape (b, t, n_embd)
103
- x = self.tok_embedding(tokens)
104
- x = x * (self.config.embedding_dim**0.5)
105
-
106
- updated_kv_entires = []
107
- for i, block in enumerate(self.transformer_blocks):
108
- kv_entry = kv_cache.caches[i] if kv_cache else None
109
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
110
- if kv_entry:
111
- updated_kv_entires.append(kv_entry)
112
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
113
-
114
- x = self.final_norm(x)
115
- logits = self.lm_head(x) # (b, t, vocab_size)
116
- return {"logits": logits, "kv_cache": updated_kv_cache}
117
-
118
-
119
36
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
120
37
  """Returns the model config for a Gemma 2B model.
121
38
 
@@ -154,6 +71,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
154
71
  num_layers=18,
155
72
  max_seq_len=8192,
156
73
  embedding_dim=2048,
74
+ embedding_scale=2048**0.5,
157
75
  kv_cache_max_len=kv_cache_max_len,
158
76
  block_configs=block_config,
159
77
  final_norm_config=norm_config,
@@ -173,12 +91,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
173
91
  return config
174
92
 
175
93
 
176
- def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
177
- config = get_model_config_2b(**kwargs)
178
- model = Gemma(config)
179
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
180
- # Since embedding and lm-head use the same weight, we need to set strict
181
- # to False.
182
- loader.load(model, strict=False)
183
- model.eval()
184
- return model
94
+ def build_2b_model(
95
+ checkpoint_path: str, **kwargs
96
+ ) -> model_builder.DecoderOnlyModel:
97
+ return model_builder.build_decoder_only_model(
98
+ checkpoint_path=checkpoint_path,
99
+ config=get_model_config_2b(**kwargs),
100
+ tensor_names=TENSOR_NAMES,
101
+ )
@@ -15,7 +15,6 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- import os
19
18
  from typing import Optional, Tuple
20
19
 
21
20
  from ai_edge_torch.generative.layers import attention
@@ -23,6 +23,12 @@ from absl import flags
23
23
  from ai_edge_torch.generative.examples.llama import llama
24
24
  from ai_edge_torch.generative.utilities import converter
25
25
 
26
+ _MODEL_SIZE = flags.DEFINE_enum(
27
+ 'model_size',
28
+ '1b',
29
+ ['1b', '3b'],
30
+ 'The size of the model to verify.',
31
+ )
26
32
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
33
  'checkpoint_path',
28
34
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
@@ -49,13 +55,18 @@ _QUANTIZE = flags.DEFINE_bool(
49
55
  'Whether the model should be quantized.',
50
56
  )
51
57
 
58
+ _BUILDER = {
59
+ '1b': llama.build_1b_model,
60
+ '3b': llama.build_3b_model,
61
+ }
62
+
52
63
 
53
64
  def main(_):
54
- pytorch_model = llama.build_model(
65
+ pytorch_model = _BUILDER[_MODEL_SIZE.value](
55
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
67
  )
57
68
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'llama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
69
+ output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
70
  converter.convert_to_tflite(
60
71
  pytorch_model,
61
72
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
@@ -15,19 +15,15 @@
15
15
 
16
16
  """Example of building Llama 3.2 models."""
17
17
 
18
- import copy
19
18
  import math
20
19
  from typing import Tuple
21
20
 
22
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
+ from ai_edge_torch.generative.utilities import model_builder
24
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
24
  import torch
26
- from torch import nn
27
25
 
28
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
29
- # SmolLM re-uses the embedding as the head projection layer.
30
- TENSOR_NAMES.lm_head = None
26
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
31
27
 
32
28
 
33
29
  def _build_llama3_rope_cache(
@@ -93,7 +89,7 @@ def _build_llama3_rope_cache(
93
89
  return cos, sin
94
90
 
95
91
 
96
- class Llama(tiny_llama.TinyLlama):
92
+ class Llama(model_builder.DecoderOnlyModel):
97
93
  """A Llama model built from the Edge Generative API layers.
98
94
 
99
95
  Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
@@ -101,9 +97,6 @@ class Llama(tiny_llama.TinyLlama):
101
97
 
102
98
  def __init__(self, config: cfg.ModelConfig):
103
99
  super().__init__(config)
104
- # Llama 3.2 re-uses the embedding as the head projection layer.
105
- self.lm_head.weight.data = self.tok_embedding.weight.data
106
- # Llama has only one block config.
107
100
  attn_config = self.config.block_config(0).attn_config
108
101
  self.rope_cache = _build_llama3_rope_cache(
109
102
  size=self.config.kv_cache_max,
@@ -119,7 +112,7 @@ class Llama(tiny_llama.TinyLlama):
119
112
  )
120
113
 
121
114
 
122
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
115
+ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
123
116
  """Returns the model config for a Llama 3.2-1B model.
124
117
 
125
118
  Args:
@@ -163,7 +156,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
163
156
 
164
157
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
165
158
  """Returns the model config for a Llama 3.2-3B model."""
166
- config = get_model_config(kv_cache_max_len)
159
+ config = get_1b_model_config(kv_cache_max_len)
167
160
  # Llama 3.2 has only one block config.
168
161
  attn_config = config.block_config(0).attn_config
169
162
  attn_config.num_heads = 24
@@ -174,7 +167,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
174
167
 
175
168
 
176
169
  def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
177
- config = get_model_config(**kwargs)
170
+ config = get_1b_model_config(**kwargs)
178
171
  config.vocab_size = 128
179
172
  config.num_layers = 2
180
173
  # SmolLM has only one block config.
@@ -182,8 +175,9 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
182
175
  return config
183
176
 
184
177
 
185
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
186
- config = get_model_config(**kwargs)
178
+ def _build_model(
179
+ checkpoint_path: str, config: cfg.ModelConfig
180
+ ) -> model_builder.DecoderOnlyModel:
187
181
  model = Llama(config)
188
182
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
189
183
  # Since embedding and lm-head use the same weight, we need to set strict
@@ -193,12 +187,13 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
193
187
  return model
194
188
 
195
189
 
196
- def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
197
- config = get_3b_model_config(**kwargs)
198
- model = Llama(config)
199
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
200
- # Since embedding and lm-head use the same weight, we need to set strict
201
- # to False.
202
- loader.load(model, strict=False)
203
- model.eval()
204
- return model
190
+ def build_1b_model(
191
+ checkpoint_path: str, **kwargs
192
+ ) -> model_builder.DecoderOnlyModel:
193
+ return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
194
+
195
+
196
+ def build_3b_model(
197
+ checkpoint_path: str, **kwargs
198
+ ) -> model_builder.DecoderOnlyModel:
199
+ return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
@@ -25,7 +25,12 @@ from ai_edge_torch.generative.utilities import transformers_verifier
25
25
  from ai_edge_torch.generative.utilities import verifier
26
26
  import transformers
27
27
 
28
-
28
+ _MODEL_SIZE = flags.DEFINE_enum(
29
+ "model_size",
30
+ "1b",
31
+ ["1b", "3b"],
32
+ "The size of the model to verify.",
33
+ )
29
34
  _PROMPTS = flags.DEFINE_multi_string(
30
35
  "prompts",
31
36
  "What is the meaning of life?",
@@ -37,9 +42,19 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
37
42
  "The maximum size of the generated tokens.",
38
43
  )
39
44
 
45
+ _CHECKPOINT = {
46
+ "1b": "meta-llama/Llama-3.2-1B-Instruct",
47
+ "3b": "meta-llama/Llama-3.2-3B-Instruct",
48
+ }
49
+
50
+ _BUILDER = {
51
+ "1b": llama.build_1b_model,
52
+ "3b": llama.build_3b_model,
53
+ }
54
+
40
55
 
41
56
  def main(_):
42
- checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
57
+ checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
43
58
  logging.info("Loading the original model from: %s", checkpoint)
44
59
  original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
60
 
@@ -49,7 +64,7 @@ def main(_):
49
64
  )
50
65
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
66
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = llama.build_model(reauthored_checkpoint)
67
+ reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
53
68
 
54
69
  logging.info("Loading the tokenizer from: %s", checkpoint)
55
70
  # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
@@ -15,14 +15,9 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
23
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
26
21
 
27
22
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
23
  ff_up_proj="transformer.layers.{}.ffn.proj_1",
@@ -39,81 +34,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
39
34
  )
40
35
 
41
36
 
42
- class OpenELM(nn.Module):
43
- """An OpenELM model built from the Edge Generative API layers."""
44
-
45
- def __init__(self, config: cfg.ModelConfig):
46
- super().__init__()
47
-
48
- # Construct model layers.
49
- self.tok_embedding = nn.Embedding(
50
- config.vocab_size, config.embedding_dim, padding_idx=0
51
- )
52
- self.lm_head = nn.Linear(
53
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
54
- )
55
- # OpenELM re-uses the embedding as the head projection layer.
56
- self.lm_head.weight.data = self.tok_embedding.weight.data
57
- self.transformer_blocks = nn.ModuleList(
58
- attention.TransformerBlock(config.block_config(idx), config)
59
- for idx in range(config.num_layers)
60
- )
61
- self.final_norm = builder.build_norm(
62
- config.embedding_dim,
63
- config.final_norm_config,
64
- )
65
- # OpenELM has same hyper parameters for rotary_percentage and head_dim for
66
- # each layer block. Use the first block.
67
- attn_config = config.block_config(0).attn_config
68
- self.rope_cache = attn_utils.build_rope_cache(
69
- size=config.kv_cache_max,
70
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
71
- base=attn_config.rotary_base,
72
- )
73
- self.mask_cache = attn_utils.build_causal_mask_cache(
74
- size=config.kv_cache_max,
75
- )
76
- self.config = config
77
-
78
- @torch.inference_mode
79
- def forward(
80
- self,
81
- tokens: torch.Tensor,
82
- input_pos: torch.Tensor,
83
- kv_cache: kv_utils.KVCache,
84
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
85
- _, seq_len = tokens.size()
86
- assert self.config.max_seq_len >= seq_len, (
87
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
88
- f" {self.config.max_seq_len}"
89
- )
90
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
91
- "The number of transformer blocks and the number of KV cache entries"
92
- " must be the same."
93
- )
94
-
95
- cos, sin = self.rope_cache
96
- cos = cos.index_select(0, input_pos)
97
- sin = sin.index_select(0, input_pos)
98
- mask = self.mask_cache.index_select(2, input_pos)
99
- mask = mask[:, :, :, : self.config.kv_cache_max]
100
-
101
- # token embeddings of shape (b, t, n_embd)
102
- x = self.tok_embedding(tokens)
103
-
104
- updated_kv_entires = []
105
- for i, block in enumerate(self.transformer_blocks):
106
- kv_entry = kv_cache.caches[i] if kv_cache else None
107
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
108
- if kv_entry:
109
- updated_kv_entires.append(kv_entry)
110
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
111
-
112
- x = self.final_norm(x)
113
- logits = self.lm_head(x) # (b, t, vocab_size)
114
- return {"logits": logits, "kv_cache": updated_kv_cache}
115
-
116
-
117
37
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
118
38
  """Returns the model config for an OpenELM model.
119
39
 
@@ -191,12 +111,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
191
111
  return config
192
112
 
193
113
 
194
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
195
- config = get_model_config(**kwargs)
196
- model = OpenELM(config)
197
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
198
- # Since embedding and lm-head use the same weight, we need to set strict
199
- # to False.
200
- loader.load(model, strict=False)
201
- model.eval()
202
- return model
114
+ def build_model(
115
+ checkpoint_path: str, **kwargs
116
+ ) -> model_builder.DecoderOnlyModel:
117
+ return model_builder.build_decoder_only_model(
118
+ checkpoint_path=checkpoint_path,
119
+ config=get_model_config(**kwargs),
120
+ tensor_names=TENSOR_NAMES,
121
+ )
@@ -15,14 +15,9 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
23
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
26
21
 
27
22
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
23
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -38,78 +33,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
38
33
  )
39
34
 
40
35
 
41
- class Phi2(nn.Module):
42
- """A Phi-2 model built from the Edge Generative API layers."""
43
-
44
- def __init__(self, config: cfg.ModelConfig):
45
- super().__init__()
46
-
47
- # Construct model layers.
48
- self.lm_head = nn.Linear(
49
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
50
- )
51
- self.tok_embedding = nn.Embedding(
52
- config.vocab_size, config.embedding_dim, padding_idx=0
53
- )
54
- # Phi-2 has only one block config.
55
- block_config = config.block_config(0)
56
- self.transformer_blocks = nn.ModuleList(
57
- attention.TransformerBlock(block_config, config)
58
- for _ in range(config.num_layers)
59
- )
60
- self.final_norm = builder.build_norm(
61
- config.embedding_dim,
62
- config.final_norm_config,
63
- )
64
- attn_config = block_config.attn_config
65
- self.rope_cache = attn_utils.build_rope_cache(
66
- size=config.kv_cache_max,
67
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
68
- base=attn_config.rotary_base,
69
- )
70
- self.mask_cache = attn_utils.build_causal_mask_cache(
71
- size=config.kv_cache_max,
72
- )
73
- self.config = config
74
-
75
- @torch.inference_mode
76
- def forward(
77
- self,
78
- tokens: torch.Tensor,
79
- input_pos: torch.Tensor,
80
- kv_cache: kv_utils.KVCache,
81
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
82
- _, seq_len = tokens.size()
83
- assert self.config.max_seq_len >= seq_len, (
84
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
85
- f" {self.config.max_seq_len}"
86
- )
87
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
88
- "The number of transformer blocks and the number of KV cache entries"
89
- " must be the same."
90
- )
91
-
92
- cos, sin = self.rope_cache
93
- cos = cos.index_select(0, input_pos)
94
- sin = sin.index_select(0, input_pos)
95
- mask = self.mask_cache.index_select(2, input_pos)
96
- mask = mask[:, :, :, : self.config.kv_cache_max]
97
-
98
- x = self.tok_embedding(tokens)
99
-
100
- updated_kv_entires = []
101
- for i, block in enumerate(self.transformer_blocks):
102
- kv_entry = kv_cache.caches[i] if kv_cache else None
103
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
104
- if kv_entry:
105
- updated_kv_entires.append(kv_entry)
106
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
107
-
108
- x = self.final_norm(x)
109
- logits = self.lm_head(x) # (b, t, vocab_size)
110
- return {"logits": logits, "kv_cache": updated_kv_cache}
111
-
112
-
113
36
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
114
37
  """Returns the model config for a Phi-2 model.
115
38
 
@@ -154,6 +77,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
154
77
  block_configs=block_config,
155
78
  final_norm_config=norm_config,
156
79
  lm_head_use_bias=True,
80
+ lm_head_share_weight_with_embedding=False,
157
81
  enable_hlfb=True,
158
82
  )
159
83
  return config
@@ -169,11 +93,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
169
93
  return config
170
94
 
171
95
 
172
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
173
- """Instantiates the model instance and load checkpoint if provided."""
174
- config = get_model_config(**kwargs)
175
- model = Phi2(config)
176
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
177
- loader.load(model)
178
- model.eval()
179
- return model
96
+ def build_model(
97
+ checkpoint_path: str, **kwargs
98
+ ) -> model_builder.DecoderOnlyModel:
99
+ return model_builder.build_decoder_only_model(
100
+ checkpoint_path=checkpoint_path,
101
+ config=get_model_config(**kwargs),
102
+ tensor_names=TENSOR_NAMES,
103
+ )
@@ -18,14 +18,10 @@
18
18
  import math
19
19
  from typing import Tuple
20
20
 
21
- from ai_edge_torch.generative.layers import attention
22
- from ai_edge_torch.generative.layers import builder
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
+ from ai_edge_torch.generative.utilities import model_builder
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
24
  import torch
28
- from torch import nn
29
25
 
30
26
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
27
  ff_up_proj="model.layers.{}.mlp.gate_up_proj",
@@ -137,32 +133,14 @@ def _build_rope_cache(
137
133
  return cos, sin
138
134
 
139
135
 
140
- class Phi3_5Mini(nn.Module):
136
+ class Phi3_5Mini(model_builder.DecoderOnlyModel):
141
137
  """A Phi-3.5 model built from the Edge Generative API layers."""
142
138
 
143
139
  def __init__(self, config: cfg.ModelConfig):
144
- super().__init__()
145
-
146
- # Construct model layers.
147
- self.lm_head = nn.Linear(
148
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
149
- )
150
- self.tok_embedding = nn.Embedding(
151
- config.vocab_size, config.embedding_dim, padding_idx=0
152
- )
153
- # Phi-3.5 has only one block config.
154
- block_config = config.block_config(0)
155
- self.transformer_blocks = nn.ModuleList(
156
- attention.TransformerBlock(block_config, config)
157
- for _ in range(config.num_layers)
158
- )
159
- self.final_norm = builder.build_norm(
160
- config.embedding_dim,
161
- config.final_norm_config,
162
- )
163
- attn_config = block_config.attn_config
140
+ super().__init__(config)
141
+ attn_config = self.config.block_config(0).attn_config
164
142
  self.rope_cache = _build_rope_cache(
165
- size=config.kv_cache_max,
143
+ size=self.config.kv_cache_max,
166
144
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
167
145
  base=attn_config.rotary_base,
168
146
  condense_ratio=1,
@@ -173,47 +151,6 @@ class Phi3_5Mini(nn.Module):
173
151
  1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
174
152
  ),
175
153
  )
176
- self.mask_cache = attn_utils.build_causal_mask_cache(
177
- size=config.kv_cache_max,
178
- )
179
- self.config = config
180
-
181
- @torch.inference_mode
182
- def forward(
183
- self,
184
- tokens: torch.Tensor,
185
- input_pos: torch.Tensor,
186
- kv_cache: kv_utils.KVCache,
187
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
188
- _, seq_len = tokens.size()
189
- assert self.config.max_seq_len >= seq_len, (
190
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
191
- f" {self.config.max_seq_len}"
192
- )
193
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
194
- "The number of transformer blocks and the number of KV cache entries"
195
- " must be the same."
196
- )
197
-
198
- cos, sin = self.rope_cache
199
- cos = cos.index_select(0, input_pos)
200
- sin = sin.index_select(0, input_pos)
201
- mask = self.mask_cache.index_select(2, input_pos)
202
- mask = mask[:, :, :, : self.config.kv_cache_max]
203
-
204
- x = self.tok_embedding(tokens)
205
-
206
- updated_kv_entires = []
207
- for i, block in enumerate(self.transformer_blocks):
208
- kv_entry = kv_cache.caches[i] if kv_cache else None
209
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
210
- if kv_entry:
211
- updated_kv_entires.append(kv_entry)
212
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
213
-
214
- x = self.final_norm(x)
215
- logits = self.lm_head(x) # (b, t, vocab_size)
216
- return {"logits": logits, "kv_cache": updated_kv_cache}
217
154
 
218
155
 
219
156
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -254,6 +191,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
254
191
  embedding_dim=3072,
255
192
  block_configs=block_config,
256
193
  final_norm_config=norm_config,
194
+ lm_head_share_weight_with_embedding=False,
257
195
  enable_hlfb=True,
258
196
  )
259
197
  return config
@@ -269,7 +207,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
269
207
  return config
270
208
 
271
209
 
272
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
210
+ def build_model(
211
+ checkpoint_path: str, **kwargs
212
+ ) -> model_builder.DecoderOnlyModel:
273
213
  """Instantiates the model instance and load checkpoint if provided."""
274
214
  config = get_model_config(**kwargs)
275
215
  model = Phi3_5Mini(config)