ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241003__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
  4. ai_edge_torch/generative/examples/llama/llama.py +19 -24
  5. ai_edge_torch/generative/examples/llama/verify.py +18 -3
  6. ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
  7. ai_edge_torch/generative/examples/phi/phi2.py +10 -86
  8. ai_edge_torch/generative/examples/phi/phi3.py +9 -69
  9. ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
  10. ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
  12. ai_edge_torch/generative/layers/model_config.py +6 -0
  13. ai_edge_torch/generative/test/test_loader.py +2 -1
  14. ai_edge_torch/generative/test/test_model_conversion.py +2 -1
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
  16. ai_edge_torch/generative/utilities/model_builder.py +141 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/RECORD +22 -23
  20. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
  21. ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
  22. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,9 @@
15
15
 
16
16
  """Example of building a Gemma1 model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
23
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
26
21
 
27
22
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
23
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -38,84 +33,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
38
33
  )
39
34
 
40
35
 
41
- class Gemma(nn.Module):
42
- """A Gemma model built from the Edge Generative API layers."""
43
-
44
- def __init__(self, config: cfg.ModelConfig):
45
- super().__init__()
46
-
47
- # Construct model layers.
48
- self.tok_embedding = nn.Embedding(
49
- config.vocab_size, config.embedding_dim, padding_idx=0
50
- )
51
- self.lm_head = nn.Linear(
52
- config.embedding_dim,
53
- config.vocab_size,
54
- bias=config.lm_head_use_bias,
55
- )
56
- # Gemma re-uses the embedding as the head projection layer.
57
- self.lm_head.weight.data = self.tok_embedding.weight.data
58
- # Gemma has only one block config.
59
- block_config = config.block_config(0)
60
- self.transformer_blocks = nn.ModuleList(
61
- attention.TransformerBlock(block_config, config)
62
- for _ in range(config.num_layers)
63
- )
64
- self.final_norm = builder.build_norm(
65
- config.embedding_dim,
66
- config.final_norm_config,
67
- )
68
- attn_config = block_config.attn_config
69
- self.rope_cache = attn_utils.build_rope_cache(
70
- size=config.kv_cache_max,
71
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
72
- base=attn_config.rotary_base,
73
- )
74
- self.mask_cache = attn_utils.build_causal_mask_cache(
75
- size=config.kv_cache_max,
76
- )
77
- self.config = config
78
-
79
- @torch.inference_mode
80
- def forward(
81
- self,
82
- tokens: torch.Tensor,
83
- input_pos: torch.Tensor,
84
- kv_cache: kv_utils.KVCache,
85
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
86
- _, seq_len = tokens.size()
87
- assert self.config.max_seq_len >= seq_len, (
88
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
89
- f" {self.config.max_seq_len}"
90
- )
91
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
92
- "The number of transformer blocks and the number of KV cache entries"
93
- " must be the same."
94
- )
95
-
96
- cos, sin = self.rope_cache
97
- cos = cos.index_select(0, input_pos)
98
- sin = sin.index_select(0, input_pos)
99
- mask = self.mask_cache.index_select(2, input_pos)
100
- mask = mask[:, :, :, : self.config.kv_cache_max]
101
-
102
- # token embeddings of shape (b, t, n_embd)
103
- x = self.tok_embedding(tokens)
104
- x = x * (self.config.embedding_dim**0.5)
105
-
106
- updated_kv_entires = []
107
- for i, block in enumerate(self.transformer_blocks):
108
- kv_entry = kv_cache.caches[i] if kv_cache else None
109
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
110
- if kv_entry:
111
- updated_kv_entires.append(kv_entry)
112
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
113
-
114
- x = self.final_norm(x)
115
- logits = self.lm_head(x) # (b, t, vocab_size)
116
- return {"logits": logits, "kv_cache": updated_kv_cache}
117
-
118
-
119
36
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
120
37
  """Returns the model config for a Gemma 2B model.
