ai-edge-torch-nightly 0.3.0.dev20250123__py3-none-any.whl → 0.3.0.dev20250125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. ai_edge_torch/_config.py +9 -0
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +11 -8
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +22 -24
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -4
  5. ai_edge_torch/generative/examples/deepseek/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +80 -0
  7. ai_edge_torch/generative/examples/deepseek/deepseek.py +92 -0
  8. ai_edge_torch/generative/examples/deepseek/verify.py +70 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +3 -0
  10. ai_edge_torch/generative/layers/experimental/__init__.py +14 -0
  11. ai_edge_torch/generative/layers/experimental/attention.py +269 -0
  12. ai_edge_torch/generative/layers/experimental/kv_cache.py +314 -0
  13. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +97 -0
  14. ai_edge_torch/generative/layers/experimental/types.py +97 -0
  15. ai_edge_torch/generative/layers/kv_cache.py +2 -1
  16. ai_edge_torch/generative/layers/model_config.py +5 -1
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +11 -2
  18. ai_edge_torch/generative/utilities/bmm_4d.py +76 -0
  19. ai_edge_torch/generative/utilities/converter.py +18 -2
  20. ai_edge_torch/generative/utilities/model_builder.py +6 -1
  21. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -1
  22. ai_edge_torch/quantize/pt2e_quantizer_utils.py +22 -2
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/RECORD +28 -18
  26. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,269 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common building blocks for a GPU-specific Attention layer.
