ai-edge-torch-nightly 0.5.0.dev20250408__py3-none-any.whl → 0.5.0.dev20250410__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
  3. ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py +50 -0
  4. ai_edge_torch/_convert/test/test_convert.py +21 -0
  5. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -2
  6. ai_edge_torch/generative/examples/gemma3/decoder.py +8 -9
  7. ai_edge_torch/generative/examples/gemma3/verify_util.py +4 -2
  8. ai_edge_torch/generative/layers/experimental/attention.py +10 -40
  9. ai_edge_torch/generative/layers/experimental/kv_cache.py +13 -283
  10. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +6 -10
  11. ai_edge_torch/generative/layers/experimental/types.py +3 -0
  12. ai_edge_torch/generative/layers/kv_cache.py +81 -14
  13. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +124 -0
  14. ai_edge_torch/generative/test/test_kv_cache.py +12 -19
  15. ai_edge_torch/generative/utilities/converter.py +8 -3
  16. ai_edge_torch/generative/utilities/export_config.py +3 -1
  17. ai_edge_torch/lowertools/odml_torch_utils.py +1 -0
  18. ai_edge_torch/odml_torch/lowerings/_basic.py +19 -0
  19. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  20. ai_edge_torch/odml_torch/lowerings/utils.py +1 -0
  21. ai_edge_torch/version.py +1 -1
  22. {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/METADATA +4 -2
  23. {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/RECORD +26 -24
  24. {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/LICENSE +0 -0
  25. {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/WHEEL +0 -0
  26. {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/top_level.txt +0 -0
@@ -40,8 +40,8 @@ def _run_convert_passes(
40
40
  fx_passes.OptimizeLayoutTransposesPass(),
41
41
  fx_passes.CanonicalizePass(),
42
42
  fx_passes.BuildAtenCompositePass(),
43
- fx_passes.CanonicalizePass(),
44
43
  fx_passes.RemoveNonUserOutputsPass(),
44
+ fx_passes.CastInputsBf16ToF32Pass(),
45
45
  fx_passes.CanonicalizePass(),
46
46
  ]
47
47
 
@@ -17,6 +17,7 @@ from typing import Sequence, Union
17
17
 
18
18
  from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
19
  from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
+ from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
20
21
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
23
  from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
@@ -0,0 +1,50 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Pass to cast all inputs with torch.bfloat16 type to torch.float32."""
16
+
17
+
18
+ from ai_edge_torch import fx_infra
19
+ import torch
20
+
21
+
22
+ def cast_f32(x):
23
+ return x.to(torch.float32)
24
+
25
+
26
+ class CastInputsBf16ToF32Pass(fx_infra.ExportedProgramPassBase):
27
+ """This pass casts all inputs with torch.bfloat16 type to torch.float32."""
28
+
29
+ def call(self, exported_program: torch.export.ExportedProgram):
30
+ modified = False
31
+ for node in exported_program.graph.nodes:
32
+ if (
33
+ node.op == "placeholder"
34
+ and node.meta.get("val").dtype == torch.bfloat16
35
+ ):
36
+ if not node.users:
37
+ continue
38
+
39
+ modified = True
40
+ user = next(iter(node.users))
41
+ with exported_program.graph.inserting_before(user):
42
+ cast_node = exported_program.graph.call_function(
43
+ cast_f32,
44
+ (node,),
45
+ )
46
+ node.replace_all_uses_with(cast_node)
47
+ cast_node.replace_input_with(cast_node, node)
48
+
49
+ exported_program.graph_module.recompile()
50
+ return fx_infra.ExportedProgramPassResult(exported_program, modified)
@@ -553,6 +553,27 @@ class TestConvert(googletest.TestCase):
553
553
  self.fail(f"PT2E conversion failed: {err}")
554
554
  # pylint: enable=broad-except
555
555
 
556
+ def test_convert_model_with_bfloat16_inputs(self):
557
+ """Test converting a simple model with torch.bfloat16 input.
558
+
559
+ bf16 inputs would remain in converted model signature but be casted to f32
560
+ right after the model inputs.
561
+ """
562
+
563
+ class SampleModel(nn.Module):
564
+
565
+ def forward(self, x: torch.Tensor):
566
+ return (x + 1) * 1.2
567
+
568
+ model = SampleModel().eval()
569
+ args = (torch.randn(10, 10).to(torch.bfloat16),)
570
+ # pylint: disable=broad-except
571
+ try:
572
+ ai_edge_torch.convert(model, args)
573
+ except Exception as err:
574
+ self.fail(f"Conversion failed with bloat16 inputs: {err}")
575
+ # pylint: enable=broad-except
576
+
556
577
 
557
578
  if __name__ == "__main__":
558
579
  googletest.main()
@@ -17,7 +17,7 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma3 import gemma3
20
- from ai_edge_torch.generative.layers.experimental import kv_cache
20
+ from ai_edge_torch.generative.layers import kv_cache
21
21
  from ai_edge_torch.generative.utilities import converter
22
22
  from ai_edge_torch.generative.utilities import export_config
23
23
  import torch
@@ -58,7 +58,7 @@ def _create_export_config(
58
58
  )
59
59
  decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
60
60
  export_config.decode_mask = decode_mask
61
- export_config.kvcache_cls = kv_cache.KVCacheTransposed
61
+ export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
62
62
  return export_config
63
63
 
64
64
 
@@ -18,9 +18,9 @@
18
18
  from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import builder
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
22
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
23
  from ai_edge_torch.generative.layers.experimental import attention
23
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
25
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
26
  from ai_edge_torch.generative.utilities import export_config as export_cfg
@@ -81,8 +81,8 @@ class DecoderBlock(attention.TransformerBlock):
81
81
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
82
82
  mask: Optional[torch.Tensor] = None,
83
83
  input_pos: Optional[torch.Tensor] = None,
84
- kv_cache: kv_utils.KVCacheEntryBase = None,
85
- ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntryBase]]:
84
+ kv_cache: kv_utils.KVCacheEntry = None,
85
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
86
86
  """Forward function of the Gemma3Block.
87
87
 
88
88
  Exactly the same as TransformerBlock but we call the post-attention norm
@@ -241,13 +241,12 @@ class Decoder(nn.Module):
241
241
  self,
242
242
  tokens: torch.Tensor,
243
243
  input_pos: torch.Tensor,
244
- kv_cache: kv_utils.KVCacheBase,
244
+ kv_cache: kv_utils.KVCache,
245
245
  input_embeds: Optional[torch.Tensor] = None,
246
246
  mask: Optional[torch.Tensor] = None,
247
247
  image_indices: Optional[torch.Tensor] = None,
248
248
  export_config: Optional[export_cfg.ExportConfig] = None,
249
- ) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
250
-
249
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
251
250
  pixel_mask = None
252
251
  if input_embeds is None:
253
252
  # token embeddings of shape (b, t, n_embd)
@@ -287,10 +286,10 @@ class Decoder(nn.Module):
287
286
  rope: List[Tuple[torch.Tensor, torch.Tensor]],
288
287
  mask: torch.Tensor | List[torch.Tensor],
289
288
  input_pos: torch.Tensor,
290
- kv_cache: kv_utils.KVCacheBase,
289
+ kv_cache: kv_utils.KVCache,
291
290
  pixel_mask: Optional[torch.Tensor] = None,
292
291
  export_config: Optional[export_cfg.ExportConfig] = None,
293
- ) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
292
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
294
293
  """Forwards the model with input embeddings."""
295
294
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
296
295
  "The number of transformer blocks and the number of KV cache entries"
@@ -326,7 +325,7 @@ class Decoder(nn.Module):
326
325
  x, kv_entry = block(x, rope[i], mask_entry, input_pos, kv_entry)
327
326
  if kv_entry:
328
327
  updated_kv_entries.append(kv_entry)
329
- updated_kv_cache = kv_utils.KVCacheBase(tuple(updated_kv_entries))
328
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
330
329
  if export_config is not None:
331
330
  if (
332
331
  torch.numel(input_pos) > 1
@@ -20,8 +20,8 @@ import os
20
20
  from typing import List, Optional, Tuple
21
21
 
22
22
  from ai_edge_torch.generative.examples.gemma3 import gemma3
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
25
25
  from ai_edge_torch.generative.utilities.experimental import verifier
26
26
  from gemma import config as gemma_config
27
27
  from gemma import model as gemma_model
@@ -94,7 +94,9 @@ class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
94
94
 
95
95
  def _init_kv_cache(self):
96
96
  """Returns an initialized KV cache."""
97
- return kv_utils.KVCacheTransposed.from_model_config(self.model.model.config)
97
+ return kv_utils.KVCache.from_model_config(
98
+ self.model.model.config, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
99
+ )
98
100
 
99
101
  def forward(
100
102
  self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
@@ -22,9 +22,9 @@ at any time.
22
22
  from typing import Optional, Tuple, Union
23
23
 
24
24
  from ai_edge_torch.generative.layers import builder
25
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
26
  from ai_edge_torch.generative.layers import lora as lora_utils
26
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
27
- from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
27
+ from ai_edge_torch.generative.layers import sdpa_with_kv_update
28
28
  import ai_edge_torch.generative.layers.model_config as cfg
29
29
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
30
30
  import torch
@@ -69,9 +69,9 @@ class TransformerBlock(nn.Module):
69
69
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
70
70
  mask: Optional[torch.Tensor] = None,
71
71
  input_pos: Optional[torch.Tensor] = None,
72
- kv_cache: kv_utils.KVCacheEntryBase = None,
72
+ kv_cache: kv_utils.KVCacheEntry = None,
73
73
  lora: Optional[lora_utils.LoRAEntry] = None,
74
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
74
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
75
75
  """Forward function of the TransformerBlock.
76
76
 
77
77
  Args:
@@ -79,7 +79,7 @@ class TransformerBlock(nn.Module):
79
79
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
80
80
  mask (torch.Tensor): the optional mask tensor.
81
81
  input_pos (torch.Tensor): the optional input position tensor.
82
- kv_cache (KVCacheEntryBase): the optional kv cache entry.
82
+ kv_cache (KVCacheEntry): the optional kv cache entry.
83
83
  lora (LoRAEntry): the optional lora entry.
84
84
 
85
85
  Returns:
@@ -146,7 +146,6 @@ class CausalSelfAttention(nn.Module):
146
146
  self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
147
147
  self.config = config
148
148
  self.enable_hlfb = enable_hlfb
149
- self.sdpa_func = sdpa.scaled_dot_product_attention
150
149
 
151
150
  def forward(
152
151
  self,
@@ -154,9 +153,9 @@ class CausalSelfAttention(nn.Module):
154
153
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
155
154
  mask: Optional[torch.Tensor] = None,
156
155
  input_pos: Optional[torch.Tensor] = None,
157
- kv_cache: Optional[kv_utils.KVCacheEntryBase] = None,
156
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
158
157
  lora: Optional[lora_utils.LoRAEntry] = None,
159
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
158
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
160
159
  """Forward function of the CausalSelfAttention layer, which can support
161
160
 
162
161
  MQA, GQA and MHA.
@@ -166,8 +165,7 @@ class CausalSelfAttention(nn.Module):
166
165
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
167
166
  mask (torch.Tensor): the optional mask tensor.
168
167
  input_pos (torch.Tensor): the optional input position tensor.
169
- kv_cache (KVCacheEntryBase): the KV cache entry corresponding to this
170
- module.
168
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
171
169
  lora (LoRAEntry): the optional lora entry.
172
170
 
173
171
  Returns:
@@ -221,36 +219,8 @@ class CausalSelfAttention(nn.Module):
221
219
  cos, sin = rope
222
220
  q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
223
221
 
224
- # Transpose k/v to specific layout for GPU implementation.
225
- b, _, n, h = q.shape
226
- g = n // self.config.num_query_groups
227
- # btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
228
- q = q.permute(0, 2, 1, 3).reshape(
229
- 1, b * self.config.num_query_groups, g * T, h
230
- )
231
-
232
- k = k.permute(0, 2, 1, 3).reshape(
233
- 1, -1, T, self.config.head_dim
234
- ) # 1, bk, s, h
235
- v = v.permute(0, 2, 3, 1).reshape(
236
- 1, -1, self.config.head_dim, T
237
- ) # 1, bk, h, s
238
-
239
- if kv_cache is not None:
240
- kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
241
- k, v = kv_cache.k_cache, kv_cache.v_cache
242
-
243
- sdpa_out = self.sdpa_func(
244
- kv_cache,
245
- q,
246
- k,
247
- v,
248
- self.config.head_dim,
249
- mask=mask,
250
- softcap=self.config.logit_softcap,
251
- ) # 1, bk, gt, h
252
- sdpa_out = (
253
- sdpa_out.reshape(B, -1, T, h).permute(0, 2, 1, 3).reshape(B, T, -1)
222
+ sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
223
+ q, k, v, kv_cache, input_pos, mask, self.config
254
224
  )
255
225
 
256
226
  # Compute the output projection.
@@ -18,303 +18,33 @@
18
18
  This is an experimental implementation and is subject to change at any time.
19
19
  """
20
20
 
21
- import dataclasses
22
- import functools
23
- from typing import Any, List, Tuple, Type
24
- from ai_edge_torch.generative.layers import model_config
25
- from ai_edge_torch.generative.layers.experimental import types
26
21
  from ai_edge_torch.generative.custom_ops import dynamic_update_slice as dus_utils
22
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
27
23
  import torch
28
- import torch.utils._pytree as pytree
29
-
30
-
31
- @dataclasses.dataclass
32
- class KVCacheEntryBase:
33
- """A single cache entry that includes K and V caches.
34
-
35
- The chaches are built based on the provided config with the shape of
36
- (batch_size, kv_cache_max, num_query_groups, head_dim).
37
- """
38
-
39
- k_cache: torch.Tensor
40
- v_cache: torch.Tensor
41
-
42
- @classmethod
43
- def _from_model_config(
44
- cls,
45
- k_shape: Tuple[int, ...],
46
- v_shape: Tuple[int, ...],
47
- dtype: torch.dtype = torch.float32,
48
- device: torch.device = None,
49
- ):
50
- """Build an instance of the class based on model config."""
51
- k = torch.zeros(k_shape, dtype=dtype, device=device)
52
- v = torch.zeros(v_shape, dtype=dtype, device=device)
53
- obj = cls(k_cache=k, v_cache=v)
54
- return obj
55
-
56
- @classmethod
57
- def from_model_config(
58
- cls,
59
- kv_cache_max: int,
60
- config: model_config.AttentionConfig,
61
- dtype: torch.dtype = torch.float32,
62
- device: torch.device = None,
63
- batch_size: int = 1,
64
- ):
65
- """Build an instance of the class based on model config."""
66
- shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
67
- return cls._from_model_config(shape, shape, dtype, device)
68
-
69
-
70
- @dataclasses.dataclass
71
- class KVCacheEntryBTNH(KVCacheEntryBase):
72
- k_type = types.BTNH()
73
- v_type = types.BTNH()
74
-
75
-
76
- @dataclasses.dataclass
77
- class KVCacheEntryTransposed(KVCacheEntryBase):
78
-
79
- k_type = types.BNTH()
80
- v_type = types.BNHT()
81
-
82
- @classmethod
83
- def from_model_config(
84
- cls,
85
- kv_cache_max: int,
86
- config: model_config.AttentionConfig,
87
- dtype: torch.dtype = torch.float32,
88
- device: torch.device = None,
89
- batch_size: int = 1,
90
- ):
91
- """Build an instance of the class based on model config."""
92
- k_shape = (
93
- batch_size,
94
- config.num_query_groups,
95
- kv_cache_max,
96
- config.head_dim,
97
- ) # b, k, s, h
98
- v_shape = (
99
- batch_size,
100
- config.num_query_groups,
101
- config.head_dim,
102
- kv_cache_max,
103
- ) # b, k, h, s
104
- return cls._from_model_config(k_shape, v_shape, dtype, device)
105
-
106
-
107
- def _flatten_kv_entry(
108
- kv_e: KVCacheEntryBase,
109
- ) -> Tuple[List[torch.Tensor], Any]:
110
- return ([kv_e.k_cache, kv_e.v_cache], None)
111
-
112
-
113
- def _unflatten_kv_entry(
114
- kv_entry_ty: Type[KVCacheEntryBase],
115
- values: List[torch.Tensor],
116
- unused_context: Any,
117
- ) -> KVCacheEntryBase:
118
- return kv_entry_ty(*values)
119
-
120
-
121
- pytree.register_pytree_node(
122
- KVCacheEntryTransposed,
123
- _flatten_kv_entry,
124
- functools.partial(_unflatten_kv_entry, KVCacheEntryTransposed),
125
- serialized_type_name="",
126
- )
127
-
128
- pytree.register_pytree_node(
129
- KVCacheEntryBase,
130
- _flatten_kv_entry,
131
- functools.partial(_unflatten_kv_entry, KVCacheEntryBase),
132
- serialized_type_name="",
133
- )
134
-
135
-
136
- @dataclasses.dataclass
137
- class KVCacheBase:
138
- """A utility class for holding KV cache entries per layer."""
139
-
140
- caches: Tuple[KVCacheEntryBase, ...]
141
-
142
- @classmethod
143
- def _from_model_config(
144
- cls,
145
- kv_entry_cls,
146
- config: model_config.ModelConfig,
147
- dtype: torch.dtype = torch.float32,
148
- device: torch.device = None,
149
- batch_size: int = 1,
150
- ):
151
- caches = [
152
- kv_entry_cls.from_model_config(
153
- config.kv_cache_max,
154
- config.block_config(idx).attn_config,
155
- dtype,
156
- device,
157
- batch_size,
158
- )
159
- for idx in range(config.num_layers)
160
- ]
161
- obj = cls(caches=tuple(caches))
162
- return obj
163
-
164
- @classmethod
165
- def from_model_config(
166
- cls,
167
- config: model_config.ModelConfig,
168
- dtype: torch.dtype = torch.float32,
169
- device: torch.device = None,
170
- batch_size: int = 1,
171
- ):
172
- """Build an instance of the class based on model config.
173
-
174
- Args:
175
- config (ModelConfig): Model config used for building the cache.
176
- dtype (torch.dtype, optional): The data type of the cache tensor.
177
- Defaults to torch.float32.
178
- device (torch.device, optional): The device placement of the cache
179
- tensors. Defaults to None.
180
- batch_size (int, optional): The batch size of the cache tensors.
181
- Defaults to 1.
182
-
183
- Returns:
184
- KVCacheBase: The created cache object.
185
- """
186
- assert batch_size == 1, "Batch size must be 1 for KV Cache."
187
- return cls._from_model_config(
188
- KVCacheEntryBase,
189
- config=config,
190
- dtype=dtype,
191
- device=device,
192
- batch_size=batch_size,
193
- )
194
-
195
- def flatten(self) -> List[torch.Tensor]:
196
- """Flatten the cache entries into a list of tensors with order k_i, v_i."""
197
- flattened, _ = _flatten_kvc(self)
198
- return flattened
199
-
200
-
201
- @dataclasses.dataclass
202
- class KVCacheBTNH(KVCacheBase):
203
-
204
- @classmethod
205
- def from_model_config(
206
- cls,
207
- config: model_config.ModelConfig,
208
- dtype: torch.dtype = torch.float32,
209
- device: torch.device = None,
210
- batch_size: int = 1,
211
- ):
212
- return cls._from_model_config(
213
- KVCacheEntryBTNH,
214
- config=config,
215
- dtype=dtype,
216
- device=device,
217
- batch_size=batch_size,
218
- )
219
-
220
-
221
- @dataclasses.dataclass
222
- class KVCacheTransposed(KVCacheBase):
223
-
224
- @classmethod
225
- def from_model_config(
226
- cls,
227
- config: model_config.ModelConfig,
228
- dtype: torch.dtype = torch.float32,
229
- device: torch.device = None,
230
- batch_size: int = 1,
231
- ):
232
- return cls._from_model_config(
233
- KVCacheEntryTransposed,
234
- config=config,
235
- dtype=dtype,
236
- device=device,
237
- batch_size=batch_size,
238
- )
239
-
240
-
241
- def _flatten_kvc(kvc: KVCacheBase) -> Tuple[List[str], List[str]]:
242
- flattened = []
243
- flat_names = []
244
- none_names = []
245
- for i, kv_entry in enumerate(kvc.caches):
246
- flattened.append(kv_entry.k_cache)
247
- flat_names.append(f"k_{i}")
248
- flattened.append(kv_entry.v_cache)
249
- flat_names.append(f"v_{i}")
250
- return flattened, [flat_names, none_names]
251
-
252
-
253
- def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
254
- flattened, (flat_names, none_names) = _flatten_kvc(kvc)
255
- return [
256
- (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
257
- ], flat_names
258
-
259
-
260
- def _unflatten_kvc(
261
- kv_ty: Type[KVCacheBase],
262
- kv_entry_type: Type[KVCacheEntryBase],
263
- values: List[torch.Tensor],
264
- context: Tuple[List, List],
265
- ) -> KVCacheBase:
266
- assert len(values) % 2 == 0, "Found odd number of K and V entries."
267
- num_layers = len(values) // 2
268
- flat_names = context[0]
269
- kv_entries = []
270
- for i in range(num_layers):
271
- k_cache_idx = flat_names.index(f"k_{i}")
272
- v_cache_idx = flat_names.index(f"v_{i}")
273
- kv_entries.append(
274
- kv_entry_type(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
275
- )
276
- obj = kv_ty(tuple(kv_entries))
277
- return obj
278
-
279
-
280
- pytree.register_pytree_node(
281
- KVCacheTransposed,
282
- _flatten_kvc,
283
- functools.partial(
284
- _unflatten_kvc, KVCacheTransposed, KVCacheEntryTransposed
285
- ),
286
- flatten_with_keys_fn=_flatten_kvc_with_keys,
287
- serialized_type_name="",
288
- )
289
-
290
- pytree.register_pytree_node(
291
- KVCacheBase,
292
- _flatten_kvc,
293
- functools.partial(_unflatten_kvc, KVCacheBase, KVCacheEntryBase),
294
- flatten_with_keys_fn=_flatten_kvc_with_keys,
295
- serialized_type_name="",
296
- )
297
24
 
298
25
 
299
26
  def update(
300
- cache: KVCacheEntryBase,
27
+ cache: kv_utils.KVCacheEntry,
301
28
  input_pos: torch.Tensor,
302
29
  k_slice: torch.Tensor,
303
30
  v_slice: torch.Tensor,
304
- ) -> KVCacheEntryBase:
31
+ ) -> kv_utils.KVCacheEntry:
305
32
  """Out of place update of Cache buffer.
306
33
 
307
34
  Args:
308
- cache (KVCacheEntryBase): The original cache buffer.
35
+ cache (kv_utils.KVCacheEntry): The original cache buffer.
309
36
  input_pos (torch.Tensor): The update slice positions.
310
37
  k_slice (torch.Tensor): The K slice to be updated in the new cache.
311
38
  v_slice (torch.Tensor): The V slice to be updated in the new cache.
312
39
 
313
40
  Returns:
314
- KVCacheEntryBase: The updated KVCacheBase entry based on the passed
41
+ kv_utils.KVCacheEntry: The updated KVCacheBase entry based on the passed
315
42
  inputs.
316
43
  """
317
- update_kv_cache = _update_kv_impl
44
+ assert (
45
+ cache.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED
46
+ ), "KV entry must have transposed layout."
47
+ update_kv_cache = _update_kv_impl_transposed
318
48
  return update_kv_cache(cache, input_pos, k_slice, v_slice)
319
49
 
320
50
 
@@ -338,12 +68,12 @@ def _get_slice_indices(
338
68
  return slice_indices
339
69
 
340
70
 
341
- def _update_kv_impl(
342
- cache: KVCacheEntryTransposed,
71
+ def _update_kv_impl_transposed(
72
+ cache: kv_utils.KVCacheEntry,
343
73
  input_pos: torch.Tensor,
344
74
  k_slice: torch.Tensor,
345
75
  v_slice: torch.Tensor,
346
- ) -> KVCacheEntryTransposed:
76
+ ) -> kv_utils.KVCacheEntry:
347
77
  """Update the cache buffer with High Level Function Boundary annotation."""
348
78
  cache_dim = 4
349
79
  k_ts_idx = 2
@@ -357,4 +87,4 @@ def _update_kv_impl(
357
87
  v = dus_utils.dynamic_update_slice(
358
88
  cache.v_cache, v_slice, [x for x in v_slice_indices]
359
89
  )
360
- return KVCacheEntryTransposed(k, v)
90
+ return kv_utils.KVCacheEntry(k, v, cache.kv_layout)
@@ -19,7 +19,7 @@ import math
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
22
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  from ai_edge_torch.generative.layers.experimental import types
24
24
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
25
  from multipledispatch import dispatch
@@ -28,7 +28,7 @@ import torch.nn.functional as F
28
28
 
29
29
 
30
30
  def scaled_dot_product_attention(
31
- kv: kv_utils.KVCacheBase,
31
+ kv: kv_utils.KVCacheEntry,
32
32
  query: torch.Tensor,
33
33
  key: torch.Tensor,
34
34
  value: torch.Tensor,
@@ -37,10 +37,10 @@ def scaled_dot_product_attention(
37
37
  scale: Optional[float] = None,
38
38
  softcap: Optional[float] = None,
39
39
  ):
40
- if hasattr(kv, "k_type") and hasattr(kv, "v_type"):
40
+ if hasattr(kv, "kv_layout"):
41
41
  return _sdpa(
42
- kv.k_type,
43
- kv.v_type,
42
+ kv.kv_layout[0](), # key layout
43
+ kv.kv_layout[1](), # value layout
44
44
  query=query,
45
45
  key=key,
46
46
  value=value,
@@ -49,10 +49,7 @@ def scaled_dot_product_attention(
49
49
  scale=scale,
50
50
  softcap=softcap,
51
51
  )
52
- raise ValueError(
53
- f"SDPA for K type {type(kv.caches[0].k_type)} and V type"
54
- f" {type(kv.caches[0].v_type)} not supported."
55
- )
52
+ raise ValueError("No kv_layout attribute found in kv.")
56
53
 
57
54
 
58
55
  @dispatch(types.BNTH, types.BNHT)
@@ -85,7 +82,6 @@ def _sdpa(k_type, v_type, *args, **kwargs):
85
82
  padded_logits = logits + mask
86
83
  padded_logits = padded_logits.reshape(1, bk, gt, s)
87
84
  probs = F.softmax(padded_logits, dim=-1).type_as(key)
88
-
89
85
  encoded = bmm_lib.bmm_4d(probs, value)
90
86
 
91
87
  return encoded # 1, bk, gt, h
@@ -62,6 +62,9 @@ class TensorDimensionMeta(type):
62
62
  def __repr__(cls):
63
63
  return f'{cls.__name__}'
64
64
 
65
+ def __iter__(cls):
66
+ return iter(getattr(cls, 'dimensions'))
67
+
65
68
 
66
69
  def create_tensor_dimension_order_class(dims: Tuple[TensorDims]):
67
70
  """Creates a TensorDimensionMeta class with the specified dimensions.