ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241003__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
- ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
- ai_edge_torch/generative/examples/llama/llama.py +19 -24
- ai_edge_torch/generative/examples/llama/verify.py +18 -3
- ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
- ai_edge_torch/generative/examples/phi/phi2.py +10 -86
- ai_edge_torch/generative/examples/phi/phi3.py +9 -69
- ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
- ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
- ai_edge_torch/generative/layers/model_config.py +6 -0
- ai_edge_torch/generative/test/test_loader.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
- ai_edge_torch/generative/utilities/model_builder.py +141 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/RECORD +22 -23
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
- ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241003.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma1 model."""
|
17
17
|
|
18
|
-
from ai_edge_torch.generative.layers import attention
|
19
|
-
from ai_edge_torch.generative.layers import builder
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
23
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
24
|
-
import torch
|
25
|
-
from torch import nn
|
26
21
|
|
27
22
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
28
23
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -38,84 +33,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
38
33
|
)
|
39
34
|
|
40
35
|
|
41
|
-
class Gemma(nn.Module):
|
42
|
-
"""A Gemma model built from the Edge Generative API layers."""
|
43
|
-
|
44
|
-
def __init__(self, config: cfg.ModelConfig):
|
45
|
-
super().__init__()
|
46
|
-
|
47
|
-
# Construct model layers.
|
48
|
-
self.tok_embedding = nn.Embedding(
|
49
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
50
|
-
)
|
51
|
-
self.lm_head = nn.Linear(
|
52
|
-
config.embedding_dim,
|
53
|
-
config.vocab_size,
|
54
|
-
bias=config.lm_head_use_bias,
|
55
|
-
)
|
56
|
-
# Gemma re-uses the embedding as the head projection layer.
|
57
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
58
|
-
# Gemma has only one block config.
|
59
|
-
block_config = config.block_config(0)
|
60
|
-
self.transformer_blocks = nn.ModuleList(
|
61
|
-
attention.TransformerBlock(block_config, config)
|
62
|
-
for _ in range(config.num_layers)
|
63
|
-
)
|
64
|
-
self.final_norm = builder.build_norm(
|
65
|
-
config.embedding_dim,
|
66
|
-
config.final_norm_config,
|
67
|
-
)
|
68
|
-
attn_config = block_config.attn_config
|
69
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
70
|
-
size=config.kv_cache_max,
|
71
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
72
|
-
base=attn_config.rotary_base,
|
73
|
-
)
|
74
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
75
|
-
size=config.kv_cache_max,
|
76
|
-
)
|
77
|
-
self.config = config
|
78
|
-
|
79
|
-
@torch.inference_mode
|
80
|
-
def forward(
|
81
|
-
self,
|
82
|
-
tokens: torch.Tensor,
|
83
|
-
input_pos: torch.Tensor,
|
84
|
-
kv_cache: kv_utils.KVCache,
|
85
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
86
|
-
_, seq_len = tokens.size()
|
87
|
-
assert self.config.max_seq_len >= seq_len, (
|
88
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
89
|
-
f" {self.config.max_seq_len}"
|
90
|
-
)
|
91
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
92
|
-
"The number of transformer blocks and the number of KV cache entries"
|
93
|
-
" must be the same."
|
94
|
-
)
|
95
|
-
|
96
|
-
cos, sin = self.rope_cache
|
97
|
-
cos = cos.index_select(0, input_pos)
|
98
|
-
sin = sin.index_select(0, input_pos)
|
99
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
100
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
101
|
-
|
102
|
-
# token embeddings of shape (b, t, n_embd)
|
103
|
-
x = self.tok_embedding(tokens)
|
104
|
-
x = x * (self.config.embedding_dim**0.5)
|
105
|
-
|
106
|
-
updated_kv_entires = []
|
107
|
-
for i, block in enumerate(self.transformer_blocks):
|
108
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
109
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
110
|
-
if kv_entry:
|
111
|
-
updated_kv_entires.append(kv_entry)
|
112
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
113
|
-
|
114
|
-
x = self.final_norm(x)
|
115
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
116
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
117
|
-
|
118
|
-
|
119
36
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
120
37
|
"""Returns the model config for a Gemma 2B model.
|
121
38
|
|
@@ -154,6 +71,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
154
71
|
num_layers=18,
|
155
72
|
max_seq_len=8192,
|
156
73
|
embedding_dim=2048,
|
74
|
+
embedding_scale=2048**0.5,
|
157
75
|
kv_cache_max_len=kv_cache_max_len,
|
158
76
|
block_configs=block_config,
|
159
77
|
final_norm_config=norm_config,
|
@@ -173,12 +91,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
173
91
|
return config
|
174
92
|
|
175
93
|
|
176
|
-
def build_2b_model(
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
return model
|
94
|
+
def build_2b_model(
|
95
|
+
checkpoint_path: str, **kwargs
|
96
|
+
) -> model_builder.DecoderOnlyModel:
|
97
|
+
return model_builder.build_decoder_only_model(
|
98
|
+
checkpoint_path=checkpoint_path,
|
99
|
+
config=get_model_config_2b(**kwargs),
|
100
|
+
tensor_names=TENSOR_NAMES,
|
101
|
+
)
|
@@ -23,6 +23,12 @@ from absl import flags
|
|
23
23
|
from ai_edge_torch.generative.examples.llama import llama
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
25
|
|
26
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
27
|
+
'model_size',
|
28
|
+
'1b',
|
29
|
+
['1b', '3b'],
|
30
|
+
'The size of the model to verify.',
|
31
|
+
)
|
26
32
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
33
|
'checkpoint_path',
|
28
34
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
@@ -49,13 +55,18 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
55
|
'Whether the model should be quantized.',
|
50
56
|
)
|
51
57
|
|
58
|
+
_BUILDER = {
|
59
|
+
'1b': llama.build_1b_model,
|
60
|
+
'3b': llama.build_3b_model,
|
61
|
+
}
|
62
|
+
|
52
63
|
|
53
64
|
def main(_):
|
54
|
-
pytorch_model =
|
65
|
+
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
55
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
67
|
)
|
57
68
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'llama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
69
|
+
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
70
|
converter.convert_to_tflite(
|
60
71
|
pytorch_model,
|
61
72
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
@@ -15,19 +15,15 @@
|
|
15
15
|
|
16
16
|
"""Example of building Llama 3.2 models."""
|
17
17
|
|
18
|
-
import copy
|
19
18
|
import math
|
20
19
|
from typing import Tuple
|
21
20
|
|
22
|
-
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
+
from ai_edge_torch.generative.utilities import model_builder
|
24
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
24
|
import torch
|
26
|
-
from torch import nn
|
27
25
|
|
28
|
-
TENSOR_NAMES =
|
29
|
-
# SmolLM re-uses the embedding as the head projection layer.
|
30
|
-
TENSOR_NAMES.lm_head = None
|
26
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
31
27
|
|
32
28
|
|
33
29
|
def _build_llama3_rope_cache(
|
@@ -93,7 +89,7 @@ def _build_llama3_rope_cache(
|
|
93
89
|
return cos, sin
|
94
90
|
|
95
91
|
|
96
|
-
class Llama(
|
92
|
+
class Llama(model_builder.DecoderOnlyModel):
|
97
93
|
"""A Llama model built from the Edge Generative API layers.
|
98
94
|
|
99
95
|
Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
|
@@ -101,9 +97,6 @@ class Llama(tiny_llama.TinyLlama):
|
|
101
97
|
|
102
98
|
def __init__(self, config: cfg.ModelConfig):
|
103
99
|
super().__init__(config)
|
104
|
-
# Llama 3.2 re-uses the embedding as the head projection layer.
|
105
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
106
|
-
# Llama has only one block config.
|
107
100
|
attn_config = self.config.block_config(0).attn_config
|
108
101
|
self.rope_cache = _build_llama3_rope_cache(
|
109
102
|
size=self.config.kv_cache_max,
|
@@ -119,7 +112,7 @@ class Llama(tiny_llama.TinyLlama):
|
|
119
112
|
)
|
120
113
|
|
121
114
|
|
122
|
-
def
|
115
|
+
def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
123
116
|
"""Returns the model config for a Llama 3.2-1B model.
|
124
117
|
|
125
118
|
Args:
|
@@ -163,7 +156,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
163
156
|
|
164
157
|
def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
165
158
|
"""Returns the model config for a Llama 3.2-3B model."""
|
166
|
-
config =
|
159
|
+
config = get_1b_model_config(kv_cache_max_len)
|
167
160
|
# Llama 3.2 has only one block config.
|
168
161
|
attn_config = config.block_config(0).attn_config
|
169
162
|
attn_config.num_heads = 24
|
@@ -174,7 +167,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
174
167
|
|
175
168
|
|
176
169
|
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
177
|
-
config =
|
170
|
+
config = get_1b_model_config(**kwargs)
|
178
171
|
config.vocab_size = 128
|
179
172
|
config.num_layers = 2
|
180
173
|
# SmolLM has only one block config.
|
@@ -182,8 +175,9 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
182
175
|
return config
|
183
176
|
|
184
177
|
|
185
|
-
def
|
186
|
-
|
178
|
+
def _build_model(
|
179
|
+
checkpoint_path: str, config: cfg.ModelConfig
|
180
|
+
) -> model_builder.DecoderOnlyModel:
|
187
181
|
model = Llama(config)
|
188
182
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
189
183
|
# Since embedding and lm-head use the same weight, we need to set strict
|
@@ -193,12 +187,13 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
193
187
|
return model
|
194
188
|
|
195
189
|
|
196
|
-
def
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
190
|
+
def build_1b_model(
|
191
|
+
checkpoint_path: str, **kwargs
|
192
|
+
) -> model_builder.DecoderOnlyModel:
|
193
|
+
return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
|
194
|
+
|
195
|
+
|
196
|
+
def build_3b_model(
|
197
|
+
checkpoint_path: str, **kwargs
|
198
|
+
) -> model_builder.DecoderOnlyModel:
|
199
|
+
return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
|
@@ -25,7 +25,12 @@ from ai_edge_torch.generative.utilities import transformers_verifier
|
|
25
25
|
from ai_edge_torch.generative.utilities import verifier
|
26
26
|
import transformers
|
27
27
|
|
28
|
-
|
28
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
29
|
+
"model_size",
|
30
|
+
"1b",
|
31
|
+
["1b", "3b"],
|
32
|
+
"The size of the model to verify.",
|
33
|
+
)
|
29
34
|
_PROMPTS = flags.DEFINE_multi_string(
|
30
35
|
"prompts",
|
31
36
|
"What is the meaning of life?",
|
@@ -37,9 +42,19 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
37
42
|
"The maximum size of the generated tokens.",
|
38
43
|
)
|
39
44
|
|
45
|
+
_CHECKPOINT = {
|
46
|
+
"1b": "meta-llama/Llama-3.2-1B-Instruct",
|
47
|
+
"3b": "meta-llama/Llama-3.2-3B-Instruct",
|
48
|
+
}
|
49
|
+
|
50
|
+
_BUILDER = {
|
51
|
+
"1b": llama.build_1b_model,
|
52
|
+
"3b": llama.build_3b_model,
|
53
|
+
}
|
54
|
+
|
40
55
|
|
41
56
|
def main(_):
|
42
|
-
checkpoint =
|
57
|
+
checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
|
43
58
|
logging.info("Loading the original model from: %s", checkpoint)
|
44
59
|
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
60
|
|
@@ -49,7 +64,7 @@ def main(_):
|
|
49
64
|
)
|
50
65
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
66
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model =
|
67
|
+
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
53
68
|
|
54
69
|
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
70
|
# Llama tokenizer_config.json sets a fast tokenizer class explicitly,
|
@@ -15,14 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Example of building an OpenELM model."""
|
17
17
|
|
18
|
-
from ai_edge_torch.generative.layers import attention
|
19
|
-
from ai_edge_torch.generative.layers import builder
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
23
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
24
|
-
import torch
|
25
|
-
from torch import nn
|
26
21
|
|
27
22
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
28
23
|
ff_up_proj="transformer.layers.{}.ffn.proj_1",
|
@@ -39,81 +34,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
39
34
|
)
|
40
35
|
|
41
36
|
|
42
|
-
class OpenELM(nn.Module):
|
43
|
-
"""An OpenELM model built from the Edge Generative API layers."""
|
44
|
-
|
45
|
-
def __init__(self, config: cfg.ModelConfig):
|
46
|
-
super().__init__()
|
47
|
-
|
48
|
-
# Construct model layers.
|
49
|
-
self.tok_embedding = nn.Embedding(
|
50
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
51
|
-
)
|
52
|
-
self.lm_head = nn.Linear(
|
53
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
54
|
-
)
|
55
|
-
# OpenELM re-uses the embedding as the head projection layer.
|
56
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
57
|
-
self.transformer_blocks = nn.ModuleList(
|
58
|
-
attention.TransformerBlock(config.block_config(idx), config)
|
59
|
-
for idx in range(config.num_layers)
|
60
|
-
)
|
61
|
-
self.final_norm = builder.build_norm(
|
62
|
-
config.embedding_dim,
|
63
|
-
config.final_norm_config,
|
64
|
-
)
|
65
|
-
# OpenELM has same hyper parameters for rotary_percentage and head_dim for
|
66
|
-
# each layer block. Use the first block.
|
67
|
-
attn_config = config.block_config(0).attn_config
|
68
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
69
|
-
size=config.kv_cache_max,
|
70
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
71
|
-
base=attn_config.rotary_base,
|
72
|
-
)
|
73
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
74
|
-
size=config.kv_cache_max,
|
75
|
-
)
|
76
|
-
self.config = config
|
77
|
-
|
78
|
-
@torch.inference_mode
|
79
|
-
def forward(
|
80
|
-
self,
|
81
|
-
tokens: torch.Tensor,
|
82
|
-
input_pos: torch.Tensor,
|
83
|
-
kv_cache: kv_utils.KVCache,
|
84
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
85
|
-
_, seq_len = tokens.size()
|
86
|
-
assert self.config.max_seq_len >= seq_len, (
|
87
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
88
|
-
f" {self.config.max_seq_len}"
|
89
|
-
)
|
90
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
91
|
-
"The number of transformer blocks and the number of KV cache entries"
|
92
|
-
" must be the same."
|
93
|
-
)
|
94
|
-
|
95
|
-
cos, sin = self.rope_cache
|
96
|
-
cos = cos.index_select(0, input_pos)
|
97
|
-
sin = sin.index_select(0, input_pos)
|
98
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
99
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
100
|
-
|
101
|
-
# token embeddings of shape (b, t, n_embd)
|
102
|
-
x = self.tok_embedding(tokens)
|
103
|
-
|
104
|
-
updated_kv_entires = []
|
105
|
-
for i, block in enumerate(self.transformer_blocks):
|
106
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
107
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
108
|
-
if kv_entry:
|
109
|
-
updated_kv_entires.append(kv_entry)
|
110
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
111
|
-
|
112
|
-
x = self.final_norm(x)
|
113
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
114
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
115
|
-
|
116
|
-
|
117
37
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
118
38
|
"""Returns the model config for an OpenELM model.
|
119
39
|
|
@@ -191,12 +111,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
191
111
|
return config
|
192
112
|
|
193
113
|
|
194
|
-
def build_model(
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
return model
|
114
|
+
def build_model(
|
115
|
+
checkpoint_path: str, **kwargs
|
116
|
+
) -> model_builder.DecoderOnlyModel:
|
117
|
+
return model_builder.build_decoder_only_model(
|
118
|
+
checkpoint_path=checkpoint_path,
|
119
|
+
config=get_model_config(**kwargs),
|
120
|
+
tensor_names=TENSOR_NAMES,
|
121
|
+
)
|
@@ -15,14 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-2 model."""
|
17
17
|
|
18
|
-
from ai_edge_torch.generative.layers import attention
|
19
|
-
from ai_edge_torch.generative.layers import builder
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
23
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
24
|
-
import torch
|
25
|
-
from torch import nn
|
26
21
|
|
27
22
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
28
23
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
@@ -38,78 +33,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
38
33
|
)
|
39
34
|
|
40
35
|
|
41
|
-
class Phi2(nn.Module):
|
42
|
-
"""A Phi-2 model built from the Edge Generative API layers."""
|
43
|
-
|
44
|
-
def __init__(self, config: cfg.ModelConfig):
|
45
|
-
super().__init__()
|
46
|
-
|
47
|
-
# Construct model layers.
|
48
|
-
self.lm_head = nn.Linear(
|
49
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
50
|
-
)
|
51
|
-
self.tok_embedding = nn.Embedding(
|
52
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
53
|
-
)
|
54
|
-
# Phi-2 has only one block config.
|
55
|
-
block_config = config.block_config(0)
|
56
|
-
self.transformer_blocks = nn.ModuleList(
|
57
|
-
attention.TransformerBlock(block_config, config)
|
58
|
-
for _ in range(config.num_layers)
|
59
|
-
)
|
60
|
-
self.final_norm = builder.build_norm(
|
61
|
-
config.embedding_dim,
|
62
|
-
config.final_norm_config,
|
63
|
-
)
|
64
|
-
attn_config = block_config.attn_config
|
65
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
66
|
-
size=config.kv_cache_max,
|
67
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
68
|
-
base=attn_config.rotary_base,
|
69
|
-
)
|
70
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
71
|
-
size=config.kv_cache_max,
|
72
|
-
)
|
73
|
-
self.config = config
|
74
|
-
|
75
|
-
@torch.inference_mode
|
76
|
-
def forward(
|
77
|
-
self,
|
78
|
-
tokens: torch.Tensor,
|
79
|
-
input_pos: torch.Tensor,
|
80
|
-
kv_cache: kv_utils.KVCache,
|
81
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
82
|
-
_, seq_len = tokens.size()
|
83
|
-
assert self.config.max_seq_len >= seq_len, (
|
84
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
85
|
-
f" {self.config.max_seq_len}"
|
86
|
-
)
|
87
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
88
|
-
"The number of transformer blocks and the number of KV cache entries"
|
89
|
-
" must be the same."
|
90
|
-
)
|
91
|
-
|
92
|
-
cos, sin = self.rope_cache
|
93
|
-
cos = cos.index_select(0, input_pos)
|
94
|
-
sin = sin.index_select(0, input_pos)
|
95
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
96
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
97
|
-
|
98
|
-
x = self.tok_embedding(tokens)
|
99
|
-
|
100
|
-
updated_kv_entires = []
|
101
|
-
for i, block in enumerate(self.transformer_blocks):
|
102
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
103
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
104
|
-
if kv_entry:
|
105
|
-
updated_kv_entires.append(kv_entry)
|
106
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
107
|
-
|
108
|
-
x = self.final_norm(x)
|
109
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
110
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
111
|
-
|
112
|
-
|
113
36
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
114
37
|
"""Returns the model config for a Phi-2 model.
|
115
38
|
|
@@ -154,6 +77,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
154
77
|
block_configs=block_config,
|
155
78
|
final_norm_config=norm_config,
|
156
79
|
lm_head_use_bias=True,
|
80
|
+
lm_head_share_weight_with_embedding=False,
|
157
81
|
enable_hlfb=True,
|
158
82
|
)
|
159
83
|
return config
|
@@ -169,11 +93,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
169
93
|
return config
|
170
94
|
|
171
95
|
|
172
|
-
def build_model(
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
96
|
+
def build_model(
|
97
|
+
checkpoint_path: str, **kwargs
|
98
|
+
) -> model_builder.DecoderOnlyModel:
|
99
|
+
return model_builder.build_decoder_only_model(
|
100
|
+
checkpoint_path=checkpoint_path,
|
101
|
+
config=get_model_config(**kwargs),
|
102
|
+
tensor_names=TENSOR_NAMES,
|
103
|
+
)
|
@@ -18,14 +18,10 @@
|
|
18
18
|
import math
|
19
19
|
from typing import Tuple
|
20
20
|
|
21
|
-
from ai_edge_torch.generative.layers import attention
|
22
|
-
from ai_edge_torch.generative.layers import builder
|
23
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
24
|
import torch
|
28
|
-
from torch import nn
|
29
25
|
|
30
26
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
27
|
ff_up_proj="model.layers.{}.mlp.gate_up_proj",
|
@@ -137,32 +133,14 @@ def _build_rope_cache(
|
|
137
133
|
return cos, sin
|
138
134
|
|
139
135
|
|
140
|
-
class Phi3_5Mini(
|
136
|
+
class Phi3_5Mini(model_builder.DecoderOnlyModel):
|
141
137
|
"""A Phi-3.5 model built from the Edge Generative API layers."""
|
142
138
|
|
143
139
|
def __init__(self, config: cfg.ModelConfig):
|
144
|
-
super().__init__()
|
145
|
-
|
146
|
-
# Construct model layers.
|
147
|
-
self.lm_head = nn.Linear(
|
148
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
149
|
-
)
|
150
|
-
self.tok_embedding = nn.Embedding(
|
151
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
152
|
-
)
|
153
|
-
# Phi-3.5 has only one block config.
|
154
|
-
block_config = config.block_config(0)
|
155
|
-
self.transformer_blocks = nn.ModuleList(
|
156
|
-
attention.TransformerBlock(block_config, config)
|
157
|
-
for _ in range(config.num_layers)
|
158
|
-
)
|
159
|
-
self.final_norm = builder.build_norm(
|
160
|
-
config.embedding_dim,
|
161
|
-
config.final_norm_config,
|
162
|
-
)
|
163
|
-
attn_config = block_config.attn_config
|
140
|
+
super().__init__(config)
|
141
|
+
attn_config = self.config.block_config(0).attn_config
|
164
142
|
self.rope_cache = _build_rope_cache(
|
165
|
-
size=config.kv_cache_max,
|
143
|
+
size=self.config.kv_cache_max,
|
166
144
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
167
145
|
base=attn_config.rotary_base,
|
168
146
|
condense_ratio=1,
|
@@ -173,47 +151,6 @@ class Phi3_5Mini(nn.Module):
|
|
173
151
|
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
|
174
152
|
),
|
175
153
|
)
|
176
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
177
|
-
size=config.kv_cache_max,
|
178
|
-
)
|
179
|
-
self.config = config
|
180
|
-
|
181
|
-
@torch.inference_mode
|
182
|
-
def forward(
|
183
|
-
self,
|
184
|
-
tokens: torch.Tensor,
|
185
|
-
input_pos: torch.Tensor,
|
186
|
-
kv_cache: kv_utils.KVCache,
|
187
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
188
|
-
_, seq_len = tokens.size()
|
189
|
-
assert self.config.max_seq_len >= seq_len, (
|
190
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
191
|
-
f" {self.config.max_seq_len}"
|
192
|
-
)
|
193
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
194
|
-
"The number of transformer blocks and the number of KV cache entries"
|
195
|
-
" must be the same."
|
196
|
-
)
|
197
|
-
|
198
|
-
cos, sin = self.rope_cache
|
199
|
-
cos = cos.index_select(0, input_pos)
|
200
|
-
sin = sin.index_select(0, input_pos)
|
201
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
202
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
203
|
-
|
204
|
-
x = self.tok_embedding(tokens)
|
205
|
-
|
206
|
-
updated_kv_entires = []
|
207
|
-
for i, block in enumerate(self.transformer_blocks):
|
208
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
209
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
210
|
-
if kv_entry:
|
211
|
-
updated_kv_entires.append(kv_entry)
|
212
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
213
|
-
|
214
|
-
x = self.final_norm(x)
|
215
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
216
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
217
154
|
|
218
155
|
|
219
156
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -254,6 +191,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
254
191
|
embedding_dim=3072,
|
255
192
|
block_configs=block_config,
|
256
193
|
final_norm_config=norm_config,
|
194
|
+
lm_head_share_weight_with_embedding=False,
|
257
195
|
enable_hlfb=True,
|
258
196
|
)
|
259
197
|
return config
|
@@ -269,7 +207,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
269
207
|
return config
|
270
208
|
|
271
209
|
|
272
|
-
def build_model(
|
210
|
+
def build_model(
|
211
|
+
checkpoint_path: str, **kwargs
|
212
|
+
) -> model_builder.DecoderOnlyModel:
|
273
213
|
"""Instantiates the model instance and load checkpoint if provided."""
|
274
214
|
config = get_model_config(**kwargs)
|
275
215
|
model = Phi3_5Mini(config)
|