17
+
18
+ This is a temporary implemenation for the GPU. It is subject to change/removal
19
+ at any time.
20
+ """
21
+
22
+ from typing import Optional, Tuple, Union
23
+
24
+ from ai_edge_torch.generative.layers import builder
25
+ from ai_edge_torch.generative.layers import lora as lora_utils
26
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
27
+ from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
28
+ import ai_edge_torch.generative.layers.model_config as cfg
29
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
30
+ import torch
31
+ from torch import nn
32
+
33
+
34
+ class TransformerBlock(nn.Module):
35
+
36
+ def __init__(
37
+ self,
38
+ config: cfg.TransformerBlockConfig,
39
+ model_config: cfg.ModelConfig,
40
+ ) -> None:
41
+ """Initialize an instance of the TransformerBlock.
42
+
43
+ Args:
44
+ config (cfg.TransformerBlockConfig): the configuration object for this
45
+ transformer block.
46
+ model_config (cfg.ModelConfig): the configuration object for the model
47
+ this transformer block belongs to.
48
+ """
49
+ super().__init__()
50
+ self.pre_atten_norm = builder.build_norm(
51
+ model_config.embedding_dim,
52
+ config.pre_attention_norm_config,
53
+ )
54
+ self.atten_func = CausalSelfAttention(
55
+ model_config.batch_size,
56
+ model_config.embedding_dim,
57
+ config.attn_config,
58
+ model_config.enable_hlfb,
59
+ )
60
+ self.post_atten_norm = builder.build_norm(
61
+ model_config.embedding_dim,
62
+ config.post_attention_norm_config,
63
+ )
64
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
65
+ self.config = config
66
+
67
+ def forward(
68
+ self,
69
+ x: torch.Tensor,
70
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
71
+ mask: Optional[torch.Tensor] = None,
72
+ input_pos: Optional[torch.Tensor] = None,
73
+ kv_cache: kv_utils.KVCacheEntryBase = None,
74
+ lora: Optional[lora_utils.LoRAEntry] = None,
75
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
76
+ """Forward function of the TransformerBlock.
77
+
78
+ Args:
79
+ x (torch.Tensor): the input tensor.
80
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
81
+ mask (torch.Tensor): the optional mask tensor.
82
+ input_pos (torch.Tensor): the optional input position tensor.
83
+ kv_cache (KVCacheEntryBase): the optional kv cache entry.
84
+ lora (LoRAEntry): the optional lora entry.
85
+
86
+ Returns:
87
+ output activation from this transformer block, and updated kv cache (if
88
+ passed in).
89
+ """
90
+ kv = None
91
+ if self.config.parallel_residual:
92
+ x_norm = self.pre_atten_norm(x)
93
+ atten_func_out = self.atten_func(
94
+ x_norm, rope, mask, input_pos, kv_cache, lora
95
+ )
96
+ if kv_cache is None:
97
+ attn_out = atten_func_out
98
+ else:
99
+ attn_out, kv = atten_func_out
100
+ ff_out = self.ff(x_norm)
101
+ output = x + attn_out + ff_out
102
+ else:
103
+ x_norm = self.pre_atten_norm(x)
104
+ atten_func_out = self.atten_func(
105
+ x_norm, rope, mask, input_pos, kv_cache, lora
106
+ )
107
+ if kv_cache is None:
108
+ attn_out = atten_func_out
109
+ else:
110
+ attn_out, kv = atten_func_out
111
+ x = x + attn_out
112
+ x_norm = self.post_atten_norm(x)
113
+ output = x + self.ff(x_norm)
114
+
115
+ return output if kv is None else (output, kv)
116
+
117
+
118
+ class CausalSelfAttention(nn.Module):
119
+
120
+ def __init__(
121
+ self,
122
+ batch_size: int,
123
+ dim: int,
124
+ config: cfg.AttentionConfig,
125
+ enable_hlfb: bool,
126
+ ) -> None:
127
+ """Initialize an instance of CausalSelfAttention.
128
+
129
+ Args:
130
+ batch_size (int): batch size of the input tensor.
131
+ dim (int): causal attention's input/output dimmension.
132
+ config (cfg.AttentionConfig): attention specific configurations.
133
+ enable_hlfb (bool): whether hlfb is enabled or not.
134
+ """
135
+ super().__init__()
136
+ self.kv_cache = None
137
+ self.batch_size = batch_size
138
+ qkv_shape = (
139
+ config.num_heads + 2 * config.num_query_groups
140
+ ) * config.head_dim
141
+ output_shape = config.num_heads * config.head_dim
142
+ # Key, query, value projections for all heads.
143
+ self.qkv_projection = nn.Linear(dim, qkv_shape, bias=config.qkv_use_bias)
144
+ self.output_projection = nn.Linear(
145
+ output_shape, dim, bias=config.output_proj_use_bias
146
+ )
147
+ self.query_norm = builder.build_norm(
148
+ config.head_dim, config.query_norm_config
149
+ )
150
+ self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
151
+ self.config = config
152
+ self.enable_hlfb = enable_hlfb
153
+ self.sdpa_func = sdpa.scaled_dot_product_attention
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
159
+ mask: Optional[torch.Tensor] = None,
160
+ input_pos: Optional[torch.Tensor] = None,
161
+ kv_cache: Optional[kv_utils.KVCacheEntryBase] = None,
162
+ lora: Optional[lora_utils.LoRAEntry] = None,
163
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
164
+ """Forward function of the CausalSelfAttention layer, which can support
165
+
166
+ MQA, GQA and MHA.
167
+
168
+ Args:
169
+ x (torch.Tensor): the input tensor.
170
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
171
+ mask (torch.Tensor): the optional mask tensor.
172
+ input_pos (torch.Tensor): the optional input position tensor.
173
+ kv_cache (KVCacheEntryBase): the KV cache entry corresponding to this
174
+ module.
175
+ lora (LoRAEntry): the optional lora entry.
176
+
177
+ Returns:
178
+ output activation from this self attention layer, and the updated
179
+ KV Cach Entry (if passed in).
180
+ """
181
+ # Batch size, sequence length, embedding dimensionality.
182
+ B, T, E = x.size()
183
+ assert B == self.batch_size, (
184
+ "batch size of input tensor must match with the batch size specified in"
185
+ " the model configuration."
186
+ )
187
+
188
+ qkv = self.qkv_projection(x)
189
+
190
+ # Assemble into a number of query groups to support MHA, MQA and GQA.
191
+ q_per_kv = self.config.num_heads // self.config.num_query_groups
192
+ # Each group has >=1 queries, 1 key, and 1 value.
193
+ if self.config.qkv_transpose_before_split:
194
+ qkv = qkv.view(B, T, -1, self.config.head_dim)
195
+ q, k, v = qkv.split(
196
+ (
197
+ q_per_kv * self.config.num_query_groups,
198
+ self.config.num_query_groups,
199
+ self.config.num_query_groups,
200
+ ),
201
+ dim=-2,
202
+ )
203
+ else:
204
+ qkv = qkv.view(B, T, self.config.num_query_groups, -1)
205
+ q, k, v = qkv.split(
206
+ (
207
+ q_per_kv * self.config.head_dim,
208
+ self.config.head_dim,
209
+ self.config.head_dim,
210
+ ),
211
+ dim=-1,
212
+ )
213
+
214
+ if lora is not None:
215
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
216
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
217
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
218
+
219
+ q = self.query_norm(q)
220
+ k = self.key_norm(k)
221
+
222
+ q = q.reshape(B, T, -1, self.config.head_dim)
223
+ k = k.reshape(B, T, -1, self.config.head_dim)
224
+ v = v.reshape(B, T, -1, self.config.head_dim)
225
+
226
+ if rope is not None:
227
+ # Compute rotary positional embedding for query and key.
228
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
229
+ cos, sin = rope
230
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
231
+
232
+ # Transpose k/v to specific layout for GPU implementation.
233
+ b, _, n, h = q.shape
234
+ g = n // self.config.num_query_groups
235
+ # btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
236
+ q = q.permute(0, 2, 1, 3).reshape(
237
+ 1, b * self.config.num_query_groups, g * T, h
238
+ )
239
+
240
+ k = k.permute(0, 2, 1, 3).reshape(
241
+ 1, -1, T, self.config.head_dim
242
+ ) # 1, bk, s, h
243
+ v = v.permute(0, 2, 3, 1).reshape(
244
+ 1, -1, self.config.head_dim, T
245
+ ) # 1, bk, h, s
246
+
247
+ if kv_cache is not None:
248
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
249
+ k, v = kv_cache.k_cache, kv_cache.v_cache
250
+
251
+ sdpa_out = self.sdpa_func(
252
+ kv_cache,
253
+ q,
254
+ k,
255
+ v,
256
+ self.config.head_dim,
257
+ mask=mask,
258
+ softcap=self.config.logit_softcap,
259
+ ) # 1, bk, gt, h
260
+ sdpa_out = (
261
+ sdpa_out.reshape(B, -1, T, h).permute(0, 2, 1, 3).reshape(B, T, -1)
262
+ )
263
+
264
+ # Compute the output projection.
265
+ y = self.output_projection(sdpa_out)
266
+ if lora is not None:
267
+ y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
268
+
269
+ return y if kv_cache is None else (y, kv_cache)
@@ -0,0 +1,314 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions for KV Cache.
17
+
18
+ This is an experimental implementation and is subject to change at any time.
19
+ """
20
+
21
+ import dataclasses
22
+ from typing import List, Tuple
23
+
24
+ from ai_edge_torch import hlfb
25
+ from ai_edge_torch.generative.layers import model_config
26
+ from ai_edge_torch.generative.layers.experimental import types as types
27
+ from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.utils._pytree as pytree
31
+
32
+ BATCH_SIZE = 1
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class KVCacheEntryBase:
37
+ """A single cache entry that includes K and V caches.
38
+
39
+ The chaches are built based on the provided config with the shape of
40
+ (batch_size=1, kv_cache_max, num_query_groups, head_dim).
41
+ """
42
+
43
+ k_cache: torch.Tensor
44
+ v_cache: torch.Tensor
45
+
46
+ @classmethod
47
+ def _from_model_config(
48
+ cls,
49
+ kv_cache_max: int,
50
+ config: model_config.AttentionConfig,
51
+ k_shape: Tuple,
52
+ v_shape: Tuple,
53
+ dtype: torch.dtype = torch.float32,
54
+ device: torch.device = None,
55
+ ) -> "KVCacheEntryBase":
56
+ """Build an instance of the class based on model config."""
57
+ k = torch.zeros(k_shape, dtype=dtype, device=device)
58
+ v = torch.zeros(v_shape, dtype=dtype, device=device)
59
+ obj = cls(k_cache=k, v_cache=v)
60
+ return obj
61
+
62
+ @classmethod
63
+ def from_model_config(
64
+ cls,
65
+ kv_cache_max: int,
66
+ config: model_config.AttentionConfig,
67
+ dtype: torch.dtype = torch.float32,
68
+ device: torch.device = None,
69
+ ) -> "KVCacheEntryBase":
70
+ """Build an instance of the class based on model config."""
71
+ shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
72
+ return cls._from_model_config(
73
+ kv_cache_max, config, shape, shape, dtype, device
74
+ )
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class KVCacheEntryBTNH(KVCacheEntryBase):
79
+ k_type = types.BTNH()
80
+ v_type = types.BTNH()
81
+
82
+
83
+ @dataclasses.dataclass
84
+ class KVCacheEntryTransposed(KVCacheEntryBase):
85
+
86
+ k_type = types.BNTH()
87
+ v_type = types.BNHT()
88
+
89
+ @classmethod
90
+ def from_model_config(
91
+ cls,
92
+ kv_cache_max: int,
93
+ config: model_config.AttentionConfig,
94
+ dtype: torch.dtype = torch.float32,
95
+ device: torch.device = None,
96
+ ) -> "KVCacheEntryBase":
97
+ """Build an instance of the class based on model config."""
98
+ num_kv_heads = config.num_query_groups
99
+ k_shape = (
100
+ 1,
101
+ BATCH_SIZE * num_kv_heads,
102
+ kv_cache_max,
103
+ config.head_dim,
104
+ ) # 1, bk, s, h
105
+ v_shape = (
106
+ 1,
107
+ BATCH_SIZE * num_kv_heads,
108
+ config.head_dim,
109
+ kv_cache_max,
110
+ ) # 1, bk, h, s
111
+ return cls._from_model_config(
112
+ kv_cache_max, config, k_shape, v_shape, dtype, device
113
+ )
114
+
115
+
116
+ @dataclasses.dataclass
117
+ class KVCacheBase:
118
+ """A utility class for holding KV cache entries per layer."""
119
+
120
+ caches: Tuple[KVCacheEntryBase, ...]
121
+
122
+ @classmethod
123
+ def _from_model_config(
124
+ cls,
125
+ kv_entry_cls,
126
+ config: model_config.ModelConfig,
127
+ dtype: torch.dtype = torch.float32,
128
+ device: torch.device = None,
129
+ ) -> "KVCacheBase":
130
+ caches = [
131
+ kv_entry_cls.from_model_config(
132
+ config.kv_cache_max,
133
+ config.block_config(idx).attn_config,
134
+ dtype,
135
+ device,
136
+ )
137
+ for idx in range(config.num_layers)
138
+ ]
139
+ obj = cls(caches=tuple(caches))
140
+ return obj
141
+
142
+ @classmethod
143
+ def from_model_config(
144
+ cls,
145
+ config: model_config.ModelConfig,
146
+ dtype: torch.dtype = torch.float32,
147
+ device: torch.device = None,
148
+ ) -> "KVCacheBase":
149
+ """Build an instance of the class based on model config.
150
+
151
+ Args:
152
+ config (ModelConfig): Model config used for building the cache.
153
+ dtype (torch.dtype, optional): The data type of the cache tensor.
154
+ Defaults to torch.float32.
155
+ device (torch.device, optional): The device placement of the cache
156
+ tensors. Defaults to None.
157
+
158
+ Returns:
159
+ KVCacheBase: The created cache object.
160
+ """
161
+ return cls._from_model_config(
162
+ KVCacheEntryBase, config=config, dtype=dtype, device=device
163
+ )
164
+
165
+ def flatten(self) -> List[torch.Tensor]:
166
+ """Flatten the cache entries into a list of tensors with order k_i, v_i."""
167
+ flattened, _ = _flatten_kvc(self)
168
+ return flattened
169
+
170
+
171
+ @dataclasses.dataclass
172
+ class KVCacheBTNH(KVCacheBase):
173
+
174
+ @classmethod
175
+ def from_model_config(
176
+ cls,
177
+ config: model_config.ModelConfig,
178
+ dtype: torch.dtype = torch.float32,
179
+ device: torch.device = None,
180
+ ) -> "KVCacheBTNH":
181
+ return cls._from_model_config(
182
+ KVCacheEntryBTNH, config=config, dtype=dtype, device=device
183
+ )
184
+
185
+
186
+ @dataclasses.dataclass
187
+ class KVCacheTransposed(KVCacheBase):
188
+
189
+ @classmethod
190
+ def from_model_config(
191
+ cls,
192
+ config: model_config.ModelConfig,
193
+ dtype: torch.dtype = torch.float32,
194
+ device: torch.device = None,
195
+ ) -> "KVCacheBTNH":
196
+ return cls._from_model_config(
197
+ KVCacheEntryTransposed, config=config, dtype=dtype, device=device
198
+ )
199
+
200
+
201
+ def _flatten_kvc(kvc: KVCacheBase) -> Tuple[List[str], List[str]]:
202
+ flattened = []
203
+ flat_names = []
204
+ none_names = []
205
+ for i, kv_entry in enumerate(kvc.caches):
206
+ flattened.append(kv_entry.k_cache)
207
+ flat_names.append(f"k_{i}")
208
+ flattened.append(kv_entry.v_cache)
209
+ flat_names.append(f"v_{i}")
210
+ return flattened, [flat_names, none_names]
211
+
212
+
213
+ def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
214
+ flattened, (flat_names, none_names) = _flatten_kvc(kvc)
215
+ return [
216
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
217
+ ], flat_names
218
+
219
+
220
+ def _unflatten_kvc(
221
+ values: List[torch.Tensor], context: Tuple[List, List]
222
+ ) -> KVCacheBase:
223
+ assert len(values) % 2 == 0, "Found odd number of K and V entries."
224
+ num_layers = len(values) // 2
225
+ flat_names = context[0]
226
+ kv_entries = []
227
+ for i in range(num_layers):
228
+ k_cache_idx = flat_names.index(f"k_{i}")
229
+ v_cache_idx = flat_names.index(f"v_{i}")
230
+ kv_entries.append(
231
+ KVCacheEntryBase(
232
+ k_cache=values[k_cache_idx], v_cache=values[v_cache_idx]
233
+ )
234
+ )
235
+ obj = KVCacheBase(tuple(kv_entries))
236
+ return obj
237
+
238
+
239
+ pytree.register_pytree_node(
240
+ KVCacheTransposed,
241
+ _flatten_kvc,
242
+ _unflatten_kvc,
243
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
244
+ serialized_type_name="",
245
+ )
246
+
247
+ pytree.register_pytree_node(
248
+ KVCacheBase,
249
+ _flatten_kvc,
250
+ _unflatten_kvc,
251
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
252
+ serialized_type_name="",
253
+ )
254
+
255
+
256
+ def update(
257
+ cache: KVCacheEntryBase,
258
+ input_pos: torch.Tensor,
259
+ k_slice: torch.Tensor,
260
+ v_slice: torch.Tensor,
261
+ use_dus: bool = True,
262
+ ) -> KVCacheEntryBase:
263
+ """Out of place update of Cache buffer.
264
+
265
+ Args:
266
+ cache (KVCacheEntryBase): The original cache buffer.
267
+ input_pos (torch.Tensor): The update slice positions.
268
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
269
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
270
+
271
+ Returns:
272
+ KVCacheEntryBase: The updated KVCacheBase entry based on the passed
273
+ inputs.
274
+ """
275
+ update_kv_cache = _update_kv_impl
276
+ return update_kv_cache(cache, input_pos, k_slice, v_slice)
277
+
278
+
279
+ def _get_slice_indices(
280
+ positions: torch.Tensor, cache_dim: int, ts_idx: int
281
+ ) -> torch.Tensor:
282
+ """Returns the slice indices."""
283
+ positions = positions.float()[0].reshape(
284
+ 1,
285
+ )
286
+
287
+ zeros = torch.zeros((1,), dtype=torch.float32)
288
+ indices = []
289
+ for i in range(cache_dim):
290
+ if i == ts_idx:
291
+ indices.append(positions)
292
+ else:
293
+ indices.append(zeros)
294
+ slice_indices = torch.cat(indices, dim=0)
295
+ slice_indices = slice_indices.int()
296
+ return slice_indices
297
+
298
+
299
+ def _update_kv_impl(
300
+ cache: KVCacheEntryTransposed,
301
+ input_pos: torch.Tensor,
302
+ k_slice: torch.Tensor,
303
+ v_slice: torch.Tensor,
304
+ ) -> KVCacheEntryTransposed:
305
+ """Update the cache buffer with High Level Function Boundary annotation."""
306
+ cache_dim = 4
307
+ k_ts_idx = 2
308
+ v_ts_idx = 3
309
+ positions = input_pos.clone()
310
+ k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
311
+ v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
312
+ k = dynamic_update_slice(cache.k_cache, k_slice, [x for x in k_slice_indices])
313
+ v = dynamic_update_slice(cache.v_cache, v_slice, [x for x in v_slice_indices])
314
+ return KVCacheEntryTransposed(k, v)
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Implements scaled dot product attention. This is experimental and
16
+ # GPU-specific code.
17
+
18
+ import math
19
+ from typing import Optional
20
+
21
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers.experimental import types
23
+ from ai_edge_torch.generative.utilities import bmm_4d as bmm_lib
24
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
+ from multipledispatch import dispatch
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+
30
+ def scaled_dot_product_attention(
31
+ kv: kv_utils.KVCacheBase,
32
+ query: torch.Tensor,
33
+ key: torch.Tensor,
34
+ value: torch.Tensor,
35
+ head_size: int,
36
+ mask: Optional[torch.Tensor] = None,
37
+ scale: Optional[float] = None,
38
+ softcap: Optional[float] = None,
39
+ ):
40
+ if hasattr(kv, "k_type") and hasattr(kv, "v_type"):
41
+ return _sdpa(
42
+ kv.k_type,
43
+ kv.v_type,
44
+ query=query,
45
+ key=key,
46
+ value=value,
47
+ head_size=head_size,
48
+ mask=mask,
49
+ scale=scale,
50
+ softcap=softcap,
51
+ )
52
+ raise ValueError(
53
+ f"SDPA for K type {type(kv.caches[0].k_type)} and V type"
54
+ f" {type(kv.caches[0].v_type)} not supported."
55
+ )
56
+
57
+
58
+ @dispatch(types.BNTH, types.BNHT)
59
+ def _sdpa(k_type, v_type, *args, **kwargs):
60
+ query = kwargs["query"]
61
+ key = kwargs["key"]
62
+ value = kwargs["value"]
63
+ head_size = kwargs["head_size"]
64
+ mask = kwargs.get("mask", None)
65
+ scale = kwargs.get("scale", None)
66
+ softcap = kwargs.get("softcap", None)
67
+
68
+ if scale is None:
69
+ scale = 1.0 / math.sqrt(head_size)
70
+
71
+ query = query * scale
72
+
73
+ assert mask is not None, "Mask should not be None!"
74
+ t = mask.shape[2]
75
+
76
+ logits = bmm_lib.bmm_4d(query, key)
77
+
78
+ _, bk, gt, s = logits.shape
79
+ g = gt // t
80
+ logits = logits.reshape((bk, g, t, s))
81
+ if softcap is not None:
82
+ logits = torch.tanh(logits / softcap)
83
+ logits = logits * softcap
84
+
85
+ padded_logits = logits + mask
86
+ padded_logits = padded_logits.reshape(1, bk, gt, s)
87
+ probs = F.softmax(padded_logits, dim=-1).type_as(key)
88
+
89
+ encoded = bmm_lib.bmm_4d(probs, value)
90
+
91
+ return encoded # 1, bk, gt, h
92
+
93
+
94
+ @dispatch(object, object)
95
+ def _sdpa(k_type, v_type, *args, **kwargs):
96
+
97
+ raise ValueError(f"No implementations for k={k_type} and v={v_type}")