ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/test/test_convert.py +35 -13
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
- ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
- ai_edge_torch/generative/examples/t5/t5.py +35 -22
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
- ai_edge_torch/generative/layers/attention.py +77 -73
- ai_edge_torch/generative/layers/builder.py +5 -3
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +38 -19
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +15 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from ai_edge_torch.testing import model_coverage
|
|
25
25
|
import numpy as np
|
26
26
|
import tensorflow as tf
|
27
27
|
import torch
|
28
|
+
from torch import nn
|
28
29
|
import torchvision
|
29
30
|
|
30
31
|
from absl.testing import absltest as googletest
|
@@ -51,7 +52,7 @@ class TestConvert(googletest.TestCase):
|
|
51
52
|
def test_convert_add(self):
|
52
53
|
"""Tests conversion of a simple Add module."""
|
53
54
|
|
54
|
-
class Add(
|
55
|
+
class Add(nn.Module):
|
55
56
|
|
56
57
|
def forward(self, a, b):
|
57
58
|
return a + b
|
@@ -70,7 +71,7 @@ class TestConvert(googletest.TestCase):
|
|
70
71
|
def test_convert_dot_add(self):
|
71
72
|
"""Tests conversion of a matrix multiplication followed by an add."""
|
72
73
|
|
73
|
-
class DotAdd(
|
74
|
+
class DotAdd(nn.Module):
|
74
75
|
|
75
76
|
def forward(self, a, b, c):
|
76
77
|
return a @ b + c
|
@@ -99,7 +100,7 @@ class TestConvert(googletest.TestCase):
|
|
99
100
|
def test_signature_args_ordering(self):
|
100
101
|
"""Tests conversion of a model with more than 10 arguments."""
|
101
102
|
|
102
|
-
class AddChainWith11Args(
|
103
|
+
class AddChainWith11Args(nn.Module):
|
103
104
|
"""A model with 11 arguments."""
|
104
105
|
|
105
106
|
def forward(
|
@@ -152,7 +153,7 @@ class TestConvert(googletest.TestCase):
|
|
152
153
|
def test_multi_output_model(self):
|
153
154
|
"""Tests conversion of a model that returns multiple outputs."""
|
154
155
|
|
155
|
-
class BasicAddModelWithMultipleOutputs(
|
156
|
+
class BasicAddModelWithMultipleOutputs(nn.Module):
|
156
157
|
"""A model that returns multiple outputs."""
|
157
158
|
|
158
159
|
def forward(self, arg0, arg1):
|
@@ -176,7 +177,7 @@ class TestConvert(googletest.TestCase):
|
|
176
177
|
def test_12_outputs_model(self):
|
177
178
|
"""Tests conversion of a model that returns more than 10 outputs."""
|
178
179
|
|
179
|
-
class BasicAddModelWithMultipleOutputs(
|
180
|
+
class BasicAddModelWithMultipleOutputs(nn.Module):
|
180
181
|
"""A model that returns multiple outputs."""
|
181
182
|
|
182
183
|
def forward(self, arg0, arg1):
|
@@ -245,7 +246,7 @@ class TestConvert(googletest.TestCase):
|
|
245
246
|
def test_convert_add_converter_flags(self):
|
246
247
|
"""Tests conversion of an add module setting a tflite converter flag."""
|
247
248
|
|
248
|
-
class Add(
|
249
|
+
class Add(nn.Module):
|
249
250
|
|
250
251
|
def forward(self, a, b):
|
251
252
|
return a + b
|
@@ -267,6 +268,27 @@ class TestConvert(googletest.TestCase):
|
|
267
268
|
)
|
268
269
|
self.assertTrue(os.path.isdir(ir_dump_path))
|
269
270
|
|
271
|
+
def test_convert_conv_transpose_batch_norm(self):
|
272
|
+
"""Tests conversion of a model with ConvTranspose2d and BatchNorm2d."""
|
273
|
+
|
274
|
+
channels = 2
|
275
|
+
size = 2
|
276
|
+
torch_model = nn.Sequential(
|
277
|
+
nn.ConvTranspose2d(
|
278
|
+
channels, channels, 1, stride=2, dilation=1, bias=False
|
279
|
+
),
|
280
|
+
nn.BatchNorm2d(channels),
|
281
|
+
)
|
282
|
+
|
283
|
+
torch_model.eval()
|
284
|
+
sample_input = (torch.rand(1, channels, size, size),)
|
285
|
+
edge_model = ai_edge_torch.convert(torch_model, sample_input)
|
286
|
+
|
287
|
+
result = model_coverage.compare_tflite_torch(
|
288
|
+
edge_model, torch_model, sample_input
|
289
|
+
)
|
290
|
+
self.assertTrue(result)
|
291
|
+
|
270
292
|
@googletest.skipIf(
|
271
293
|
not config.Config.use_torch_xla,
|
272
294
|
reason="Shape polymorphism is not yet support with odml_torch.",
|
@@ -274,7 +296,7 @@ class TestConvert(googletest.TestCase):
|
|
274
296
|
def test_convert_model_with_dynamic_batch(self):
|
275
297
|
"""Test converting a simple model with dynamic batch size."""
|
276
298
|
|
277
|
-
class SampleModel(
|
299
|
+
class SampleModel(nn.Module):
|
278
300
|
|
279
301
|
def __init__(self):
|
280
302
|
super().__init__()
|
@@ -304,7 +326,7 @@ class TestConvert(googletest.TestCase):
|
|
304
326
|
def test_convert_model_with_kwargs(self):
|
305
327
|
"""Test converting a simple model with sample_kwargs."""
|
306
328
|
|
307
|
-
class SampleModel(
|
329
|
+
class SampleModel(nn.Module):
|
308
330
|
|
309
331
|
def forward(self, x, y):
|
310
332
|
return x + y
|
@@ -323,7 +345,7 @@ class TestConvert(googletest.TestCase):
|
|
323
345
|
def test_convert_model_with_args_kwargs(self):
|
324
346
|
"""Test converting a simple model with both sample_args and sample_kwargs."""
|
325
347
|
|
326
|
-
class SampleModel(
|
348
|
+
class SampleModel(nn.Module):
|
327
349
|
|
328
350
|
def forward(self, x, y):
|
329
351
|
return x + y
|
@@ -343,7 +365,7 @@ class TestConvert(googletest.TestCase):
|
|
343
365
|
def test_convert_model_with_args_nested_kwargs_1(self):
|
344
366
|
"""Test converting a simple model with both sample_args and nested sample_kwargs."""
|
345
367
|
|
346
|
-
class SampleModel(
|
368
|
+
class SampleModel(nn.Module):
|
347
369
|
|
348
370
|
def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
|
349
371
|
return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
|
@@ -370,7 +392,7 @@ class TestConvert(googletest.TestCase):
|
|
370
392
|
def test_convert_model_with_args_nested_kwargs_2(self):
|
371
393
|
"""Test converting a simple model with both sample_args and nested sample_kwargs."""
|
372
394
|
|
373
|
-
class SampleModel(
|
395
|
+
class SampleModel(nn.Module):
|
374
396
|
|
375
397
|
def forward(self, x, y, z):
|
376
398
|
return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
|
@@ -397,7 +419,7 @@ class TestConvert(googletest.TestCase):
|
|
397
419
|
def test_convert_model_with_args_nested_kwargs_3(self):
|
398
420
|
"""Test converting a simple model with both sample_args and nested sample_kwargs."""
|
399
421
|
|
400
|
-
class SampleModel(
|
422
|
+
class SampleModel(nn.Module):
|
401
423
|
|
402
424
|
def forward(self, x, y, z):
|
403
425
|
return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
|
@@ -424,7 +446,7 @@ class TestConvert(googletest.TestCase):
|
|
424
446
|
def test_convert_model_non_flat_output_dict(self):
|
425
447
|
"""Test converting a model with non-flat output structure."""
|
426
448
|
|
427
|
-
class SampleModel(
|
449
|
+
class SampleModel(nn.Module):
|
428
450
|
|
429
451
|
def forward(self, x, y, z):
|
430
452
|
return {"x": x, "y": TestContainer1(data_1=y, data_2=[y, z])}
|
@@ -13,32 +13,35 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma2 model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
24
27
|
|
25
|
-
def
|
28
|
+
def convert_gemma2_to_tflite(
|
26
29
|
checkpoint_path: str,
|
27
30
|
prefill_seq_len: int = 512,
|
28
31
|
kv_cache_max_len: int = 1024,
|
29
32
|
quantize: bool = True,
|
30
33
|
):
|
31
|
-
"""
|
32
|
-
tflite model.
|
34
|
+
"""Converts a Gemma2 2B model to multi-signature tflite model.
|
33
35
|
|
34
36
|
Args:
|
35
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
37
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
+
holding the checkpoint.
|
36
39
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
37
40
|
Defaults to 512.
|
38
41
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
39
42
|
including both prefill and decode. Defaults to 1024.
|
40
|
-
quantize (bool, optional): Whether the model should be quanized.
|
41
|
-
|
43
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
+
to True.
|
42
45
|
"""
|
43
46
|
pytorch_model = gemma2.build_2b_model(
|
44
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
|
86
|
+
convert_gemma2_to_tflite(path)
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
convert_gemma_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
|
86
|
+
convert_gemma_to_tflite(path)
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -48,7 +50,6 @@ class Gemma(nn.Module):
|
|
48
50
|
def __init__(self, config: cfg.ModelConfig):
|
49
51
|
super().__init__()
|
50
52
|
|
51
|
-
self.config = config
|
52
53
|
# Construct model layers.
|
53
54
|
self.tok_embedding = nn.Embedding(
|
54
55
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
@@ -60,18 +61,20 @@ class Gemma(nn.Module):
|
|
60
61
|
)
|
61
62
|
# Gemma re-uses the embedding as the head projection layer.
|
62
63
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
64
|
+
# Gemma has only one block config.
|
65
|
+
block_config = config.block_config(0)
|
63
66
|
self.transformer_blocks = nn.ModuleList(
|
64
|
-
attention.TransformerBlock(
|
67
|
+
attention.TransformerBlock(block_config, config)
|
68
|
+
for _ in range(config.num_layers)
|
65
69
|
)
|
66
70
|
self.final_norm = builder.build_norm(
|
67
71
|
config.embedding_dim,
|
68
72
|
config.final_norm_config,
|
69
73
|
)
|
74
|
+
attn_config = block_config.attn_config
|
70
75
|
self.rope_cache = attn_utils.build_rope_cache(
|
71
76
|
size=config.kv_cache_max,
|
72
|
-
dim=int(
|
73
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
74
|
-
),
|
77
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
75
78
|
base=10_000,
|
76
79
|
condense_ratio=1,
|
77
80
|
dtype=torch.float32,
|
@@ -84,16 +87,22 @@ class Gemma(nn.Module):
|
|
84
87
|
)
|
85
88
|
self.config = config
|
86
89
|
|
87
|
-
# The model's forward function takes in additional k/v cache tensors
|
88
|
-
# and returns the updated k/v cache tensors to the caller.
|
89
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
90
90
|
@torch.inference_mode
|
91
|
-
def forward(
|
92
|
-
|
91
|
+
def forward(
|
92
|
+
self,
|
93
|
+
tokens: torch.Tensor,
|
94
|
+
input_pos: torch.Tensor,
|
95
|
+
kv_cache: kv_utils.KVCache,
|
96
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
97
|
+
_, seq_len = tokens.size()
|
93
98
|
assert self.config.max_seq_len >= seq_len, (
|
94
99
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
95
100
|
f" {self.config.max_seq_len}"
|
96
101
|
)
|
102
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
103
|
+
"The number of transformer blocks and the number of KV cache entries"
|
104
|
+
" must be the same."
|
105
|
+
)
|
97
106
|
|
98
107
|
cos, sin = self.rope_cache
|
99
108
|
cos = cos.index_select(0, input_pos)
|
@@ -102,15 +111,20 @@ class Gemma(nn.Module):
|
|
102
111
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
103
112
|
|
104
113
|
# token embeddings of shape (b, t, n_embd)
|
105
|
-
x = self.tok_embedding(
|
114
|
+
x = self.tok_embedding(tokens)
|
106
115
|
x = x * (self.config.embedding_dim**0.5)
|
107
116
|
|
108
|
-
|
109
|
-
|
117
|
+
updated_kv_entires = []
|
118
|
+
for i, block in enumerate(self.transformer_blocks):
|
119
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
120
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
121
|
+
if kv_entry:
|
122
|
+
updated_kv_entires.append(kv_entry)
|
123
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
110
124
|
|
111
125
|
x = self.final_norm(x)
|
112
|
-
|
113
|
-
return
|
126
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
127
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
114
128
|
|
115
129
|
|
116
130
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -139,18 +153,20 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
139
153
|
epsilon=1e-6,
|
140
154
|
zero_centered=True,
|
141
155
|
)
|
156
|
+
block_config = cfg.TransformerBlockConfig(
|
157
|
+
attn_config=attn_config,
|
158
|
+
ff_config=ff_config,
|
159
|
+
pre_attention_norm_config=norm_config,
|
160
|
+
post_attention_norm_config=norm_config,
|
161
|
+
)
|
142
162
|
config = cfg.ModelConfig(
|
143
163
|
vocab_size=256000,
|
144
164
|
num_layers=18,
|
145
165
|
max_seq_len=8192,
|
146
166
|
embedding_dim=2048,
|
147
167
|
kv_cache_max_len=kv_cache_max_len,
|
148
|
-
|
149
|
-
ff_config=ff_config,
|
150
|
-
pre_attention_norm_config=norm_config,
|
151
|
-
post_attention_norm_config=norm_config,
|
168
|
+
block_configs=block_config,
|
152
169
|
final_norm_config=norm_config,
|
153
|
-
parallel_residual=False,
|
154
170
|
lm_head_use_bias=False,
|
155
171
|
enable_hlfb=True,
|
156
172
|
)
|
@@ -159,7 +175,8 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
159
175
|
|
160
176
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
161
177
|
config = get_model_config_2b(kv_cache_max_len)
|
162
|
-
config.
|
178
|
+
# Gemma has only one block config.
|
179
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
163
180
|
config.vocab_size = 128
|
164
181
|
config.num_layers = 2
|
165
182
|
config.max_seq_len = 2 * kv_cache_max_len
|
@@ -170,32 +187,35 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
170
187
|
config = get_model_config_2b(**kwargs)
|
171
188
|
model = Gemma(config)
|
172
189
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
173
|
-
#
|
190
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
174
191
|
# to False.
|
175
192
|
loader.load(model, strict=False)
|
176
193
|
model.eval()
|
177
194
|
return model
|
178
195
|
|
179
196
|
|
180
|
-
def define_and_run_2b() -> None:
|
197
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
181
198
|
"""Instantiates and runs a Gemma 2B model."""
|
182
199
|
|
183
|
-
current_dir = Path(__file__).parent.resolve()
|
200
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
184
201
|
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
185
202
|
|
186
203
|
kv_cache_max_len = 1024
|
187
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
|
188
204
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
189
205
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
190
206
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
191
207
|
tokens[0, :4] = idx
|
192
208
|
input_pos = torch.arange(0, kv_cache_max_len)
|
193
|
-
|
209
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
210
|
+
output = model.forward(tokens, input_pos, kv)
|
194
211
|
print("comparing with goldens..")
|
195
212
|
assert torch.allclose(
|
196
|
-
gemma_goldens,
|
213
|
+
gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
197
214
|
)
|
198
215
|
|
199
216
|
|
200
217
|
if __name__ == "__main__":
|
201
|
-
|
218
|
+
input_checkpoint_path = os.path.join(
|
219
|
+
pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
|
220
|
+
)
|
221
|
+
define_and_run_2b(input_checkpoint_path)
|