ai-edge-torch-nightly 0.5.0.dev20250427__py3-none-any.whl → 0.5.0.dev20250429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +1 -2
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +1 -2
  3. ai_edge_torch/generative/examples/gemma/gemma1.py +1 -0
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -1
  5. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +18 -2
  6. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -2
  7. ai_edge_torch/generative/examples/gemma3/verify_util.py +5 -3
  8. ai_edge_torch/generative/layers/attention_test.py +153 -0
  9. ai_edge_torch/generative/layers/attention_utils_test.py +64 -0
  10. ai_edge_torch/generative/layers/kv_cache.py +73 -3
  11. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +1 -2
  12. ai_edge_torch/generative/utilities/converter.py +1 -1
  13. ai_edge_torch/generative/utilities/export_config.py +8 -2
  14. ai_edge_torch/generative/utilities/verifier.py +24 -1
  15. ai_edge_torch/version.py +1 -1
  16. {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/RECORD +20 -20
  18. ai_edge_torch/generative/layers/experimental/__init__.py +0 -14
  19. ai_edge_torch/generative/layers/experimental/kv_cache.py +0 -90
  20. {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/LICENSE +0 -0
  21. {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/WHEEL +0 -0
  22. {ai_edge_torch_nightly-0.5.0.dev20250427.dist-info → ai_edge_torch_nightly-0.5.0.dev20250429.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("gemma-2b")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("gemma2-2b")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -65,6 +65,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
65
65
  type=cfg.NormalizationType.RMS_NORM,
66
66
  epsilon=1e-6,
67
67
  zero_centered=True,
68
+ enable_hlfb=True,
68
69
  )
69
70
  block_config = cfg.TransformerBlockConfig(
70
71
  attn_config=attn_config,
@@ -236,6 +236,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
236
236
  type=cfg.NormalizationType.RMS_NORM,
237
237
  epsilon=1e-6,
238
238
  zero_centered=True,
239
+ enable_hlfb=True,
239
240
  )
240
241
  ff_config = cfg.FeedForwardConfig(
241
242
  type=cfg.FeedForwardType.GATED,
@@ -314,5 +315,5 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
314
315
  tensor_names=tensor_names,
315
316
  model_class=Gemma2,
316
317
  )
317
- except KeyError as ke:
318
+ except KeyError as _:
318
319
  continue
@@ -18,6 +18,7 @@
18
18
  from absl import app
19
19
  from absl import flags
20
20
  from ai_edge_torch.generative.examples.gemma import verify_util
21
+ import ai_edge_torch.generative.layers.kv_cache as kv_utils
21
22
  import kagglehub
22
23
 
23
24
 
@@ -31,12 +32,27 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
31
32
  30,
32
33
  "The maximum size of the generated tokens.",
33
34
  )
34
-
35
+ _MASK_AS_INPUT = flags.DEFINE_bool(
36
+ "mask_as_input",
37
+ True,
38
+ "Pass the causal self attention mask to the model.",
39
+ )
40
+ _TRANSPOSE_KV_CACHE = flags.DEFINE_bool(
41
+ "transpose_kv_cache",
42
+ True,
43
+ "Transpose the KV cache to reduce memory usage.",
44
+ )
35
45
 
36
46
  def main(_):
37
47
  checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
38
48
 
39
- verify_util.verify_gemma2(checkpoint, _PROMPTS.value, _MAX_NEW_TOKENS.value)
49
+ verify_util.verify_gemma2(
50
+ checkpoint,
51
+ _PROMPTS.value,
52
+ _MAX_NEW_TOKENS.value,
53
+ _MASK_AS_INPUT.value,
54
+ kv_utils.KV_LAYOUT_TRANSPOSED if _TRANSPOSE_KV_CACHE.value else kv_utils.KV_LAYOUT_DEFAULT,
55
+ )
40
56
 
41
57
 
42
58
  if __name__ == "__main__":
@@ -21,6 +21,7 @@ from typing import List, Tuple
21
21
 
22
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
+ import ai_edge_torch.generative.layers.kv_cache as kv_utils
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  from gemma import config as gemma_config
26
27
  from gemma import model as gemma_model
@@ -108,6 +109,8 @@ def verify_reauthored_gemma_model(
108
109
  weight_filename: str = "model.ckpt",
109
110
  tokenizer_filename: str = "tokenizer.model",
110
111
  max_new_tokens: int = 20,
112
+ mask_as_input: bool = False,
113
+ kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
111
114
  rtol: float = 1e-05,
112
115
  atol: float = 1e-05,
113
116
  ) -> bool:
@@ -126,7 +129,11 @@ def verify_reauthored_gemma_model(
126
129
 
127
130
  return verifier.verify_reauthored_model(
128
131
  original_model=GemmaWrapper(original_model),
129
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
132
+ reauthored_model=verifier.ReauthoredModelWrapper(
133
+ reauthored_model,
134
+ mask_as_input=mask_as_input,
135
+ kv_layout=kv_layout,
136
+ ),
130
137
  tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
131
138
  generate_prompts=generate_prompts,
132
139
  max_new_tokens=max_new_tokens,
@@ -137,7 +144,11 @@ def verify_reauthored_gemma_model(
137
144
 
138
145
 
139
146
  def verify_gemma2(
140
- gemma2_model_path: str, prompts: List[str], max_new_tokens: int
147
+ gemma2_model_path: str,
148
+ prompts: List[str],
149
+ max_new_tokens: int,
150
+ mask_as_input: bool = False,
151
+ kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
141
152
  ) -> bool:
142
153
  """Verifies the reauthored Gemma2 model.
143
154
 
@@ -153,5 +164,7 @@ def verify_gemma2(
153
164
  generate_prompts=prompts,
154
165
  forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
155
166
  max_new_tokens=max_new_tokens,
167
+ mask_as_input=mask_as_input,
168
+ kv_layout=kv_layout,
156
169
  atol=1e-04,
157
170
  )
@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
22
22
  from ai_edge_torch.generative.examples.gemma3 import gemma3
23
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
- from ai_edge_torch.generative.utilities.experimental import verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
26
  from gemma import config as gemma_config
27
27
  from gemma import model as gemma_model
28
28
  import torch
@@ -92,10 +92,12 @@ class GemmaWrapper(verifier.ModelWrapper):
92
92
  class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
93
93
  """Unified Gemma3 model wrapper for verification."""
94
94
 
95
+ def __init__(self, model: torch.nn.Module):
96
+ super().__init__(model, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED)
97
+
95
98
  def _init_kv_cache(self):
96
- """Returns an initialized KV cache."""
97
99
  return kv_utils.KVCache.from_model_config(
98
- self.model.model.config, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
100
+ self.model.model.config, kv_layout=self.kv_layout
99
101
  )
100
102
 
101
103
  def forward(
@@ -0,0 +1,153 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ai_edge_torch.generative.layers import attention
17
+ from ai_edge_torch.generative.layers import model_config as cfg
18
+ import torch
19
+
20
+ from absl.testing import absltest as googletest
21
+ from absl.testing import parameterized
22
+
23
+
24
+ class AttentionTest(parameterized.TestCase):
25
+
26
+ @parameterized.named_parameters(
27
+ dict(
28
+ testcase_name="local_causal_self_attention",
29
+ attn_type=cfg.AttentionType.LOCAL_SLIDING,
30
+ expected_shape=(1, 10, 16),
31
+ ),
32
+ dict(
33
+ testcase_name="global_causal_self_attention",
34
+ attn_type=cfg.AttentionType.GLOBAL,
35
+ expected_shape=(1, 10, 16),
36
+ ),
37
+ )
38
+ def test_causal_self_attention(
39
+ self, attn_type: cfg.AttentionType, expected_shape: tuple[int, ...]
40
+ ):
41
+ norm_config = cfg.NormalizationConfig(
42
+ type=cfg.NormalizationType.RMS_NORM,
43
+ epsilon=1e-6,
44
+ zero_centered=True,
45
+ enable_hlfb=True,
46
+ )
47
+ attn_config = cfg.AttentionConfig(
48
+ num_heads=2,
49
+ head_dim=8,
50
+ num_query_groups=1,
51
+ rotary_base=100,
52
+ rotary_percentage=1.0,
53
+ qkv_transpose_before_split=True,
54
+ query_norm_config=norm_config,
55
+ key_norm_config=norm_config,
56
+ logit_softcap=None,
57
+ sliding_window_size=16,
58
+ attn_type=attn_type,
59
+ )
60
+ self_atten = attention.CausalSelfAttention(
61
+ dim=16,
62
+ config=attn_config,
63
+ enable_hlfb=True,
64
+ )
65
+ x = torch.randn(1, 10, 16)
66
+ attn_mask = torch.ones((1, 1, 10, 10), dtype=torch.float32)
67
+ out = self_atten(x, rope=None, mask=attn_mask)
68
+ self.assertEqual(out.shape, expected_shape)
69
+
70
+ def test_cross_attention(self):
71
+ norm_config = cfg.NormalizationConfig(
72
+ type=cfg.NormalizationType.RMS_NORM,
73
+ epsilon=1e-6,
74
+ zero_centered=True,
75
+ enable_hlfb=True,
76
+ )
77
+ attn_config = cfg.AttentionConfig(
78
+ num_heads=2,
79
+ head_dim=8,
80
+ num_query_groups=1,
81
+ rotary_base=100,
82
+ rotary_percentage=1.0,
83
+ qkv_transpose_before_split=True,
84
+ query_norm_config=norm_config,
85
+ key_norm_config=norm_config,
86
+ logit_softcap=None,
87
+ sliding_window_size=16,
88
+ attn_type=cfg.AttentionType.GLOBAL,
89
+ )
90
+ cross_atten = attention.CrossAttention(
91
+ query_dim=16,
92
+ cross_dim=16,
93
+ hidden_dim=16,
94
+ output_dim=16,
95
+ config=attn_config,
96
+ enable_hlfb=True,
97
+ )
98
+ x = torch.randn(1, 10, 16)
99
+ y = torch.randn(1, 10, 16)
100
+ out = cross_atten(x, y, rope=None)
101
+ self.assertEqual(out.shape, (1, 10, 16))
102
+
103
+ def test_transformer_block(self):
104
+ norm_config = cfg.NormalizationConfig(
105
+ type=cfg.NormalizationType.RMS_NORM,
106
+ epsilon=1e-6,
107
+ zero_centered=True,
108
+ enable_hlfb=True,
109
+ )
110
+ attn_config = cfg.AttentionConfig(
111
+ num_heads=2,
112
+ head_dim=8,
113
+ num_query_groups=1,
114
+ rotary_base=100,
115
+ rotary_percentage=1.0,
116
+ qkv_transpose_before_split=True,
117
+ query_norm_config=norm_config,
118
+ key_norm_config=norm_config,
119
+ logit_softcap=None,
120
+ sliding_window_size=16,
121
+ attn_type=cfg.AttentionType.GLOBAL,
122
+ )
123
+ ff_config = cfg.FeedForwardConfig(
124
+ type=cfg.FeedForwardType.GATED,
125
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
126
+ intermediate_size=32,
127
+ )
128
+ block_config = cfg.TransformerBlockConfig(
129
+ attn_config=attn_config,
130
+ ff_config=ff_config,
131
+ post_attention_norm_config=norm_config,
132
+ parallel_residual=True,
133
+ )
134
+ model_config = cfg.ModelConfig(
135
+ vocab_size=100,
136
+ embedding_dim=16,
137
+ enable_hlfb=True,
138
+ num_layers=1,
139
+ max_seq_len=10,
140
+ block_configs=[block_config],
141
+ )
142
+ transformer_block = attention.TransformerBlock(
143
+ config=block_config,
144
+ model_config=model_config,
145
+ )
146
+ x = torch.randn(1, 10, 16)
147
+ attn_mask = torch.ones((1, 1, 10, 10), dtype=torch.float32)
148
+ out = transformer_block(x, rope=None, mask=attn_mask)
149
+ self.assertEqual(out.shape, (1, 10, 16))
150
+
151
+
152
+ if __name__ == "__main__":
153
+ googletest.main()
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ai_edge_torch.generative.layers import attention_utils
17
+ import torch
18
+
19
+ from absl.testing import absltest as googletest
20
+
21
+
22
+ class AttentionUtilsTest(googletest.TestCase):
23
+
24
+ def test_build_causal_mask_cache(self):
25
+ mask = attention_utils.build_causal_mask_cache(3)
26
+ self.assertEqual(mask.shape, (1, 1, 3, 3))
27
+ self.assertEqual(mask[0, 0, 0, 0], 0)
28
+ self.assertEqual(mask[0, 0, 0, 1], float("-inf"))
29
+ self.assertEqual(mask[0, 0, 0, 2], float("-inf"))
30
+ self.assertEqual(mask[0, 0, 1, 0], 0)
31
+ self.assertEqual(mask[0, 0, 1, 1], 0)
32
+ self.assertEqual(mask[0, 0, 1, 2], float("-inf"))
33
+ self.assertEqual(mask[0, 0, 2, 0], 0)
34
+ self.assertEqual(mask[0, 0, 2, 1], 0)
35
+ self.assertEqual(mask[0, 0, 2, 2], 0)
36
+
37
+ def test_build_sliding_window_mask_cache(self):
38
+ mask = attention_utils.build_sliding_window_mask_cache(3, 2)
39
+ self.assertEqual(mask.shape, (1, 1, 3, 3))
40
+ self.assertEqual(mask[0, 0, 0, 0], 0)
41
+ self.assertEqual(mask[0, 0, 0, 1], float("-inf"))
42
+ self.assertEqual(mask[0, 0, 0, 2], float("-inf"))
43
+ self.assertEqual(mask[0, 0, 1, 0], 0)
44
+ self.assertEqual(mask[0, 0, 1, 1], 0)
45
+ self.assertEqual(mask[0, 0, 1, 2], float("-inf"))
46
+ self.assertEqual(mask[0, 0, 2, 0], float("-inf"))
47
+ self.assertEqual(mask[0, 0, 2, 1], 0)
48
+ self.assertEqual(mask[0, 0, 2, 2], 0)
49
+
50
+ def test_build_relative_position_buckets(self):
51
+ buckets = attention_utils.build_relative_position_buckets(
52
+ query_length=3, key_length=3, bidirectional=True, num_buckets=4
53
+ )
54
+ print(buckets)
55
+ self.assertEqual(buckets.shape, (1, 1, 3, 3))
56
+ self.assertTrue(
57
+ torch.equal(
58
+ buckets, torch.tensor([[[[0, 3, 3], [1, 0, 3], [1, 1, 0]]]])
59
+ )
60
+ )
61
+
62
+
63
+ if __name__ == "__main__":
64
+ googletest.main()
@@ -18,7 +18,7 @@
18
18
  import dataclasses
19
19
  from typing import Any, List, Tuple
20
20
 
21
- from ai_edge_torch.generative.custom_ops.dynamic_update_slice import dynamic_update_slice
21
+ import ai_edge_torch.generative.custom_ops.dynamic_update_slice as dus_utils
22
22
  from ai_edge_torch.generative.layers import model_config
23
23
  from ai_edge_torch.generative.utilities import types
24
24
  import torch
@@ -266,8 +266,78 @@ def _update_kv_impl(
266
266
  k_slice_indices = _get_slice_indices(input_pos)
267
267
  v_slice_indices = _get_slice_indices(input_pos)
268
268
 
269
- k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
270
- v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
269
+ k = dus_utils.dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
270
+ v = dus_utils.dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
271
271
 
272
272
  updated_cache = KVCacheEntry(k, v, cache.kv_layout)
273
273
  return updated_cache
274
+
275
+
276
+ def update_transposed(
277
+ cache: KVCacheEntry,
278
+ input_pos: torch.Tensor,
279
+ k_slice: torch.Tensor,
280
+ v_slice: torch.Tensor,
281
+ ) -> KVCacheEntry:
282
+ """Out of place update of Cache buffer.
283
+
284
+ Args:
285
+ cache (KVCacheEntry): The original cache buffer.
286
+ input_pos (torch.Tensor): The update slice positions.
287
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
288
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
289
+
290
+ Returns:
291
+ KVCacheEntry: The updated KVCacheBase entry based on the passed
292
+ inputs.
293
+ """
294
+ assert (
295
+ cache.kv_layout == KV_LAYOUT_TRANSPOSED
296
+ ), "KV entry must have transposed layout."
297
+ return _update_kv_impl_transposed(cache, input_pos, k_slice, v_slice)
298
+
299
+
300
+ def _get_slice_indices_transposed(
301
+ positions: torch.Tensor, cache_dim: int, ts_idx: int
302
+ ) -> torch.Tensor:
303
+ """Returns the slice indices."""
304
+ positions = positions.float()[0].reshape(
305
+ 1,
306
+ )
307
+
308
+ zeros = torch.zeros((1,), dtype=torch.float32)
309
+ indices = []
310
+ for i in range(cache_dim):
311
+ if i == ts_idx:
312
+ indices.append(positions)
313
+ else:
314
+ indices.append(zeros)
315
+ slice_indices = torch.cat(indices, dim=0)
316
+ slice_indices = slice_indices.int()
317
+ return slice_indices
318
+
319
+
320
+ def _update_kv_impl_transposed(
321
+ cache: KVCacheEntry,
322
+ input_pos: torch.Tensor,
323
+ k_slice: torch.Tensor,
324
+ v_slice: torch.Tensor,
325
+ ) -> KVCacheEntry:
326
+ """Updates the cache buffer with High Level Function Boundary annotation."""
327
+ cache_dim = 4
328
+ k_ts_idx = 2
329
+ v_ts_idx = 3
330
+ positions = input_pos.clone()
331
+ k_slice_indices = _get_slice_indices_transposed(
332
+ positions, cache_dim, k_ts_idx
333
+ )
334
+ v_slice_indices = _get_slice_indices_transposed(
335
+ positions, cache_dim, v_ts_idx
336
+ )
337
+ k = dus_utils.dynamic_update_slice(
338
+ cache.k_cache, k_slice, [x for x in k_slice_indices]
339
+ )
340
+ v = dus_utils.dynamic_update_slice(
341
+ cache.v_cache, v_slice, [x for x in v_slice_indices]
342
+ )
343
+ return KVCacheEntry(k, v, cache.kv_layout)
@@ -19,7 +19,6 @@ from typing import Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
21
  from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
22
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
23
22
  import ai_edge_torch.generative.layers.model_config as cfg
24
23
  import torch
25
24
 
@@ -68,7 +67,7 @@ def _sdpa_with_kv_update_transposed(
68
67
  1, -1, config.head_dim, seq_len
69
68
  ) # 1, bk, h, s
70
69
 
71
- kv = kv_utils_experimental.update(kv, input_pos, key, value)
70
+ kv = kv_utils.update_transposed(kv, input_pos, key, value)
72
71
  key, value = kv.k_cache, kv.v_cache
73
72
 
74
73
  sdpa_out = sdpa.scaled_dot_product_attention_transposed(
@@ -57,7 +57,7 @@ def define_conversion_flags(model_name: str):
57
57
  )
58
58
  flags.DEFINE_string(
59
59
  'output_name_prefix',
60
- model_name,
60
+ f'{model_name}',
61
61
  'The prefix of the output tflite model name.',
62
62
  )
63
63
  flags.DEFINE_multi_integer(
@@ -50,8 +50,7 @@ def _build_mask(mask_len, kv_cache_max_len) -> torch.Tensor:
50
50
  mask = torch.full(
51
51
  (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
52
52
  )
53
- mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
54
- return mask
53
+ return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
55
54
 
56
55
 
57
56
  def get_from_flags() -> ExportConfig:
@@ -62,6 +61,13 @@ def get_from_flags() -> ExportConfig:
62
61
  export_config.prefill_mask = _build_mask(
63
62
  flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
64
63
  )
64
+ # Note that the decode mask is not a correct causal mask, but it is okay
65
+ # for the conversion purpose because only the shape matters in conversion.
66
+ # A correct causal mask of decode for a given token position of decode, it
67
+ # should be built like:
68
+ #
69
+ # torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
70
+ #
65
71
  export_config.decode_mask = _build_mask(1, flags.FLAGS.kv_cache_max_len)
66
72
 
67
73
  if flags.FLAGS.transpose_kv_cache:
@@ -85,14 +85,35 @@ class ModelWrapper(torch.nn.Module):
85
85
  class ReauthoredModelWrapper(ModelWrapper):
86
86
  """A wrapper for the model reauthored with ai_edge_torch Generative API."""
87
87
 
88
+ def __init__(
89
+ self,
90
+ model: torch.nn.Module,
91
+ mask_as_input: bool = False,
92
+ kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
93
+ ):
94
+ """Wraps a reauthored model with some options."""
95
+ super().__init__(model)
96
+ self.mask_as_input = mask_as_input
97
+ self.kv_layout = kv_layout
98
+
88
99
  def _init_kv_cache(self):
89
100
  """Returns an initialized KV cache."""
90
- return kv_utils.KVCache.from_model_config(self.model.config)
101
+ return kv_utils.KVCache.from_model_config(
102
+ self.model.config, kv_layout=self.kv_layout
103
+ )
91
104
 
92
105
  def _get_extra_args_for_forward(self) -> dict[str, Any]:
93
106
  """Returns extra arguments for the forward() method."""
94
107
  return {}
95
108
 
109
+ def _build_mask(self, input_pos: torch.Tensor) -> torch.Tensor:
110
+ """Builds a mask for the model."""
111
+ kv_cache_max_len = self.model.config.kv_cache_max_len
112
+ mask = torch.full(
113
+ (len(input_pos), kv_cache_max_len), float("-inf"), dtype=torch.float32
114
+ )
115
+ return torch.triu(mask, diagonal=input_pos[0] + 1).unsqueeze(0).unsqueeze(0)
116
+
96
117
  def _forward_with_kv_cache(
97
118
  self,
98
119
  tokens: torch.Tensor,
@@ -119,6 +140,8 @@ class ReauthoredModelWrapper(ModelWrapper):
119
140
  extra_args["export_config"] = self.export_config
120
141
  if pixel_values is not None:
121
142
  extra_args["pixel_values"] = pixel_values
143
+ if self.mask_as_input:
144
+ extra_args["mask"] = self._build_mask(input_pos)
122
145
  output = self.model.forward(tokens, input_pos, kv_cache, **extra_args)
123
146
  return output["logits"], output["kv_cache"]
124
147
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250427"
16
+ __version__ = "0.5.0.dev20250429"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250427
3
+ Version: 0.5.0.dev20250429
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=RhNMNIs4sG78K3SOLk6zxuILeS_S2vhG7FJJOrV4cLM,706
5
+ ai_edge_torch/version.py,sha256=I820JmIf90_QKTKyhmQGVjX9U-WMGUVEo9_N-Q_aQuk,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -57,20 +57,20 @@ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=l0OrPGmX8
57
57
  ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
58
58
  ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
59
59
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=tSEtGeS-Ndcc_cTm7c4CT4FqRiwrHedEv1oJk4Y_zYU,1552
61
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=k2BUkf2cciItc3gFAyzWqcWZhlVFrD3TVikTmLXq04c,1553
62
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
63
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=SyyRae8vWLn0WxduxtahzVRbdSq4T2k5-7t8PfCR_k8,11534
60
+ ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=RRilUl2Ui08R9gy1Ua0jnaXNCrIJJb-oztgP62G3mX4,1526
61
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=7IlF-4NEfZAzIfkOUHR-HeCSLSUGEu7wnO52UtERCa4,1527
62
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=6ImjTzJcq6JoKz2Z-z8pjv5BsRu5nUeEsTK3IPs3xgI,3521
63
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=JQLLiHNVBM9jOrZqUF0EmgAwtDD0yTRlmIbLaWM7qTg,11557
64
64
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
65
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
66
- ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
65
+ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
66
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
67
67
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
68
68
  ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
69
69
  ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
70
70
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
71
71
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
72
72
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
73
- ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=nEv0qQ0l6gSXKxP5mNwkd2lRGxpFfD4e7FNV3V76zhw,8915
73
+ ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=KnE9ME3mrpQkAxFlBOJLsqcQkjsdDL1ClNhJahX5K5I,8960
74
74
  ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
75
  ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=946mchDmvUhMsv1kzslp4LHtCIuHn4qjimHYQ-XnxMo,2962
76
76
  ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
@@ -154,18 +154,18 @@ ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx
154
154
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
155
155
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
156
156
  ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWnQ8npEAfgcjMIkEY,12964
157
+ ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
157
158
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
159
+ ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
158
160
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
159
161
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
160
- ai_edge_torch/generative/layers/kv_cache.py,sha256=dDeirtuo9AnlN1tYoLbFi_pKhIDmn35FQY1m6X28hSY,8468
162
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
161
163
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
162
164
  ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
163
165
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
164
166
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
165
167
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=efqqGRZPJ55hKn1MQJ-cXfrJD85uS1v7W_juyGyts58,5648
166
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=Hn8Zw-jiB9GH2uZ-yaRMcDdpmjECcW4uCy-YNH9zV8c,3693
167
- ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
168
- ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
168
+ ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=iw7D_46CFe9iRvU0UumbkIoqWQEhDroxm9ABcK-CLlM,3600
169
169
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
170
170
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
171
171
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -187,8 +187,8 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
187
187
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
188
188
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
189
189
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
190
- ai_edge_torch/generative/utilities/converter.py,sha256=4RNNl7vk3WN_JG5EZajofiRSqtPnUNCYosxTacdEOto,10948
191
- ai_edge_torch/generative/utilities/export_config.py,sha256=maUVt0T5FsLpHO5H-BZ-O0FRBZO_ejKwGhPR9Qq8ViM,2490
190
+ ai_edge_torch/generative/utilities/converter.py,sha256=8A1MvU8SbJQkn2SIhF-73TXbI_i6nrloCdkpw83P2xQ,10953
191
+ ai_edge_torch/generative/utilities/export_config.py,sha256=yGkfdN8Qrp8b_K8e5H0qaYmDrg0Dx_eb75JLhOnlygQ,2827
192
192
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
193
193
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
194
194
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
@@ -196,7 +196,7 @@ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWt
196
196
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
197
197
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
198
198
  ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
199
- ai_edge_torch/generative/utilities/verifier.py,sha256=RSMQ8eda63VHM-5KmquKfogmTPyhGvGnqkoz9i4bppY,12270
199
+ ai_edge_torch/generative/utilities/verifier.py,sha256=ETO2ShU5KXG7MLP8eVOWuzuRLCUtapafYHcZ6TZHIkw,13061
200
200
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
201
201
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
202
202
  ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
@@ -246,8 +246,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
246
246
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
247
247
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
248
248
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
249
- ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
250
- ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/METADATA,sha256=g53PvQrw8WP7McVXcoMYSEF9lmh7VWexPnfQLGOTVJg,2051
251
- ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
252
- ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
253
- ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/RECORD,,
249
+ ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
250
+ ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/METADATA,sha256=05nMBPcVBVJcZhDI9SzsjryW3d4vpeeH_9H07RaA-PI,2051
251
+ ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
252
+ ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
253
+ ai_edge_torch_nightly-0.5.0.dev20250429.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- # Copyright 2025 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- #
@@ -1,90 +0,0 @@
1
- # Copyright 2025 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Utility functions for KV Cache.
17
-
18
- This is an experimental implementation and is subject to change at any time.
19
- """
20
-
21
- from ai_edge_torch.generative.custom_ops import dynamic_update_slice as dus_utils
22
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
- import torch
24
-
25
-
26
- def update(
27
- cache: kv_utils.KVCacheEntry,
28
- input_pos: torch.Tensor,
29
- k_slice: torch.Tensor,
30
- v_slice: torch.Tensor,
31
- ) -> kv_utils.KVCacheEntry:
32
- """Out of place update of Cache buffer.
33
-
34
- Args:
35
- cache (kv_utils.KVCacheEntry): The original cache buffer.
36
- input_pos (torch.Tensor): The update slice positions.
37
- k_slice (torch.Tensor): The K slice to be updated in the new cache.
38
- v_slice (torch.Tensor): The V slice to be updated in the new cache.
39
-
40
- Returns:
41
- kv_utils.KVCacheEntry: The updated KVCacheBase entry based on the passed
42
- inputs.
43
- """
44
- assert (
45
- cache.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED
46
- ), "KV entry must have transposed layout."
47
- update_kv_cache = _update_kv_impl_transposed
48
- return update_kv_cache(cache, input_pos, k_slice, v_slice)
49
-
50
-
51
- def _get_slice_indices(
52
- positions: torch.Tensor, cache_dim: int, ts_idx: int
53
- ) -> torch.Tensor:
54
- """Returns the slice indices."""
55
- positions = positions.float()[0].reshape(
56
- 1,
57
- )
58
-
59
- zeros = torch.zeros((1,), dtype=torch.float32)
60
- indices = []
61
- for i in range(cache_dim):
62
- if i == ts_idx:
63
- indices.append(positions)
64
- else:
65
- indices.append(zeros)
66
- slice_indices = torch.cat(indices, dim=0)
67
- slice_indices = slice_indices.int()
68
- return slice_indices
69
-
70
-
71
- def _update_kv_impl_transposed(
72
- cache: kv_utils.KVCacheEntry,
73
- input_pos: torch.Tensor,
74
- k_slice: torch.Tensor,
75
- v_slice: torch.Tensor,
76
- ) -> kv_utils.KVCacheEntry:
77
- """Update the cache buffer with High Level Function Boundary annotation."""
78
- cache_dim = 4
79
- k_ts_idx = 2
80
- v_ts_idx = 3
81
- positions = input_pos.clone()
82
- k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
83
- v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
84
- k = dus_utils.dynamic_update_slice(
85
- cache.k_cache, k_slice, [x for x in k_slice_indices]
86
- )
87
- v = dus_utils.dynamic_update_slice(
88
- cache.v_cache, v_slice, [x for x in v_slice_indices]
89
- )
90
- return kv_utils.KVCacheEntry(k, v, cache.kv_layout)