ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +35 -13
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
- ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
- ai_edge_torch/generative/examples/t5/t5.py +35 -22
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
- ai_edge_torch/generative/layers/attention.py +77 -73
- ai_edge_torch/generative/layers/builder.py +5 -3
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +38 -19
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +15 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -12,14 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma2 model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
|
|
51
53
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
52
54
|
mask: Optional[torch.Tensor] = None,
|
53
55
|
input_pos: Optional[torch.Tensor] = None,
|
54
|
-
|
56
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
57
|
+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
|
55
58
|
"""Forward function of the Gemma2Block.
|
56
59
|
|
57
60
|
Exactly the same as TransformerBlock but we call the post-attention norm
|
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
|
|
62
65
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
63
66
|
mask (torch.Tensor): the optional mask tensor.
|
64
67
|
input_pos (torch.Tensor): the optional input position tensor.
|
68
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
65
69
|
|
66
70
|
Returns:
|
67
|
-
output activation from this transformer block
|
71
|
+
output activation from this transformer block, and updated kv cache (if
|
72
|
+
passed in).
|
68
73
|
"""
|
69
74
|
|
70
75
|
x_norm = self.pre_atten_norm(x)
|
71
|
-
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
|
76
|
+
attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
72
77
|
attn_out_norm = self.post_atten_norm(attn_out)
|
73
78
|
x = x + attn_out_norm
|
74
79
|
output = x + self.ff(x)
|
75
|
-
return output
|
80
|
+
return output, kv
|
76
81
|
|
77
82
|
|
78
83
|
class Gemma2(nn.Module):
|
@@ -81,7 +86,6 @@ class Gemma2(nn.Module):
|
|
81
86
|
def __init__(self, config: cfg.ModelConfig):
|
82
87
|
super().__init__()
|
83
88
|
|
84
|
-
self.config = config
|
85
89
|
# Construct model layers.
|
86
90
|
self.tok_embedding = nn.Embedding(
|
87
91
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
@@ -91,20 +95,22 @@ class Gemma2(nn.Module):
|
|
91
95
|
config.vocab_size,
|
92
96
|
bias=config.lm_head_use_bias,
|
93
97
|
)
|
94
|
-
#
|
98
|
+
# Gemma2 re-uses the embedding as the head projection layer.
|
95
99
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
96
100
|
self.transformer_blocks = nn.ModuleList(
|
97
|
-
Gemma2Block(config)
|
101
|
+
Gemma2Block(config.block_config(idx), config)
|
102
|
+
for idx in range(config.num_layers)
|
98
103
|
)
|
99
104
|
self.final_norm = builder.build_norm(
|
100
105
|
config.embedding_dim,
|
101
106
|
config.final_norm_config,
|
102
107
|
)
|
108
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
109
|
+
# types. Use the first layer.
|
110
|
+
attn_config = config.block_config(0).attn_config
|
103
111
|
self.rope_cache = attn_utils.build_rope_cache(
|
104
112
|
size=config.kv_cache_max,
|
105
|
-
dim=int(
|
106
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
107
|
-
),
|
113
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
108
114
|
base=10_000,
|
109
115
|
condense_ratio=1,
|
110
116
|
dtype=torch.float32,
|
@@ -115,47 +121,56 @@ class Gemma2(nn.Module):
|
|
115
121
|
dtype=torch.float32,
|
116
122
|
device=torch.device("cpu"),
|
117
123
|
)
|
118
|
-
|
119
124
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
120
125
|
size=config.kv_cache_max,
|
121
|
-
window_size=
|
126
|
+
window_size=attn_config.sliding_window_size,
|
122
127
|
dtype=torch.float32,
|
123
128
|
device=torch.device("cpu"),
|
124
129
|
)
|
125
|
-
|
126
130
|
self.config = config
|
127
131
|
|
128
132
|
def get_attention_mask(
|
129
|
-
self,
|
133
|
+
self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
|
130
134
|
) -> torch.Tensor:
|
131
|
-
if
|
132
|
-
|
133
|
-
self.config.attn_config.attn_types[idx]
|
134
|
-
== cfg.AttentionType.LOCAL_SLIDING
|
135
|
-
):
|
136
|
-
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
137
|
-
|
135
|
+
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
|
136
|
+
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
138
137
|
return self.mask_cache.index_select(2, input_pos)
|
139
138
|
|
140
139
|
@torch.inference_mode
|
141
|
-
def forward(
|
142
|
-
|
140
|
+
def forward(
|
141
|
+
self,
|
142
|
+
tokens: torch.Tensor,
|
143
|
+
input_pos: torch.Tensor,
|
144
|
+
kv_cache: kv_utils.KVCache,
|
145
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
146
|
+
_, seq_len = tokens.size()
|
143
147
|
assert self.config.max_seq_len >= seq_len, (
|
144
148
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
145
149
|
f" {self.config.max_seq_len}"
|
146
150
|
)
|
151
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
152
|
+
"The number of transformer blocks and the number of KV cache entries"
|
153
|
+
" must be the same."
|
154
|
+
)
|
147
155
|
|
148
156
|
cos, sin = self.rope_cache
|
149
157
|
cos = cos.index_select(0, input_pos)
|
150
158
|
sin = sin.index_select(0, input_pos)
|
151
159
|
|
152
160
|
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(
|
161
|
+
x = self.tok_embedding(tokens)
|
154
162
|
x = x * (self.config.embedding_dim**0.5)
|
155
163
|
|
164
|
+
updated_kv_entires = []
|
156
165
|
for i, block in enumerate(self.transformer_blocks):
|
157
|
-
mask = self.get_attention_mask(
|
158
|
-
|
166
|
+
mask = self.get_attention_mask(
|
167
|
+
block.config.attn_config.attn_type, input_pos
|
168
|
+
)
|
169
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
170
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
171
|
+
if kv_entry:
|
172
|
+
updated_kv_entires.append(kv_entry)
|
173
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
159
174
|
|
160
175
|
x = self.final_norm(x)
|
161
176
|
res = self.lm_head(x) # (b, t, vocab_size)
|
@@ -163,7 +178,8 @@ class Gemma2(nn.Module):
|
|
163
178
|
res = res / self.config.final_logit_softcap
|
164
179
|
res = torch.tanh(res)
|
165
180
|
res = res * self.config.final_logit_softcap
|
166
|
-
|
181
|
+
|
182
|
+
return {"logits": res, "kv_cache": updated_kv_cache}
|
167
183
|
|
168
184
|
|
169
185
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -176,18 +192,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
176
192
|
Returns:
|
177
193
|
The model config for a Gemma 2B model.
|
178
194
|
"""
|
179
|
-
attn_config = cfg.AttentionConfig(
|
180
|
-
num_heads=8,
|
181
|
-
head_dim=256,
|
182
|
-
num_query_groups=4,
|
183
|
-
rotary_percentage=1.0,
|
184
|
-
qkv_transpose_before_split=True,
|
185
|
-
logit_softcap=50.0,
|
186
|
-
sliding_window_size=4096,
|
187
|
-
attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
|
188
|
-
* 13,
|
189
|
-
)
|
190
|
-
|
191
195
|
norm_config = cfg.NormalizationConfig(
|
192
196
|
type=cfg.NormalizationType.RMS_NORM,
|
193
197
|
epsilon=1e-6,
|
@@ -200,18 +204,38 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
200
204
|
pre_ff_norm_config=norm_config,
|
201
205
|
post_ff_norm_config=norm_config,
|
202
206
|
)
|
207
|
+
|
208
|
+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
209
|
+
attn_config = cfg.AttentionConfig(
|
210
|
+
num_heads=8,
|
211
|
+
head_dim=256,
|
212
|
+
num_query_groups=4,
|
213
|
+
rotary_percentage=1.0,
|
214
|
+
qkv_transpose_before_split=True,
|
215
|
+
logit_softcap=50.0,
|
216
|
+
sliding_window_size=4096,
|
217
|
+
attn_type=(
|
218
|
+
cfg.AttentionType.GLOBAL
|
219
|
+
if idx % 2 == 0
|
220
|
+
else cfg.AttentionType.LOCAL_SLIDING
|
221
|
+
),
|
222
|
+
)
|
223
|
+
return cfg.TransformerBlockConfig(
|
224
|
+
attn_config=attn_config,
|
225
|
+
ff_config=ff_config,
|
226
|
+
pre_attention_norm_config=norm_config,
|
227
|
+
post_attention_norm_config=norm_config,
|
228
|
+
)
|
229
|
+
|
230
|
+
num_layers = 26
|
203
231
|
config = cfg.ModelConfig(
|
204
232
|
vocab_size=256000,
|
205
|
-
num_layers=
|
233
|
+
num_layers=num_layers,
|
206
234
|
max_seq_len=8192,
|
207
235
|
embedding_dim=2304,
|
208
236
|
kv_cache_max_len=kv_cache_max_len,
|
209
|
-
|
210
|
-
ff_config=ff_config,
|
211
|
-
pre_attention_norm_config=norm_config,
|
212
|
-
post_attention_norm_config=norm_config,
|
237
|
+
block_configs=[get_block_config(i) for i in range(num_layers)],
|
213
238
|
final_norm_config=norm_config,
|
214
|
-
parallel_residual=False,
|
215
239
|
lm_head_use_bias=False,
|
216
240
|
enable_hlfb=True,
|
217
241
|
final_logit_softcap=30.0,
|
@@ -221,14 +245,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
221
245
|
|
222
246
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
223
247
|
config = get_model_config_2b(kv_cache_max_len)
|
224
|
-
config.attn_config.num_heads = 4
|
225
|
-
config.attn_config.head_dim = 64
|
226
|
-
config.attn_config.sliding_window_size = 64
|
227
|
-
config.ff_config.intermediate_size = 128
|
228
248
|
config.vocab_size = 128
|
229
249
|
config.num_layers = 2
|
230
250
|
config.max_seq_len = 2 * kv_cache_max_len
|
231
251
|
config.embedding_dim = 128
|
252
|
+
config.block_configs = config.block_configs[: config.num_layers]
|
253
|
+
for block_config in config.block_configs:
|
254
|
+
block_config.attn_config.num_heads = 4
|
255
|
+
block_config.attn_config.head_dim = 64
|
256
|
+
block_config.attn_config.sliding_window_size = 64
|
257
|
+
block_config.ff_config.intermediate_size = 128
|
232
258
|
return config
|
233
259
|
|
234
260
|
|
@@ -236,21 +262,20 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
236
262
|
config = get_model_config_2b(**kwargs)
|
237
263
|
model = Gemma2(config)
|
238
264
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
239
|
-
#
|
265
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
240
266
|
# to False.
|
241
267
|
loader.load(model, strict=False)
|
242
268
|
model.eval()
|
243
269
|
return model
|
244
270
|
|
245
271
|
|
246
|
-
def define_and_run_2b() -> None:
|
272
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
247
273
|
"""Instantiates and runs a Gemma2 2B model."""
|
248
274
|
|
249
|
-
current_dir = Path(__file__).parent.resolve()
|
275
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
250
276
|
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
251
277
|
print("Running GEMMA 2")
|
252
278
|
kv_cache_max_len = 1024
|
253
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
|
254
279
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
255
280
|
toks = torch.from_numpy(
|
256
281
|
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
@@ -258,11 +283,13 @@ def define_and_run_2b() -> None:
|
|
258
283
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
259
284
|
tokens[0, :9] = toks
|
260
285
|
input_pos = torch.arange(0, kv_cache_max_len)
|
261
|
-
|
262
|
-
|
286
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
287
|
+
out = model.forward(tokens, input_pos, kv)
|
288
|
+
out_final = out["logits"][0, 8, :]
|
263
289
|
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
264
290
|
|
265
291
|
|
266
292
|
if __name__ == "__main__":
|
267
293
|
torch.set_printoptions(sci_mode=True)
|
268
|
-
|
294
|
+
path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
|
295
|
+
define_and_run_2b(path)
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of converting a Phi-2 model to multi-signature tflite model."""
|
18
17
|
|
19
18
|
import os
|
20
|
-
|
19
|
+
import pathlib
|
21
20
|
|
22
21
|
import ai_edge_torch
|
23
|
-
from ai_edge_torch.generative.examples.
|
24
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
25
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
26
25
|
import torch
|
27
26
|
|
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
|
|
32
31
|
kv_cache_max_len: int = 1024,
|
33
32
|
quantize: bool = True,
|
34
33
|
):
|
35
|
-
"""
|
34
|
+
"""Converts a Phi-2 model to multi-signature tflite model.
|
36
35
|
|
37
|
-
tflite model.
|
38
36
|
Args:
|
39
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
40
38
|
holding the checkpoint.
|
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
|
|
53
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
54
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
55
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
56
|
-
kv =
|
54
|
+
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
57
55
|
|
58
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
59
57
|
edge_model = (
|
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
|
|
77
75
|
)
|
78
76
|
.convert(quant_config=quant_config)
|
79
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
80
79
|
edge_model.export(
|
81
|
-
f'/tmp/
|
80
|
+
f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
82
81
|
)
|
83
82
|
|
84
83
|
|
85
84
|
if __name__ == '__main__':
|
86
|
-
|
87
|
-
convert_phi2_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
|
86
|
+
convert_phi2_to_tflite(path)
|
@@ -12,26 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Note: This is an experimental version of phi2 with external KV cache.
|
18
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of building a Phi-2 model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
22
|
-
from typing import Tuple
|
19
|
+
import pathlib
|
23
20
|
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
24
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
27
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
27
|
import numpy as np
|
31
28
|
import torch
|
32
29
|
from torch import nn
|
33
30
|
|
34
|
-
|
35
31
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
32
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
37
33
|
ff_down_proj="model.layers.{}.mlp.fc2",
|
@@ -52,7 +48,6 @@ class Phi2(nn.Module):
|
|
52
48
|
def __init__(self, config: cfg.ModelConfig):
|
53
49
|
super().__init__()
|
54
50
|
|
55
|
-
self.config = config
|
56
51
|
# Construct model layers.
|
57
52
|
self.lm_head = nn.Linear(
|
58
53
|
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
@@ -60,18 +55,20 @@ class Phi2(nn.Module):
|
|
60
55
|
self.tok_embedding = nn.Embedding(
|
61
56
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
62
57
|
)
|
58
|
+
# Phi-2 has only one block config.
|
59
|
+
block_config = config.block_config(0)
|
63
60
|
self.transformer_blocks = nn.ModuleList(
|
64
|
-
attention.TransformerBlock(
|
61
|
+
attention.TransformerBlock(block_config, config)
|
62
|
+
for _ in range(config.num_layers)
|
65
63
|
)
|
66
64
|
self.final_norm = builder.build_norm(
|
67
65
|
config.embedding_dim,
|
68
66
|
config.final_norm_config,
|
69
67
|
)
|
68
|
+
attn_config = block_config.attn_config
|
70
69
|
self.rope_cache = attn_utils.build_rope_cache(
|
71
70
|
size=config.kv_cache_max,
|
72
|
-
dim=int(
|
73
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
74
|
-
),
|
71
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
75
72
|
base=10_000,
|
76
73
|
condense_ratio=1,
|
77
74
|
dtype=torch.float32,
|
@@ -89,13 +86,17 @@ class Phi2(nn.Module):
|
|
89
86
|
self,
|
90
87
|
tokens: torch.Tensor,
|
91
88
|
input_pos: torch.Tensor,
|
92
|
-
kv_cache: kv_utils.
|
93
|
-
) ->
|
89
|
+
kv_cache: kv_utils.KVCache,
|
90
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
94
91
|
_, seq_len = tokens.size()
|
95
92
|
assert self.config.max_seq_len >= seq_len, (
|
96
93
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
97
94
|
f" {self.config.max_seq_len}"
|
98
95
|
)
|
96
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
97
|
+
"The number of transformer blocks and the number of KV cache entries"
|
98
|
+
" must be the same."
|
99
|
+
)
|
99
100
|
|
100
101
|
cos, sin = self.rope_cache
|
101
102
|
cos = cos.index_select(0, input_pos)
|
@@ -111,11 +112,11 @@ class Phi2(nn.Module):
|
|
111
112
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
112
113
|
if kv_entry:
|
113
114
|
updated_kv_entires.append(kv_entry)
|
114
|
-
updated_kv_cache = kv_utils.
|
115
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
115
116
|
|
116
117
|
x = self.final_norm(x)
|
117
|
-
|
118
|
-
return
|
118
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
119
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
119
120
|
|
120
121
|
|
121
122
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -143,17 +144,20 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
143
144
|
use_bias=True,
|
144
145
|
)
|
145
146
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
147
|
+
block_config = cfg.TransformerBlockConfig(
|
148
|
+
attn_config=attn_config,
|
149
|
+
ff_config=ff_config,
|
150
|
+
pre_attention_norm_config=norm_config,
|
151
|
+
parallel_residual=True,
|
152
|
+
)
|
146
153
|
config = cfg.ModelConfig(
|
147
154
|
vocab_size=51200,
|
148
155
|
num_layers=32,
|
149
156
|
max_seq_len=2048,
|
150
157
|
kv_cache_max_len=kv_cache_max_len,
|
151
158
|
embedding_dim=2560,
|
152
|
-
|
153
|
-
ff_config=ff_config,
|
154
|
-
pre_attention_norm_config=norm_config,
|
159
|
+
block_configs=block_config,
|
155
160
|
final_norm_config=norm_config,
|
156
|
-
parallel_residual=True,
|
157
161
|
lm_head_use_bias=True,
|
158
162
|
enable_hlfb=True,
|
159
163
|
)
|
@@ -165,43 +169,42 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
165
169
|
config.vocab_size = 128
|
166
170
|
config.num_layers = 2
|
167
171
|
config.max_seq_len = 2 * kv_cache_max_len
|
168
|
-
config.
|
172
|
+
# Phi-2 has only one block config.
|
173
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
169
174
|
return config
|
170
175
|
|
171
176
|
|
172
|
-
def build_model(
|
173
|
-
checkpoint_path: str, test_model: bool = False, **kwargs
|
174
|
-
) -> nn.Module:
|
177
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
175
178
|
"""Instantiates the model instance and load checkpoint if provided."""
|
176
|
-
config = (
|
177
|
-
get_fake_model_config(**kwargs)
|
178
|
-
if test_model
|
179
|
-
else get_model_config(**kwargs)
|
180
|
-
)
|
179
|
+
config = get_model_config(**kwargs)
|
181
180
|
model = Phi2(config)
|
182
|
-
|
183
|
-
|
184
|
-
loader.load(model)
|
181
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
182
|
+
loader.load(model)
|
185
183
|
model.eval()
|
186
184
|
return model
|
187
185
|
|
188
186
|
|
189
|
-
def define_and_run(checkpoint_path: str
|
187
|
+
def define_and_run(checkpoint_path: str) -> None:
|
190
188
|
"""Instantiates and runs a Phi-2 model."""
|
191
189
|
|
190
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
191
|
+
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
192
192
|
kv_cache_max_len = 1024
|
193
|
-
model = build_model(
|
194
|
-
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
195
|
-
)
|
193
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
196
194
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
197
195
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
198
196
|
tokens[0, :4] = idx
|
199
197
|
input_pos = torch.arange(0, kv_cache_max_len)
|
200
|
-
kv = kv_utils.
|
201
|
-
|
202
|
-
print(
|
198
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
199
|
+
output = model.forward(tokens, input_pos, kv)
|
200
|
+
print("comparing with goldens..")
|
201
|
+
assert torch.allclose(
|
202
|
+
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
203
|
+
)
|
203
204
|
|
204
205
|
|
205
206
|
if __name__ == "__main__":
|
206
|
-
input_checkpoint_path = os.path.join(
|
207
|
+
input_checkpoint_path = os.path.join(
|
208
|
+
pathlib.Path.home(), "Downloads/llm_data/phi2"
|
209
|
+
)
|
207
210
|
define_and_run(input_checkpoint_path)
|
@@ -12,30 +12,27 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
#
|
16
|
-
# Note: This is an experimental version of Gemma with external KV cache.
|
17
|
-
# Please use with caution.
|
18
15
|
|
16
|
+
"""Example of converting SmalLM model to multi-signature tflite model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
19
|
+
import pathlib
|
22
20
|
|
23
21
|
import ai_edge_torch
|
24
|
-
from ai_edge_torch.generative.examples.
|
25
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.smallm import smallm
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
27
25
|
import torch
|
28
26
|
|
29
27
|
|
30
|
-
def
|
28
|
+
def convert_smallm_to_tflite(
|
31
29
|
checkpoint_path: str,
|
32
30
|
prefill_seq_len: int = 512,
|
33
31
|
kv_cache_max_len: int = 1024,
|
34
32
|
quantize: bool = True,
|
35
33
|
):
|
36
|
-
"""
|
34
|
+
"""Converts SmalLM model to multi-signature tflite model.
|
37
35
|
|
38
|
-
tflite model.
|
39
36
|
Args:
|
40
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
41
38
|
holding the checkpoint.
|
@@ -46,7 +43,7 @@ def convert_gemma_to_tflite(
|
|
46
43
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
47
44
|
to True.
|
48
45
|
"""
|
49
|
-
pytorch_model =
|
46
|
+
pytorch_model = smallm.build_model(
|
50
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
51
48
|
)
|
52
49
|
# Tensors used to trace the model graph during conversion.
|
@@ -54,7 +51,7 @@ def convert_gemma_to_tflite(
|
|
54
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
55
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
56
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
57
|
-
kv = kv_utils.
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
58
55
|
|
59
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
60
57
|
edge_model = (
|
@@ -78,11 +75,12 @@ def convert_gemma_to_tflite(
|
|
78
75
|
)
|
79
76
|
.convert(quant_config=quant_config)
|
80
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
81
79
|
edge_model.export(
|
82
|
-
f'/tmp/
|
80
|
+
f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
83
81
|
)
|
84
82
|
|
85
83
|
|
86
84
|
if __name__ == '__main__':
|
87
|
-
|
88
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
|
86
|
+
convert_smallm_to_tflite(path)
|