ai-edge-torch-nightly 0.5.0.dev20250408__py3-none-any.whl → 0.5.0.dev20250410__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py +50 -0
- ai_edge_torch/_convert/test/test_convert.py +21 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/gemma3/decoder.py +8 -9
- ai_edge_torch/generative/examples/gemma3/verify_util.py +4 -2
- ai_edge_torch/generative/layers/experimental/attention.py +10 -40
- ai_edge_torch/generative/layers/experimental/kv_cache.py +13 -283
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +6 -10
- ai_edge_torch/generative/layers/experimental/types.py +3 -0
- ai_edge_torch/generative/layers/kv_cache.py +81 -14
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +124 -0
- ai_edge_torch/generative/test/test_kv_cache.py +12 -19
- ai_edge_torch/generative/utilities/converter.py +8 -3
- ai_edge_torch/generative/utilities/export_config.py +3 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +19 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/utils.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/METADATA +4 -2
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/RECORD +26 -24
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/top_level.txt +0 -0
@@ -40,8 +40,8 @@ def _run_convert_passes(
|
|
40
40
|
fx_passes.OptimizeLayoutTransposesPass(),
|
41
41
|
fx_passes.CanonicalizePass(),
|
42
42
|
fx_passes.BuildAtenCompositePass(),
|
43
|
-
fx_passes.CanonicalizePass(),
|
44
43
|
fx_passes.RemoveNonUserOutputsPass(),
|
44
|
+
fx_passes.CastInputsBf16ToF32Pass(),
|
45
45
|
fx_passes.CanonicalizePass(),
|
46
46
|
]
|
47
47
|
|
@@ -17,6 +17,7 @@ from typing import Sequence, Union
|
|
17
17
|
|
18
18
|
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
19
|
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
|
20
|
+
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
|
20
21
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
23
|
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Pass to cast all inputs with torch.bfloat16 type to torch.float32."""
|
16
|
+
|
17
|
+
|
18
|
+
from ai_edge_torch import fx_infra
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
def cast_f32(x):
|
23
|
+
return x.to(torch.float32)
|
24
|
+
|
25
|
+
|
26
|
+
class CastInputsBf16ToF32Pass(fx_infra.ExportedProgramPassBase):
|
27
|
+
"""This pass casts all inputs with torch.bfloat16 type to torch.float32."""
|
28
|
+
|
29
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
30
|
+
modified = False
|
31
|
+
for node in exported_program.graph.nodes:
|
32
|
+
if (
|
33
|
+
node.op == "placeholder"
|
34
|
+
and node.meta.get("val").dtype == torch.bfloat16
|
35
|
+
):
|
36
|
+
if not node.users:
|
37
|
+
continue
|
38
|
+
|
39
|
+
modified = True
|
40
|
+
user = next(iter(node.users))
|
41
|
+
with exported_program.graph.inserting_before(user):
|
42
|
+
cast_node = exported_program.graph.call_function(
|
43
|
+
cast_f32,
|
44
|
+
(node,),
|
45
|
+
)
|
46
|
+
node.replace_all_uses_with(cast_node)
|
47
|
+
cast_node.replace_input_with(cast_node, node)
|
48
|
+
|
49
|
+
exported_program.graph_module.recompile()
|
50
|
+
return fx_infra.ExportedProgramPassResult(exported_program, modified)
|
@@ -553,6 +553,27 @@ class TestConvert(googletest.TestCase):
|
|
553
553
|
self.fail(f"PT2E conversion failed: {err}")
|
554
554
|
# pylint: enable=broad-except
|
555
555
|
|
556
|
+
def test_convert_model_with_bfloat16_inputs(self):
|
557
|
+
"""Test converting a simple model with torch.bfloat16 input.
|
558
|
+
|
559
|
+
bf16 inputs would remain in converted model signature but be casted to f32
|
560
|
+
right after the model inputs.
|
561
|
+
"""
|
562
|
+
|
563
|
+
class SampleModel(nn.Module):
|
564
|
+
|
565
|
+
def forward(self, x: torch.Tensor):
|
566
|
+
return (x + 1) * 1.2
|
567
|
+
|
568
|
+
model = SampleModel().eval()
|
569
|
+
args = (torch.randn(10, 10).to(torch.bfloat16),)
|
570
|
+
# pylint: disable=broad-except
|
571
|
+
try:
|
572
|
+
ai_edge_torch.convert(model, args)
|
573
|
+
except Exception as err:
|
574
|
+
self.fail(f"Conversion failed with bloat16 inputs: {err}")
|
575
|
+
# pylint: enable=broad-except
|
576
|
+
|
556
577
|
|
557
578
|
if __name__ == "__main__":
|
558
579
|
googletest.main()
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
20
|
-
from ai_edge_torch.generative.layers
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache
|
21
21
|
from ai_edge_torch.generative.utilities import converter
|
22
22
|
from ai_edge_torch.generative.utilities import export_config
|
23
23
|
import torch
|
@@ -58,7 +58,7 @@ def _create_export_config(
|
|
58
58
|
)
|
59
59
|
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
60
60
|
export_config.decode_mask = decode_mask
|
61
|
-
export_config.
|
61
|
+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
62
62
|
return export_config
|
63
63
|
|
64
64
|
|
@@ -18,9 +18,9 @@
|
|
18
18
|
from typing import List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import builder
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
22
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
23
|
from ai_edge_torch.generative.layers.experimental import attention
|
23
|
-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
25
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
26
|
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
@@ -81,8 +81,8 @@ class DecoderBlock(attention.TransformerBlock):
|
|
81
81
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
82
82
|
mask: Optional[torch.Tensor] = None,
|
83
83
|
input_pos: Optional[torch.Tensor] = None,
|
84
|
-
kv_cache: kv_utils.
|
85
|
-
) -> Tuple[torch.Tensor, Optional[kv_utils.
|
84
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
85
|
+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
|
86
86
|
"""Forward function of the Gemma3Block.
|
87
87
|
|
88
88
|
Exactly the same as TransformerBlock but we call the post-attention norm
|
@@ -241,13 +241,12 @@ class Decoder(nn.Module):
|
|
241
241
|
self,
|
242
242
|
tokens: torch.Tensor,
|
243
243
|
input_pos: torch.Tensor,
|
244
|
-
kv_cache: kv_utils.
|
244
|
+
kv_cache: kv_utils.KVCache,
|
245
245
|
input_embeds: Optional[torch.Tensor] = None,
|
246
246
|
mask: Optional[torch.Tensor] = None,
|
247
247
|
image_indices: Optional[torch.Tensor] = None,
|
248
248
|
export_config: Optional[export_cfg.ExportConfig] = None,
|
249
|
-
) -> dict[torch.Tensor, kv_utils.
|
250
|
-
|
249
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
251
250
|
pixel_mask = None
|
252
251
|
if input_embeds is None:
|
253
252
|
# token embeddings of shape (b, t, n_embd)
|
@@ -287,10 +286,10 @@ class Decoder(nn.Module):
|
|
287
286
|
rope: List[Tuple[torch.Tensor, torch.Tensor]],
|
288
287
|
mask: torch.Tensor | List[torch.Tensor],
|
289
288
|
input_pos: torch.Tensor,
|
290
|
-
kv_cache: kv_utils.
|
289
|
+
kv_cache: kv_utils.KVCache,
|
291
290
|
pixel_mask: Optional[torch.Tensor] = None,
|
292
291
|
export_config: Optional[export_cfg.ExportConfig] = None,
|
293
|
-
) -> dict[torch.Tensor, kv_utils.
|
292
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
294
293
|
"""Forwards the model with input embeddings."""
|
295
294
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
296
295
|
"The number of transformer blocks and the number of KV cache entries"
|
@@ -326,7 +325,7 @@ class Decoder(nn.Module):
|
|
326
325
|
x, kv_entry = block(x, rope[i], mask_entry, input_pos, kv_entry)
|
327
326
|
if kv_entry:
|
328
327
|
updated_kv_entries.append(kv_entry)
|
329
|
-
updated_kv_cache = kv_utils.
|
328
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
330
329
|
if export_config is not None:
|
331
330
|
if (
|
332
331
|
torch.numel(input_pos) > 1
|
@@ -20,8 +20,8 @@ import os
|
|
20
20
|
from typing import List, Optional, Tuple
|
21
21
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
|
-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
|
25
25
|
from ai_edge_torch.generative.utilities.experimental import verifier
|
26
26
|
from gemma import config as gemma_config
|
27
27
|
from gemma import model as gemma_model
|
@@ -94,7 +94,9 @@ class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
|
|
94
94
|
|
95
95
|
def _init_kv_cache(self):
|
96
96
|
"""Returns an initialized KV cache."""
|
97
|
-
return kv_utils.
|
97
|
+
return kv_utils.KVCache.from_model_config(
|
98
|
+
self.model.model.config, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
|
99
|
+
)
|
98
100
|
|
99
101
|
def forward(
|
100
102
|
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
@@ -22,9 +22,9 @@ at any time.
|
|
22
22
|
from typing import Optional, Tuple, Union
|
23
23
|
|
24
24
|
from ai_edge_torch.generative.layers import builder
|
25
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
26
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
26
|
-
from ai_edge_torch.generative.layers
|
27
|
-
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
|
27
|
+
from ai_edge_torch.generative.layers import sdpa_with_kv_update
|
28
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
29
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
30
30
|
import torch
|
@@ -69,9 +69,9 @@ class TransformerBlock(nn.Module):
|
|
69
69
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
70
70
|
mask: Optional[torch.Tensor] = None,
|
71
71
|
input_pos: Optional[torch.Tensor] = None,
|
72
|
-
kv_cache: kv_utils.
|
72
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
73
73
|
lora: Optional[lora_utils.LoRAEntry] = None,
|
74
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.
|
74
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
75
75
|
"""Forward function of the TransformerBlock.
|
76
76
|
|
77
77
|
Args:
|
@@ -79,7 +79,7 @@ class TransformerBlock(nn.Module):
|
|
79
79
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
80
80
|
mask (torch.Tensor): the optional mask tensor.
|
81
81
|
input_pos (torch.Tensor): the optional input position tensor.
|
82
|
-
kv_cache (
|
82
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
83
83
|
lora (LoRAEntry): the optional lora entry.
|
84
84
|
|
85
85
|
Returns:
|
@@ -146,7 +146,6 @@ class CausalSelfAttention(nn.Module):
|
|
146
146
|
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
|
147
147
|
self.config = config
|
148
148
|
self.enable_hlfb = enable_hlfb
|
149
|
-
self.sdpa_func = sdpa.scaled_dot_product_attention
|
150
149
|
|
151
150
|
def forward(
|
152
151
|
self,
|
@@ -154,9 +153,9 @@ class CausalSelfAttention(nn.Module):
|
|
154
153
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
155
154
|
mask: Optional[torch.Tensor] = None,
|
156
155
|
input_pos: Optional[torch.Tensor] = None,
|
157
|
-
kv_cache: Optional[kv_utils.
|
156
|
+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
158
157
|
lora: Optional[lora_utils.LoRAEntry] = None,
|
159
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.
|
158
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
160
159
|
"""Forward function of the CausalSelfAttention layer, which can support
|
161
160
|
|
162
161
|
MQA, GQA and MHA.
|
@@ -166,8 +165,7 @@ class CausalSelfAttention(nn.Module):
|
|
166
165
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
167
166
|
mask (torch.Tensor): the optional mask tensor.
|
168
167
|
input_pos (torch.Tensor): the optional input position tensor.
|
169
|
-
kv_cache (
|
170
|
-
module.
|
168
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
171
169
|
lora (LoRAEntry): the optional lora entry.
|
172
170
|
|
173
171
|
Returns:
|
@@ -221,36 +219,8 @@ class CausalSelfAttention(nn.Module):
|
|
221
219
|
cos, sin = rope
|
222
220
|
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
223
221
|
|
224
|
-
|
225
|
-
|
226
|
-
g = n // self.config.num_query_groups
|
227
|
-
# btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
|
228
|
-
q = q.permute(0, 2, 1, 3).reshape(
|
229
|
-
1, b * self.config.num_query_groups, g * T, h
|
230
|
-
)
|
231
|
-
|
232
|
-
k = k.permute(0, 2, 1, 3).reshape(
|
233
|
-
1, -1, T, self.config.head_dim
|
234
|
-
) # 1, bk, s, h
|
235
|
-
v = v.permute(0, 2, 3, 1).reshape(
|
236
|
-
1, -1, self.config.head_dim, T
|
237
|
-
) # 1, bk, h, s
|
238
|
-
|
239
|
-
if kv_cache is not None:
|
240
|
-
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
241
|
-
k, v = kv_cache.k_cache, kv_cache.v_cache
|
242
|
-
|
243
|
-
sdpa_out = self.sdpa_func(
|
244
|
-
kv_cache,
|
245
|
-
q,
|
246
|
-
k,
|
247
|
-
v,
|
248
|
-
self.config.head_dim,
|
249
|
-
mask=mask,
|
250
|
-
softcap=self.config.logit_softcap,
|
251
|
-
) # 1, bk, gt, h
|
252
|
-
sdpa_out = (
|
253
|
-
sdpa_out.reshape(B, -1, T, h).permute(0, 2, 1, 3).reshape(B, T, -1)
|
222
|
+
sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
|
223
|
+
q, k, v, kv_cache, input_pos, mask, self.config
|
254
224
|
)
|
255
225
|
|
256
226
|
# Compute the output projection.
|
@@ -18,303 +18,33 @@
|
|
18
18
|
This is an experimental implementation and is subject to change at any time.
|
19
19
|
"""
|
20
20
|
|
21
|
-
import dataclasses
|
22
|
-
import functools
|
23
|
-
from typing import Any, List, Tuple, Type
|
24
|
-
from ai_edge_torch.generative.layers import model_config
|
25
|
-
from ai_edge_torch.generative.layers.experimental import types
|
26
21
|
from ai_edge_torch.generative.custom_ops import dynamic_update_slice as dus_utils
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
27
23
|
import torch
|
28
|
-
import torch.utils._pytree as pytree
|
29
|
-
|
30
|
-
|
31
|
-
@dataclasses.dataclass
|
32
|
-
class KVCacheEntryBase:
|
33
|
-
"""A single cache entry that includes K and V caches.
|
34
|
-
|
35
|
-
The chaches are built based on the provided config with the shape of
|
36
|
-
(batch_size, kv_cache_max, num_query_groups, head_dim).
|
37
|
-
"""
|
38
|
-
|
39
|
-
k_cache: torch.Tensor
|
40
|
-
v_cache: torch.Tensor
|
41
|
-
|
42
|
-
@classmethod
|
43
|
-
def _from_model_config(
|
44
|
-
cls,
|
45
|
-
k_shape: Tuple[int, ...],
|
46
|
-
v_shape: Tuple[int, ...],
|
47
|
-
dtype: torch.dtype = torch.float32,
|
48
|
-
device: torch.device = None,
|
49
|
-
):
|
50
|
-
"""Build an instance of the class based on model config."""
|
51
|
-
k = torch.zeros(k_shape, dtype=dtype, device=device)
|
52
|
-
v = torch.zeros(v_shape, dtype=dtype, device=device)
|
53
|
-
obj = cls(k_cache=k, v_cache=v)
|
54
|
-
return obj
|
55
|
-
|
56
|
-
@classmethod
|
57
|
-
def from_model_config(
|
58
|
-
cls,
|
59
|
-
kv_cache_max: int,
|
60
|
-
config: model_config.AttentionConfig,
|
61
|
-
dtype: torch.dtype = torch.float32,
|
62
|
-
device: torch.device = None,
|
63
|
-
batch_size: int = 1,
|
64
|
-
):
|
65
|
-
"""Build an instance of the class based on model config."""
|
66
|
-
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
|
67
|
-
return cls._from_model_config(shape, shape, dtype, device)
|
68
|
-
|
69
|
-
|
70
|
-
@dataclasses.dataclass
|
71
|
-
class KVCacheEntryBTNH(KVCacheEntryBase):
|
72
|
-
k_type = types.BTNH()
|
73
|
-
v_type = types.BTNH()
|
74
|
-
|
75
|
-
|
76
|
-
@dataclasses.dataclass
|
77
|
-
class KVCacheEntryTransposed(KVCacheEntryBase):
|
78
|
-
|
79
|
-
k_type = types.BNTH()
|
80
|
-
v_type = types.BNHT()
|
81
|
-
|
82
|
-
@classmethod
|
83
|
-
def from_model_config(
|
84
|
-
cls,
|
85
|
-
kv_cache_max: int,
|
86
|
-
config: model_config.AttentionConfig,
|
87
|
-
dtype: torch.dtype = torch.float32,
|
88
|
-
device: torch.device = None,
|
89
|
-
batch_size: int = 1,
|
90
|
-
):
|
91
|
-
"""Build an instance of the class based on model config."""
|
92
|
-
k_shape = (
|
93
|
-
batch_size,
|
94
|
-
config.num_query_groups,
|
95
|
-
kv_cache_max,
|
96
|
-
config.head_dim,
|
97
|
-
) # b, k, s, h
|
98
|
-
v_shape = (
|
99
|
-
batch_size,
|
100
|
-
config.num_query_groups,
|
101
|
-
config.head_dim,
|
102
|
-
kv_cache_max,
|
103
|
-
) # b, k, h, s
|
104
|
-
return cls._from_model_config(k_shape, v_shape, dtype, device)
|
105
|
-
|
106
|
-
|
107
|
-
def _flatten_kv_entry(
|
108
|
-
kv_e: KVCacheEntryBase,
|
109
|
-
) -> Tuple[List[torch.Tensor], Any]:
|
110
|
-
return ([kv_e.k_cache, kv_e.v_cache], None)
|
111
|
-
|
112
|
-
|
113
|
-
def _unflatten_kv_entry(
|
114
|
-
kv_entry_ty: Type[KVCacheEntryBase],
|
115
|
-
values: List[torch.Tensor],
|
116
|
-
unused_context: Any,
|
117
|
-
) -> KVCacheEntryBase:
|
118
|
-
return kv_entry_ty(*values)
|
119
|
-
|
120
|
-
|
121
|
-
pytree.register_pytree_node(
|
122
|
-
KVCacheEntryTransposed,
|
123
|
-
_flatten_kv_entry,
|
124
|
-
functools.partial(_unflatten_kv_entry, KVCacheEntryTransposed),
|
125
|
-
serialized_type_name="",
|
126
|
-
)
|
127
|
-
|
128
|
-
pytree.register_pytree_node(
|
129
|
-
KVCacheEntryBase,
|
130
|
-
_flatten_kv_entry,
|
131
|
-
functools.partial(_unflatten_kv_entry, KVCacheEntryBase),
|
132
|
-
serialized_type_name="",
|
133
|
-
)
|
134
|
-
|
135
|
-
|
136
|
-
@dataclasses.dataclass
|
137
|
-
class KVCacheBase:
|
138
|
-
"""A utility class for holding KV cache entries per layer."""
|
139
|
-
|
140
|
-
caches: Tuple[KVCacheEntryBase, ...]
|
141
|
-
|
142
|
-
@classmethod
|
143
|
-
def _from_model_config(
|
144
|
-
cls,
|
145
|
-
kv_entry_cls,
|
146
|
-
config: model_config.ModelConfig,
|
147
|
-
dtype: torch.dtype = torch.float32,
|
148
|
-
device: torch.device = None,
|
149
|
-
batch_size: int = 1,
|
150
|
-
):
|
151
|
-
caches = [
|
152
|
-
kv_entry_cls.from_model_config(
|
153
|
-
config.kv_cache_max,
|
154
|
-
config.block_config(idx).attn_config,
|
155
|
-
dtype,
|
156
|
-
device,
|
157
|
-
batch_size,
|
158
|
-
)
|
159
|
-
for idx in range(config.num_layers)
|
160
|
-
]
|
161
|
-
obj = cls(caches=tuple(caches))
|
162
|
-
return obj
|
163
|
-
|
164
|
-
@classmethod
|
165
|
-
def from_model_config(
|
166
|
-
cls,
|
167
|
-
config: model_config.ModelConfig,
|
168
|
-
dtype: torch.dtype = torch.float32,
|
169
|
-
device: torch.device = None,
|
170
|
-
batch_size: int = 1,
|
171
|
-
):
|
172
|
-
"""Build an instance of the class based on model config.
|
173
|
-
|
174
|
-
Args:
|
175
|
-
config (ModelConfig): Model config used for building the cache.
|
176
|
-
dtype (torch.dtype, optional): The data type of the cache tensor.
|
177
|
-
Defaults to torch.float32.
|
178
|
-
device (torch.device, optional): The device placement of the cache
|
179
|
-
tensors. Defaults to None.
|
180
|
-
batch_size (int, optional): The batch size of the cache tensors.
|
181
|
-
Defaults to 1.
|
182
|
-
|
183
|
-
Returns:
|
184
|
-
KVCacheBase: The created cache object.
|
185
|
-
"""
|
186
|
-
assert batch_size == 1, "Batch size must be 1 for KV Cache."
|
187
|
-
return cls._from_model_config(
|
188
|
-
KVCacheEntryBase,
|
189
|
-
config=config,
|
190
|
-
dtype=dtype,
|
191
|
-
device=device,
|
192
|
-
batch_size=batch_size,
|
193
|
-
)
|
194
|
-
|
195
|
-
def flatten(self) -> List[torch.Tensor]:
|
196
|
-
"""Flatten the cache entries into a list of tensors with order k_i, v_i."""
|
197
|
-
flattened, _ = _flatten_kvc(self)
|
198
|
-
return flattened
|
199
|
-
|
200
|
-
|
201
|
-
@dataclasses.dataclass
|
202
|
-
class KVCacheBTNH(KVCacheBase):
|
203
|
-
|
204
|
-
@classmethod
|
205
|
-
def from_model_config(
|
206
|
-
cls,
|
207
|
-
config: model_config.ModelConfig,
|
208
|
-
dtype: torch.dtype = torch.float32,
|
209
|
-
device: torch.device = None,
|
210
|
-
batch_size: int = 1,
|
211
|
-
):
|
212
|
-
return cls._from_model_config(
|
213
|
-
KVCacheEntryBTNH,
|
214
|
-
config=config,
|
215
|
-
dtype=dtype,
|
216
|
-
device=device,
|
217
|
-
batch_size=batch_size,
|
218
|
-
)
|
219
|
-
|
220
|
-
|
221
|
-
@dataclasses.dataclass
|
222
|
-
class KVCacheTransposed(KVCacheBase):
|
223
|
-
|
224
|
-
@classmethod
|
225
|
-
def from_model_config(
|
226
|
-
cls,
|
227
|
-
config: model_config.ModelConfig,
|
228
|
-
dtype: torch.dtype = torch.float32,
|
229
|
-
device: torch.device = None,
|
230
|
-
batch_size: int = 1,
|
231
|
-
):
|
232
|
-
return cls._from_model_config(
|
233
|
-
KVCacheEntryTransposed,
|
234
|
-
config=config,
|
235
|
-
dtype=dtype,
|
236
|
-
device=device,
|
237
|
-
batch_size=batch_size,
|
238
|
-
)
|
239
|
-
|
240
|
-
|
241
|
-
def _flatten_kvc(kvc: KVCacheBase) -> Tuple[List[str], List[str]]:
|
242
|
-
flattened = []
|
243
|
-
flat_names = []
|
244
|
-
none_names = []
|
245
|
-
for i, kv_entry in enumerate(kvc.caches):
|
246
|
-
flattened.append(kv_entry.k_cache)
|
247
|
-
flat_names.append(f"k_{i}")
|
248
|
-
flattened.append(kv_entry.v_cache)
|
249
|
-
flat_names.append(f"v_{i}")
|
250
|
-
return flattened, [flat_names, none_names]
|
251
|
-
|
252
|
-
|
253
|
-
def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
|
254
|
-
flattened, (flat_names, none_names) = _flatten_kvc(kvc)
|
255
|
-
return [
|
256
|
-
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
|
257
|
-
], flat_names
|
258
|
-
|
259
|
-
|
260
|
-
def _unflatten_kvc(
|
261
|
-
kv_ty: Type[KVCacheBase],
|
262
|
-
kv_entry_type: Type[KVCacheEntryBase],
|
263
|
-
values: List[torch.Tensor],
|
264
|
-
context: Tuple[List, List],
|
265
|
-
) -> KVCacheBase:
|
266
|
-
assert len(values) % 2 == 0, "Found odd number of K and V entries."
|
267
|
-
num_layers = len(values) // 2
|
268
|
-
flat_names = context[0]
|
269
|
-
kv_entries = []
|
270
|
-
for i in range(num_layers):
|
271
|
-
k_cache_idx = flat_names.index(f"k_{i}")
|
272
|
-
v_cache_idx = flat_names.index(f"v_{i}")
|
273
|
-
kv_entries.append(
|
274
|
-
kv_entry_type(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
|
275
|
-
)
|
276
|
-
obj = kv_ty(tuple(kv_entries))
|
277
|
-
return obj
|
278
|
-
|
279
|
-
|
280
|
-
pytree.register_pytree_node(
|
281
|
-
KVCacheTransposed,
|
282
|
-
_flatten_kvc,
|
283
|
-
functools.partial(
|
284
|
-
_unflatten_kvc, KVCacheTransposed, KVCacheEntryTransposed
|
285
|
-
),
|
286
|
-
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
287
|
-
serialized_type_name="",
|
288
|
-
)
|
289
|
-
|
290
|
-
pytree.register_pytree_node(
|
291
|
-
KVCacheBase,
|
292
|
-
_flatten_kvc,
|
293
|
-
functools.partial(_unflatten_kvc, KVCacheBase, KVCacheEntryBase),
|
294
|
-
flatten_with_keys_fn=_flatten_kvc_with_keys,
|
295
|
-
serialized_type_name="",
|
296
|
-
)
|
297
24
|
|
298
25
|
|
299
26
|
def update(
|
300
|
-
cache:
|
27
|
+
cache: kv_utils.KVCacheEntry,
|
301
28
|
input_pos: torch.Tensor,
|
302
29
|
k_slice: torch.Tensor,
|
303
30
|
v_slice: torch.Tensor,
|
304
|
-
) ->
|
31
|
+
) -> kv_utils.KVCacheEntry:
|
305
32
|
"""Out of place update of Cache buffer.
|
306
33
|
|
307
34
|
Args:
|
308
|
-
cache (
|
35
|
+
cache (kv_utils.KVCacheEntry): The original cache buffer.
|
309
36
|
input_pos (torch.Tensor): The update slice positions.
|
310
37
|
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
311
38
|
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
312
39
|
|
313
40
|
Returns:
|
314
|
-
|
41
|
+
kv_utils.KVCacheEntry: The updated KVCacheBase entry based on the passed
|
315
42
|
inputs.
|
316
43
|
"""
|
317
|
-
|
44
|
+
assert (
|
45
|
+
cache.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED
|
46
|
+
), "KV entry must have transposed layout."
|
47
|
+
update_kv_cache = _update_kv_impl_transposed
|
318
48
|
return update_kv_cache(cache, input_pos, k_slice, v_slice)
|
319
49
|
|
320
50
|
|
@@ -338,12 +68,12 @@ def _get_slice_indices(
|
|
338
68
|
return slice_indices
|
339
69
|
|
340
70
|
|
341
|
-
def
|
342
|
-
cache:
|
71
|
+
def _update_kv_impl_transposed(
|
72
|
+
cache: kv_utils.KVCacheEntry,
|
343
73
|
input_pos: torch.Tensor,
|
344
74
|
k_slice: torch.Tensor,
|
345
75
|
v_slice: torch.Tensor,
|
346
|
-
) ->
|
76
|
+
) -> kv_utils.KVCacheEntry:
|
347
77
|
"""Update the cache buffer with High Level Function Boundary annotation."""
|
348
78
|
cache_dim = 4
|
349
79
|
k_ts_idx = 2
|
@@ -357,4 +87,4 @@ def _update_kv_impl(
|
|
357
87
|
v = dus_utils.dynamic_update_slice(
|
358
88
|
cache.v_cache, v_slice, [x for x in v_slice_indices]
|
359
89
|
)
|
360
|
-
return
|
90
|
+
return kv_utils.KVCacheEntry(k, v, cache.kv_layout)
|
@@ -19,7 +19,7 @@ import math
|
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
|
22
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
from ai_edge_torch.generative.layers.experimental import types
|
24
24
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
25
25
|
from multipledispatch import dispatch
|
@@ -28,7 +28,7 @@ import torch.nn.functional as F
|
|
28
28
|
|
29
29
|
|
30
30
|
def scaled_dot_product_attention(
|
31
|
-
kv: kv_utils.
|
31
|
+
kv: kv_utils.KVCacheEntry,
|
32
32
|
query: torch.Tensor,
|
33
33
|
key: torch.Tensor,
|
34
34
|
value: torch.Tensor,
|
@@ -37,10 +37,10 @@ def scaled_dot_product_attention(
|
|
37
37
|
scale: Optional[float] = None,
|
38
38
|
softcap: Optional[float] = None,
|
39
39
|
):
|
40
|
-
if hasattr(kv, "
|
40
|
+
if hasattr(kv, "kv_layout"):
|
41
41
|
return _sdpa(
|
42
|
-
kv.
|
43
|
-
kv.
|
42
|
+
kv.kv_layout[0](), # key layout
|
43
|
+
kv.kv_layout[1](), # value layout
|
44
44
|
query=query,
|
45
45
|
key=key,
|
46
46
|
value=value,
|
@@ -49,10 +49,7 @@ def scaled_dot_product_attention(
|
|
49
49
|
scale=scale,
|
50
50
|
softcap=softcap,
|
51
51
|
)
|
52
|
-
raise ValueError(
|
53
|
-
f"SDPA for K type {type(kv.caches[0].k_type)} and V type"
|
54
|
-
f" {type(kv.caches[0].v_type)} not supported."
|
55
|
-
)
|
52
|
+
raise ValueError("No kv_layout attribute found in kv.")
|
56
53
|
|
57
54
|
|
58
55
|
@dispatch(types.BNTH, types.BNHT)
|
@@ -85,7 +82,6 @@ def _sdpa(k_type, v_type, *args, **kwargs):
|
|
85
82
|
padded_logits = logits + mask
|
86
83
|
padded_logits = padded_logits.reshape(1, bk, gt, s)
|
87
84
|
probs = F.softmax(padded_logits, dim=-1).type_as(key)
|
88
|
-
|
89
85
|
encoded = bmm_lib.bmm_4d(probs, value)
|
90
86
|
|
91
87
|
return encoded # 1, bk, gt, h
|
@@ -62,6 +62,9 @@ class TensorDimensionMeta(type):
|
|
62
62
|
def __repr__(cls):
|
63
63
|
return f'{cls.__name__}'
|
64
64
|
|
65
|
+
def __iter__(cls):
|
66
|
+
return iter(getattr(cls, 'dimensions'))
|
67
|
+
|
65
68
|
|
66
69
|
def create_tensor_dimension_order_class(dims: Tuple[TensorDims]):
|
67
70
|
"""Creates a TensorDimensionMeta class with the specified dimensions.
|