ai-edge-torch-nightly 0.5.0.dev20250427__py3-none-any.whl → 0.5.0.dev20250429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/gemma/gemma1.py +1 -0
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -1
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +18 -2
- ai_edge_torch/generative/examples/gemma/verify_util.py +15 -2
- ai_edge_torch/generative/examples/gemma3/verify_util.py +5 -3
- ai_edge_torch/generative/layers/attention_test.py +153 -0
- ai_edge_torch/generative/layers/attention_utils_test.py +64 -0
- ai_edge_torch/generative/layers/kv_cache.py +73 -3
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +1 -2
- ai_edge_torch/generative/utilities/converter.py +1 -1
- ai_edge_torch/generative/utilities/export_config.py +8 -2
- ai_edge_torch/generative/utilities/verifier.py +24 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/RECORD +20 -20
- ai_edge_torch/generative/layers/experimental/__init__.py +0 -14
- ai_edge_torch/generative/layers/experimental/kv_cache.py +0 -90
- {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("gemma-2b")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("gemma2-2b")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -236,6 +236,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
236
236
|
type=cfg.NormalizationType.RMS_NORM,
|
237
237
|
epsilon=1e-6,
|
238
238
|
zero_centered=True,
|
239
|
+
enable_hlfb=True,
|
239
240
|
)
|
240
241
|
ff_config = cfg.FeedForwardConfig(
|
241
242
|
type=cfg.FeedForwardType.GATED,
|
@@ -314,5 +315,5 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
314
315
|
tensor_names=tensor_names,
|
315
316
|
model_class=Gemma2,
|
316
317
|
)
|
317
|
-
except KeyError as
|
318
|
+
except KeyError as _:
|
318
319
|
continue
|
@@ -18,6 +18,7 @@
|
|
18
18
|
from absl import app
|
19
19
|
from absl import flags
|
20
20
|
from ai_edge_torch.generative.examples.gemma import verify_util
|
21
|
+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
21
22
|
import kagglehub
|
22
23
|
|
23
24
|
|
@@ -31,12 +32,27 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
31
32
|
30,
|
32
33
|
"The maximum size of the generated tokens.",
|
33
34
|
)
|
34
|
-
|
35
|
+
_MASK_AS_INPUT = flags.DEFINE_bool(
|
36
|
+
"mask_as_input",
|
37
|
+
True,
|
38
|
+
"Pass the causal self attention mask to the model.",
|
39
|
+
)
|
40
|
+
_TRANSPOSE_KV_CACHE = flags.DEFINE_bool(
|
41
|
+
"transpose_kv_cache",
|
42
|
+
True,
|
43
|
+
"Transpose the KV cache to reduce memory usage.",
|
44
|
+
)
|
35
45
|
|
36
46
|
def main(_):
|
37
47
|
checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
|
38
48
|
|
39
|
-
verify_util.verify_gemma2(
|
49
|
+
verify_util.verify_gemma2(
|
50
|
+
checkpoint,
|
51
|
+
_PROMPTS.value,
|
52
|
+
_MAX_NEW_TOKENS.value,
|
53
|
+
_MASK_AS_INPUT.value,
|
54
|
+
kv_utils.KV_LAYOUT_TRANSPOSED if _TRANSPOSE_KV_CACHE.value else kv_utils.KV_LAYOUT_DEFAULT,
|
55
|
+
)
|
40
56
|
|
41
57
|
|
42
58
|
if __name__ == "__main__":
|
@@ -21,6 +21,7 @@ from typing import List, Tuple
|
|
21
21
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
|
+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
from gemma import config as gemma_config
|
26
27
|
from gemma import model as gemma_model
|
@@ -108,6 +109,8 @@ def verify_reauthored_gemma_model(
|
|
108
109
|
weight_filename: str = "model.ckpt",
|
109
110
|
tokenizer_filename: str = "tokenizer.model",
|
110
111
|
max_new_tokens: int = 20,
|
112
|
+
mask_as_input: bool = False,
|
113
|
+
kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
|
111
114
|
rtol: float = 1e-05,
|
112
115
|
atol: float = 1e-05,
|
113
116
|
) -> bool:
|
@@ -126,7 +129,11 @@ def verify_reauthored_gemma_model(
|
|
126
129
|
|
127
130
|
return verifier.verify_reauthored_model(
|
128
131
|
original_model=GemmaWrapper(original_model),
|
129
|
-
reauthored_model=verifier.ReauthoredModelWrapper(
|
132
|
+
reauthored_model=verifier.ReauthoredModelWrapper(
|
133
|
+
reauthored_model,
|
134
|
+
mask_as_input=mask_as_input,
|
135
|
+
kv_layout=kv_layout,
|
136
|
+
),
|
130
137
|
tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
|
131
138
|
generate_prompts=generate_prompts,
|
132
139
|
max_new_tokens=max_new_tokens,
|
@@ -137,7 +144,11 @@ def verify_reauthored_gemma_model(
|
|
137
144
|
|
138
145
|
|
139
146
|
def verify_gemma2(
|
140
|
-
gemma2_model_path: str,
|
147
|
+
gemma2_model_path: str,
|
148
|
+
prompts: List[str],
|
149
|
+
max_new_tokens: int,
|
150
|
+
mask_as_input: bool = False,
|
151
|
+
kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
|
141
152
|
) -> bool:
|
142
153
|
"""Verifies the reauthored Gemma2 model.
|
143
154
|
|
@@ -153,5 +164,7 @@ def verify_gemma2(
|
|
153
164
|
generate_prompts=prompts,
|
154
165
|
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
155
166
|
max_new_tokens=max_new_tokens,
|
167
|
+
mask_as_input=mask_as_input,
|
168
|
+
kv_layout=kv_layout,
|
156
169
|
atol=1e-04,
|
157
170
|
)
|
@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
|
-
from ai_edge_torch.generative.utilities
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
26
|
from gemma import config as gemma_config
|
27
27
|
from gemma import model as gemma_model
|
28
28
|
import torch
|
@@ -92,10 +92,12 @@ class GemmaWrapper(verifier.ModelWrapper):
|
|
92
92
|
class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
|
93
93
|
"""Unified Gemma3 model wrapper for verification."""
|
94
94
|
|
95
|
+
def __init__(self, model: torch.nn.Module):
|
96
|
+
super().__init__(model, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED)
|
97
|
+
|
95
98
|
def _init_kv_cache(self):
|
96
|
-
"""Returns an initialized KV cache."""
|
97
99
|
return kv_utils.KVCache.from_model_config(
|
98
|
-
self.model.model.config, kv_layout=
|
100
|
+
self.model.model.config, kv_layout=self.kv_layout
|
99
101
|
)
|
100
102
|
|
101
103
|
def forward(
|
@@ -0,0 +1,153 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ai_edge_torch.generative.layers import attention
|
17
|
+
from ai_edge_torch.generative.layers import model_config as cfg
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from absl.testing import absltest as googletest
|
21
|
+
from absl.testing import parameterized
|
22
|
+
|
23
|
+
|
24
|
+
class AttentionTest(parameterized.TestCase):
|
25
|
+
|
26
|
+
@parameterized.named_parameters(
|
27
|
+
dict(
|
28
|
+
testcase_name="local_causal_self_attention",
|
29
|
+
attn_type=cfg.AttentionType.LOCAL_SLIDING,
|
30
|
+
expected_shape=(1, 10, 16),
|
31
|
+
),
|
32
|
+
dict(
|
33
|
+
testcase_name="global_causal_self_attention",
|
34
|
+
attn_type=cfg.AttentionType.GLOBAL,
|
35
|
+
expected_shape=(1, 10, 16),
|
36
|
+
),
|
37
|
+
)
|
38
|
+
def test_causal_self_attention(
|
39
|
+
self, attn_type: cfg.AttentionType, expected_shape: tuple[int, ...]
|
40
|
+
):
|
41
|
+
norm_config = cfg.NormalizationConfig(
|
42
|
+
type=cfg.NormalizationType.RMS_NORM,
|
43
|
+
epsilon=1e-6,
|
44
|
+
zero_centered=True,
|
45
|
+
enable_hlfb=True,
|
46
|
+
)
|
47
|
+
attn_config = cfg.AttentionConfig(
|
48
|
+
num_heads=2,
|
49
|
+
head_dim=8,
|
50
|
+
num_query_groups=1,
|
51
|
+
rotary_base=100,
|
52
|
+
rotary_percentage=1.0,
|
53
|
+
qkv_transpose_before_split=True,
|
54
|
+
query_norm_config=norm_config,
|
55
|
+
key_norm_config=norm_config,
|
56
|
+
logit_softcap=None,
|
57
|
+
sliding_window_size=16,
|
58
|
+
attn_type=attn_type,
|
59
|
+
)
|
60
|
+
self_atten = attention.CausalSelfAttention(
|
61
|
+
dim=16,
|
62
|
+
config=attn_config,
|
63
|
+
enable_hlfb=True,
|
64
|
+
)
|
65
|
+
x = torch.randn(1, 10, 16)
|
66
|
+
attn_mask = torch.ones((1, 1, 10, 10), dtype=torch.float32)
|
67
|
+
out = self_atten(x, rope=None, mask=attn_mask)
|
68
|
+
self.assertEqual(out.shape, expected_shape)
|
69
|
+
|
70
|
+
def test_cross_attention(self):
|
71
|
+
norm_config = cfg.NormalizationConfig(
|
72
|
+
type=cfg.NormalizationType.RMS_NORM,
|
73
|
+
epsilon=1e-6,
|
74
|
+
zero_centered=True,
|
75
|
+
enable_hlfb=True,
|
76
|
+
)
|
77
|
+
attn_config = cfg.AttentionConfig(
|
78
|
+
num_heads=2,
|
79
|
+
head_dim=8,
|
80
|
+
num_query_groups=1,
|
81
|
+
rotary_base=100,
|
82
|
+
rotary_percentage=1.0,
|
83
|
+
qkv_transpose_before_split=True,
|
84
|
+
query_norm_config=norm_config,
|
85
|
+
key_norm_config=norm_config,
|
86
|
+
logit_softcap=None,
|
87
|
+
sliding_window_size=16,
|
88
|
+
attn_type=cfg.AttentionType.GLOBAL,
|
89
|
+
)
|
90
|
+
cross_atten = attention.CrossAttention(
|
91
|
+
query_dim=16,
|
92
|
+
cross_dim=16,
|
93
|
+
hidden_dim=16,
|
94
|
+
output_dim=16,
|
95
|
+
config=attn_config,
|
96
|
+
enable_hlfb=True,
|
97
|
+
)
|
98
|
+
x = torch.randn(1, 10, 16)
|
99
|
+
y = torch.randn(1, 10, 16)
|
100
|
+
out = cross_atten(x, y, rope=None)
|
101
|
+
self.assertEqual(out.shape, (1, 10, 16))
|
102
|
+
|
103
|
+
def test_transformer_block(self):
|
104
|
+
norm_config = cfg.NormalizationConfig(
|
105
|
+
type=cfg.NormalizationType.RMS_NORM,
|
106
|
+
epsilon=1e-6,
|
107
|
+
zero_centered=True,
|
108
|
+
enable_hlfb=True,
|
109
|
+
)
|
110
|
+
attn_config = cfg.AttentionConfig(
|
111
|
+
num_heads=2,
|
112
|
+
head_dim=8,
|
113
|
+
num_query_groups=1,
|
114
|
+
rotary_base=100,
|
115
|
+
rotary_percentage=1.0,
|
116
|
+
qkv_transpose_before_split=True,
|
117
|
+
query_norm_config=norm_config,
|
118
|
+
key_norm_config=norm_config,
|
119
|
+
logit_softcap=None,
|
120
|
+
sliding_window_size=16,
|
121
|
+
attn_type=cfg.AttentionType.GLOBAL,
|
122
|
+
)
|
123
|
+
ff_config = cfg.FeedForwardConfig(
|
124
|
+
type=cfg.FeedForwardType.GATED,
|
125
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
126
|
+
intermediate_size=32,
|
127
|
+
)
|
128
|
+
block_config = cfg.TransformerBlockConfig(
|
129
|
+
attn_config=attn_config,
|
130
|
+
ff_config=ff_config,
|
131
|
+
post_attention_norm_config=norm_config,
|
132
|
+
parallel_residual=True,
|
133
|
+
)
|
134
|
+
model_config = cfg.ModelConfig(
|
135
|
+
vocab_size=100,
|
136
|
+
embedding_dim=16,
|
137
|
+
enable_hlfb=True,
|
138
|
+
num_layers=1,
|
139
|
+
max_seq_len=10,
|
140
|
+
block_configs=[block_config],
|
141
|
+
)
|
142
|
+
transformer_block = attention.TransformerBlock(
|
143
|
+
config=block_config,
|
144
|
+
model_config=model_config,
|
145
|
+
)
|
146
|
+
x = torch.randn(1, 10, 16)
|
147
|
+
attn_mask = torch.ones((1, 1, 10, 10), dtype=torch.float32)
|
148
|
+
out = transformer_block(x, rope=None, mask=attn_mask)
|
149
|
+
self.assertEqual(out.shape, (1, 10, 16))
|
150
|
+
|
151
|
+
|
152
|
+
if __name__ == "__main__":
|
153
|
+
googletest.main()
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ai_edge_torch.generative.layers import attention_utils
|
17
|
+
import torch
|
18
|
+
|
19
|
+
from absl.testing import absltest as googletest
|
20
|
+
|
21
|
+
|
22
|
+
class AttentionUtilsTest(googletest.TestCase):
|
23
|
+
|
24
|
+
def test_build_causal_mask_cache(self):
|
25
|
+
mask = attention_utils.build_causal_mask_cache(3)
|
26
|
+
self.assertEqual(mask.shape, (1, 1, 3, 3))
|
27
|
+
self.assertEqual(mask[0, 0, 0, 0], 0)
|
28
|
+
self.assertEqual(mask[0, 0, 0, 1], float("-inf"))
|
29
|
+
self.assertEqual(mask[0, 0, 0, 2], float("-inf"))
|
30
|
+
self.assertEqual(mask[0, 0, 1, 0], 0)
|
31
|
+
self.assertEqual(mask[0, 0, 1, 1], 0)
|
32
|
+
self.assertEqual(mask[0, 0, 1, 2], float("-inf"))
|
33
|
+
self.assertEqual(mask[0, 0, 2, 0], 0)
|
34
|
+
self.assertEqual(mask[0, 0, 2, 1], 0)
|
35
|
+
self.assertEqual(mask[0, 0, 2, 2], 0)
|
36
|
+
|
37
|
+
def test_build_sliding_window_mask_cache(self):
|
38
|
+
mask = attention_utils.build_sliding_window_mask_cache(3, 2)
|
39
|
+
self.assertEqual(mask.shape, (1, 1, 3, 3))
|
40
|
+
self.assertEqual(mask[0, 0, 0, 0], 0)
|
41
|
+
self.assertEqual(mask[0, 0, 0, 1], float("-inf"))
|
42
|
+
self.assertEqual(mask[0, 0, 0, 2], float("-inf"))
|
43
|
+
self.assertEqual(mask[0, 0, 1, 0], 0)
|
44
|
+
self.assertEqual(mask[0, 0, 1, 1], 0)
|
45
|
+
self.assertEqual(mask[0, 0, 1, 2], float("-inf"))
|
46
|
+
self.assertEqual(mask[0, 0, 2, 0], float("-inf"))
|
47
|
+
self.assertEqual(mask[0, 0, 2, 1], 0)
|
48
|
+
self.assertEqual(mask[0, 0, 2, 2], 0)
|
49
|
+
|
50
|
+
def test_build_relative_position_buckets(self):
|
51
|
+
buckets = attention_utils.build_relative_position_buckets(
|
52
|
+
query_length=3, key_length=3, bidirectional=True, num_buckets=4
|
53
|
+
)
|
54
|
+
print(buckets)
|
55
|
+
self.assertEqual(buckets.shape, (1, 1, 3, 3))
|
56
|
+
self.assertTrue(
|
57
|
+
torch.equal(
|
58
|
+
buckets, torch.tensor([[[[0, 3, 3], [1, 0, 3], [1, 1, 0]]]])
|
59
|
+
)
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
if __name__ == "__main__":
|
64
|
+
googletest.main()
|
@@ -18,7 +18,7 @@
|
|
18
18
|
import dataclasses
|
19
19
|
from typing import Any, List, Tuple
|
20
20
|
|
21
|
-
|
21
|
+
import ai_edge_torch.generative.custom_ops.dynamic_update_slice as dus_utils
|
22
22
|
from ai_edge_torch.generative.layers import model_config
|
23
23
|
from ai_edge_torch.generative.utilities import types
|
24
24
|
import torch
|
@@ -266,8 +266,78 @@ def _update_kv_impl(
|
|
266
266
|
k_slice_indices = _get_slice_indices(input_pos)
|
267
267
|
v_slice_indices = _get_slice_indices(input_pos)
|
268
268
|
|
269
|
-
k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
|
270
|
-
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
|
269
|
+
k = dus_utils.dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
|
270
|
+
v = dus_utils.dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
|
271
271
|
|
272
272
|
updated_cache = KVCacheEntry(k, v, cache.kv_layout)
|
273
273
|
return updated_cache
|
274
|
+
|
275
|
+
|
276
|
+
def update_transposed(
|
277
|
+
cache: KVCacheEntry,
|
278
|
+
input_pos: torch.Tensor,
|
279
|
+
k_slice: torch.Tensor,
|
280
|
+
v_slice: torch.Tensor,
|
281
|
+
) -> KVCacheEntry:
|
282
|
+
"""Out of place update of Cache buffer.
|
283
|
+
|
284
|
+
Args:
|
285
|
+
cache (KVCacheEntry): The original cache buffer.
|
286
|
+
input_pos (torch.Tensor): The update slice positions.
|
287
|
+
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
288
|
+
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
KVCacheEntry: The updated KVCacheBase entry based on the passed
|
292
|
+
inputs.
|
293
|
+
"""
|
294
|
+
assert (
|
295
|
+
cache.kv_layout == KV_LAYOUT_TRANSPOSED
|
296
|
+
), "KV entry must have transposed layout."
|
297
|
+
return _update_kv_impl_transposed(cache, input_pos, k_slice, v_slice)
|
298
|
+
|
299
|
+
|
300
|
+
def _get_slice_indices_transposed(
|
301
|
+
positions: torch.Tensor, cache_dim: int, ts_idx: int
|
302
|
+
) -> torch.Tensor:
|
303
|
+
"""Returns the slice indices."""
|
304
|
+
positions = positions.float()[0].reshape(
|
305
|
+
1,
|
306
|
+
)
|
307
|
+
|
308
|
+
zeros = torch.zeros((1,), dtype=torch.float32)
|
309
|
+
indices = []
|
310
|
+
for i in range(cache_dim):
|
311
|
+
if i == ts_idx:
|
312
|
+
indices.append(positions)
|
313
|
+
else:
|
314
|
+
indices.append(zeros)
|
315
|
+
slice_indices = torch.cat(indices, dim=0)
|
316
|
+
slice_indices = slice_indices.int()
|
317
|
+
return slice_indices
|
318
|
+
|
319
|
+
|
320
|
+
def _update_kv_impl_transposed(
|
321
|
+
cache: KVCacheEntry,
|
322
|
+
input_pos: torch.Tensor,
|
323
|
+
k_slice: torch.Tensor,
|
324
|
+
v_slice: torch.Tensor,
|
325
|
+
) -> KVCacheEntry:
|
326
|
+
"""Updates the cache buffer with High Level Function Boundary annotation."""
|
327
|
+
cache_dim = 4
|
328
|
+
k_ts_idx = 2
|
329
|
+
v_ts_idx = 3
|
330
|
+
positions = input_pos.clone()
|
331
|
+
k_slice_indices = _get_slice_indices_transposed(
|
332
|
+
positions, cache_dim, k_ts_idx
|
333
|
+
)
|
334
|
+
v_slice_indices = _get_slice_indices_transposed(
|
335
|
+
positions, cache_dim, v_ts_idx
|
336
|
+
)
|
337
|
+
k = dus_utils.dynamic_update_slice(
|
338
|
+
cache.k_cache, k_slice, [x for x in k_slice_indices]
|
339
|
+
)
|
340
|
+
v = dus_utils.dynamic_update_slice(
|
341
|
+
cache.v_cache, v_slice, [x for x in v_slice_indices]
|
342
|
+
)
|
343
|
+
return KVCacheEntry(k, v, cache.kv_layout)
|
@@ -19,7 +19,6 @@ from typing import Tuple
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
21
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
22
|
-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
23
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
23
|
import torch
|
25
24
|
|
@@ -68,7 +67,7 @@ def _sdpa_with_kv_update_transposed(
|
|
68
67
|
1, -1, config.head_dim, seq_len
|
69
68
|
) # 1, bk, h, s
|
70
69
|
|
71
|
-
kv =
|
70
|
+
kv = kv_utils.update_transposed(kv, input_pos, key, value)
|
72
71
|
key, value = kv.k_cache, kv.v_cache
|
73
72
|
|
74
73
|
sdpa_out = sdpa.scaled_dot_product_attention_transposed(
|
@@ -50,8 +50,7 @@ def _build_mask(mask_len, kv_cache_max_len) -> torch.Tensor:
|
|
50
50
|
mask = torch.full(
|
51
51
|
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
52
52
|
)
|
53
|
-
|
54
|
-
return mask
|
53
|
+
return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
55
54
|
|
56
55
|
|
57
56
|
def get_from_flags() -> ExportConfig:
|
@@ -62,6 +61,13 @@ def get_from_flags() -> ExportConfig:
|
|
62
61
|
export_config.prefill_mask = _build_mask(
|
63
62
|
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
64
63
|
)
|
64
|
+
# Note that the decode mask is not a correct causal mask, but it is okay
|
65
|
+
# for the conversion purpose because only the shape matters in conversion.
|
66
|
+
# A correct causal mask of decode for a given token position of decode, it
|
67
|
+
# should be built like:
|
68
|
+
#
|
69
|
+
# torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
|
70
|
+
#
|
65
71
|
export_config.decode_mask = _build_mask(1, flags.FLAGS.kv_cache_max_len)
|
66
72
|
|
67
73
|
if flags.FLAGS.transpose_kv_cache:
|
@@ -85,14 +85,35 @@ class ModelWrapper(torch.nn.Module):
|
|
85
85
|
class ReauthoredModelWrapper(ModelWrapper):
|
86
86
|
"""A wrapper for the model reauthored with ai_edge_torch Generative API."""
|
87
87
|
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
model: torch.nn.Module,
|
91
|
+
mask_as_input: bool = False,
|
92
|
+
kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
|
93
|
+
):
|
94
|
+
"""Wraps a reauthored model with some options."""
|
95
|
+
super().__init__(model)
|
96
|
+
self.mask_as_input = mask_as_input
|
97
|
+
self.kv_layout = kv_layout
|
98
|
+
|
88
99
|
def _init_kv_cache(self):
|
89
100
|
"""Returns an initialized KV cache."""
|
90
|
-
return kv_utils.KVCache.from_model_config(
|
101
|
+
return kv_utils.KVCache.from_model_config(
|
102
|
+
self.model.config, kv_layout=self.kv_layout
|
103
|
+
)
|
91
104
|
|
92
105
|
def _get_extra_args_for_forward(self) -> dict[str, Any]:
|
93
106
|
"""Returns extra arguments for the forward() method."""
|
94
107
|
return {}
|
95
108
|
|
109
|
+
def _build_mask(self, input_pos: torch.Tensor) -> torch.Tensor:
|
110
|
+
"""Builds a mask for the model."""
|
111
|
+
kv_cache_max_len = self.model.config.kv_cache_max_len
|
112
|
+
mask = torch.full(
|
113
|
+
(len(input_pos), kv_cache_max_len), float("-inf"), dtype=torch.float32
|
114
|
+
)
|
115
|
+
return torch.triu(mask, diagonal=input_pos[0] + 1).unsqueeze(0).unsqueeze(0)
|
116
|
+
|
96
117
|
def _forward_with_kv_cache(
|
97
118
|
self,
|
98
119
|
tokens: torch.Tensor,
|
@@ -119,6 +140,8 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
119
140
|
extra_args["export_config"] = self.export_config
|
120
141
|
if pixel_values is not None:
|
121
142
|
extra_args["pixel_values"] = pixel_values
|
143
|
+
if self.mask_as_input:
|
144
|
+
extra_args["mask"] = self._build_mask(input_pos)
|
122
145
|
output = self.model.forward(tokens, input_pos, kv_cache, **extra_args)
|
123
146
|
return output["logits"], output["kv_cache"]
|
124
147
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250429
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=I820JmIf90_QKTKyhmQGVjX9U-WMGUVEo9_N-Q_aQuk,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -57,20 +57,20 @@ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=l0OrPGmX8
|
|
57
57
|
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
|
58
58
|
ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
|
59
59
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
60
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=
|
61
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
62
|
-
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=
|
63
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=RRilUl2Ui08R9gy1Ua0jnaXNCrIJJb-oztgP62G3mX4,1526
|
61
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=7IlF-4NEfZAzIfkOUHR-HeCSLSUGEu7wnO52UtERCa4,1527
|
62
|
+
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=6ImjTzJcq6JoKz2Z-z8pjv5BsRu5nUeEsTK3IPs3xgI,3521
|
63
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=JQLLiHNVBM9jOrZqUF0EmgAwtDD0yTRlmIbLaWM7qTg,11557
|
64
64
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
65
|
-
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=
|
66
|
-
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=
|
65
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
|
66
|
+
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
|
67
67
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
68
68
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
|
69
69
|
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
|
70
70
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
|
71
71
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
72
72
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
73
|
-
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=
|
73
|
+
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=KnE9ME3mrpQkAxFlBOJLsqcQkjsdDL1ClNhJahX5K5I,8960
|
74
74
|
ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
75
75
|
ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=946mchDmvUhMsv1kzslp4LHtCIuHn4qjimHYQ-XnxMo,2962
|
76
76
|
ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
|
@@ -154,18 +154,18 @@ ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx
|
|
154
154
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
155
155
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
156
156
|
ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWnQ8npEAfgcjMIkEY,12964
|
157
|
+
ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
|
157
158
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
159
|
+
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
|
158
160
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
159
161
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
160
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
162
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
|
161
163
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
162
164
|
ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
|
163
165
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
164
166
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
165
167
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=efqqGRZPJ55hKn1MQJ-cXfrJD85uS1v7W_juyGyts58,5648
|
166
|
-
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=
|
167
|
-
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
168
|
-
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
|
168
|
+
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=iw7D_46CFe9iRvU0UumbkIoqWQEhDroxm9ABcK-CLlM,3600
|
169
169
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
170
170
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
171
171
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
@@ -187,8 +187,8 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
|
|
187
187
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
188
188
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
189
189
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
190
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
191
|
-
ai_edge_torch/generative/utilities/export_config.py,sha256=
|
190
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=8A1MvU8SbJQkn2SIhF-73TXbI_i6nrloCdkpw83P2xQ,10953
|
191
|
+
ai_edge_torch/generative/utilities/export_config.py,sha256=yGkfdN8Qrp8b_K8e5H0qaYmDrg0Dx_eb75JLhOnlygQ,2827
|
192
192
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
193
193
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
194
194
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
@@ -196,7 +196,7 @@ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWt
|
|
196
196
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
197
197
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
198
198
|
ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
|
199
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
199
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=ETO2ShU5KXG7MLP8eVOWuzuRLCUtapafYHcZ6TZHIkw,13061
|
200
200
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
201
201
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
|
202
202
|
ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
|
@@ -246,8 +246,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
246
246
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
247
247
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
248
248
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
-
ai_edge_torch_nightly-0.5.0.
|
252
|
-
ai_edge_torch_nightly-0.5.0.
|
253
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/METADATA,sha256=05nMBPcVBVJcZhDI9SzsjryW3d4vpeeH_9H07RaA-PI,2051
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
252
|
+
ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
253
|
+
ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/RECORD,,
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
#
|
@@ -1,90 +0,0 @@
|
|
1
|
-
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Utility functions for KV Cache.
|
17
|
-
|
18
|
-
This is an experimental implementation and is subject to change at any time.
|
19
|
-
"""
|
20
|
-
|
21
|
-
from ai_edge_torch.generative.custom_ops import dynamic_update_slice as dus_utils
|
22
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
|
-
import torch
|
24
|
-
|
25
|
-
|
26
|
-
def update(
|
27
|
-
cache: kv_utils.KVCacheEntry,
|
28
|
-
input_pos: torch.Tensor,
|
29
|
-
k_slice: torch.Tensor,
|
30
|
-
v_slice: torch.Tensor,
|
31
|
-
) -> kv_utils.KVCacheEntry:
|
32
|
-
"""Out of place update of Cache buffer.
|
33
|
-
|
34
|
-
Args:
|
35
|
-
cache (kv_utils.KVCacheEntry): The original cache buffer.
|
36
|
-
input_pos (torch.Tensor): The update slice positions.
|
37
|
-
k_slice (torch.Tensor): The K slice to be updated in the new cache.
|
38
|
-
v_slice (torch.Tensor): The V slice to be updated in the new cache.
|
39
|
-
|
40
|
-
Returns:
|
41
|
-
kv_utils.KVCacheEntry: The updated KVCacheBase entry based on the passed
|
42
|
-
inputs.
|
43
|
-
"""
|
44
|
-
assert (
|
45
|
-
cache.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED
|
46
|
-
), "KV entry must have transposed layout."
|
47
|
-
update_kv_cache = _update_kv_impl_transposed
|
48
|
-
return update_kv_cache(cache, input_pos, k_slice, v_slice)
|
49
|
-
|
50
|
-
|
51
|
-
def _get_slice_indices(
|
52
|
-
positions: torch.Tensor, cache_dim: int, ts_idx: int
|
53
|
-
) -> torch.Tensor:
|
54
|
-
"""Returns the slice indices."""
|
55
|
-
positions = positions.float()[0].reshape(
|
56
|
-
1,
|
57
|
-
)
|
58
|
-
|
59
|
-
zeros = torch.zeros((1,), dtype=torch.float32)
|
60
|
-
indices = []
|
61
|
-
for i in range(cache_dim):
|
62
|
-
if i == ts_idx:
|
63
|
-
indices.append(positions)
|
64
|
-
else:
|
65
|
-
indices.append(zeros)
|
66
|
-
slice_indices = torch.cat(indices, dim=0)
|
67
|
-
slice_indices = slice_indices.int()
|
68
|
-
return slice_indices
|
69
|
-
|
70
|
-
|
71
|
-
def _update_kv_impl_transposed(
|
72
|
-
cache: kv_utils.KVCacheEntry,
|
73
|
-
input_pos: torch.Tensor,
|
74
|
-
k_slice: torch.Tensor,
|
75
|
-
v_slice: torch.Tensor,
|
76
|
-
) -> kv_utils.KVCacheEntry:
|
77
|
-
"""Update the cache buffer with High Level Function Boundary annotation."""
|
78
|
-
cache_dim = 4
|
79
|
-
k_ts_idx = 2
|
80
|
-
v_ts_idx = 3
|
81
|
-
positions = input_pos.clone()
|
82
|
-
k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
|
83
|
-
v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
|
84
|
-
k = dus_utils.dynamic_update_slice(
|
85
|
-
cache.k_cache, k_slice, [x for x in k_slice_indices]
|
86
|
-
)
|
87
|
-
v = dus_utils.dynamic_update_slice(
|
88
|
-
cache.v_cache, v_slice, [x for x in v_slice_indices]
|
89
|
-
)
|
90
|
-
return kv_utils.KVCacheEntry(k, v, cache.kv_layout)
|
File without changes
|
File without changes
|