ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +40 -24
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/layers/model_config.py +1 -0
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +33 -39
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -13,32 +13,35 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma2 model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
24
27
|
|
25
|
-
def
|
28
|
+
def convert_gemma2_to_tflite(
|
26
29
|
checkpoint_path: str,
|
27
30
|
prefill_seq_len: int = 512,
|
28
31
|
kv_cache_max_len: int = 1024,
|
29
32
|
quantize: bool = True,
|
30
33
|
):
|
31
|
-
"""
|
32
|
-
tflite model.
|
34
|
+
"""Converts a Gemma2 2B model to multi-signature tflite model.
|
33
35
|
|
34
36
|
Args:
|
35
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
37
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
+
holding the checkpoint.
|
36
39
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
37
40
|
Defaults to 512.
|
38
41
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
39
42
|
including both prefill and decode. Defaults to 1024.
|
40
|
-
quantize (bool, optional): Whether the model should be quanized.
|
41
|
-
|
43
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
+
to True.
|
42
45
|
"""
|
43
46
|
pytorch_model = gemma2.build_2b_model(
|
44
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
|
86
|
+
convert_gemma2_to_tflite(path)
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
convert_gemma_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
|
86
|
+
convert_gemma_to_tflite(path)
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -84,16 +86,22 @@ class Gemma(nn.Module):
|
|
84
86
|
)
|
85
87
|
self.config = config
|
86
88
|
|
87
|
-
# The model's forward function takes in additional k/v cache tensors
|
88
|
-
# and returns the updated k/v cache tensors to the caller.
|
89
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
90
89
|
@torch.inference_mode
|
91
|
-
def forward(
|
92
|
-
|
90
|
+
def forward(
|
91
|
+
self,
|
92
|
+
tokens: torch.Tensor,
|
93
|
+
input_pos: torch.Tensor,
|
94
|
+
kv_cache: kv_utils.KVCache,
|
95
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
96
|
+
_, seq_len = tokens.size()
|
93
97
|
assert self.config.max_seq_len >= seq_len, (
|
94
98
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
95
99
|
f" {self.config.max_seq_len}"
|
96
100
|
)
|
101
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
102
|
+
"The number of transformer blocks and the number of KV cache entries"
|
103
|
+
" must be the same."
|
104
|
+
)
|
97
105
|
|
98
106
|
cos, sin = self.rope_cache
|
99
107
|
cos = cos.index_select(0, input_pos)
|
@@ -102,15 +110,20 @@ class Gemma(nn.Module):
|
|
102
110
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
103
111
|
|
104
112
|
# token embeddings of shape (b, t, n_embd)
|
105
|
-
x = self.tok_embedding(
|
113
|
+
x = self.tok_embedding(tokens)
|
106
114
|
x = x * (self.config.embedding_dim**0.5)
|
107
115
|
|
108
|
-
|
109
|
-
|
116
|
+
updated_kv_entires = []
|
117
|
+
for i, block in enumerate(self.transformer_blocks):
|
118
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
119
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
120
|
+
if kv_entry:
|
121
|
+
updated_kv_entires.append(kv_entry)
|
122
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
110
123
|
|
111
124
|
x = self.final_norm(x)
|
112
|
-
|
113
|
-
return
|
125
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
126
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
114
127
|
|
115
128
|
|
116
129
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -177,25 +190,28 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
177
190
|
return model
|
178
191
|
|
179
192
|
|
180
|
-
def define_and_run_2b() -> None:
|
193
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
181
194
|
"""Instantiates and runs a Gemma 2B model."""
|
182
195
|
|
183
|
-
current_dir = Path(__file__).parent.resolve()
|
196
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
184
197
|
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
185
198
|
|
186
199
|
kv_cache_max_len = 1024
|
187
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
|
188
200
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
189
201
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
190
202
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
191
203
|
tokens[0, :4] = idx
|
192
204
|
input_pos = torch.arange(0, kv_cache_max_len)
|
193
|
-
|
205
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
206
|
+
output = model.forward(tokens, input_pos, kv)
|
194
207
|
print("comparing with goldens..")
|
195
208
|
assert torch.allclose(
|
196
|
-
gemma_goldens,
|
209
|
+
gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
197
210
|
)
|
198
211
|
|
199
212
|
|
200
213
|
if __name__ == "__main__":
|
201
|
-
|
214
|
+
input_checkpoint_path = os.path.join(
|
215
|
+
pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
|
216
|
+
)
|
217
|
+
define_and_run_2b(input_checkpoint_path)
|
@@ -12,14 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma2 model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
|
|
51
53
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
52
54
|
mask: Optional[torch.Tensor] = None,
|
53
55
|
input_pos: Optional[torch.Tensor] = None,
|
54
|
-
|
56
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
57
|
+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
|
55
58
|
"""Forward function of the Gemma2Block.
|
56
59
|
|
57
60
|
Exactly the same as TransformerBlock but we call the post-attention norm
|
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
|
|
62
65
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
63
66
|
mask (torch.Tensor): the optional mask tensor.
|
64
67
|
input_pos (torch.Tensor): the optional input position tensor.
|
68
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
65
69
|
|
66
70
|
Returns:
|
67
|
-
output activation from this transformer block
|
71
|
+
output activation from this transformer block, and updated kv cache (if
|
72
|
+
passed in).
|
68
73
|
"""
|
69
74
|
|
70
75
|
x_norm = self.pre_atten_norm(x)
|
71
|
-
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
|
76
|
+
attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
72
77
|
attn_out_norm = self.post_atten_norm(attn_out)
|
73
78
|
x = x + attn_out_norm
|
74
79
|
output = x + self.ff(x)
|
75
|
-
return output
|
80
|
+
return output, kv
|
76
81
|
|
77
82
|
|
78
83
|
class Gemma2(nn.Module):
|
@@ -138,24 +143,38 @@ class Gemma2(nn.Module):
|
|
138
143
|
return self.mask_cache.index_select(2, input_pos)
|
139
144
|
|
140
145
|
@torch.inference_mode
|
141
|
-
def forward(
|
142
|
-
|
146
|
+
def forward(
|
147
|
+
self,
|
148
|
+
tokens: torch.Tensor,
|
149
|
+
input_pos: torch.Tensor,
|
150
|
+
kv_cache: kv_utils.KVCache,
|
151
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
152
|
+
_, seq_len = tokens.size()
|
143
153
|
assert self.config.max_seq_len >= seq_len, (
|
144
154
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
145
155
|
f" {self.config.max_seq_len}"
|
146
156
|
)
|
157
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
158
|
+
"The number of transformer blocks and the number of KV cache entries"
|
159
|
+
" must be the same."
|
160
|
+
)
|
147
161
|
|
148
162
|
cos, sin = self.rope_cache
|
149
163
|
cos = cos.index_select(0, input_pos)
|
150
164
|
sin = sin.index_select(0, input_pos)
|
151
165
|
|
152
166
|
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(
|
167
|
+
x = self.tok_embedding(tokens)
|
154
168
|
x = x * (self.config.embedding_dim**0.5)
|
155
169
|
|
170
|
+
updated_kv_entires = []
|
156
171
|
for i, block in enumerate(self.transformer_blocks):
|
157
172
|
mask = self.get_attention_mask(i, input_pos)
|
158
|
-
|
173
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
174
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
175
|
+
if kv_entry:
|
176
|
+
updated_kv_entires.append(kv_entry)
|
177
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
159
178
|
|
160
179
|
x = self.final_norm(x)
|
161
180
|
res = self.lm_head(x) # (b, t, vocab_size)
|
@@ -163,7 +182,8 @@ class Gemma2(nn.Module):
|
|
163
182
|
res = res / self.config.final_logit_softcap
|
164
183
|
res = torch.tanh(res)
|
165
184
|
res = res * self.config.final_logit_softcap
|
166
|
-
|
185
|
+
|
186
|
+
return {"logits": res, "kv_cache": updated_kv_cache}
|
167
187
|
|
168
188
|
|
169
189
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -243,14 +263,13 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
243
263
|
return model
|
244
264
|
|
245
265
|
|
246
|
-
def define_and_run_2b() -> None:
|
266
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
247
267
|
"""Instantiates and runs a Gemma2 2B model."""
|
248
268
|
|
249
|
-
current_dir = Path(__file__).parent.resolve()
|
269
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
250
270
|
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
251
271
|
print("Running GEMMA 2")
|
252
272
|
kv_cache_max_len = 1024
|
253
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
|
254
273
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
255
274
|
toks = torch.from_numpy(
|
256
275
|
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
@@ -258,11 +277,13 @@ def define_and_run_2b() -> None:
|
|
258
277
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
259
278
|
tokens[0, :9] = toks
|
260
279
|
input_pos = torch.arange(0, kv_cache_max_len)
|
261
|
-
|
262
|
-
|
280
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
281
|
+
out = model.forward(tokens, input_pos, kv)
|
282
|
+
out_final = out["logits"][0, 8, :]
|
263
283
|
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
264
284
|
|
265
285
|
|
266
286
|
if __name__ == "__main__":
|
267
287
|
torch.set_printoptions(sci_mode=True)
|
268
|
-
|
288
|
+
path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
|
289
|
+
define_and_run_2b(path)
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of converting a Phi-2 model to multi-signature tflite model."""
|
18
17
|
|
19
18
|
import os
|
20
|
-
|
19
|
+
import pathlib
|
21
20
|
|
22
21
|
import ai_edge_torch
|
23
|
-
from ai_edge_torch.generative.examples.
|
24
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
25
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
26
25
|
import torch
|
27
26
|
|
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
|
|
32
31
|
kv_cache_max_len: int = 1024,
|
33
32
|
quantize: bool = True,
|
34
33
|
):
|
35
|
-
"""
|
34
|
+
"""Converts a Phi-2 model to multi-signature tflite model.
|
36
35
|
|
37
|
-
tflite model.
|
38
36
|
Args:
|
39
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
40
38
|
holding the checkpoint.
|
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
|
|
53
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
54
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
55
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
56
|
-
kv =
|
54
|
+
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
57
55
|
|
58
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
59
57
|
edge_model = (
|
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
|
|
77
75
|
)
|
78
76
|
.convert(quant_config=quant_config)
|
79
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
80
79
|
edge_model.export(
|
81
|
-
f'/tmp/
|
80
|
+
f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
82
81
|
)
|
83
82
|
|
84
83
|
|
85
84
|
if __name__ == '__main__':
|
86
|
-
|
87
|
-
convert_phi2_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
|
86
|
+
convert_phi2_to_tflite(path)
|
@@ -12,26 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Note: This is an experimental version of phi2 with external KV cache.
|
18
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of building a Phi-2 model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
22
|
-
from typing import Tuple
|
19
|
+
import pathlib
|
23
20
|
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
24
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
27
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
27
|
import numpy as np
|
31
28
|
import torch
|
32
29
|
from torch import nn
|
33
30
|
|
34
|
-
|
35
31
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
32
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
37
33
|
ff_down_proj="model.layers.{}.mlp.fc2",
|
@@ -89,13 +85,17 @@ class Phi2(nn.Module):
|
|
89
85
|
self,
|
90
86
|
tokens: torch.Tensor,
|
91
87
|
input_pos: torch.Tensor,
|
92
|
-
kv_cache: kv_utils.
|
93
|
-
) ->
|
88
|
+
kv_cache: kv_utils.KVCache,
|
89
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
94
90
|
_, seq_len = tokens.size()
|
95
91
|
assert self.config.max_seq_len >= seq_len, (
|
96
92
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
97
93
|
f" {self.config.max_seq_len}"
|
98
94
|
)
|
95
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
96
|
+
"The number of transformer blocks and the number of KV cache entries"
|
97
|
+
" must be the same."
|
98
|
+
)
|
99
99
|
|
100
100
|
cos, sin = self.rope_cache
|
101
101
|
cos = cos.index_select(0, input_pos)
|
@@ -111,11 +111,11 @@ class Phi2(nn.Module):
|
|
111
111
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
112
112
|
if kv_entry:
|
113
113
|
updated_kv_entires.append(kv_entry)
|
114
|
-
updated_kv_cache = kv_utils.
|
114
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
115
115
|
|
116
116
|
x = self.final_norm(x)
|
117
|
-
|
118
|
-
return
|
117
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
118
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
119
119
|
|
120
120
|
|
121
121
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -169,39 +169,37 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
169
169
|
return config
|
170
170
|
|
171
171
|
|
172
|
-
def build_model(
|
173
|
-
checkpoint_path: str, test_model: bool = False, **kwargs
|
174
|
-
) -> nn.Module:
|
172
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
175
173
|
"""Instantiates the model instance and load checkpoint if provided."""
|
176
|
-
config = (
|
177
|
-
get_fake_model_config(**kwargs)
|
178
|
-
if test_model
|
179
|
-
else get_model_config(**kwargs)
|
180
|
-
)
|
174
|
+
config = get_model_config(**kwargs)
|
181
175
|
model = Phi2(config)
|
182
|
-
|
183
|
-
|
184
|
-
loader.load(model)
|
176
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
177
|
+
loader.load(model)
|
185
178
|
model.eval()
|
186
179
|
return model
|
187
180
|
|
188
181
|
|
189
|
-
def define_and_run(checkpoint_path: str
|
182
|
+
def define_and_run(checkpoint_path: str) -> None:
|
190
183
|
"""Instantiates and runs a Phi-2 model."""
|
191
184
|
|
185
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
186
|
+
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
192
187
|
kv_cache_max_len = 1024
|
193
|
-
model = build_model(
|
194
|
-
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
195
|
-
)
|
188
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
196
189
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
197
190
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
198
191
|
tokens[0, :4] = idx
|
199
192
|
input_pos = torch.arange(0, kv_cache_max_len)
|
200
|
-
kv = kv_utils.
|
201
|
-
|
202
|
-
print(
|
193
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
194
|
+
output = model.forward(tokens, input_pos, kv)
|
195
|
+
print("comparing with goldens..")
|
196
|
+
assert torch.allclose(
|
197
|
+
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
198
|
+
)
|
203
199
|
|
204
200
|
|
205
201
|
if __name__ == "__main__":
|
206
|
-
input_checkpoint_path = os.path.join(
|
202
|
+
input_checkpoint_path = os.path.join(
|
203
|
+
pathlib.Path.home(), "Downloads/llm_data/phi2"
|
204
|
+
)
|
207
205
|
define_and_run(input_checkpoint_path)
|
@@ -12,30 +12,27 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
#
|
16
|
-
# Note: This is an experimental version of Gemma with external KV cache.
|
17
|
-
# Please use with caution.
|
18
15
|
|
16
|
+
"""Example of converting SmalLM model to multi-signature tflite model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
19
|
+
import pathlib
|
22
20
|
|
23
21
|
import ai_edge_torch
|
24
|
-
from ai_edge_torch.generative.examples.
|
25
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.smallm import smallm
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
27
25
|
import torch
|
28
26
|
|
29
27
|
|
30
|
-
def
|
28
|
+
def convert_smallm_to_tflite(
|
31
29
|
checkpoint_path: str,
|
32
30
|
prefill_seq_len: int = 512,
|
33
31
|
kv_cache_max_len: int = 1024,
|
34
32
|
quantize: bool = True,
|
35
33
|
):
|
36
|
-
"""
|
34
|
+
"""Converts SmalLM model to multi-signature tflite model.
|
37
35
|
|
38
|
-
tflite model.
|
39
36
|
Args:
|
40
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
41
38
|
holding the checkpoint.
|
@@ -46,7 +43,7 @@ def convert_gemma_to_tflite(
|
|
46
43
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
47
44
|
to True.
|
48
45
|
"""
|
49
|
-
pytorch_model =
|
46
|
+
pytorch_model = smallm.build_model(
|
50
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
51
48
|
)
|
52
49
|
# Tensors used to trace the model graph during conversion.
|
@@ -54,7 +51,7 @@ def convert_gemma_to_tflite(
|
|
54
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
55
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
56
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
57
|
-
kv = kv_utils.
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
58
55
|
|
59
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
60
57
|
edge_model = (
|
@@ -78,11 +75,12 @@ def convert_gemma_to_tflite(
|
|
78
75
|
)
|
79
76
|
.convert(quant_config=quant_config)
|
80
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
81
79
|
edge_model.export(
|
82
|
-
f'/tmp/
|
80
|
+
f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
83
81
|
)
|
84
82
|
|
85
83
|
|
86
84
|
if __name__ == '__main__':
|
87
|
-
|
88
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
|
86
|
+
convert_smallm_to_tflite(path)
|