ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +40 -24
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/layers/model_config.py +1 -0
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +33 -39
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,119 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a SmalLM model."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
import os
|
20
|
+
import pathlib
|
21
|
+
|
22
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
+
import numpy as np
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
|
30
|
+
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
31
|
+
# SmalLM re-uses the embedding as the head projection layer.
|
32
|
+
TENSOR_NAMES.lm_head = None
|
33
|
+
|
34
|
+
|
35
|
+
class SmalLM(tiny_llama.TinyLlama):
|
36
|
+
"""A SmalLM model built from the Edge Generative API layers.
|
37
|
+
|
38
|
+
SmalLM shares the same architecture as TinyLlama, but with different model
|
39
|
+
sizes.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, config: cfg.ModelConfig):
|
43
|
+
super().__init__(config)
|
44
|
+
# SmalLM re-uses the embedding as the head projection layer.
|
45
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
46
|
+
|
47
|
+
|
48
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
49
|
+
"""Returns the model config for a SmalLM 135M model.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
53
|
+
is 1024.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The model config for a SmalLM model.
|
57
|
+
"""
|
58
|
+
attn_config = cfg.AttentionConfig(
|
59
|
+
num_heads=9,
|
60
|
+
head_dim=64,
|
61
|
+
num_query_groups=3,
|
62
|
+
rotary_percentage=1.0,
|
63
|
+
)
|
64
|
+
ff_config = cfg.FeedForwardConfig(
|
65
|
+
type=cfg.FeedForwardType.GATED,
|
66
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
67
|
+
intermediate_size=1536,
|
68
|
+
)
|
69
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
70
|
+
config = cfg.ModelConfig(
|
71
|
+
vocab_size=49152,
|
72
|
+
num_layers=30,
|
73
|
+
max_seq_len=2048,
|
74
|
+
embedding_dim=576,
|
75
|
+
kv_cache_max_len=kv_cache_max_len,
|
76
|
+
attn_config=attn_config,
|
77
|
+
ff_config=ff_config,
|
78
|
+
pre_attention_norm_config=norm_config,
|
79
|
+
post_attention_norm_config=norm_config,
|
80
|
+
final_norm_config=norm_config,
|
81
|
+
enable_hlfb=True,
|
82
|
+
)
|
83
|
+
return config
|
84
|
+
|
85
|
+
|
86
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
87
|
+
config = get_model_config(**kwargs)
|
88
|
+
model = SmalLM(config)
|
89
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
90
|
+
# since embedding and lm-head use the same weight, we need to set strict
|
91
|
+
# to False.
|
92
|
+
loader.load(model, strict=False)
|
93
|
+
model.eval()
|
94
|
+
return model
|
95
|
+
|
96
|
+
|
97
|
+
def define_and_run(checkpoint_path: str) -> None:
|
98
|
+
"""Instantiates and runs a SmalLM model."""
|
99
|
+
|
100
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
101
|
+
smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
|
102
|
+
kv_cache_max_len = 1024
|
103
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
104
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
105
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
106
|
+
tokens[0, :4] = idx
|
107
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
108
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
109
|
+
output = model.forward(tokens, input_pos, kv)
|
110
|
+
assert torch.allclose(
|
111
|
+
smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
if __name__ == "__main__":
|
116
|
+
input_checkpoint_path = os.path.join(
|
117
|
+
pathlib.Path.home(), "Downloads/llm_data/smallm"
|
118
|
+
)
|
119
|
+
define_and_run(input_checkpoint_path)
|
@@ -12,14 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""A toy example which has basic transformer block (w/ externalized KV-Cache)."""
|
17
|
+
|
16
18
|
from typing import Tuple
|
17
19
|
|
18
20
|
import ai_edge_torch
|
19
21
|
from ai_edge_torch import lowertools
|
20
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.layers import attention
|
23
|
+
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
-
import ai_edge_torch.generative.layers.builder as builder
|
23
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
27
|
import torch
|
25
28
|
import torch.nn as nn
|
@@ -27,7 +30,7 @@ import torch.nn as nn
|
|
27
30
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
28
31
|
|
29
32
|
|
30
|
-
class
|
33
|
+
class ToyModelWithKVCache(torch.nn.Module):
|
31
34
|
|
32
35
|
def __init__(self, config: cfg.ModelConfig) -> None:
|
33
36
|
super().__init__()
|
@@ -36,7 +39,7 @@ class ToyModelWithKV(torch.nn.Module):
|
|
36
39
|
)
|
37
40
|
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
38
41
|
self.transformer_blocks = nn.ModuleList(
|
39
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
42
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
40
43
|
)
|
41
44
|
self.final_norm = builder.build_norm(
|
42
45
|
config.embedding_dim,
|
@@ -57,18 +60,29 @@ class ToyModelWithKV(torch.nn.Module):
|
|
57
60
|
)
|
58
61
|
self.config = config
|
59
62
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
+
def forward(
|
64
|
+
self,
|
65
|
+
tokens: torch.Tensor,
|
66
|
+
input_pos: torch.Tensor,
|
67
|
+
kv_cache: kv_utils.KVCache,
|
68
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCache]:
|
69
|
+
x = self.tok_embedding(tokens)
|
63
70
|
cos, sin = self.rope_cache
|
64
71
|
cos = cos.index_select(0, input_pos)
|
65
72
|
sin = sin.index_select(0, input_pos)
|
66
73
|
mask = self.mask_cache.index_select(2, input_pos)
|
67
74
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
75
|
+
|
76
|
+
updated_kv_entires = []
|
68
77
|
for i, block in enumerate(self.transformer_blocks):
|
69
|
-
|
78
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
79
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
80
|
+
if kv_entry:
|
81
|
+
updated_kv_entires.append(kv_entry)
|
82
|
+
|
70
83
|
x = self.final_norm(x)
|
71
|
-
|
84
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
85
|
+
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
|
72
86
|
|
73
87
|
|
74
88
|
def _export_stablehlo_mlir(model, args):
|
@@ -89,7 +103,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
89
103
|
config = cfg.ModelConfig(
|
90
104
|
vocab_size=150,
|
91
105
|
num_layers=2,
|
92
|
-
max_seq_len=
|
106
|
+
max_seq_len=100,
|
93
107
|
embedding_dim=128,
|
94
108
|
attn_config=attn_config,
|
95
109
|
ff_config=ff_config,
|
@@ -102,40 +116,59 @@ def get_model_config() -> cfg.ModelConfig:
|
|
102
116
|
|
103
117
|
|
104
118
|
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
105
|
-
|
119
|
+
tokens = torch.unsqueeze(torch.arange(0, 100), 0)
|
106
120
|
input_pos = torch.arange(0, 100)
|
107
|
-
return
|
121
|
+
return tokens, input_pos
|
108
122
|
|
109
123
|
|
110
124
|
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
111
|
-
|
112
|
-
input_pos = torch.tensor([10]
|
113
|
-
return
|
125
|
+
tokens = torch.tensor([[1]], dtype=torch.long)
|
126
|
+
input_pos = torch.tensor([10])
|
127
|
+
return tokens, input_pos
|
114
128
|
|
115
129
|
|
116
130
|
def define_and_run() -> None:
|
117
131
|
dump_mlir = False
|
118
132
|
|
119
133
|
config = get_model_config()
|
120
|
-
model =
|
134
|
+
model = ToyModelWithExternalKV(config)
|
135
|
+
model.eval()
|
121
136
|
print('running an inference')
|
122
|
-
|
123
|
-
|
124
|
-
|
137
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
138
|
+
|
139
|
+
tokens, input_pos = get_sample_prefill_inputs()
|
140
|
+
decode_token, decode_input_pos = get_sample_decode_inputs()
|
141
|
+
print(model.forward(tokens, input_pos, kv))
|
125
142
|
|
126
143
|
if dump_mlir:
|
127
|
-
mlir_text = _export_stablehlo_mlir(model, (
|
128
|
-
with open('/tmp/
|
144
|
+
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
145
|
+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
129
146
|
f.write(mlir_text)
|
130
147
|
|
131
148
|
# Convert model to tflite with 2 signatures (prefill + decode).
|
132
149
|
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
133
150
|
edge_model = (
|
134
|
-
ai_edge_torch.signature(
|
135
|
-
|
151
|
+
ai_edge_torch.signature(
|
152
|
+
'prefill',
|
153
|
+
model,
|
154
|
+
sample_kwargs={
|
155
|
+
'tokens': tokens,
|
156
|
+
'input_pos': input_pos,
|
157
|
+
'kv_cache': kv,
|
158
|
+
},
|
159
|
+
)
|
160
|
+
.signature(
|
161
|
+
'decode',
|
162
|
+
model,
|
163
|
+
sample_kwargs={
|
164
|
+
'tokens': decode_token,
|
165
|
+
'input_pos': decode_input_pos,
|
166
|
+
'kv_cache': kv,
|
167
|
+
},
|
168
|
+
)
|
136
169
|
.convert()
|
137
170
|
)
|
138
|
-
edge_model.export('/tmp/
|
171
|
+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
139
172
|
|
140
173
|
|
141
174
|
if __name__ == '__main__':
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting TinyLlama model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
@@ -48,20 +51,36 @@ def convert_tiny_llama_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
convert_tiny_llama_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
|
86
|
+
convert_tiny_llama_to_tflite(path)
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a TinyLlama model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -42,7 +44,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
42
44
|
)
|
43
45
|
|
44
46
|
|
45
|
-
class
|
47
|
+
class TinyLlama(nn.Module):
|
46
48
|
"""A TinyLlama model built from the Edge Generative API layers."""
|
47
49
|
|
48
50
|
def __init__(self, config: cfg.ModelConfig):
|
@@ -80,16 +82,22 @@ class TinyLLamma(nn.Module):
|
|
80
82
|
)
|
81
83
|
self.config = config
|
82
84
|
|
83
|
-
# The model's forward function takes in additional k/v cache tensors
|
84
|
-
# and returns the updated k/v cache tensors to the caller.
|
85
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
86
85
|
@torch.inference_mode
|
87
|
-
def forward(
|
88
|
-
|
86
|
+
def forward(
|
87
|
+
self,
|
88
|
+
tokens: torch.Tensor,
|
89
|
+
input_pos: torch.Tensor,
|
90
|
+
kv_cache: kv_utils.KVCache,
|
91
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
92
|
+
_, seq_len = tokens.size()
|
89
93
|
assert self.config.max_seq_len >= seq_len, (
|
90
94
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
91
95
|
f" {self.config.max_seq_len}"
|
92
96
|
)
|
97
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
98
|
+
"The number of transformer blocks and the number of KV cache entries"
|
99
|
+
" must be the same."
|
100
|
+
)
|
93
101
|
|
94
102
|
cos, sin = self.rope_cache
|
95
103
|
cos = cos.index_select(0, input_pos)
|
@@ -97,16 +105,20 @@ class TinyLLamma(nn.Module):
|
|
97
105
|
mask = self.mask_cache.index_select(2, input_pos)
|
98
106
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
99
107
|
|
100
|
-
#
|
101
|
-
x = self.tok_embedding(
|
108
|
+
# token embeddings of shape (b, t, n_embd)
|
109
|
+
x = self.tok_embedding(tokens)
|
102
110
|
|
103
|
-
|
104
|
-
|
111
|
+
updated_kv_entires = []
|
112
|
+
for i, block in enumerate(self.transformer_blocks):
|
113
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
114
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
115
|
+
if kv_entry:
|
116
|
+
updated_kv_entires.append(kv_entry)
|
117
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
105
118
|
|
106
119
|
x = self.final_norm(x)
|
107
|
-
|
108
|
-
|
109
|
-
return res
|
120
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
121
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
110
122
|
|
111
123
|
|
112
124
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -147,8 +159,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
147
159
|
return config
|
148
160
|
|
149
161
|
|
150
|
-
def get_fake_model_config() -> cfg.ModelConfig:
|
151
|
-
config = get_model_config()
|
162
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
163
|
+
config = get_model_config(**kwargs)
|
152
164
|
config.vocab_size = 128
|
153
165
|
config.num_layers = 2
|
154
166
|
config.ff_config.intermediate_size = 64
|
@@ -157,29 +169,33 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
157
169
|
|
158
170
|
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
159
171
|
config = get_model_config(**kwargs)
|
160
|
-
model =
|
172
|
+
model = TinyLlama(config)
|
161
173
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
162
174
|
loader.load(model)
|
175
|
+
model.eval()
|
163
176
|
return model
|
164
177
|
|
165
178
|
|
166
|
-
def define_and_run() -> None:
|
179
|
+
def define_and_run(checkpoint_path: str) -> None:
|
167
180
|
"""Instantiates and runs a TinyLlama model."""
|
168
181
|
|
169
|
-
current_dir = Path(__file__).parent.resolve()
|
182
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
170
183
|
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
|
171
184
|
kv_cache_max_len = 1024
|
172
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
|
173
185
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
174
186
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
175
187
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
176
188
|
tokens[0, :4] = idx
|
177
189
|
input_pos = torch.arange(0, kv_cache_max_len)
|
178
|
-
|
190
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
191
|
+
output = model.forward(tokens, input_pos, kv)
|
179
192
|
assert torch.allclose(
|
180
|
-
tiny_llama_goldens,
|
193
|
+
tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
181
194
|
)
|
182
195
|
|
183
196
|
|
184
197
|
if __name__ == "__main__":
|
185
|
-
|
198
|
+
input_checkpoint_path = os.path.join(
|
199
|
+
pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
|
200
|
+
)
|
201
|
+
define_and_run(input_checkpoint_path)
|