ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +40 -24
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/layers/model_config.py +1 -0
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +33 -39
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -13,32 +13,35 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma2 model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
24
27
|
|
25
|
-
def
|
28
|
+
def convert_gemma2_to_tflite(
|
26
29
|
checkpoint_path: str,
|
27
30
|
prefill_seq_len: int = 512,
|
28
31
|
kv_cache_max_len: int = 1024,
|
29
32
|
quantize: bool = True,
|
30
33
|
):
|
31
|
-
"""
|
32
|
-
tflite model.
|
34
|
+
"""Converts a Gemma2 2B model to multi-signature tflite model.
|
33
35
|
|
34
36
|
Args:
|
35
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
37
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
+
holding the checkpoint.
|
36
39
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
37
40
|
Defaults to 512.
|
38
41
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
39
42
|
including both prefill and decode. Defaults to 1024.
|
40
|
-
quantize (bool, optional): Whether the model should be quanized.
|
41
|
-
|
43
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
+
to True.
|
42
45
|
"""
|
43
46
|
pytorch_model = gemma2.build_2b_model(
|
44
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
|
86
|
+
convert_gemma2_to_tflite(path)
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
convert_gemma_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
|
86
|
+
convert_gemma_to_tflite(path)
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -84,16 +86,22 @@ class Gemma(nn.Module):
|
|
84
86
|
)
|
85
87
|
self.config = config
|
86
88
|
|
87
|
-
# The model's forward function takes in additional k/v cache tensors
|
88
|
-
# and returns the updated k/v cache tensors to the caller.
|
89
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
90
89
|
@torch.inference_mode
|
91
|
-
def forward(
|
92
|
-
|
90
|
+
def forward(
|
91
|
+
self,
|
92
|
+
tokens: torch.Tensor,
|
93
|
+
input_pos: torch.Tensor,
|
94
|
+
kv_cache: kv_utils.KVCache,
|
95
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
96
|
+
_, seq_len = tokens.size()
|
93
97
|
assert self.config.max_seq_len >= seq_len, (
|
94
98
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
95
99
|
f" {self.config.max_seq_len}"
|
96
100
|
)
|
101
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
102
|
+
"The number of transformer blocks and the number of KV cache entries"
|
103
|
+
" must be the same."
|
104
|
+
)
|
97
105
|
|
98
106
|
cos, sin = self.rope_cache
|
99
107
|
cos = cos.index_select(0, input_pos)
|
@@ -102,15 +110,20 @@ class Gemma(nn.Module):
|
|
102
110
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
103
111
|
|
104
112
|
# token embeddings of shape (b, t, n_embd)
|
105
|
-
x = self.tok_embedding(
|
113
|
+
x = self.tok_embedding(tokens)
|
106
114
|
x = x * (self.config.embedding_dim**0.5)
|
107
115
|
|
108
|
-
|
109
|
-
|
116
|
+
updated_kv_entires = []
|
117
|
+
for i, block in enumerate(self.transformer_blocks):
|
118
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
119
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
120
|
+
if kv_entry:
|
121
|
+
updated_kv_entires.append(kv_entry)
|
122
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
110
123
|
|
111
124
|
x = self.final_norm(x)
|
112
|
-
|
113
|
-
return
|
125
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
126
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
114
127
|
|
115
128
|
|
116
129
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -177,25 +190,28 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
177
190
|
return model
|
178
191
|
|
179
192
|
|
180
|
-
def define_and_run_2b() -> None:
|
193
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
181
194
|
"""Instantiates and runs a Gemma 2B model."""
|
182
195
|
|
183
|
-
current_dir = Path(__file__).parent.resolve()
|
196
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
184
197
|
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
185
198
|
|
186
199
|
kv_cache_max_len = 1024
|
187
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
|
188
200
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
189
201
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
190
202
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
191
203
|
tokens[0, :4] = idx
|
192
204
|
input_pos = torch.arange(0, kv_cache_max_len)
|
193
|
-
|
205
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
206
|
+
output = model.forward(tokens, input_pos, kv)
|
194
207
|
print("comparing with goldens..")
|
195
208
|
assert torch.allclose(
|
196
|
-
gemma_goldens,
|
209
|
+
gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
197
210
|
)
|
198
211
|
|
199
212
|
|
200
213
|
if __name__ == "__main__":
|
201
|
-
|
214
|
+
input_checkpoint_path = os.path.join(
|
215
|
+
pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
|
216
|
+
)
|
217
|
+
define_and_run_2b(input_checkpoint_path)
|
@@ -12,14 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma2 model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
|
|
51
53
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
52
54
|
mask: Optional[torch.Tensor] = None,
|
53
55
|
input_pos: Optional[torch.Tensor] = None,
|
54
|
-
|
56
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
57
|
+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
|
55
58
|
"""Forward function of the Gemma2Block.
|
56
59
|
|
57
60
|
Exactly the same as TransformerBlock but we call the post-attention norm
|
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
|
|
62
65
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
63
66
|
mask (torch.Tensor): the optional mask tensor.
|
64
67
|
input_pos (torch.Tensor): the optional input position tensor.
|
68
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
65
69
|
|
66
70
|
Returns:
|
67
|
-
output activation from this transformer block
|
71
|
+
output activation from this transformer block, and updated kv cache (if
|
72
|
+
passed in).
|
68
73
|
"""
|
69
74
|
|
70
75
|
x_norm = self.pre_atten_norm(x)
|
71
|
-
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
|
76
|
+
attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
72
77
|
attn_out_norm = self.post_atten_norm(attn_out)
|
73
78
|
x = x + attn_out_norm
|
74
79
|
output = x + self.ff(x)
|
75
|
-
return output
|
80
|
+
return output, kv
|
76
81
|
|
77
82
|
|
78
83
|
class Gemma2(nn.Module):
|
@@ -138,24 +143,38 @@ class Gemma2(nn.Module):
|
|
138
143
|
return self.mask_cache.index_select(2, input_pos)
|
139
144
|
|
140
145
|
@torch.inference_mode
|
141
|
-
def forward(
|
142
|
-
|
146
|
+
def forward(
|
147
|
+
self,
|
148
|
+
tokens: torch.Tensor,
|
149
|
+
input_pos: torch.Tensor,
|
150
|
+
kv_cache: kv_utils.KVCache,
|
151
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
152
|
+
_, seq_len = tokens.size()
|
143
153
|
assert self.config.max_seq_len >= seq_len, (
|
144
154
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
145
155
|
f" {self.config.max_seq_len}"
|
146
156
|
)
|
157
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
158
|
+
"The number of transformer blocks and the number of KV cache entries"
|
159
|
+
" must be the same."
|
160
|
+
)
|
147
161
|
|
148
162
|
cos, sin = self.rope_cache
|
149
163
|
cos = cos.index_select(0, input_pos)
|
150
164
|
sin = sin.index_select(0, input_pos)
|
151
165
|
|
152
166
|
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(
|
167
|
+
x = self.tok_embedding(tokens)
|
154
168
|
x = x * (self.config.embedding_dim**0.5)
|
155
169
|
|
170
|
+
updated_kv_entires = []
|
156
171
|
for i, block in enumerate(self.transformer_blocks):
|
157
172
|
mask = self.get_attention_mask(i, input_pos)
|
158
|
-
|
173
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
174
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
175
|
+
if kv_entry:
|
176
|
+
updated_kv_entires.append(kv_entry)
|
177
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
159
178
|
|
160
179
|
x = self.final_norm(x)
|
161
180
|
res = self.lm_head(x) # (b, t, vocab_size)
|
@@ -163,7 +182,8 @@ class Gemma2(nn.Module):
|
|
163
182
|
res = res / self.config.final_logit_softcap
|
164
183
|
res = torch.tanh(res)
|
165
184
|
res = res * self.config.final_logit_softcap
|
166
|
-
|
185
|
+
|
186
|
+
return {"logits": res, "kv_cache": updated_kv_cache}
|
167
187
|
|
168
188
|
|
169
189
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -243,14 +263,13 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
243
263
|
return model
|
244
264
|
|
245
265
|
|
246
|
-
def define_and_run_2b() -> None:
|
266
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
247
267
|
"""Instantiates and runs a Gemma2 2B model."""
|
248
268
|
|
249
|
-
current_dir = Path(__file__).parent.resolve()
|
269
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
250
270
|
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
251
271
|
print("Running GEMMA 2")
|
252
272
|
kv_cache_max_len = 1024
|
253
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
|
254
273
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
255
274
|
toks = torch.from_numpy(
|
256
275
|
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
@@ -258,11 +277,13 @@ def define_and_run_2b() -> None:
|
|
258
277
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
259
278
|
tokens[0, :9] = toks
|
260
279
|
input_pos = torch.arange(0, kv_cache_max_len)
|
261
|
-
|
262
|
-
|
280
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
281
|
+
out = model.forward(tokens, input_pos, kv)
|
282
|
+
out_final = out["logits"][0, 8, :]
|
263
283
|
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
264
284
|
|
265
285
|
|
266
286
|
if __name__ == "__main__":
|
267
287
|
torch.set_printoptions(sci_mode=True)
|
268
|
-
|
288
|
+
path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
|
289
|
+
define_and_run_2b(path)
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of converting a Phi-2 model to multi-signature tflite model."""
|
18
17
|
|
19
18
|
import os
|
20
|
-
|
19
|
+
import pathlib
|
21
20
|
|
22
21
|
import ai_edge_torch
|
23
|
-
from ai_edge_torch.generative.examples.
|
24
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
25
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
26
25
|
import torch
|
27
26
|
|
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
|
|
32
31
|
kv_cache_max_len: int = 1024,
|
33
32
|
quantize: bool = True,
|
34
33
|
):
|
35
|
-
"""
|
34
|
+
"""Converts a Phi-2 model to multi-signature tflite model.
|
36
35
|
|
37
|
-
tflite model.
|
38
36
|
Args:
|
39
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
40
38
|
holding the checkpoint.
|
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
|
|
53
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
54
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
55
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
56
|
-
kv =
|
54
|
+
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
57
55
|
|
58
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
59
57
|
edge_model = (
|
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
|
|
77
75
|
)
|
78
76
|
.convert(quant_config=quant_config)
|
79
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
80
79
|
edge_model.export(
|
81
|
-
f'/tmp/
|
80
|
+
f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
82
81
|
)
|
83
82
|
|
84
83
|
|
85
84
|
if __name__ == '__main__':
|
86
|
-
|
87
|
-
convert_phi2_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
|
86
|
+
convert_phi2_to_tflite(path)
|
@@ -12,26 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Note: This is an experimental version of phi2 with external KV cache.
|
18
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of building a Phi-2 model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
22
|
-
from typing import Tuple
|
19
|
+
import pathlib
|
23
20
|
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
24
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
27
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
27
|
import numpy as np
|
31
28
|
import torch
|
32
29
|
from torch import nn
|
33
30
|
|
34
|
-
|
35
31
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
32
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
37
33
|
ff_down_proj="model.layers.{}.mlp.fc2",
|
@@ -89,13 +85,17 @@ class Phi2(nn.Module):
|
|
89
85
|
self,
|
90
86
|
tokens: torch.Tensor,
|
91
87
|
input_pos: torch.Tensor,
|
92
|
-
kv_cache: kv_utils.
|
93
|
-
) ->
|
88
|
+
kv_cache: kv_utils.KVCache,
|
89
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
94
90
|
_, seq_len = tokens.size()
|
95
91
|
assert self.config.max_seq_len >= seq_len, (
|
96
92
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
97
93
|
f" {self.config.max_seq_len}"
|
98
94
|
)
|
95
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
96
|
+
"The number of transformer blocks and the number of KV cache entries"
|
97
|
+
" must be the same."
|
98
|
+
)
|
99
99
|
|
100
100
|
cos, sin = self.rope_cache
|
101
101
|
cos = cos.index_select(0, input_pos)
|
@@ -111,11 +111,11 @@ class Phi2(nn.Module):
|
|
111
111
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
112
112
|
if kv_entry:
|
113
113
|
updated_kv_entires.append(kv_entry)
|
114
|
-
updated_kv_cache = kv_utils.
|
114
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
115
115
|
|
116
116
|
x = self.final_norm(x)
|
117
|
-
|
118
|
-
return
|
117
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
118
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
119
119
|
|
120
120
|
|
121
121
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -169,39 +169,37 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
169
169
|
return config
|
170
170
|
|
171
171
|
|
172
|
-
def build_model(
|
173
|
-
checkpoint_path: str, test_model: bool = False, **kwargs
|
174
|
-
) -> nn.Module:
|
172
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
175
173
|
"""Instantiates the model instance and load checkpoint if provided."""
|
176
|
-
config = (
|
177
|
-
get_fake_model_config(**kwargs)
|
178
|
-
if test_model
|
179
|
-
else get_model_config(**kwargs)
|
180
|
-
)
|
174
|
+
config = get_model_config(**kwargs)
|
181
175
|
model = Phi2(config)
|
182
|
-
|
183
|
-
|
184
|
-
loader.load(model)
|
176
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
177
|
+
loader.load(model)
|
185
178
|
model.eval()
|
186
179
|
return model
|
187
180
|
|
188
181
|
|
189
|
-
def define_and_run(checkpoint_path: str
|
182
|
+
def define_and_run(checkpoint_path: str) -> None:
|
190
183
|
"""Instantiates and runs a Phi-2 model."""
|
191
184
|
|
185
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
186
|
+
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
192
187
|
kv_cache_max_len = 1024
|
193
|
-
model = build_model(
|
194
|
-
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
195
|
-
)
|
188
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
196
189
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
197
190
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
198
191
|
tokens[0, :4] = idx
|
199
192
|
input_pos = torch.arange(0, kv_cache_max_len)
|
200
|
-
kv = kv_utils.
|
201
|
-
|
202
|
-
print(
|
193
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
194
|
+
output = model.forward(tokens, input_pos, kv)
|
195
|
+
print("comparing with goldens..")
|
196
|
+
assert torch.allclose(
|
197
|
+
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
198
|
+
)
|
203
199
|
|
204
200
|
|
205
201
|
if __name__ == "__main__":
|
206
|
-
input_checkpoint_path = os.path.join(
|
202
|
+
input_checkpoint_path = os.path.join(
|
203
|
+
pathlib.Path.home(), "Downloads/llm_data/phi2"
|
204
|
+
)
|
207
205
|
define_and_run(input_checkpoint_path)
|
@@ -12,30 +12,27 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
#
|
16
|
-
# Note: This is an experimental version of Gemma with external KV cache.
|
17
|
-
# Please use with caution.
|
18
15
|
|
16
|
+
"""Example of converting SmalLM model to multi-signature tflite model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
19
|
+
import pathlib
|
22
20
|
|
23
21
|
import ai_edge_torch
|
24
|
-
from ai_edge_torch.generative.examples.
|
25
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.smallm import smallm
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
27
25
|
import torch
|
28
26
|
|
29
27
|
|
30
|
-
def
|
28
|
+
def convert_smallm_to_tflite(
|
31
29
|
checkpoint_path: str,
|
32
30
|
prefill_seq_len: int = 512,
|
33
31
|
kv_cache_max_len: int = 1024,
|
34
32
|
quantize: bool = True,
|
35
33
|
):
|
36
|
-
"""
|
34
|
+
"""Converts SmalLM model to multi-signature tflite model.
|
37
35
|
|
38
|
-
tflite model.
|
39
36
|
Args:
|
40
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
41
38
|
holding the checkpoint.
|
@@ -46,7 +43,7 @@ def convert_gemma_to_tflite(
|
|
46
43
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
47
44
|
to True.
|
48
45
|
"""
|
49
|
-
pytorch_model =
|
46
|
+
pytorch_model = smallm.build_model(
|
50
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
51
48
|
)
|
52
49
|
# Tensors used to trace the model graph during conversion.
|
@@ -54,7 +51,7 @@ def convert_gemma_to_tflite(
|
|
54
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
55
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
56
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
57
|
-
kv = kv_utils.
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
58
55
|
|
59
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
60
57
|
edge_model = (
|
@@ -78,11 +75,12 @@ def convert_gemma_to_tflite(
|
|
78
75
|
)
|
79
76
|
.convert(quant_config=quant_config)
|
80
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
81
79
|
edge_model.export(
|
82
|
-
f'/tmp/
|
80
|
+
f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
83
81
|
)
|
84
82
|
|
85
83
|
|
86
84
|
if __name__ == '__main__':
|
87
|
-
|
88
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
|
86
|
+
convert_smallm_to_tflite(path)
|