121
38
 
@@ -154,6 +71,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
154
71
  num_layers=18,
155
72
  max_seq_len=8192,
156
73
  embedding_dim=2048,
74
+ embedding_scale=2048**0.5,
157
75
  kv_cache_max_len=kv_cache_max_len,
158
76
  block_configs=block_config,
159
77
  final_norm_config=norm_config,
@@ -173,12 +91,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
173
91
  return config
174
92
 
175
93
 
176
- def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
177
- config = get_model_config_2b(**kwargs)
178
- model = Gemma(config)
179
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
180
- # Since embedding and lm-head use the same weight, we need to set strict
181
- # to False.
182
- loader.load(model, strict=False)
183
- model.eval()
184
- return model
94
+ def build_2b_model(
95
+ checkpoint_path: str, **kwargs
96
+ ) -> model_builder.DecoderOnlyModel:
97
+ return model_builder.build_decoder_only_model(
98
+ checkpoint_path=checkpoint_path,
99
+ config=get_model_config_2b(**kwargs),
100
+ tensor_names=TENSOR_NAMES,
101
+ )
@@ -15,7 +15,6 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- import os
19
18
  from typing import Optional, Tuple
20
19
 
21
20
  from ai_edge_torch.generative.layers import attention
@@ -23,6 +23,12 @@ from absl import flags
23
23
  from ai_edge_torch.generative.examples.llama import llama
24
24
  from ai_edge_torch.generative.utilities import converter
25
25
 
26
+ _MODEL_SIZE = flags.DEFINE_enum(
27
+ 'model_size',
28
+ '1b',
29
+ ['1b', '3b'],
30
+ 'The size of the model to verify.',
31
+ )
26
32
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
33
  'checkpoint_path',
28
34
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
@@ -49,13 +55,18 @@ _QUANTIZE = flags.DEFINE_bool(
49
55
  'Whether the model should be quantized.',
50
56
  )
51
57
 
58
+ _BUILDER = {
59
+ '1b': llama.build_1b_model,
60
+ '3b': llama.build_3b_model,
61
+ }
62
+
52
63
 
53
64
  def main(_):
