ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +35 -13
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_model_conversion.py +71 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +25 -35
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# Common building blocks for Attention layer.
|
16
15
|
|
17
|
-
|
16
|
+
"""Common building blocks for Attention layer."""
|
18
17
|
|
19
|
-
import
|
20
|
-
|
18
|
+
from typing import Optional, Tuple, Union
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.layers import builder
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
21
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
24
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
23
|
-
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
|
24
|
-
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
25
25
|
import torch
|
26
26
|
from torch import nn
|
27
27
|
|
@@ -62,7 +62,6 @@ class TransformerBlock(nn.Module):
|
|
62
62
|
config (cfg.ModelConfig): the configuration object for this transformer
|
63
63
|
block.
|
64
64
|
"""
|
65
|
-
|
66
65
|
super().__init__()
|
67
66
|
self.pre_atten_norm = builder.build_norm(
|
68
67
|
config.embedding_dim, config.pre_attention_norm_config
|
@@ -71,7 +70,6 @@ class TransformerBlock(nn.Module):
|
|
71
70
|
config.batch_size,
|
72
71
|
config.embedding_dim,
|
73
72
|
config.attn_config,
|
74
|
-
config.kv_cache_max,
|
75
73
|
config.enable_hlfb,
|
76
74
|
)
|
77
75
|
self.post_atten_norm = builder.build_norm(
|
@@ -86,7 +84,8 @@ class TransformerBlock(nn.Module):
|
|
86
84
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
87
85
|
mask: Optional[torch.Tensor] = None,
|
88
86
|
input_pos: Optional[torch.Tensor] = None,
|
89
|
-
|
87
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
88
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
90
89
|
"""Forward function of the TransformerBlock.
|
91
90
|
|
92
91
|
Args:
|
@@ -94,24 +93,34 @@ class TransformerBlock(nn.Module):
|
|
94
93
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
95
94
|
mask (torch.Tensor): the optional mask tensor.
|
96
95
|
input_pos (torch.Tensor): the optional input position tensor.
|
96
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
97
97
|
|
98
98
|
Returns:
|
99
|
-
output activation from this transformer block
|
99
|
+
output activation from this transformer block, and updated kv cache (if
|
100
|
+
passed in).
|
100
101
|
"""
|
101
|
-
|
102
|
+
kv = None
|
102
103
|
if self.config.parallel_residual:
|
103
104
|
x_norm = self.pre_atten_norm(x)
|
104
|
-
|
105
|
+
atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
106
|
+
if kv_cache is None:
|
107
|
+
attn_out = atten_func_out
|
108
|
+
else:
|
109
|
+
attn_out, kv = atten_func_out
|
105
110
|
ff_out = self.ff(x_norm)
|
106
111
|
output = x + attn_out + ff_out
|
107
112
|
else:
|
108
113
|
x_norm = self.pre_atten_norm(x)
|
109
|
-
|
114
|
+
atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
115
|
+
if kv_cache is None:
|
116
|
+
attn_out = atten_func_out
|
117
|
+
else:
|
118
|
+
attn_out, kv = atten_func_out
|
110
119
|
x = x + attn_out
|
111
120
|
x_norm = self.post_atten_norm(x)
|
112
121
|
output = x + self.ff(x_norm)
|
113
122
|
|
114
|
-
return output
|
123
|
+
return output if kv is None else (output, kv)
|
115
124
|
|
116
125
|
|
117
126
|
class CausalSelfAttention(nn.Module):
|
@@ -121,7 +130,6 @@ class CausalSelfAttention(nn.Module):
|
|
121
130
|
batch_size: int,
|
122
131
|
dim: int,
|
123
132
|
config: cfg.AttentionConfig,
|
124
|
-
kv_cache_max: int,
|
125
133
|
enable_hlfb: bool,
|
126
134
|
) -> None:
|
127
135
|
"""Initialize an instance of CausalSelfAttention.
|
@@ -130,8 +138,6 @@ class CausalSelfAttention(nn.Module):
|
|
130
138
|
batch_size (int): batch size of the input tensor.
|
131
139
|
dim (int): causal attention's input/output dimmension.
|
132
140
|
config (cfg.AttentionConfig): attention specific configurations.
|
133
|
-
kv_cache_max (int): determines the size of the KV Cache buffer, if
|
134
|
-
enabled.
|
135
141
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
136
142
|
"""
|
137
143
|
super().__init__()
|
@@ -147,21 +153,13 @@ class CausalSelfAttention(nn.Module):
|
|
147
153
|
self.output_projection = nn.Linear(
|
148
154
|
output_shape, dim, bias=config.output_proj_use_bias
|
149
155
|
)
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
config.head_dim,
|
158
|
-
enable_hlfb,
|
159
|
-
)
|
160
|
-
|
161
|
-
if enable_hlfb:
|
162
|
-
self.sdpa_func = scaled_dot_product_attention_with_hlfb
|
163
|
-
else:
|
164
|
-
self.sdpa_func = scaled_dot_product_attention
|
156
|
+
self.config = config
|
157
|
+
self.enable_hlfb = enable_hlfb
|
158
|
+
self.sdpa_func = (
|
159
|
+
sdpa.scaled_dot_product_attention_with_hlfb
|
160
|
+
if enable_hlfb
|
161
|
+
else sdpa.scaled_dot_product_attention
|
162
|
+
)
|
165
163
|
|
166
164
|
def forward(
|
167
165
|
self,
|
@@ -169,7 +167,8 @@ class CausalSelfAttention(nn.Module):
|
|
169
167
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
170
168
|
mask: Optional[torch.Tensor] = None,
|
171
169
|
input_pos: Optional[torch.Tensor] = None,
|
172
|
-
|
170
|
+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
171
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
173
172
|
"""Forward function of the CausalSelfAttention layer, which can support
|
174
173
|
|
175
174
|
MQA, GQA and MHA.
|
@@ -179,9 +178,11 @@ class CausalSelfAttention(nn.Module):
|
|
179
178
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
180
179
|
mask (torch.Tensor): the optional mask tensor.
|
181
180
|
input_pos (torch.Tensor): the optional input position tensor.
|
181
|
+
kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
|
182
182
|
|
183
183
|
Returns:
|
184
|
-
output activation from this self attention layer
|
184
|
+
output activation from this self attention layer, and the updated
|
185
|
+
KV Cach Entry (if passed in).
|
185
186
|
"""
|
186
187
|
# Batch size, sequence length, embedding dimensionality.
|
187
188
|
B, T, E = x.size()
|
@@ -224,9 +225,11 @@ class CausalSelfAttention(nn.Module):
|
|
224
225
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
225
226
|
q, k = _embed_rope(q, k, n_elem, rope)
|
226
227
|
|
227
|
-
if
|
228
|
-
|
229
|
-
|
228
|
+
if kv_cache is not None:
|
229
|
+
kv_cache = kv_utils.update(
|
230
|
+
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
231
|
+
)
|
232
|
+
k, v = kv_cache.k_cache, kv_cache.v_cache
|
230
233
|
|
231
234
|
y = self.sdpa_func(
|
232
235
|
q,
|
@@ -240,7 +243,7 @@ class CausalSelfAttention(nn.Module):
|
|
240
243
|
|
241
244
|
# Compute the output projection.
|
242
245
|
y = self.output_projection(y)
|
243
|
-
return y
|
246
|
+
return y if kv_cache is None else (y, kv_cache)
|
244
247
|
|
245
248
|
|
246
249
|
class SelfAttention(CausalSelfAttention):
|
@@ -251,16 +254,19 @@ class SelfAttention(CausalSelfAttention):
|
|
251
254
|
x: torch.Tensor,
|
252
255
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
253
256
|
input_pos: Optional[torch.Tensor] = None,
|
254
|
-
|
257
|
+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
258
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
255
259
|
"""Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
|
256
260
|
|
257
261
|
Args:
|
258
262
|
x (torch.Tensor): the input tensor.
|
259
263
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
260
264
|
input_pos (torch.Tensor): the optional input position tensor.
|
265
|
+
kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
|
261
266
|
|
262
267
|
Returns:
|
263
|
-
output activation from this self attention layer
|
268
|
+
output activation from this self attention layer, and the updated
|
269
|
+
KV Cach Entry (if passed in).
|
264
270
|
"""
|
265
271
|
B, T, _ = x.size()
|
266
272
|
return super().forward(
|
@@ -279,9 +285,8 @@ class CrossAttention(nn.Module):
|
|
279
285
|
query_dim: int,
|
280
286
|
cross_dim: int,
|
281
287
|
config: cfg.AttentionConfig,
|
282
|
-
kv_cache_max: int,
|
283
288
|
enable_hlfb: bool,
|
284
|
-
)
|
289
|
+
):
|
285
290
|
"""Initialize an instance of CrossAttention.
|
286
291
|
|
287
292
|
Args:
|
@@ -289,8 +294,6 @@ class CrossAttention(nn.Module):
|
|
289
294
|
query_dim (int): query tensor's dimension.
|
290
295
|
cross_dim (int): cross attention's dimensions, for key and value tensors.
|
291
296
|
config (cfg.AttentionConfig): attention specific configurations.
|
292
|
-
kv_cache_max (int): determines the size of the KV Cache buffer, if
|
293
|
-
enabled.
|
294
297
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
295
298
|
"""
|
296
299
|
super().__init__()
|
@@ -309,21 +312,11 @@ class CrossAttention(nn.Module):
|
|
309
312
|
query_dim, query_dim, bias=config.output_proj_use_bias
|
310
313
|
)
|
311
314
|
|
312
|
-
self.
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
kv_cache_max,
|
318
|
-
config.num_query_groups,
|
319
|
-
self.config.head_dim,
|
320
|
-
enable_hlfb,
|
321
|
-
)
|
322
|
-
|
323
|
-
if enable_hlfb:
|
324
|
-
self.sdpa_func = scaled_dot_product_attention_with_hlfb
|
325
|
-
else:
|
326
|
-
self.sdpa_func = scaled_dot_product_attention
|
315
|
+
self.sdpa_func = (
|
316
|
+
sdpa.scaled_dot_product_attention_with_hlfb
|
317
|
+
if enable_hlfb
|
318
|
+
else sdpa.scaled_dot_product_attention
|
319
|
+
)
|
327
320
|
|
328
321
|
def forward(
|
329
322
|
self,
|
@@ -332,6 +325,7 @@ class CrossAttention(nn.Module):
|
|
332
325
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
333
326
|
mask: Optional[torch.Tensor] = None,
|
334
327
|
input_pos: Optional[torch.Tensor] = None,
|
328
|
+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
335
329
|
):
|
336
330
|
"""Forward function of the CrossAttention layer.
|
337
331
|
|
@@ -342,6 +336,7 @@ class CrossAttention(nn.Module):
|
|
342
336
|
mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
|
343
337
|
[B, n_heads, target_seq_len, source_seq_len].
|
344
338
|
input_pos (torch.Tensor): the optional input position tensor.
|
339
|
+
kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
|
345
340
|
|
346
341
|
Returns:
|
347
342
|
output activation from this cross attention layer.
|
@@ -363,9 +358,11 @@ class CrossAttention(nn.Module):
|
|
363
358
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
364
359
|
q, k = _embed_rope(q, k, n_elem, rope)
|
365
360
|
|
366
|
-
if
|
367
|
-
|
368
|
-
|
361
|
+
if kv_cache is not None:
|
362
|
+
kv_cache = kv_utils.update(
|
363
|
+
kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
|
364
|
+
)
|
365
|
+
k, v = kv_cache.k_cache, kv_cache.v_cache
|
369
366
|
if mask is None:
|
370
367
|
mask = torch.zeros(
|
371
368
|
(batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
|
@@ -375,4 +372,4 @@ class CrossAttention(nn.Module):
|
|
375
372
|
|
376
373
|
# Compute the output projection.
|
377
374
|
y = self.output_projection(y)
|
378
|
-
return y
|
375
|
+
return y if kv_cache is None else (y, kv_cache)
|
@@ -12,72 +12,181 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# `nn.Module` which implements a KV cache.
|
16
15
|
|
17
|
-
|
16
|
+
"""Utility functions for externalized KV Cache."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
from typing import List, Tuple
|
20
|
+
|
21
|
+
from ai_edge_torch import hlfb
|
22
|
+
from ai_edge_torch.generative.layers import model_config
|
18
23
|
import torch
|
19
|
-
|
24
|
+
import torch.utils._pytree as pytree
|
20
25
|
|
21
26
|
|
22
|
-
|
27
|
+
@dataclasses.dataclass
|
28
|
+
class KVCacheEntry:
|
29
|
+
"""A single cache entry that includes K and V caches.
|
23
30
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
"""Initializes the KVCache layer.
|
31
|
+
The chaches are built based on the provided config with the shape of
|
32
|
+
(batch_size=1, kv_cache_max, num_query_groups, head_dim).
|
33
|
+
"""
|
28
34
|
|
29
|
-
|
30
|
-
|
31
|
-
kv_cache_max (int): the max length of KV cache.
|
32
|
-
n_heads (int): number of kv heads.
|
33
|
-
head_dim (int): the head dimension size.
|
34
|
-
enable_hlfb (bool): whether hlfb is enabled or not.
|
35
|
-
"""
|
36
|
-
super().__init__()
|
37
|
-
cache_shape = (batch_size, kv_cache_max, n_heads, head_dim)
|
38
|
-
self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False)
|
39
|
-
self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False)
|
40
|
-
self.enable_hlfb = enable_hlfb
|
41
|
-
self.kv_cache_max = kv_cache_max
|
35
|
+
k_cache: torch.Tensor
|
36
|
+
v_cache: torch.Tensor
|
42
37
|
|
43
|
-
|
44
|
-
|
38
|
+
@classmethod
|
39
|
+
def from_model_config(
|
40
|
+
cls,
|
41
|
+
config: model_config.ModelConfig,
|
42
|
+
dtype: torch.dtype = torch.float32,
|
43
|
+
device: torch.device = None,
|
44
|
+
) -> "KVCacheEntry":
|
45
|
+
"""Build an instance of the class based on model config."""
|
46
|
+
shape = (
|
47
|
+
1, # Batch dimmension.
|
48
|
+
config.kv_cache_max,
|
49
|
+
config.attn_config.num_query_groups,
|
50
|
+
config.attn_config.head_dim,
|
51
|
+
)
|
52
|
+
k = torch.zeros(shape, dtype=dtype, device=device)
|
53
|
+
v = torch.zeros(shape, dtype=dtype, device=device)
|
54
|
+
obj = cls(k_cache=k, v_cache=v)
|
55
|
+
return obj
|
45
56
|
|
46
|
-
Args:
|
47
|
-
input_pos (torch.Tensor): the input position.
|
48
|
-
k_val (torch.Tensor): the new `key` value.
|
49
|
-
v_val (torch.Tensor): the new `value` value.
|
50
57
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
if self.enable_hlfb:
|
55
|
-
return self.update_cache_with_hlfb(input_pos, k_val, v_val)
|
58
|
+
@dataclasses.dataclass
|
59
|
+
class KVCache:
|
60
|
+
"""A utility class for holding KV cache entries per layer."""
|
56
61
|
|
57
|
-
|
58
|
-
updated_v = self.v_cache.index_copy_(1, input_pos, v_val)
|
59
|
-
# Here we need a clone otherwise dynamo export will fail.
|
60
|
-
return torch.clone(updated_k), torch.clone(updated_v)
|
62
|
+
caches: Tuple[KVCacheEntry, ...]
|
61
63
|
|
62
|
-
|
63
|
-
|
64
|
+
@classmethod
|
65
|
+
def from_model_config(
|
66
|
+
cls,
|
67
|
+
config: model_config.ModelConfig,
|
68
|
+
dtype: torch.dtype = torch.float32,
|
69
|
+
device: torch.device = None,
|
70
|
+
) -> "KVCache":
|
71
|
+
"""Build an instance of the class based on model config.
|
64
72
|
|
65
73
|
Args:
|
66
|
-
|
67
|
-
|
68
|
-
|
74
|
+
config (ModelConfig): Model config used for building the cache.
|
75
|
+
dtype (torch.dtype, optional): The data type of the cache tensor.
|
76
|
+
Defaults to torch.float32.
|
77
|
+
device (torch.device, optional): The device placement of the cache
|
78
|
+
tensors. Defaults to None.
|
69
79
|
|
70
80
|
Returns:
|
71
|
-
|
81
|
+
KVCache: The created cache object.
|
72
82
|
"""
|
83
|
+
caches = [
|
84
|
+
KVCacheEntry.from_model_config(config, dtype, device)
|
85
|
+
for _ in range(config.num_layers)
|
86
|
+
]
|
87
|
+
obj = cls(caches=tuple(caches))
|
88
|
+
return obj
|
73
89
|
|
74
|
-
|
75
|
-
|
76
|
-
)
|
77
|
-
|
78
|
-
|
90
|
+
def flatten(self) -> List[torch.Tensor]:
|
91
|
+
"""Flatten the cache entries into a list of tensors with order k_i, v_i."""
|
92
|
+
flattened, _ = _flatten_kvc(self)
|
93
|
+
return flattened
|
94
|
+
|
95
|
+
|
96
|
+
def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
|
97
|
+
flattened = []
|
98
|
+
flat_names = []
|
99
|
+
none_names = []
|
100
|
+
for i, kv_entry in enumerate(kvc.caches):
|
101
|
+
flattened.append(kv_entry.k_cache)
|
102
|
+
flat_names.append(f"k_{i}")
|
103
|
+
flattened.append(kv_entry.v_cache)
|
104
|
+
flat_names.append(f"v_{i}")
|
105
|
+
return flattened, [flat_names, none_names]
|
106
|
+
|
107
|
+
|
108
|
+
def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
|
109
|
+
flattened, (flat_names, none_names) = _flatten_kvc(kvc)
|
110
|
+
return [
|
111
|
+
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
|
112
|
+
], flat_names
|
113
|
+
|
114
|
+
|
115
|
+
def _unflatten_kvc(
|
116
|
+
values: List[torch.Tensor], context: Tuple[List, List]
|
117
|
+
) -> KVCache:
|
118
|
+
assert len(values) % 2 == 0, "Found odd number of K and V entries."
|
119
|
+
num_layers = len(values) // 2
|
120
|
+
flat_names = context[0]
|
121
|
+
kv_entries = []
|
122
|
+
for i in range(num_layers):
|
123
|
+
k_cache_idx = flat_names.index(f"k_{i}")
|
124
|
+
v_cache_idx = flat_names.index(f"v_{i}")
|
125
|
+
kv_entries.append(
|
126
|
+
KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
|
79
127
|
)
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
128
|
+
obj = KVCache(tuple(kv_entries))
|
129
|
+
return obj
|
130
|
+
|
131
|
+
|
132
|
+
pytree.register_pytree_node(
|
133
|
+
KVCache,
|
134
|
+
_flatten_kvc,
|
135
|
+
_unflatten_kvc,
|
136
|
+
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
137
|
+
serialized_type_name="",
|
138
|
+
)
|
139
|
+
|
140
|
+
|
141
|
+
def update(
|
142
|
+
cache: KVCacheEntry,
|
143
|
+
input_pos: torch.Tensor,
|
144
|
+
k_slice: torch.Tensor,
|
145
|
+
v_slice: torch.Tensor,
|
146
|
+
enable_hlfb: bool = True,
|
147
|
+
) -> KVCacheEntry:
|
148
|
+
"""Out of place update of Cache buffer.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
cache (KVCacheEntry): The original cache buffer.
|
152
|
+
input_pos (torch.Tensor): The update slice positions.
|
153
|
+
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
154
|
+
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
155
|
+
enable_hlfb (bool, optional): Whether the op is annotated for export with
|
156
|
+
High Level Function Boundary. Defaults to True.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
KVCacheEntry: The updated KVCache entry based on the passed inputs.
|
160
|
+
"""
|
161
|
+
update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
|
162
|
+
return update_func(cache, input_pos, k_slice, v_slice)
|
163
|
+
|
164
|
+
|
165
|
+
def _update_kv_base_impl(
|
166
|
+
cache: KVCacheEntry,
|
167
|
+
input_pos: torch.Tensor,
|
168
|
+
k_slice: torch.Tensor,
|
169
|
+
v_slice: torch.Tensor,
|
170
|
+
) -> KVCacheEntry:
|
171
|
+
"""Update the cache buffer without High Level Function Boundary annotation."""
|
172
|
+
k = cache.k_cache.index_copy(1, input_pos, k_slice)
|
173
|
+
v = cache.v_cache.index_copy(1, input_pos, v_slice)
|
174
|
+
updated_cache = KVCacheEntry(k, v)
|
175
|
+
return updated_cache
|
176
|
+
|
177
|
+
|
178
|
+
def _update_kv_hlfb_impl(
|
179
|
+
cache: KVCacheEntry,
|
180
|
+
input_pos: torch.Tensor,
|
181
|
+
k_slice: torch.Tensor,
|
182
|
+
v_slice: torch.Tensor,
|
183
|
+
) -> KVCacheEntry:
|
184
|
+
"""Update the cache buffer with High Level Function Boundary annotation."""
|
185
|
+
builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
|
186
|
+
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
|
187
|
+
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
|
188
|
+
)
|
189
|
+
k = k_cache.index_copy(1, input_pos, k_slice)
|
190
|
+
v = v_cache.index_copy(1, input_pos, v_slice)
|
191
|
+
k, v = builder.mark_outputs(k, v)
|
192
|
+
return KVCacheEntry(k, v)
|
@@ -12,19 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
# A suite of tests to validate experimental external KV Cache layers and models.
|
16
15
|
|
17
|
-
|
18
|
-
|
19
|
-
from ai_edge_torch.generative.
|
20
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
16
|
+
"""A suite of tests to validate KV Cache layer."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
20
|
import torch
|
23
21
|
|
24
22
|
from absl.testing import absltest as googletest
|
25
23
|
|
26
24
|
|
27
|
-
class
|
25
|
+
class TestKVLayers(googletest.TestCase):
|
28
26
|
|
29
27
|
def _get_test_config(
|
30
28
|
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
@@ -54,7 +52,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
54
52
|
num_query_groups=NUM_QG,
|
55
53
|
kv_cache_max_len=KV_LEN,
|
56
54
|
)
|
57
|
-
kv = kv_utils.
|
55
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
58
56
|
entry = kv.caches[0]
|
59
57
|
# single-slice update
|
60
58
|
input_pos = torch.tensor([1])
|
@@ -88,14 +86,14 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
88
86
|
def test_serialization(self):
|
89
87
|
class TestModel(torch.nn.Module):
|
90
88
|
|
91
|
-
def forward(self, kv: kv_utils.
|
89
|
+
def forward(self, kv: kv_utils.KVCache) -> kv_utils.KVCache:
|
92
90
|
updated_kv_entries = [
|
93
91
|
kv_utils.KVCacheEntry(
|
94
92
|
torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
|
95
93
|
)
|
96
94
|
for entry in kv.caches
|
97
95
|
]
|
98
|
-
return kv_utils.
|
96
|
+
return kv_utils.KVCache(updated_kv_entries)
|
99
97
|
|
100
98
|
N = 1
|
101
99
|
HEAD_DIM = 2
|
@@ -107,7 +105,7 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
107
105
|
num_query_groups=NUM_QG,
|
108
106
|
kv_cache_max_len=KV_LEN,
|
109
107
|
)
|
110
|
-
kv = kv_utils.
|
108
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
111
109
|
model = TestModel()
|
112
110
|
exported_program = torch.export.export(model, (kv,))
|
113
111
|
input_specs = exported_program.graph_signature.input_specs
|
@@ -116,17 +114,5 @@ class TestExternalKVLayers(googletest.TestCase):
|
|
116
114
|
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
117
115
|
|
118
116
|
|
119
|
-
class TestExternalKVModels(googletest.TestCase):
|
120
|
-
|
121
|
-
def test_can_build_gemma(self):
|
122
|
-
gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
|
123
|
-
|
124
|
-
def test_can_build_phi2(self):
|
125
|
-
phi2.define_and_run(checkpoint_path=None, test_model=True)
|
126
|
-
|
127
|
-
def test_can_build_tinyllama(self):
|
128
|
-
tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
|
129
|
-
|
130
|
-
|
131
117
|
if __name__ == "__main__":
|
132
118
|
googletest.main()
|