ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241005__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
- ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
- ai_edge_torch/generative/examples/llama/llama.py +19 -24
- ai_edge_torch/generative/examples/llama/verify.py +18 -3
- ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
- ai_edge_torch/generative/examples/phi/phi2.py +10 -86
- ai_edge_torch/generative/examples/phi/phi3.py +9 -69
- ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
- ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +40 -32
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
- ai_edge_torch/generative/layers/model_config.py +6 -0
- ai_edge_torch/generative/test/test_loader.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion.py +39 -17
- ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
- ai_edge_torch/generative/utilities/model_builder.py +141 -0
- ai_edge_torch/lowertools/translate_recipe.py +2 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/RECORD +25 -26
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
- ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma1 model."""
|
17
17
|
|
18
|
-
from ai_edge_torch.generative.layers import attention
|
19
|
-
from ai_edge_torch.generative.layers import builder
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
23
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
24
|
-
import torch
|
25
|
-
from torch import nn
|
26
21
|
|
27
22
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
28
23
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -38,84 +33,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
38
33
|
)
|
39
34
|
|
40
35
|
|
41
|
-
class Gemma(nn.Module):
|
42
|
-
"""A Gemma model built from the Edge Generative API layers."""
|
43
|
-
|
44
|
-
def __init__(self, config: cfg.ModelConfig):
|
45
|
-
super().__init__()
|
46
|
-
|
47
|
-
# Construct model layers.
|
48
|
-
self.tok_embedding = nn.Embedding(
|
49
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
50
|
-
)
|
51
|
-
self.lm_head = nn.Linear(
|
52
|
-
config.embedding_dim,
|
53
|
-
config.vocab_size,
|
54
|
-
bias=config.lm_head_use_bias,
|
55
|
-
)
|
56
|
-
# Gemma re-uses the embedding as the head projection layer.
|
57
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
58
|
-
# Gemma has only one block config.
|
59
|
-
block_config = config.block_config(0)
|
60
|
-
self.transformer_blocks = nn.ModuleList(
|
61
|
-
attention.TransformerBlock(block_config, config)
|
62
|
-
for _ in range(config.num_layers)
|
63
|
-
)
|
64
|
-
self.final_norm = builder.build_norm(
|
65
|
-
config.embedding_dim,
|
66
|
-
config.final_norm_config,
|
67
|
-
)
|
68
|
-
attn_config = block_config.attn_config
|
69
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
70
|
-
size=config.kv_cache_max,
|
71
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
72
|
-
base=attn_config.rotary_base,
|
73
|
-
)
|
74
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
75
|
-
size=config.kv_cache_max,
|
76
|
-
)
|
77
|
-
self.config = config
|
78
|
-
|
79
|
-
@torch.inference_mode
|
80
|
-
def forward(
|
81
|
-
self,
|
82
|
-
tokens: torch.Tensor,
|
83
|
-
input_pos: torch.Tensor,
|
84
|
-
kv_cache: kv_utils.KVCache,
|
85
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
86
|
-
_, seq_len = tokens.size()
|
87
|
-
assert self.config.max_seq_len >= seq_len, (
|
88
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
89
|
-
f" {self.config.max_seq_len}"
|
90
|
-
)
|
91
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
92
|
-
"The number of transformer blocks and the number of KV cache entries"
|
93
|
-
" must be the same."
|
94
|
-
)
|
95
|
-
|
96
|
-
cos, sin = self.rope_cache
|
97
|
-
cos = cos.index_select(0, input_pos)
|
98
|
-
sin = sin.index_select(0, input_pos)
|
99
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
100
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
101
|
-
|
102
|
-
# token embeddings of shape (b, t, n_embd)
|
103
|
-
x = self.tok_embedding(tokens)
|
104
|
-
x = x * (self.config.embedding_dim**0.5)
|
105
|
-
|
106
|
-
updated_kv_entires = []
|
107
|
-
for i, block in enumerate(self.transformer_blocks):
|
108
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
109
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
110
|
-
if kv_entry:
|
111
|
-
updated_kv_entires.append(kv_entry)
|
112
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
113
|
-
|
114
|
-
x = self.final_norm(x)
|
115
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
116
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
117
|
-
|
118
|
-
|
119
36
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
120
37
|
"""Returns the model config for a Gemma 2B model.
|
121
38
|
|
@@ -154,6 +71,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
154
71
|
num_layers=18,
|
155
72
|
max_seq_len=8192,
|
156
73
|
embedding_dim=2048,
|
74
|
+
embedding_scale=2048**0.5,
|
157
75
|
kv_cache_max_len=kv_cache_max_len,
|
158
76
|
block_configs=block_config,
|
159
77
|
final_norm_config=norm_config,
|
@@ -173,12 +91,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
173
91
|
return config
|
174
92
|
|
175
93
|
|
176
|
-
def build_2b_model(
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
return model
|
94
|
+
def build_2b_model(
|
95
|
+
checkpoint_path: str, **kwargs
|
96
|
+
) -> model_builder.DecoderOnlyModel:
|
97
|
+
return model_builder.build_decoder_only_model(
|
98
|
+
checkpoint_path=checkpoint_path,
|
99
|
+
config=get_model_config_2b(**kwargs),
|
100
|
+
tensor_names=TENSOR_NAMES,
|
101
|
+
)
|
@@ -23,6 +23,12 @@ from absl import flags
|
|
23
23
|
from ai_edge_torch.generative.examples.llama import llama
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
25
|
|
26
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
27
|
+
'model_size',
|
28
|
+
'1b',
|
29
|
+
['1b', '3b'],
|
30
|
+
'The size of the model to verify.',
|
31
|
+
)
|
26
32
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
33
|
'checkpoint_path',
|
28
34
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
@@ -49,13 +55,18 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
55
|
'Whether the model should be quantized.',
|
50
56
|
)
|
51
57
|
|
58
|
+
_BUILDER = {
|
59
|
+
'1b': llama.build_1b_model,
|
60
|
+
'3b': llama.build_3b_model,
|
61
|
+
}
|
62
|
+
|
52
63
|
|
53
64
|
def main(_):
|
54
|
-
pytorch_model =
|
65
|
+
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
55
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
67
|
)
|
57
68
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
-
output_filename = f'llama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
69
|
+
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
70
|
converter.convert_to_tflite(
|
60
71
|
pytorch_model,
|
61
72
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
@@ -15,19 +15,15 @@
|
|
15
15
|
|
16
16
|
"""Example of building Llama 3.2 models."""
|
17
17
|
|
18
|
-
import copy
|
19
18
|
import math
|
20
19
|
from typing import Tuple
|
21
20
|
|
22
|
-
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
+
from ai_edge_torch.generative.utilities import model_builder
|
24
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
24
|
import torch
|
26
|
-
from torch import nn
|
27
25
|
|
28
|
-
TENSOR_NAMES =
|
29
|
-
# SmolLM re-uses the embedding as the head projection layer.
|
30
|
-
TENSOR_NAMES.lm_head = None
|
26
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
31
27
|
|
32
28
|
|
33
29
|
def _build_llama3_rope_cache(
|
@@ -93,7 +89,7 @@ def _build_llama3_rope_cache(
|
|
93
89
|
return cos, sin
|
94
90
|
|
95
91
|
|
96
|
-
class Llama(
|
92
|
+
class Llama(model_builder.DecoderOnlyModel):
|
97
93
|
"""A Llama model built from the Edge Generative API layers.
|
98
94
|
|
99
95
|
Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
|
@@ -101,9 +97,6 @@ class Llama(tiny_llama.TinyLlama):
|
|
101
97
|
|
102
98
|
def __init__(self, config: cfg.ModelConfig):
|
103
99
|
super().__init__(config)
|
104
|
-
# Llama 3.2 re-uses the embedding as the head projection layer.
|
105
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
106
|
-
# Llama has only one block config.
|
107
100
|
attn_config = self.config.block_config(0).attn_config
|
108
101
|
self.rope_cache = _build_llama3_rope_cache(
|
109
102
|
size=self.config.kv_cache_max,
|
@@ -119,7 +112,7 @@ class Llama(tiny_llama.TinyLlama):
|
|
119
112
|
)
|
120
113
|
|
121
114
|
|
122
|
-
def
|
115
|
+
def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
123
116
|
"""Returns the model config for a Llama 3.2-1B model.
|
124
117
|
|
125
118
|
Args:
|
@@ -163,7 +156,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
163
156
|
|
164
157
|
def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
165
158
|
"""Returns the model config for a Llama 3.2-3B model."""
|
166
|
-
config =
|
159
|
+
config = get_1b_model_config(kv_cache_max_len)
|
167
160
|
# Llama 3.2 has only one block config.
|
168
161
|
attn_config = config.block_config(0).attn_config
|
169
162
|
attn_config.num_heads = 24
|
@@ -174,7 +167,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
174
167
|
|
175
168
|
|
176
169
|
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
177
|
-
config =
|
170
|
+
config = get_1b_model_config(**kwargs)
|
178
171
|
config.vocab_size = 128
|
179
172
|
config.num_layers = 2
|
180
173
|
# SmolLM has only one block config.
|
@@ -182,8 +175,9 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
182
175
|
return config
|
183
176
|
|
184
177
|
|
185
|
-
def
|
186
|
-
|
178
|
+
def _build_model(
|
179
|
+
checkpoint_path: str, config: cfg.ModelConfig
|
180
|
+
) -> model_builder.DecoderOnlyModel:
|
187
181
|
model = Llama(config)
|
188
182
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
189
183
|
# Since embedding and lm-head use the same weight, we need to set strict
|
@@ -193,12 +187,13 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
193
187
|
return model
|
194
188
|
|
195
189
|
|
196
|
-
def
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
190
|
+
def build_1b_model(
|
191
|
+
checkpoint_path: str, **kwargs
|
192
|
+
) -> model_builder.DecoderOnlyModel:
|
193
|
+
return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
|
194
|
+
|
195
|
+
|
196
|
+
def build_3b_model(
|
197
|
+
checkpoint_path: str, **kwargs
|
198
|
+
) -> model_builder.DecoderOnlyModel:
|
199
|
+
return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
|
@@ -25,7 +25,12 @@ from ai_edge_torch.generative.utilities import transformers_verifier
|
|
25
25
|
from ai_edge_torch.generative.utilities import verifier
|
26
26
|
import transformers
|
27
27
|
|
28
|
-
|
28
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
29
|
+
"model_size",
|
30
|
+
"1b",
|
31
|
+
["1b", "3b"],
|
32
|
+
"The size of the model to verify.",
|
33
|
+
)
|
29
34
|
_PROMPTS = flags.DEFINE_multi_string(
|
30
35
|
"prompts",
|
31
36
|
"What is the meaning of life?",
|
@@ -37,9 +42,19 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
37
42
|
"The maximum size of the generated tokens.",
|
38
43
|
)
|
39
44
|
|
45
|
+
_CHECKPOINT = {
|
46
|
+
"1b": "meta-llama/Llama-3.2-1B-Instruct",
|
47
|
+
"3b": "meta-llama/Llama-3.2-3B-Instruct",
|
48
|
+
}
|
49
|
+
|
50
|
+
_BUILDER = {
|
51
|
+
"1b": llama.build_1b_model,
|
52
|
+
"3b": llama.build_3b_model,
|
53
|
+
}
|
54
|
+
|
40
55
|
|
41
56
|
def main(_):
|
42
|
-
checkpoint =
|
57
|
+
checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
|
43
58
|
logging.info("Loading the original model from: %s", checkpoint)
|
44
59
|
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
60
|
|
@@ -49,7 +64,7 @@ def main(_):
|
|
49
64
|
)
|
50
65
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
66
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model =
|
67
|
+
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
53
68
|
|
54
69
|
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
70
|
# Llama tokenizer_config.json sets a fast tokenizer class explicitly,
|
@@ -15,14 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Example of building an OpenELM model."""
|
17
17
|
|
18
|
-
from ai_edge_torch.generative.layers import attention
|
19
|
-
from ai_edge_torch.generative.layers import builder
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
23
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
24
|
-
import torch
|
25
|
-
from torch import nn
|
26
21
|
|
27
22
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
28
23
|
ff_up_proj="transformer.layers.{}.ffn.proj_1",
|
@@ -39,81 +34,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
39
34
|
)
|
40
35
|
|
41
36
|
|
42
|
-
class OpenELM(nn.Module):
|
43
|
-
"""An OpenELM model built from the Edge Generative API layers."""
|
44
|
-
|
45
|
-
def __init__(self, config: cfg.ModelConfig):
|
46
|
-
super().__init__()
|
47
|
-
|
48
|
-
# Construct model layers.
|
49
|
-
self.tok_embedding = nn.Embedding(
|
50
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
51
|
-
)
|
52
|
-
self.lm_head = nn.Linear(
|
53
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
54
|
-
)
|
55
|
-
# OpenELM re-uses the embedding as the head projection layer.
|
56
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
57
|
-
self.transformer_blocks = nn.ModuleList(
|
58
|
-
attention.TransformerBlock(config.block_config(idx), config)
|
59
|
-
for idx in range(config.num_layers)
|
60
|
-
)
|
61
|
-
self.final_norm = builder.build_norm(
|
62
|
-
config.embedding_dim,
|
63
|
-
config.final_norm_config,
|
64
|
-
)
|
65
|
-
# OpenELM has same hyper parameters for rotary_percentage and head_dim for
|
66
|
-
# each layer block. Use the first block.
|
67
|
-
attn_config = config.block_config(0).attn_config
|
68
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
69
|
-
size=config.kv_cache_max,
|
70
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
71
|
-
base=attn_config.rotary_base,
|
72
|
-
)
|
73
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
74
|
-
size=config.kv_cache_max,
|
75
|
-
)
|
76
|
-
self.config = config
|
77
|
-
|
78
|
-
@torch.inference_mode
|
79
|
-
def forward(
|
80
|
-
self,
|
81
|
-
tokens: torch.Tensor,
|
82
|
-
input_pos: torch.Tensor,
|
83
|
-
kv_cache: kv_utils.KVCache,
|
84
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
85
|
-
_, seq_len = tokens.size()
|
86
|
-
assert self.config.max_seq_len >= seq_len, (
|
87
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
88
|
-
f" {self.config.max_seq_len}"
|
89
|
-
)
|
90
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
91
|
-
"The number of transformer blocks and the number of KV cache entries"
|
92
|
-
" must be the same."
|
93
|
-
)
|
94
|
-
|
95
|
-
cos, sin = self.rope_cache
|
96
|
-
cos = cos.index_select(0, input_pos)
|
97
|
-
sin = sin.index_select(0, input_pos)
|
98
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
99
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
100
|
-
|
101
|
-
# token embeddings of shape (b, t, n_embd)
|
102
|
-
x = self.tok_embedding(tokens)
|
103
|
-
|
104
|
-
updated_kv_entires = []
|
105
|
-
for i, block in enumerate(self.transformer_blocks):
|
106
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
107
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
108
|
-
if kv_entry:
|
109
|
-
updated_kv_entires.append(kv_entry)
|
110
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
111
|
-
|
112
|
-
x = self.final_norm(x)
|
113
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
114
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
115
|
-
|
116
|
-
|
117
37
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
118
38
|
"""Returns the model config for an OpenELM model.
|
119
39
|
|
@@ -191,12 +111,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
191
111
|
return config
|
192
112
|
|
193
113
|
|
194
|
-
def build_model(
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
return model
|
114
|
+
def build_model(
|
115
|
+
checkpoint_path: str, **kwargs
|
116
|
+
) -> model_builder.DecoderOnlyModel:
|
117
|
+
return model_builder.build_decoder_only_model(
|
118
|
+
checkpoint_path=checkpoint_path,
|
119
|
+
config=get_model_config(**kwargs),
|
120
|
+
tensor_names=TENSOR_NAMES,
|
121
|
+
)
|
@@ -15,14 +15,9 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-2 model."""
|
17
17
|
|
18
|
-
from ai_edge_torch.generative.layers import attention
|
19
|
-
from ai_edge_torch.generative.layers import builder
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
23
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
24
|
-
import torch
|
25
|
-
from torch import nn
|
26
21
|
|
27
22
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
28
23
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
@@ -38,78 +33,6 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
38
33
|
)
|
39
34
|
|
40
35
|
|
41
|
-
class Phi2(nn.Module):
|
42
|
-
"""A Phi-2 model built from the Edge Generative API layers."""
|
43
|
-
|
44
|
-
def __init__(self, config: cfg.ModelConfig):
|
45
|
-
super().__init__()
|
46
|
-
|
47
|
-
# Construct model layers.
|
48
|
-
self.lm_head = nn.Linear(
|
49
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
50
|
-
)
|
51
|
-
self.tok_embedding = nn.Embedding(
|
52
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
53
|
-
)
|
54
|
-
# Phi-2 has only one block config.
|
55
|
-
block_config = config.block_config(0)
|
56
|
-
self.transformer_blocks = nn.ModuleList(
|
57
|
-
attention.TransformerBlock(block_config, config)
|
58
|
-
for _ in range(config.num_layers)
|
59
|
-
)
|
60
|
-
self.final_norm = builder.build_norm(
|
61
|
-
config.embedding_dim,
|
62
|
-
config.final_norm_config,
|
63
|
-
)
|
64
|
-
attn_config = block_config.attn_config
|
65
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
66
|
-
size=config.kv_cache_max,
|
67
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
68
|
-
base=attn_config.rotary_base,
|
69
|
-
)
|
70
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
71
|
-
size=config.kv_cache_max,
|
72
|
-
)
|
73
|
-
self.config = config
|
74
|
-
|
75
|
-
@torch.inference_mode
|
76
|
-
def forward(
|
77
|
-
self,
|
78
|
-
tokens: torch.Tensor,
|
79
|
-
input_pos: torch.Tensor,
|
80
|
-
kv_cache: kv_utils.KVCache,
|
81
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
82
|
-
_, seq_len = tokens.size()
|
83
|
-
assert self.config.max_seq_len >= seq_len, (
|
84
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
85
|
-
f" {self.config.max_seq_len}"
|
86
|
-
)
|
87
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
88
|
-
"The number of transformer blocks and the number of KV cache entries"
|
89
|
-
" must be the same."
|
90
|
-
)
|
91
|
-
|
92
|
-
cos, sin = self.rope_cache
|
93
|
-
cos = cos.index_select(0, input_pos)
|
94
|
-
sin = sin.index_select(0, input_pos)
|
95
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
96
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
97
|
-
|
98
|
-
x = self.tok_embedding(tokens)
|
99
|
-
|
100
|
-
updated_kv_entires = []
|
101
|
-
for i, block in enumerate(self.transformer_blocks):
|
102
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
103
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
104
|
-
if kv_entry:
|
105
|
-
updated_kv_entires.append(kv_entry)
|
106
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
107
|
-
|
108
|
-
x = self.final_norm(x)
|
109
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
110
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
111
|
-
|
112
|
-
|
113
36
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
114
37
|
"""Returns the model config for a Phi-2 model.
|
115
38
|
|
@@ -154,6 +77,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
154
77
|
block_configs=block_config,
|
155
78
|
final_norm_config=norm_config,
|
156
79
|
lm_head_use_bias=True,
|
80
|
+
lm_head_share_weight_with_embedding=False,
|
157
81
|
enable_hlfb=True,
|
158
82
|
)
|
159
83
|
return config
|
@@ -169,11 +93,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
169
93
|
return config
|
170
94
|
|
171
95
|
|
172
|
-
def build_model(
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
96
|
+
def build_model(
|
97
|
+
checkpoint_path: str, **kwargs
|
98
|
+
) -> model_builder.DecoderOnlyModel:
|
99
|
+
return model_builder.build_decoder_only_model(
|
100
|
+
checkpoint_path=checkpoint_path,
|
101
|
+
config=get_model_config(**kwargs),
|
102
|
+
tensor_names=TENSOR_NAMES,
|
103
|
+
)
|
@@ -18,14 +18,10 @@
|
|
18
18
|
import math
|
19
19
|
from typing import Tuple
|
20
20
|
|
21
|
-
from ai_edge_torch.generative.layers import attention
|
22
|
-
from ai_edge_torch.generative.layers import builder
|
23
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
24
|
import torch
|
28
|
-
from torch import nn
|
29
25
|
|
30
26
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
27
|
ff_up_proj="model.layers.{}.mlp.gate_up_proj",
|
@@ -137,32 +133,14 @@ def _build_rope_cache(
|
|
137
133
|
return cos, sin
|
138
134
|
|
139
135
|
|
140
|
-
class Phi3_5Mini(
|
136
|
+
class Phi3_5Mini(model_builder.DecoderOnlyModel):
|
141
137
|
"""A Phi-3.5 model built from the Edge Generative API layers."""
|
142
138
|
|
143
139
|
def __init__(self, config: cfg.ModelConfig):
|
144
|
-
super().__init__()
|
145
|
-
|
146
|
-
# Construct model layers.
|
147
|
-
self.lm_head = nn.Linear(
|
148
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
149
|
-
)
|
150
|
-
self.tok_embedding = nn.Embedding(
|
151
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
152
|
-
)
|
153
|
-
# Phi-3.5 has only one block config.
|
154
|
-
block_config = config.block_config(0)
|
155
|
-
self.transformer_blocks = nn.ModuleList(
|
156
|
-
attention.TransformerBlock(block_config, config)
|
157
|
-
for _ in range(config.num_layers)
|
158
|
-
)
|
159
|
-
self.final_norm = builder.build_norm(
|
160
|
-
config.embedding_dim,
|
161
|
-
config.final_norm_config,
|
162
|
-
)
|
163
|
-
attn_config = block_config.attn_config
|
140
|
+
super().__init__(config)
|
141
|
+
attn_config = self.config.block_config(0).attn_config
|
164
142
|
self.rope_cache = _build_rope_cache(
|
165
|
-
size=config.kv_cache_max,
|
143
|
+
size=self.config.kv_cache_max,
|
166
144
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
167
145
|
base=attn_config.rotary_base,
|
168
146
|
condense_ratio=1,
|
@@ -173,47 +151,6 @@ class Phi3_5Mini(nn.Module):
|
|
173
151
|
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
|
174
152
|
),
|
175
153
|
)
|
176
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
177
|
-
size=config.kv_cache_max,
|
178
|
-
)
|
179
|
-
self.config = config
|
180
|
-
|
181
|
-
@torch.inference_mode
|
182
|
-
def forward(
|
183
|
-
self,
|
184
|
-
tokens: torch.Tensor,
|
185
|
-
input_pos: torch.Tensor,
|
186
|
-
kv_cache: kv_utils.KVCache,
|
187
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
188
|
-
_, seq_len = tokens.size()
|
189
|
-
assert self.config.max_seq_len >= seq_len, (
|
190
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
191
|
-
f" {self.config.max_seq_len}"
|
192
|
-
)
|
193
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
194
|
-
"The number of transformer blocks and the number of KV cache entries"
|
195
|
-
" must be the same."
|
196
|
-
)
|
197
|
-
|
198
|
-
cos, sin = self.rope_cache
|
199
|
-
cos = cos.index_select(0, input_pos)
|
200
|
-
sin = sin.index_select(0, input_pos)
|
201
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
202
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
203
|
-
|
204
|
-
x = self.tok_embedding(tokens)
|
205
|
-
|
206
|
-
updated_kv_entires = []
|
207
|
-
for i, block in enumerate(self.transformer_blocks):
|
208
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
209
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
210
|
-
if kv_entry:
|
211
|
-
updated_kv_entires.append(kv_entry)
|
212
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
213
|
-
|
214
|
-
x = self.final_norm(x)
|
215
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
216
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
217
154
|
|
218
155
|
|
219
156
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -254,6 +191,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
254
191
|
embedding_dim=3072,
|
255
192
|
block_configs=block_config,
|
256
193
|
final_norm_config=norm_config,
|
194
|
+
lm_head_share_weight_with_embedding=False,
|
257
195
|
enable_hlfb=True,
|
258
196
|
)
|
259
197
|
return config
|
@@ -269,7 +207,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
269
207
|
return config
|
270
208
|
|
271
209
|
|
272
|
-
def build_model(
|
210
|
+
def build_model(
|
211
|
+
checkpoint_path: str, **kwargs
|
212
|
+
) -> model_builder.DecoderOnlyModel:
|
273
213
|
"""Instantiates the model instance and load checkpoint if provided."""
|
274
214
|
config = get_model_config(**kwargs)
|
275
215
|
model = Phi3_5Mini(config)
|