54
- pytorch_model = llama.build_model(
65
+ pytorch_model = _BUILDER[_MODEL_SIZE.value](
55
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
67
  )
57
68
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'llama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
69
+ output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
70
  converter.convert_to_tflite(
60
71
  pytorch_model,
61
72
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
@@ -15,19 +15,15 @@
15
15
 
16
16
  """Example of building Llama 3.2 models."""
17
17
 
18
- import copy
19
18
  import math
20
19
  from typing import Tuple
21
20
 
22
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
+ from ai_edge_torch.generative.utilities import model_builder
24
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
24
  import torch
26
- from torch import nn
27
25
 
28
- TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
29
- # SmolLM re-uses the embedding as the head projection layer.
30
- TENSOR_NAMES.lm_head = None
26
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
31
27
 
32
28
 
33
29
  def _build_llama3_rope_cache(
@@ -93,7 +89,7 @@ def _build_llama3_rope_cache(
93
89
  return cos, sin
94
90
 
95
91
 
96
- class Llama(tiny_llama.TinyLlama):
92
+ class Llama(model_builder.DecoderOnlyModel):
97
93
  """A Llama model built from the Edge Generative API layers.
98
94
 
99
95
  Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
@@ -101,9 +97,6 @@ class Llama(tiny_llama.TinyLlama):
101
97
 
102
98
  def __init__(self, config: cfg.ModelConfig):
103
99
  super().__init__(config)
104
- # Llama 3.2 re-uses the embedding as the head projection layer.
105
- self.lm_head.weight.data = self.tok_embedding.weight.data
106
- # Llama has only one block config.
107
100
  attn_config = self.config.block_config(0).attn_config
108
101
  self.rope_cache = _build_llama3_rope_cache(
109
102
  size=self.config.kv_cache_max,
@@ -119,7 +112,7 @@ class Llama(tiny_llama.TinyLlama):
119
112
  )
120
113
 
121
114
 
122
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
115
+ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
123
116
  """Returns the model config for a Llama 3.2-1B model.
124
117
 
125
118
  Args:
@@ -163,7 +156,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
163
156
 
164
157
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
165
158
  """Returns the model config for a Llama 3.2-3B model."""
166
- config = get_model_config(kv_cache_max_len)
159
+ config = get_1b_model_config(kv_cache_max_len)
167
160
  # Llama 3.2 has only one block config.
168
161
  attn_config = config.block_config(0).attn_config
169
162
  attn_config.num_heads = 24
@@ -174,7 +167,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
174
167
 
175
168
 
176
169
  def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
177
- config = get_model_config(**kwargs)
170
+ config = get_1b_model_config(**kwargs)
178
171
  config.vocab_size = 128
179
172
  config.num_layers = 2
180
173
  # SmolLM has only one block config.
@@ -182,8 +175,9 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
182
175
  return config
183
176
 
184
177
 
185
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
186
- config = get_model_config(**kwargs)
178
+ def _build_model(
179
+ checkpoint_path: str, config: cfg.ModelConfig
180
+ ) -> model_builder.DecoderOnlyModel:
187
181
  model = Llama(config)
188
182
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
189
183
  # Since embedding and lm-head use the same weight, we need to set strict
@@ -193,12 +187,13 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
193
187
  return model
194
188
 
195
189
 
196
- def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
197
- config = get_3b_model_config(**kwargs)
198
- model = Llama(config)
199
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
200
- # Since embedding and lm-head use the same weight, we need to set strict
201
- # to False.
202
- loader.load(model, strict=False)
203
- model.eval()
204
- return model
190
+ def build_1b_model(
191
+ checkpoint_path: str, **kwargs
192
+ ) -> model_builder.DecoderOnlyModel:
193
+ return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
194
+
195
+
196
+ def build_3b_model(
197
+ checkpoint_path: str, **kwargs
198
+ ) -> model_builder.DecoderOnlyModel:
199
+ return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
@@ -25,7 +25,12 @@ from ai_edge_torch.generative.utilities import transformers_verifier
25
25
  from ai_edge_torch.generative.utilities import verifier
26
26
  import transformers
27
27
 
28
-
28
+ _MODEL_SIZE = flags.DEFINE_enum(
29
+ "model_size",
30
+ "1b",
31
+ ["1b", "3b"],
32
+ "The size of the model to verify.",
33
+ )
29
34
  _PROMPTS = flags.DEFINE_multi_string(
30
35
  "prompts",
31
36
  "What is the meaning of life?",
@@ -37,9 +42,19 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
37
42
  "The maximum size of the generated tokens.",
38
43
  )
39
44
 
45
+ _CHECKPOINT = {
46
+ "1b": "meta-llama/Llama-3.2-1B-Instruct",
47
+ "3b": "meta-llama/Llama-3.2-3B-Instruct",
48
+ }
49
+
50
+ _BUILDER = {
51
+ "1b": llama.build_1b_model,
52
+ "3b": llama.build_3b_model,
53
+ }
54
+
40
55
 
41
56
  def main(_):
42
- checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
57
+ checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
43
58
  logging.info("Loading the original model from: %s", checkpoint)
44
59
  original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
60
 
@@ -49,7 +64,7 @@ def main(_):
49
64
  )
50
65
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
66
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = llama.build_model(reauthored_checkpoint)
67
+ reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
53
68
 
54
69
  logging.info("Loading the tokenizer from: %s", checkpoint)
55
70
  # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
@@ -15,14 +15,9 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
23
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
26
21
 
27
22
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
23
  ff_up_proj="transformer.layers.{}.ffn.proj_1",
@@ -39,81 +34,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
39
34
  )
40
35
 
41
36
 
42
- class OpenELM(nn.Module):
43
- """An OpenELM model built from the Edge Generative API layers."""
44
-
45
- def __init__(self, config: cfg.ModelConfig):
46
- super().__init__()
47
-
48
- # Construct model layers.
49
- self.tok_embedding = nn.Embedding(
50
- config.vocab_size, config.embedding_dim, padding_idx=0
51
- )
52
- self.lm_head = nn.Linear(
53
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
54
- )
55
- # OpenELM re-uses the embedding as the head projection layer.
56
- self.lm_head.weight.data = self.tok_embedding.weight.data
57
- self.transformer_blocks = nn.ModuleList(
58
- attention.TransformerBlock(config.block_config(idx), config)
59
- for idx in range(config.num_layers)
60
- )
61
- self.final_norm = builder.build_norm(
62
- config.embedding_dim,
63
- config.final_norm_config,
64
- )
65
- # OpenELM has same hyper parameters for rotary_percentage and head_dim for
66
- # each layer block. Use the first block.
67
- attn_config = config.block_config(0).attn_config
68
- self.rope_cache = attn_utils.build_rope_cache(
69
- size=config.kv_cache_max,
70
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
71
- base=attn_config.rotary_base,
72
- )
73
- self.mask_cache = attn_utils.build_causal_mask_cache(
74
- size=config.kv_cache_max,
75
- )
76
- self.config = config
77
-
78
- @torch.inference_mode
79
- def forward(
80
- self,
81
- tokens: torch.Tensor,
82
- input_pos: torch.Tensor,
83
- kv_cache: kv_utils.KVCache,
84
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
85
- _, seq_len = tokens.size()
86
- assert self.config.max_seq_len >= seq_len, (
87
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
88
- f" {self.config.max_seq_len}"
89
- )
90
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
91
- "The number of transformer blocks and the number of KV cache entries"
92
- " must be the same."
93
- )
94
-
95
- cos, sin = self.rope_cache
96
- cos = cos.index_select(0, input_pos)
97
- sin = sin.index_select(0, input_pos)
98
- mask = self.mask_cache.index_select(2, input_pos)
99
- mask = mask[:, :, :, : self.config.kv_cache_max]
100
-
101
- # token embeddings of shape (b, t, n_embd)
102
- x = self.tok_embedding(tokens)
103
-
104
- updated_kv_entires = []
105
- for i, block in enumerate(self.transformer_blocks):
106
- kv_entry = kv_cache.caches[i] if kv_cache else None
107
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
108
- if kv_entry:
109
- updated_kv_entires.append(kv_entry)
110
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
111
-
112
- x = self.final_norm(x)
113
- logits = self.lm_head(x) # (b, t, vocab_size)
114
- return {"logits": logits, "kv_cache": updated_kv_cache}
115
-
116
-
117
37
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
118
38
  """Returns the model config for an OpenELM model.
119
39
 
@@ -191,12 +111,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
191
111
  return config
192
112
 
193
113
 
194
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
195
- config = get_model_config(**kwargs)
196
- model = OpenELM(config)
197
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
198
- # Since embedding and lm-head use the same weight, we need to set strict
199
- # to False.
200
- loader.load(model, strict=False)
201
- model.eval()
202
- return model
114
+ def build_model(
115
+ checkpoint_path: str, **kwargs
116
+ ) -> model_builder.DecoderOnlyModel:
117
+ return model_builder.build_decoder_only_model(
118
+ checkpoint_path=checkpoint_path,
119
+ config=get_model_config(**kwargs),
120
+ tensor_names=TENSOR_NAMES,
121
+ )
@@ -15,14 +15,9 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
- from ai_edge_torch.generative.layers import attention
19
- from ai_edge_torch.generative.layers import builder
20
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
23
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
- import torch
25
- from torch import nn
26
21
 
27
22
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
23
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -38,78 +33,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
38
33
  )
39
34
 
40
35
 
41
- class Phi2(nn.Module):
42
- """A Phi-2 model built from the Edge Generative API layers."""
43
-
44
- def __init__(self, config: cfg.ModelConfig):
45
- super().__init__()
46
-
47
- # Construct model layers.
48
- self.lm_head = nn.Linear(
49
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
50
- )
51
- self.tok_embedding = nn.Embedding(
52
- config.vocab_size, config.embedding_dim, padding_idx=0
53
- )
54
- # Phi-2 has only one block config.
55
- block_config = config.block_config(0)
56
- self.transformer_blocks = nn.ModuleList(
57
- attention.TransformerBlock(block_config, config)
58
- for _ in range(config.num_layers)
59
- )
60
- self.final_norm = builder.build_norm(
61
- config.embedding_dim,
62
- config.final_norm_config,
63
- )
64
- attn_config = block_config.attn_config
65
- self.rope_cache = attn_utils.build_rope_cache(
66
- size=config.kv_cache_max,
67
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
68
- base=attn_config.rotary_base,
69
- )
70
- self.mask_cache = attn_utils.build_causal_mask_cache(
71
- size=config.kv_cache_max,
72
- )
73
- self.config = config
74
-
75
- @torch.inference_mode
76
- def forward(
77
- self,
78
- tokens: torch.Tensor,
79
- input_pos: torch.Tensor,
80
- kv_cache: kv_utils.KVCache,
81
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
82
- _, seq_len = tokens.size()
83
- assert self.config.max_seq_len >= seq_len, (
84
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
85
- f" {self.config.max_seq_len}"
86
- )
87
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
88
- "The number of transformer blocks and the number of KV cache entries"
89
- " must be the same."
90
- )
91
-
92
- cos, sin = self.rope_cache
93
- cos = cos.index_select(0, input_pos)
94
- sin = sin.index_select(0, input_pos)
95
- mask = self.mask_cache.index_select(2, input_pos)
96
- mask = mask[:, :, :, : self.config.kv_cache_max]
97
-
98
- x = self.tok_embedding(tokens)
99
-
100
- updated_kv_entires = []
101
- for i, block in enumerate(self.transformer_blocks):
102
- kv_entry = kv_cache.caches[i] if kv_cache else None
103
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
104
- if kv_entry:
105
- updated_kv_entires.append(kv_entry)
106
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
107
-
108
- x = self.final_norm(x)
109
- logits = self.lm_head(x) # (b, t, vocab_size)
110
- return {"logits": logits, "kv_cache": updated_kv_cache}
111
-
112
-
113
36
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
114
37
  """Returns the model config for a Phi-2 model.
115
38
 
@@ -154,6 +77,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
154
77
  block_configs=block_config,
155
78
  final_norm_config=norm_config,
156
79
  lm_head_use_bias=True,
80
+ lm_head_share_weight_with_embedding=False,
157
81
  enable_hlfb=True,
158
82
  )
159
83
  return config
@@ -169,11 +93,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
169
93
  return config
170
94
 
171
95
 
172
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
173
- """Instantiates the model instance and load checkpoint if provided."""
174
- config = get_model_config(**kwargs)
175
- model = Phi2(config)
176
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
177
- loader.load(model)
178
- model.eval()
179
- return model
96
+ def build_model(
97
+ checkpoint_path: str, **kwargs
98
+ ) -> model_builder.DecoderOnlyModel:
99
+ return model_builder.build_decoder_only_model(
100
+ checkpoint_path=checkpoint_path,
101
+ config=get_model_config(**kwargs),
102
+ tensor_names=TENSOR_NAMES,
103
+ )
@@ -18,14 +18,10 @@
18
18
  import math
19
19
  from typing import Tuple
20
20
 
21
- from ai_edge_torch.generative.layers import attention
22
- from ai_edge_torch.generative.layers import builder
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
+ from ai_edge_torch.generative.utilities import model_builder
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
24
  import torch
28
- from torch import nn
29
25
 
30
26
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
27
  ff_up_proj="model.layers.{}.mlp.gate_up_proj",
@@ -137,32 +133,14 @@ def _build_rope_cache(
137
133
  return cos, sin
138
134
 
139
135
 
140
- class Phi3_5Mini(nn.Module):
136
+ class Phi3_5Mini(model_builder.DecoderOnlyModel):
141
137
  """A Phi-3.5 model built from the Edge Generative API layers."""
142
138
 
143
139
  def __init__(self, config: cfg.ModelConfig):
144
- super().__init__()
145
-
146
- # Construct model layers.
147
- self.lm_head = nn.Linear(
148
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
149
- )
150
- self.tok_embedding = nn.Embedding(
151
- config.vocab_size, config.embedding_dim, padding_idx=0
152
- )
153
- # Phi-3.5 has only one block config.
154
- block_config = config.block_config(0)
155
- self.transformer_blocks = nn.ModuleList(
156
- attention.TransformerBlock(block_config, config)
157
- for _ in range(config.num_layers)
158
- )
159
- self.final_norm = builder.build_norm(
160
- config.embedding_dim,
161
- config.final_norm_config,
162
- )
163
- attn_config = block_config.attn_config
140
+ super().__init__(config)
141
+ attn_config = self.config.block_config(0).attn_config
164
142
  self.rope_cache = _build_rope_cache(
165
- size=config.kv_cache_max,
143
+ size=self.config.kv_cache_max,
166
144
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
167
145
  base=attn_config.rotary_base,
168
146
  condense_ratio=1,
@@ -173,47 +151,6 @@ class Phi3_5Mini(nn.Module):
173
151
  1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
174
152
  ),
175
153
  )
176
- self.mask_cache = attn_utils.build_causal_mask_cache(
177
- size=config.kv_cache_max,
178
- )
179
- self.config = config
180
-
181
- @torch.inference_mode
182
- def forward(
183
- self,
184
- tokens: torch.Tensor,
185
- input_pos: torch.Tensor,
186
- kv_cache: kv_utils.KVCache,
187
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
188
- _, seq_len = tokens.size()
189
- assert self.config.max_seq_len >= seq_len, (
190
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
191
- f" {self.config.max_seq_len}"
192
- )
193
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
194
- "The number of transformer blocks and the number of KV cache entries"
195
- " must be the same."
196
- )
197
-
198
- cos, sin = self.rope_cache
199
- cos = cos.index_select(0, input_pos)
200
- sin = sin.index_select(0, input_pos)
201
- mask = self.mask_cache.index_select(2, input_pos)
202
- mask = mask[:, :, :, : self.config.kv_cache_max]
203
-
204
- x = self.tok_embedding(tokens)
205
-
206
- updated_kv_entires = []
207
- for i, block in enumerate(self.transformer_blocks):
208
- kv_entry = kv_cache.caches[i] if kv_cache else None
209
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
210
- if kv_entry:
211
- updated_kv_entires.append(kv_entry)
212
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
213
-
214
- x = self.final_norm(x)
215
- logits = self.lm_head(x) # (b, t, vocab_size)
216
- return {"logits": logits, "kv_cache": updated_kv_cache}
217
154
 
218
155
 
219
156
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -254,6 +191,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
254
191
  embedding_dim=3072,
255
192
  block_configs=block_config,
256
193
  final_norm_config=norm_config,
194
+ lm_head_share_weight_with_embedding=False,
257
195
  enable_hlfb=True,
258
196
  )
259
197
  return config
@@ -269,7 +207,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
269
207
  return config
270
208
 
271
209
 
272
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
210
+ def build_model(
211
+ checkpoint_path: str, **kwargs
212
+ ) -> model_builder.DecoderOnlyModel:
273
213
  """Instantiates the model instance and load checkpoint if provided."""
274
214
  config = get_model_config(**kwargs)
275
215
  model = Phi3_5Mini(config)