ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_model_conversion.py +71 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +22 -32
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -12,14 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""A toy example which has basic transformer block (w/ externalized KV-Cache)."""
|
17
|
+
|
16
18
|
from typing import Tuple
|
17
19
|
|
18
20
|
import ai_edge_torch
|
19
21
|
from ai_edge_torch import lowertools
|
20
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.layers import attention
|
23
|
+
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
-
import ai_edge_torch.generative.layers.builder as builder
|
23
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
27
|
import torch
|
25
28
|
import torch.nn as nn
|
@@ -27,7 +30,7 @@ import torch.nn as nn
|
|
27
30
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
28
31
|
|
29
32
|
|
30
|
-
class
|
33
|
+
class ToyModelWithKVCache(torch.nn.Module):
|
31
34
|
|
32
35
|
def __init__(self, config: cfg.ModelConfig) -> None:
|
33
36
|
super().__init__()
|
@@ -36,7 +39,7 @@ class ToyModelWithKV(torch.nn.Module):
|
|
36
39
|
)
|
37
40
|
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
38
41
|
self.transformer_blocks = nn.ModuleList(
|
39
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
42
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
40
43
|
)
|
41
44
|
self.final_norm = builder.build_norm(
|
42
45
|
config.embedding_dim,
|
@@ -57,18 +60,29 @@ class ToyModelWithKV(torch.nn.Module):
|
|
57
60
|
)
|
58
61
|
self.config = config
|
59
62
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
+
def forward(
|
64
|
+
self,
|
65
|
+
tokens: torch.Tensor,
|
66
|
+
input_pos: torch.Tensor,
|
67
|
+
kv_cache: kv_utils.KVCache,
|
68
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCache]:
|
69
|
+
x = self.tok_embedding(tokens)
|
63
70
|
cos, sin = self.rope_cache
|
64
71
|
cos = cos.index_select(0, input_pos)
|
65
72
|
sin = sin.index_select(0, input_pos)
|
66
73
|
mask = self.mask_cache.index_select(2, input_pos)
|
67
74
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
75
|
+
|
76
|
+
updated_kv_entires = []
|
68
77
|
for i, block in enumerate(self.transformer_blocks):
|
69
|
-
|
78
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
79
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
80
|
+
if kv_entry:
|
81
|
+
updated_kv_entires.append(kv_entry)
|
82
|
+
|
70
83
|
x = self.final_norm(x)
|
71
|
-
|
84
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
85
|
+
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
|
72
86
|
|
73
87
|
|
74
88
|
def _export_stablehlo_mlir(model, args):
|
@@ -89,7 +103,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
89
103
|
config = cfg.ModelConfig(
|
90
104
|
vocab_size=150,
|
91
105
|
num_layers=2,
|
92
|
-
max_seq_len=
|
106
|
+
max_seq_len=100,
|
93
107
|
embedding_dim=128,
|
94
108
|
attn_config=attn_config,
|
95
109
|
ff_config=ff_config,
|
@@ -102,40 +116,59 @@ def get_model_config() -> cfg.ModelConfig:
|
|
102
116
|
|
103
117
|
|
104
118
|
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
105
|
-
|
119
|
+
tokens = torch.unsqueeze(torch.arange(0, 100), 0)
|
106
120
|
input_pos = torch.arange(0, 100)
|
107
|
-
return
|
121
|
+
return tokens, input_pos
|
108
122
|
|
109
123
|
|
110
124
|
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
111
|
-
|
112
|
-
input_pos = torch.tensor([10]
|
113
|
-
return
|
125
|
+
tokens = torch.tensor([[1]], dtype=torch.long)
|
126
|
+
input_pos = torch.tensor([10])
|
127
|
+
return tokens, input_pos
|
114
128
|
|
115
129
|
|
116
130
|
def define_and_run() -> None:
|
117
131
|
dump_mlir = False
|
118
132
|
|
119
133
|
config = get_model_config()
|
120
|
-
model =
|
134
|
+
model = ToyModelWithExternalKV(config)
|
135
|
+
model.eval()
|
121
136
|
print('running an inference')
|
122
|
-
|
123
|
-
|
124
|
-
|
137
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
138
|
+
|
139
|
+
tokens, input_pos = get_sample_prefill_inputs()
|
140
|
+
decode_token, decode_input_pos = get_sample_decode_inputs()
|
141
|
+
print(model.forward(tokens, input_pos, kv))
|
125
142
|
|
126
143
|
if dump_mlir:
|
127
|
-
mlir_text = _export_stablehlo_mlir(model, (
|
128
|
-
with open('/tmp/
|
144
|
+
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
145
|
+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
129
146
|
f.write(mlir_text)
|
130
147
|
|
131
148
|
# Convert model to tflite with 2 signatures (prefill + decode).
|
132
149
|
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
133
150
|
edge_model = (
|
134
|
-
ai_edge_torch.signature(
|
135
|
-
|
151
|
+
ai_edge_torch.signature(
|
152
|
+
'prefill',
|
153
|
+
model,
|
154
|
+
sample_kwargs={
|
155
|
+
'tokens': tokens,
|
156
|
+
'input_pos': input_pos,
|
157
|
+
'kv_cache': kv,
|
158
|
+
},
|
159
|
+
)
|
160
|
+
.signature(
|
161
|
+
'decode',
|
162
|
+
model,
|
163
|
+
sample_kwargs={
|
164
|
+
'tokens': decode_token,
|
165
|
+
'input_pos': decode_input_pos,
|
166
|
+
'kv_cache': kv,
|
167
|
+
},
|
168
|
+
)
|
136
169
|
.convert()
|
137
170
|
)
|
138
|
-
edge_model.export('/tmp/
|
171
|
+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
139
172
|
|
140
173
|
|
141
174
|
if __name__ == '__main__':
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting TinyLlama model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
@@ -48,20 +51,36 @@ def convert_tiny_llama_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
convert_tiny_llama_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
|
86
|
+
convert_tiny_llama_to_tflite(path)
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a TinyLlama model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -80,16 +82,22 @@ class TinyLLamma(nn.Module):
|
|
80
82
|
)
|
81
83
|
self.config = config
|
82
84
|
|
83
|
-
# The model's forward function takes in additional k/v cache tensors
|
84
|
-
# and returns the updated k/v cache tensors to the caller.
|
85
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
86
85
|
@torch.inference_mode
|
87
|
-
def forward(
|
88
|
-
|
86
|
+
def forward(
|
87
|
+
self,
|
88
|
+
tokens: torch.Tensor,
|
89
|
+
input_pos: torch.Tensor,
|
90
|
+
kv_cache: kv_utils.KVCache,
|
91
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
92
|
+
_, seq_len = tokens.size()
|
89
93
|
assert self.config.max_seq_len >= seq_len, (
|
90
94
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
91
95
|
f" {self.config.max_seq_len}"
|
92
96
|
)
|
97
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
98
|
+
"The number of transformer blocks and the number of KV cache entries"
|
99
|
+
" must be the same."
|
100
|
+
)
|
93
101
|
|
94
102
|
cos, sin = self.rope_cache
|
95
103
|
cos = cos.index_select(0, input_pos)
|
@@ -97,16 +105,20 @@ class TinyLLamma(nn.Module):
|
|
97
105
|
mask = self.mask_cache.index_select(2, input_pos)
|
98
106
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
99
107
|
|
100
|
-
#
|
101
|
-
x = self.tok_embedding(
|
108
|
+
# token embeddings of shape (b, t, n_embd)
|
109
|
+
x = self.tok_embedding(tokens)
|
102
110
|
|
103
|
-
|
104
|
-
|
111
|
+
updated_kv_entires = []
|
112
|
+
for i, block in enumerate(self.transformer_blocks):
|
113
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
114
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
115
|
+
if kv_entry:
|
116
|
+
updated_kv_entires.append(kv_entry)
|
117
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
105
118
|
|
106
119
|
x = self.final_norm(x)
|
107
|
-
|
108
|
-
|
109
|
-
return res
|
120
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
121
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
110
122
|
|
111
123
|
|
112
124
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -147,8 +159,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
147
159
|
return config
|
148
160
|
|
149
161
|
|
150
|
-
def get_fake_model_config() -> cfg.ModelConfig:
|
151
|
-
config = get_model_config()
|
162
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
163
|
+
config = get_model_config(**kwargs)
|
152
164
|
config.vocab_size = 128
|
153
165
|
config.num_layers = 2
|
154
166
|
config.ff_config.intermediate_size = 64
|
@@ -160,26 +172,30 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
160
172
|
model = TinyLLamma(config)
|
161
173
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
162
174
|
loader.load(model)
|
175
|
+
model.eval()
|
163
176
|
return model
|
164
177
|
|
165
178
|
|
166
|
-
def define_and_run() -> None:
|
179
|
+
def define_and_run(checkpoint_path: str) -> None:
|
167
180
|
"""Instantiates and runs a TinyLlama model."""
|
168
181
|
|
169
|
-
current_dir = Path(__file__).parent.resolve()
|
182
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
170
183
|
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
|
171
184
|
kv_cache_max_len = 1024
|
172
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
|
173
185
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
174
186
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
175
187
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
176
188
|
tokens[0, :4] = idx
|
177
189
|
input_pos = torch.arange(0, kv_cache_max_len)
|
178
|
-
|
190
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
191
|
+
output = model.forward(tokens, input_pos, kv)
|
179
192
|
assert torch.allclose(
|
180
|
-
tiny_llama_goldens,
|
193
|
+
tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
181
194
|
)
|
182
195
|
|
183
196
|
|
184
197
|
if __name__ == "__main__":
|
185
|
-
|
198
|
+
input_checkpoint_path = os.path.join(
|
199
|
+
pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
|
200
|
+
)
|
201
|
+
define_and_run(input_checkpoint_path)
|
@@ -12,16 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# Common building blocks for Attention layer.
|
16
15
|
|
17
|
-
|
16
|
+
"""Common building blocks for Attention layer."""
|
18
17
|
|
19
|
-
import
|
20
|
-
|
18
|
+
from typing import Optional, Tuple, Union
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.layers import builder
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
21
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
24
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
23
|
-
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
|
24
|
-
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
25
25
|
import torch
|
26
26
|
from torch import nn
|
27
27
|
|
@@ -62,7 +62,6 @@ class TransformerBlock(nn.Module):
|
|
62
62
|
config (cfg.ModelConfig): the configuration object for this transformer
|
63
63
|
block.
|
64
64
|
"""
|
65
|
-
|
66
65
|
super().__init__()
|
67
66
|
self.pre_atten_norm = builder.build_norm(
|
68
67
|
config.embedding_dim, config.pre_attention_norm_config
|
@@ -71,7 +70,6 @@ class TransformerBlock(nn.Module):
|
|
71
70
|
config.batch_size,
|
72
71
|
config.embedding_dim,
|
73
72
|
config.attn_config,
|
74
|
-
config.kv_cache_max,
|
75
73
|
config.enable_hlfb,
|
76
74
|
)
|
77
75
|
self.post_atten_norm = builder.build_norm(
|
@@ -86,7 +84,8 @@ class TransformerBlock(nn.Module):
|
|
86
84
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
87
85
|
mask: Optional[torch.Tensor] = None,
|
88
86
|
input_pos: Optional[torch.Tensor] = None,
|
89
|
-
|
87
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
88
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
90
89
|
"""Forward function of the TransformerBlock.
|
91
90
|
|
92
91
|
Args:
|
@@ -94,24 +93,34 @@ class TransformerBlock(nn.Module):
|
|
94
93
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
95
94
|
mask (torch.Tensor): the optional mask tensor.
|
96
95
|
input_pos (torch.Tensor): the optional input position tensor.
|
96
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
97
97
|
|
98
98
|
Returns:
|
99
|
-
output activation from this transformer block
|
99
|
+
output activation from this transformer block, and updated kv cache (if
|
100
|
+
passed in).
|
100
101
|
"""
|
101
|
-
|
102
|
+
kv = None
|
102
103
|
if self.config.parallel_residual:
|
103
104
|
x_norm = self.pre_atten_norm(x)
|
104
|
-
|
105
|
+
atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
106
|
+
if kv_cache is None:
|
107
|
+
attn_out = atten_func_out
|
108
|
+
else:
|
109
|
+
attn_out, kv = atten_func_out
|
105
110
|
ff_out = self.ff(x_norm)
|
106
111
|
output = x + attn_out + ff_out
|
107
112
|
else:
|
108
113
|
x_norm = self.pre_atten_norm(x)
|
109
|
-
|
114
|
+
atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
115
|
+
if kv_cache is None:
|
116
|
+
attn_out = atten_func_out
|
117
|
+
else:
|
118
|
+
attn_out, kv = atten_func_out
|
110
119
|
x = x + attn_out
|
111
120
|
x_norm = self.post_atten_norm(x)
|
112
121
|
output = x + self.ff(x_norm)
|
113
122
|
|
114
|
-
return output
|
123
|
+
return output if kv is None else (output, kv)
|
115
124
|
|
116
125
|
|
117
126
|
class CausalSelfAttention(nn.Module):
|
@@ -121,7 +130,6 @@ class CausalSelfAttention(nn.Module):
|
|
121
130
|
batch_size: int,
|
122
131
|
dim: int,
|
123
132
|
config: cfg.AttentionConfig,
|
124
|
-
kv_cache_max: int,
|
125
133
|
enable_hlfb: bool,
|
126
134
|
) -> None:
|
127
135
|
"""Initialize an instance of CausalSelfAttention.
|
@@ -130,8 +138,6 @@ class CausalSelfAttention(nn.Module):
|
|
130
138
|
batch_size (int): batch size of the input tensor.
|
131
139
|
dim (int): causal attention's input/output dimmension.
|
132
140
|
config (cfg.AttentionConfig): attention specific configurations.
|
133
|
-
kv_cache_max (int): determines the size of the KV Cache buffer, if
|
134
|
-
enabled.
|
135
141
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
136
142
|
"""
|
137
143
|
super().__init__()
|
@@ -147,21 +153,13 @@ class CausalSelfAttention(nn.Module):
|
|
147
153
|
self.output_projection = nn.Linear(
|
148
154
|
output_shape, dim, bias=config.output_proj_use_bias
|
149
155
|
)
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
config.head_dim,
|
158
|
-
enable_hlfb,
|
159
|
-
)
|
160
|
-
|
161
|
-
if enable_hlfb:
|
162
|
-
self.sdpa_func = scaled_dot_product_attention_with_hlfb
|
163
|
-
else:
|
164
|
-
self.sdpa_func = scaled_dot_product_attention
|
156
|
+
self.config = config
|
157
|
+
self.enable_hlfb = enable_hlfb
|
158
|
+
self.sdpa_func = (
|
159
|
+
sdpa.scaled_dot_product_attention_with_hlfb
|
160
|
+
if enable_hlfb
|
161
|
+
else sdpa.scaled_dot_product_attention
|
162
|
+
)
|
165
163
|
|
166
164
|
def forward(
|
167
165
|
self,
|
@@ -169,7 +167,8 @@ class CausalSelfAttention(nn.Module):
|
|
169
167
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
170
168
|
mask: Optional[torch.Tensor] = None,
|
171
169
|
input_pos: Optional[torch.Tensor] = None,
|
172
|
-
|
170
|
+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
171
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
173
172
|
"""Forward function of the CausalSelfAttention layer, which can support
|
174
173
|
|
175
174
|
MQA, GQA and MHA.
|
@@ -179,9 +178,11 @@ class CausalSelfAttention(nn.Module):
|
|
179
178
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
180
179
|
mask (torch.Tensor): the optional mask tensor.
|
181
180
|
input_pos (torch.Tensor): the optional input position tensor.
|
181
|
+
kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
|
182
182
|
|
183
183
|
Returns:
|
184
|
-
output activation from this self attention layer
|
184
|
+
output activation from this self attention layer, and the updated
|
185
|
+
KV Cach Entry (if passed in).
|
185
186
|
"""
|
186
187
|
# Batch size, sequence length, embedding dimensionality.
|
187
188
|
B, T, E = x.size()
|
@@ -224,9 +225,11 @@ class CausalSelfAttention(nn.Module):
|
|
224
225
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
225
226
|
q, k = _embed_rope(q, k, n_elem, rope)
|
226
227
|
|
227
|
-
if
|
228
|
-
|
229
|
-
|
228
|
+
if kv_cache is not None:
|
229
|
+
kv_cache = kv_utils.update(
|
230
|
+
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
231
|
+
)
|
232
|
+
k, v = kv_cache.k_cache, kv_cache.v_cache
|
230
233
|
|
231
234
|
y = self.sdpa_func(
|
232
235
|
q,
|
@@ -240,7 +243,7 @@ class CausalSelfAttention(nn.Module):
|
|
240
243
|
|
241
244
|
# Compute the output projection.
|
242
245
|
y = self.output_projection(y)
|
243
|
-
return y
|
246
|
+
return y if kv_cache is None else (y, kv_cache)
|
244
247
|
|
245
248
|
|
246
249
|
class SelfAttention(CausalSelfAttention):
|
@@ -251,16 +254,19 @@ class SelfAttention(CausalSelfAttention):
|
|
251
254
|
x: torch.Tensor,
|
252
255
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
253
256
|
input_pos: Optional[torch.Tensor] = None,
|
254
|
-
|
257
|
+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
258
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
255
259
|
"""Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
|
256
260
|
|
257
261
|
Args:
|
258
262
|
x (torch.Tensor): the input tensor.
|
259
263
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
260
264
|
input_pos (torch.Tensor): the optional input position tensor.
|
265
|
+
kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
|
261
266
|
|
262
267
|
Returns:
|
263
|
-
output activation from this self attention layer
|
268
|
+
output activation from this self attention layer, and the updated
|
269
|
+
KV Cach Entry (if passed in).
|
264
270
|
"""
|
265
271
|
B, T, _ = x.size()
|
266
272
|
return super().forward(
|
@@ -279,9 +285,8 @@ class CrossAttention(nn.Module):
|
|
279
285
|
query_dim: int,
|
280
286
|
cross_dim: int,
|
281
287
|
config: cfg.AttentionConfig,
|
282
|
-
kv_cache_max: int,
|
283
288
|
enable_hlfb: bool,
|
284
|
-
)
|
289
|
+
):
|
285
290
|
"""Initialize an instance of CrossAttention.
|
286
291
|
|
287
292
|
Args:
|
@@ -289,8 +294,6 @@ class CrossAttention(nn.Module):
|
|
289
294
|
query_dim (int): query tensor's dimension.
|
290
295
|
cross_dim (int): cross attention's dimensions, for key and value tensors.
|
291
296
|
config (cfg.AttentionConfig): attention specific configurations.
|
292
|
-
kv_cache_max (int): determines the size of the KV Cache buffer, if
|
293
|
-
enabled.
|
294
297
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
295
298
|
"""
|
296
299
|
super().__init__()
|
@@ -309,21 +312,11 @@ class CrossAttention(nn.Module):
|
|
309
312
|
query_dim, query_dim, bias=config.output_proj_use_bias
|
310
313
|
)
|
311
314
|
|
312
|
-
self.
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
kv_cache_max,
|
318
|
-
config.num_query_groups,
|
319
|
-
self.config.head_dim,
|
320
|
-
enable_hlfb,
|
321
|
-
)
|
322
|
-
|
323
|
-
if enable_hlfb:
|
324
|
-
self.sdpa_func = scaled_dot_product_attention_with_hlfb
|
325
|
-
else:
|
326
|
-
self.sdpa_func = scaled_dot_product_attention
|
315
|
+
self.sdpa_func = (
|
316
|
+
sdpa.scaled_dot_product_attention_with_hlfb
|
317
|
+
if enable_hlfb
|
318
|
+
else sdpa.scaled_dot_product_attention
|
319
|
+
)
|
327
320
|
|
328
321
|
def forward(
|
329
322
|
self,
|
@@ -332,6 +325,7 @@ class CrossAttention(nn.Module):
|
|
332
325
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
333
326
|
mask: Optional[torch.Tensor] = None,
|
334
327
|
input_pos: Optional[torch.Tensor] = None,
|
328
|
+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
335
329
|
):
|
336
330
|
"""Forward function of the CrossAttention layer.
|
337
331
|
|
@@ -342,6 +336,7 @@ class CrossAttention(nn.Module):
|
|
342
336
|
mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
|
343
337
|
[B, n_heads, target_seq_len, source_seq_len].
|
344
338
|
input_pos (torch.Tensor): the optional input position tensor.
|
339
|
+
kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
|
345
340
|
|
346
341
|
Returns:
|
347
342
|
output activation from this cross attention layer.
|
@@ -363,9 +358,11 @@ class CrossAttention(nn.Module):
|
|
363
358
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
364
359
|
q, k = _embed_rope(q, k, n_elem, rope)
|
365
360
|
|
366
|
-
if
|
367
|
-
|
368
|
-
|
361
|
+
if kv_cache is not None:
|
362
|
+
kv_cache = kv_utils.update(
|
363
|
+
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
364
|
+
)
|
365
|
+
k, v = kv_cache.k_cache, kv_cache.v_cache
|
369
366
|
if mask is None:
|
370
367
|
mask = torch.zeros(
|
371
368
|
(batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
|
@@ -375,4 +372,4 @@ class CrossAttention(nn.Module):
|
|
375
372
|
|
376
373
|
# Compute the output projection.
|
377
374
|
y = self.output_projection(y)
|
378
|
-
return y
|
375
|
+
return y if kv_cache is None else (y, kv_cache)
|