ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +35 -13
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_model_conversion.py +71 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +25 -35
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -12,26 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Note: This is an experimental version of phi2 with external KV cache.
|
18
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of building a Phi-2 model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
22
|
-
from typing import Tuple
|
19
|
+
import pathlib
|
23
20
|
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
24
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
27
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
27
|
import numpy as np
|
31
28
|
import torch
|
32
29
|
from torch import nn
|
33
30
|
|
34
|
-
|
35
31
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
32
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
37
33
|
ff_down_proj="model.layers.{}.mlp.fc2",
|
@@ -89,13 +85,17 @@ class Phi2(nn.Module):
|
|
89
85
|
self,
|
90
86
|
tokens: torch.Tensor,
|
91
87
|
input_pos: torch.Tensor,
|
92
|
-
kv_cache: kv_utils.
|
93
|
-
) ->
|
88
|
+
kv_cache: kv_utils.KVCache,
|
89
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
94
90
|
_, seq_len = tokens.size()
|
95
91
|
assert self.config.max_seq_len >= seq_len, (
|
96
92
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
97
93
|
f" {self.config.max_seq_len}"
|
98
94
|
)
|
95
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
96
|
+
"The number of transformer blocks and the number of KV cache entries"
|
97
|
+
" must be the same."
|
98
|
+
)
|
99
99
|
|
100
100
|
cos, sin = self.rope_cache
|
101
101
|
cos = cos.index_select(0, input_pos)
|
@@ -111,11 +111,11 @@ class Phi2(nn.Module):
|
|
111
111
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
112
112
|
if kv_entry:
|
113
113
|
updated_kv_entires.append(kv_entry)
|
114
|
-
updated_kv_cache = kv_utils.
|
114
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
115
115
|
|
116
116
|
x = self.final_norm(x)
|
117
|
-
|
118
|
-
return
|
117
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
118
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
119
119
|
|
120
120
|
|
121
121
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -169,39 +169,37 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
169
169
|
return config
|
170
170
|
|
171
171
|
|
172
|
-
def build_model(
|
173
|
-
checkpoint_path: str, test_model: bool = False, **kwargs
|
174
|
-
) -> nn.Module:
|
172
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
175
173
|
"""Instantiates the model instance and load checkpoint if provided."""
|
176
|
-
config = (
|
177
|
-
get_fake_model_config(**kwargs)
|
178
|
-
if test_model
|
179
|
-
else get_model_config(**kwargs)
|
180
|
-
)
|
174
|
+
config = get_model_config(**kwargs)
|
181
175
|
model = Phi2(config)
|
182
|
-
|
183
|
-
|
184
|
-
loader.load(model)
|
176
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
177
|
+
loader.load(model)
|
185
178
|
model.eval()
|
186
179
|
return model
|
187
180
|
|
188
181
|
|
189
|
-
def define_and_run(checkpoint_path: str
|
182
|
+
def define_and_run(checkpoint_path: str) -> None:
|
190
183
|
"""Instantiates and runs a Phi-2 model."""
|
191
184
|
|
185
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
186
|
+
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
192
187
|
kv_cache_max_len = 1024
|
193
|
-
model = build_model(
|
194
|
-
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
195
|
-
)
|
188
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
196
189
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
197
190
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
198
191
|
tokens[0, :4] = idx
|
199
192
|
input_pos = torch.arange(0, kv_cache_max_len)
|
200
|
-
kv = kv_utils.
|
201
|
-
|
202
|
-
print(
|
193
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
194
|
+
output = model.forward(tokens, input_pos, kv)
|
195
|
+
print("comparing with goldens..")
|
196
|
+
assert torch.allclose(
|
197
|
+
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
198
|
+
)
|
203
199
|
|
204
200
|
|
205
201
|
if __name__ == "__main__":
|
206
|
-
input_checkpoint_path = os.path.join(
|
202
|
+
input_checkpoint_path = os.path.join(
|
203
|
+
pathlib.Path.home(), "Downloads/llm_data/phi2"
|
204
|
+
)
|
207
205
|
define_and_run(input_checkpoint_path)
|
@@ -12,14 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""A toy example which has basic transformer block (w/ externalized KV-Cache)."""
|
17
|
+
|
16
18
|
from typing import Tuple
|
17
19
|
|
18
20
|
import ai_edge_torch
|
19
21
|
from ai_edge_torch import lowertools
|
20
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.layers import attention
|
23
|
+
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
-
import ai_edge_torch.generative.layers.builder as builder
|
23
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
27
|
import torch
|
25
28
|
import torch.nn as nn
|
@@ -27,7 +30,7 @@ import torch.nn as nn
|
|
27
30
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
28
31
|
|
29
32
|
|
30
|
-
class
|
33
|
+
class ToyModelWithKVCache(torch.nn.Module):
|
31
34
|
|
32
35
|
def __init__(self, config: cfg.ModelConfig) -> None:
|
33
36
|
super().__init__()
|
@@ -36,7 +39,7 @@ class ToyModelWithKV(torch.nn.Module):
|
|
36
39
|
)
|
37
40
|
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
38
41
|
self.transformer_blocks = nn.ModuleList(
|
39
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
42
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
40
43
|
)
|
41
44
|
self.final_norm = builder.build_norm(
|
42
45
|
config.embedding_dim,
|
@@ -57,18 +60,29 @@ class ToyModelWithKV(torch.nn.Module):
|
|
57
60
|
)
|
58
61
|
self.config = config
|
59
62
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
+
def forward(
|
64
|
+
self,
|
65
|
+
tokens: torch.Tensor,
|
66
|
+
input_pos: torch.Tensor,
|
67
|
+
kv_cache: kv_utils.KVCache,
|
68
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCache]:
|
69
|
+
x = self.tok_embedding(tokens)
|
63
70
|
cos, sin = self.rope_cache
|
64
71
|
cos = cos.index_select(0, input_pos)
|
65
72
|
sin = sin.index_select(0, input_pos)
|
66
73
|
mask = self.mask_cache.index_select(2, input_pos)
|
67
74
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
75
|
+
|
76
|
+
updated_kv_entires = []
|
68
77
|
for i, block in enumerate(self.transformer_blocks):
|
69
|
-
|
78
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
79
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
80
|
+
if kv_entry:
|
81
|
+
updated_kv_entires.append(kv_entry)
|
82
|
+
|
70
83
|
x = self.final_norm(x)
|
71
|
-
|
84
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
85
|
+
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
|
72
86
|
|
73
87
|
|
74
88
|
def _export_stablehlo_mlir(model, args):
|
@@ -89,7 +103,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
89
103
|
config = cfg.ModelConfig(
|
90
104
|
vocab_size=150,
|
91
105
|
num_layers=2,
|
92
|
-
max_seq_len=
|
106
|
+
max_seq_len=100,
|
93
107
|
embedding_dim=128,
|
94
108
|
attn_config=attn_config,
|
95
109
|
ff_config=ff_config,
|
@@ -102,40 +116,59 @@ def get_model_config() -> cfg.ModelConfig:
|
|
102
116
|
|
103
117
|
|
104
118
|
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
105
|
-
|
119
|
+
tokens = torch.unsqueeze(torch.arange(0, 100), 0)
|
106
120
|
input_pos = torch.arange(0, 100)
|
107
|
-
return
|
121
|
+
return tokens, input_pos
|
108
122
|
|
109
123
|
|
110
124
|
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
111
|
-
|
112
|
-
input_pos = torch.tensor([10]
|
113
|
-
return
|
125
|
+
tokens = torch.tensor([[1]], dtype=torch.long)
|
126
|
+
input_pos = torch.tensor([10])
|
127
|
+
return tokens, input_pos
|
114
128
|
|
115
129
|
|
116
130
|
def define_and_run() -> None:
|
117
131
|
dump_mlir = False
|
118
132
|
|
119
133
|
config = get_model_config()
|
120
|
-
model =
|
134
|
+
model = ToyModelWithExternalKV(config)
|
135
|
+
model.eval()
|
121
136
|
print('running an inference')
|
122
|
-
|
123
|
-
|
124
|
-
|
137
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
138
|
+
|
139
|
+
tokens, input_pos = get_sample_prefill_inputs()
|
140
|
+
decode_token, decode_input_pos = get_sample_decode_inputs()
|
141
|
+
print(model.forward(tokens, input_pos, kv))
|
125
142
|
|
126
143
|
if dump_mlir:
|
127
|
-
mlir_text = _export_stablehlo_mlir(model, (
|
128
|
-
with open('/tmp/
|
144
|
+
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
145
|
+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
129
146
|
f.write(mlir_text)
|
130
147
|
|
131
148
|
# Convert model to tflite with 2 signatures (prefill + decode).
|
132
149
|
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
133
150
|
edge_model = (
|
134
|
-
ai_edge_torch.signature(
|
135
|
-
|
151
|
+
ai_edge_torch.signature(
|
152
|
+
'prefill',
|
153
|
+
model,
|
154
|
+
sample_kwargs={
|
155
|
+
'tokens': tokens,
|
156
|
+
'input_pos': input_pos,
|
157
|
+
'kv_cache': kv,
|
158
|
+
},
|
159
|
+
)
|
160
|
+
.signature(
|
161
|
+
'decode',
|
162
|
+
model,
|
163
|
+
sample_kwargs={
|
164
|
+
'tokens': decode_token,
|
165
|
+
'input_pos': decode_input_pos,
|
166
|
+
'kv_cache': kv,
|
167
|
+
},
|
168
|
+
)
|
136
169
|
.convert()
|
137
170
|
)
|
138
|
-
edge_model.export('/tmp/
|
171
|
+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
139
172
|
|
140
173
|
|
141
174
|
if __name__ == '__main__':
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting TinyLlama model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
@@ -48,20 +51,36 @@ def convert_tiny_llama_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
convert_tiny_llama_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
|
86
|
+
convert_tiny_llama_to_tflite(path)
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a TinyLlama model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -80,16 +82,22 @@ class TinyLLamma(nn.Module):
|
|
80
82
|
)
|
81
83
|
self.config = config
|
82
84
|
|
83
|
-
# The model's forward function takes in additional k/v cache tensors
|
84
|
-
# and returns the updated k/v cache tensors to the caller.
|
85
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
86
85
|
@torch.inference_mode
|
87
|
-
def forward(
|
88
|
-
|
86
|
+
def forward(
|
87
|
+
self,
|
88
|
+
tokens: torch.Tensor,
|
89
|
+
input_pos: torch.Tensor,
|
90
|
+
kv_cache: kv_utils.KVCache,
|
91
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
92
|
+
_, seq_len = tokens.size()
|
89
93
|
assert self.config.max_seq_len >= seq_len, (
|
90
94
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
91
95
|
f" {self.config.max_seq_len}"
|
92
96
|
)
|
97
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
98
|
+
"The number of transformer blocks and the number of KV cache entries"
|
99
|
+
" must be the same."
|
100
|
+
)
|
93
101
|
|
94
102
|
cos, sin = self.rope_cache
|
95
103
|
cos = cos.index_select(0, input_pos)
|
@@ -97,16 +105,20 @@ class TinyLLamma(nn.Module):
|
|
97
105
|
mask = self.mask_cache.index_select(2, input_pos)
|
98
106
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
99
107
|
|
100
|
-
#
|
101
|
-
x = self.tok_embedding(
|
108
|
+
# token embeddings of shape (b, t, n_embd)
|
109
|
+
x = self.tok_embedding(tokens)
|
102
110
|
|
103
|
-
|
104
|
-
|
111
|
+
updated_kv_entires = []
|
112
|
+
for i, block in enumerate(self.transformer_blocks):
|
113
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
114
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
115
|
+
if kv_entry:
|
116
|
+
updated_kv_entires.append(kv_entry)
|
117
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
105
118
|
|
106
119
|
x = self.final_norm(x)
|
107
|
-
|
108
|
-
|
109
|
-
return res
|
120
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
121
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
110
122
|
|
111
123
|
|
112
124
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -147,8 +159,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
147
159
|
return config
|
148
160
|
|
149
161
|
|
150
|
-
def get_fake_model_config() -> cfg.ModelConfig:
|
151
|
-
config = get_model_config()
|
162
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
163
|
+
config = get_model_config(**kwargs)
|
152
164
|
config.vocab_size = 128
|
153
165
|
config.num_layers = 2
|
154
166
|
config.ff_config.intermediate_size = 64
|
@@ -160,26 +172,30 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
160
172
|
model = TinyLLamma(config)
|
161
173
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
162
174
|
loader.load(model)
|
175
|
+
model.eval()
|
163
176
|
return model
|
164
177
|
|
165
178
|
|
166
|
-
def define_and_run() -> None:
|
179
|
+
def define_and_run(checkpoint_path: str) -> None:
|
167
180
|
"""Instantiates and runs a TinyLlama model."""
|
168
181
|
|
169
|
-
current_dir = Path(__file__).parent.resolve()
|
182
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
170
183
|
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
|
171
184
|
kv_cache_max_len = 1024
|
172
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
|
173
185
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
174
186
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
175
187
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
176
188
|
tokens[0, :4] = idx
|
177
189
|
input_pos = torch.arange(0, kv_cache_max_len)
|
178
|
-
|
190
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
191
|
+
output = model.forward(tokens, input_pos, kv)
|
179
192
|
assert torch.allclose(
|
180
|
-
tiny_llama_goldens,
|
193
|
+
tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
181
194
|
)
|
182
195
|
|
183
196
|
|
184
197
|
if __name__ == "__main__":
|
185
|
-
|
198
|
+
input_checkpoint_path = os.path.join(
|
199
|
+
pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
|
200
|
+
)
|
201
|
+
define_and_run(input_checkpoint_path)
|