ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +35 -13
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_model_conversion.py +71 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +25 -35
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from ai_edge_torch.testing import model_coverage
|
|
25
25
|
import numpy as np
|
26
26
|
import tensorflow as tf
|
27
27
|
import torch
|
28
|
+
from torch import nn
|
28
29
|
import torchvision
|
29
30
|
|
30
31
|
from absl.testing import absltest as googletest
|
@@ -51,7 +52,7 @@ class TestConvert(googletest.TestCase):
|
|
51
52
|
def test_convert_add(self):
|
52
53
|
"""Tests conversion of a simple Add module."""
|
53
54
|
|
54
|
-
class Add(
|
55
|
+
class Add(nn.Module):
|
55
56
|
|
56
57
|
def forward(self, a, b):
|
57
58
|
return a + b
|
@@ -70,7 +71,7 @@ class TestConvert(googletest.TestCase):
|
|
70
71
|
def test_convert_dot_add(self):
|
71
72
|
"""Tests conversion of a matrix multiplication followed by an add."""
|
72
73
|
|
73
|
-
class DotAdd(
|
74
|
+
class DotAdd(nn.Module):
|
74
75
|
|
75
76
|
def forward(self, a, b, c):
|
76
77
|
return a @ b + c
|
@@ -99,7 +100,7 @@ class TestConvert(googletest.TestCase):
|
|
99
100
|
def test_signature_args_ordering(self):
|
100
101
|
"""Tests conversion of a model with more than 10 arguments."""
|
101
102
|
|
102
|
-
class AddChainWith11Args(
|
103
|
+
class AddChainWith11Args(nn.Module):
|
103
104
|
"""A model with 11 arguments."""
|
104
105
|
|
105
106
|
def forward(
|
@@ -152,7 +153,7 @@ class TestConvert(googletest.TestCase):
|
|
152
153
|
def test_multi_output_model(self):
|
153
154
|
"""Tests conversion of a model that returns multiple outputs."""
|
154
155
|
|
155
|
-
class BasicAddModelWithMultipleOutputs(
|
156
|
+
class BasicAddModelWithMultipleOutputs(nn.Module):
|
156
157
|
"""A model that returns multiple outputs."""
|
157
158
|
|
158
159
|
def forward(self, arg0, arg1):
|
@@ -176,7 +177,7 @@ class TestConvert(googletest.TestCase):
|
|
176
177
|
def test_12_outputs_model(self):
|
177
178
|
"""Tests conversion of a model that returns more than 10 outputs."""
|
178
179
|
|
179
|
-
class BasicAddModelWithMultipleOutputs(
|
180
|
+
class BasicAddModelWithMultipleOutputs(nn.Module):
|
180
181
|
"""A model that returns multiple outputs."""
|
181
182
|
|
182
183
|
def forward(self, arg0, arg1):
|
@@ -245,7 +246,7 @@ class TestConvert(googletest.TestCase):
|
|
245
246
|
def test_convert_add_converter_flags(self):
|
246
247
|
"""Tests conversion of an add module setting a tflite converter flag."""
|
247
248
|
|
248
|
-
class Add(
|
249
|
+
class Add(nn.Module):
|
249
250
|
|
250
251
|
def forward(self, a, b):
|
251
252
|
return a + b
|
@@ -267,6 +268,27 @@ class TestConvert(googletest.TestCase):
|
|
267
268
|
)
|
268
269
|
self.assertTrue(os.path.isdir(ir_dump_path))
|
269
270
|
|
271
|
+
def test_convert_conv_transpose_batch_norm(self):
|
272
|
+
"""Tests conversion of a model with ConvTranspose2d and BatchNorm2d."""
|
273
|
+
|
274
|
+
channels = 2
|
275
|
+
size = 2
|
276
|
+
torch_model = nn.Sequential(
|
277
|
+
nn.ConvTranspose2d(
|
278
|
+
channels, channels, 1, stride=2, dilation=1, bias=False
|
279
|
+
),
|
280
|
+
nn.BatchNorm2d(channels),
|
281
|
+
)
|
282
|
+
|
283
|
+
torch_model.eval()
|
284
|
+
sample_input = (torch.rand(1, channels, size, size),)
|
285
|
+
edge_model = ai_edge_torch.convert(torch_model, sample_input)
|
286
|
+
|
287
|
+
result = model_coverage.compare_tflite_torch(
|
288
|
+
edge_model, torch_model, sample_input
|
289
|
+
)
|
290
|
+
self.assertTrue(result)
|
291
|
+
|
270
292
|
@googletest.skipIf(
|
271
293
|
not config.Config.use_torch_xla,
|
272
294
|
reason="Shape polymorphism is not yet support with odml_torch.",
|
@@ -274,7 +296,7 @@ class TestConvert(googletest.TestCase):
|
|
274
296
|
def test_convert_model_with_dynamic_batch(self):
|
275
297
|
"""Test converting a simple model with dynamic batch size."""
|
276
298
|
|
277
|
-
class SampleModel(
|
299
|
+
class SampleModel(nn.Module):
|
278
300
|
|
279
301
|
def __init__(self):
|
280
302
|
super().__init__()
|
@@ -304,7 +326,7 @@ class TestConvert(googletest.TestCase):
|
|
304
326
|
def test_convert_model_with_kwargs(self):
|
305
327
|
"""Test converting a simple model with sample_kwargs."""
|
306
328
|
|
307
|
-
class SampleModel(
|
329
|
+
class SampleModel(nn.Module):
|
308
330
|
|
309
331
|
def forward(self, x, y):
|
310
332
|
return x + y
|
@@ -323,7 +345,7 @@ class TestConvert(googletest.TestCase):
|
|
323
345
|
def test_convert_model_with_args_kwargs(self):
|
324
346
|
"""Test converting a simple model with both sample_args and sample_kwargs."""
|
325
347
|
|
326
|
-
class SampleModel(
|
348
|
+
class SampleModel(nn.Module):
|
327
349
|
|
328
350
|
def forward(self, x, y):
|
329
351
|
return x + y
|
@@ -343,7 +365,7 @@ class TestConvert(googletest.TestCase):
|
|
343
365
|
def test_convert_model_with_args_nested_kwargs_1(self):
|
344
366
|
"""Test converting a simple model with both sample_args and nested sample_kwargs."""
|
345
367
|
|
346
|
-
class SampleModel(
|
368
|
+
class SampleModel(nn.Module):
|
347
369
|
|
348
370
|
def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
|
349
371
|
return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
|
@@ -370,7 +392,7 @@ class TestConvert(googletest.TestCase):
|
|
370
392
|
def test_convert_model_with_args_nested_kwargs_2(self):
|
371
393
|
"""Test converting a simple model with both sample_args and nested sample_kwargs."""
|
372
394
|
|
373
|
-
class SampleModel(
|
395
|
+
class SampleModel(nn.Module):
|
374
396
|
|
375
397
|
def forward(self, x, y, z):
|
376
398
|
return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
|
@@ -397,7 +419,7 @@ class TestConvert(googletest.TestCase):
|
|
397
419
|
def test_convert_model_with_args_nested_kwargs_3(self):
|
398
420
|
"""Test converting a simple model with both sample_args and nested sample_kwargs."""
|
399
421
|
|
400
|
-
class SampleModel(
|
422
|
+
class SampleModel(nn.Module):
|
401
423
|
|
402
424
|
def forward(self, x, y, z):
|
403
425
|
return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
|
@@ -424,7 +446,7 @@ class TestConvert(googletest.TestCase):
|
|
424
446
|
def test_convert_model_non_flat_output_dict(self):
|
425
447
|
"""Test converting a model with non-flat output structure."""
|
426
448
|
|
427
|
-
class SampleModel(
|
449
|
+
class SampleModel(nn.Module):
|
428
450
|
|
429
451
|
def forward(self, x, y, z):
|
430
452
|
return {"x": x, "y": TestContainer1(data_1=y, data_2=[y, z])}
|
@@ -13,32 +13,35 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma2 model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
24
27
|
|
25
|
-
def
|
28
|
+
def convert_gemma2_to_tflite(
|
26
29
|
checkpoint_path: str,
|
27
30
|
prefill_seq_len: int = 512,
|
28
31
|
kv_cache_max_len: int = 1024,
|
29
32
|
quantize: bool = True,
|
30
33
|
):
|
31
|
-
"""
|
32
|
-
tflite model.
|
34
|
+
"""Converts a Gemma2 2B model to multi-signature tflite model.
|
33
35
|
|
34
36
|
Args:
|
35
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
37
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
+
holding the checkpoint.
|
36
39
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
37
40
|
Defaults to 512.
|
38
41
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
39
42
|
including both prefill and decode. Defaults to 1024.
|
40
|
-
quantize (bool, optional): Whether the model should be quanized.
|
41
|
-
|
43
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
+
to True.
|
42
45
|
"""
|
43
46
|
pytorch_model = gemma2.build_2b_model(
|
44
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
|
86
|
+
convert_gemma2_to_tflite(path)
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
convert_gemma_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
|
86
|
+
convert_gemma_to_tflite(path)
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -84,16 +86,22 @@ class Gemma(nn.Module):
|
|
84
86
|
)
|
85
87
|
self.config = config
|
86
88
|
|
87
|
-
# The model's forward function takes in additional k/v cache tensors
|
88
|
-
# and returns the updated k/v cache tensors to the caller.
|
89
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
90
89
|
@torch.inference_mode
|
91
|
-
def forward(
|
92
|
-
|
90
|
+
def forward(
|
91
|
+
self,
|
92
|
+
tokens: torch.Tensor,
|
93
|
+
input_pos: torch.Tensor,
|
94
|
+
kv_cache: kv_utils.KVCache,
|
95
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
96
|
+
_, seq_len = tokens.size()
|
93
97
|
assert self.config.max_seq_len >= seq_len, (
|
94
98
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
95
99
|
f" {self.config.max_seq_len}"
|
96
100
|
)
|
101
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
102
|
+
"The number of transformer blocks and the number of KV cache entries"
|
103
|
+
" must be the same."
|
104
|
+
)
|
97
105
|
|
98
106
|
cos, sin = self.rope_cache
|
99
107
|
cos = cos.index_select(0, input_pos)
|
@@ -102,15 +110,20 @@ class Gemma(nn.Module):
|
|
102
110
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
103
111
|
|
104
112
|
# token embeddings of shape (b, t, n_embd)
|
105
|
-
x = self.tok_embedding(
|
113
|
+
x = self.tok_embedding(tokens)
|
106
114
|
x = x * (self.config.embedding_dim**0.5)
|
107
115
|
|
108
|
-
|
109
|
-
|
116
|
+
updated_kv_entires = []
|
117
|
+
for i, block in enumerate(self.transformer_blocks):
|
118
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
119
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
120
|
+
if kv_entry:
|
121
|
+
updated_kv_entires.append(kv_entry)
|
122
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
110
123
|
|
111
124
|
x = self.final_norm(x)
|
112
|
-
|
113
|
-
return
|
125
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
126
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
114
127
|
|
115
128
|
|
116
129
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -177,25 +190,28 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
177
190
|
return model
|
178
191
|
|
179
192
|
|
180
|
-
def define_and_run_2b() -> None:
|
193
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
181
194
|
"""Instantiates and runs a Gemma 2B model."""
|
182
195
|
|
183
|
-
current_dir = Path(__file__).parent.resolve()
|
196
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
184
197
|
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
185
198
|
|
186
199
|
kv_cache_max_len = 1024
|
187
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
|
188
200
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
189
201
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
190
202
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
191
203
|
tokens[0, :4] = idx
|
192
204
|
input_pos = torch.arange(0, kv_cache_max_len)
|
193
|
-
|
205
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
206
|
+
output = model.forward(tokens, input_pos, kv)
|
194
207
|
print("comparing with goldens..")
|
195
208
|
assert torch.allclose(
|
196
|
-
gemma_goldens,
|
209
|
+
gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
197
210
|
)
|
198
211
|
|
199
212
|
|
200
213
|
if __name__ == "__main__":
|
201
|
-
|
214
|
+
input_checkpoint_path = os.path.join(
|
215
|
+
pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
|
216
|
+
)
|
217
|
+
define_and_run_2b(input_checkpoint_path)
|
@@ -12,14 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma2 model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
|
|
51
53
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
52
54
|
mask: Optional[torch.Tensor] = None,
|
53
55
|
input_pos: Optional[torch.Tensor] = None,
|
54
|
-
|
56
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
57
|
+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
|
55
58
|
"""Forward function of the Gemma2Block.
|
56
59
|
|
57
60
|
Exactly the same as TransformerBlock but we call the post-attention norm
|
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
|
|
62
65
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
63
66
|
mask (torch.Tensor): the optional mask tensor.
|
64
67
|
input_pos (torch.Tensor): the optional input position tensor.
|
68
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
65
69
|
|
66
70
|
Returns:
|
67
|
-
output activation from this transformer block
|
71
|
+
output activation from this transformer block, and updated kv cache (if
|
72
|
+
passed in).
|
68
73
|
"""
|
69
74
|
|
70
75
|
x_norm = self.pre_atten_norm(x)
|
71
|
-
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
|
76
|
+
attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
72
77
|
attn_out_norm = self.post_atten_norm(attn_out)
|
73
78
|
x = x + attn_out_norm
|
74
79
|
output = x + self.ff(x)
|
75
|
-
return output
|
80
|
+
return output, kv
|
76
81
|
|
77
82
|
|
78
83
|
class Gemma2(nn.Module):
|
@@ -138,24 +143,38 @@ class Gemma2(nn.Module):
|
|
138
143
|
return self.mask_cache.index_select(2, input_pos)
|
139
144
|
|
140
145
|
@torch.inference_mode
|
141
|
-
def forward(
|
142
|
-
|
146
|
+
def forward(
|
147
|
+
self,
|
148
|
+
tokens: torch.Tensor,
|
149
|
+
input_pos: torch.Tensor,
|
150
|
+
kv_cache: kv_utils.KVCache,
|
151
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
152
|
+
_, seq_len = tokens.size()
|
143
153
|
assert self.config.max_seq_len >= seq_len, (
|
144
154
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
145
155
|
f" {self.config.max_seq_len}"
|
146
156
|
)
|
157
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
158
|
+
"The number of transformer blocks and the number of KV cache entries"
|
159
|
+
" must be the same."
|
160
|
+
)
|
147
161
|
|
148
162
|
cos, sin = self.rope_cache
|
149
163
|
cos = cos.index_select(0, input_pos)
|
150
164
|
sin = sin.index_select(0, input_pos)
|
151
165
|
|
152
166
|
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(
|
167
|
+
x = self.tok_embedding(tokens)
|
154
168
|
x = x * (self.config.embedding_dim**0.5)
|
155
169
|
|
170
|
+
updated_kv_entires = []
|
156
171
|
for i, block in enumerate(self.transformer_blocks):
|
157
172
|
mask = self.get_attention_mask(i, input_pos)
|
158
|
-
|
173
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
174
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
175
|
+
if kv_entry:
|
176
|
+
updated_kv_entires.append(kv_entry)
|
177
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
159
178
|
|
160
179
|
x = self.final_norm(x)
|
161
180
|
res = self.lm_head(x) # (b, t, vocab_size)
|
@@ -163,7 +182,8 @@ class Gemma2(nn.Module):
|
|
163
182
|
res = res / self.config.final_logit_softcap
|
164
183
|
res = torch.tanh(res)
|
165
184
|
res = res * self.config.final_logit_softcap
|
166
|
-
|
185
|
+
|
186
|
+
return {"logits": res, "kv_cache": updated_kv_cache}
|
167
187
|
|
168
188
|
|
169
189
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -243,14 +263,13 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
243
263
|
return model
|
244
264
|
|
245
265
|
|
246
|
-
def define_and_run_2b() -> None:
|
266
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
247
267
|
"""Instantiates and runs a Gemma2 2B model."""
|
248
268
|
|
249
|
-
current_dir = Path(__file__).parent.resolve()
|
269
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
250
270
|
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
251
271
|
print("Running GEMMA 2")
|
252
272
|
kv_cache_max_len = 1024
|
253
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
|
254
273
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
255
274
|
toks = torch.from_numpy(
|
256
275
|
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
@@ -258,11 +277,13 @@ def define_and_run_2b() -> None:
|
|
258
277
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
259
278
|
tokens[0, :9] = toks
|
260
279
|
input_pos = torch.arange(0, kv_cache_max_len)
|
261
|
-
|
262
|
-
|
280
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
281
|
+
out = model.forward(tokens, input_pos, kv)
|
282
|
+
out_final = out["logits"][0, 8, :]
|
263
283
|
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
264
284
|
|
265
285
|
|
266
286
|
if __name__ == "__main__":
|
267
287
|
torch.set_printoptions(sci_mode=True)
|
268
|
-
|
288
|
+
path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
|
289
|
+
define_and_run_2b(path)
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of converting a Phi-2 model to multi-signature tflite model."""
|
18
17
|
|
19
18
|
import os
|
20
|
-
|
19
|
+
import pathlib
|
21
20
|
|
22
21
|
import ai_edge_torch
|
23
|
-
from ai_edge_torch.generative.examples.
|
24
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
25
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
26
25
|
import torch
|
27
26
|
|
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
|
|
32
31
|
kv_cache_max_len: int = 1024,
|
33
32
|
quantize: bool = True,
|
34
33
|
):
|
35
|
-
"""
|
34
|
+
"""Converts a Phi-2 model to multi-signature tflite model.
|
36
35
|
|
37
|
-
tflite model.
|
38
36
|
Args:
|
39
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
40
38
|
holding the checkpoint.
|
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
|
|
53
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
54
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
55
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
56
|
-
kv =
|
54
|
+
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
57
55
|
|
58
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
59
57
|
edge_model = (
|
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
|
|
77
75
|
)
|
78
76
|
.convert(quant_config=quant_config)
|
79
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
80
79
|
edge_model.export(
|
81
|
-
f'/tmp/
|
80
|
+
f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
82
81
|
)
|
83
82
|
|
84
83
|
|
85
84
|
if __name__ == '__main__':
|
86
|
-
|
87
|
-
convert_phi2_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
|
86
|
+
convert_phi2_to_tflite(path)
|