ai-edge-torch-nightly 0.3.0.dev20250123__py3-none-any.whl → 0.3.0.dev20250125__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (28) hide show
  1. ai_edge_torch/_config.py +9 -0
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +11 -8
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +22 -24
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -4
  5. ai_edge_torch/generative/examples/deepseek/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +80 -0
  7. ai_edge_torch/generative/examples/deepseek/deepseek.py +92 -0
  8. ai_edge_torch/generative/examples/deepseek/verify.py +70 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +3 -0
  10. ai_edge_torch/generative/layers/experimental/__init__.py +14 -0
  11. ai_edge_torch/generative/layers/experimental/attention.py +269 -0
  12. ai_edge_torch/generative/layers/experimental/kv_cache.py +314 -0
  13. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +97 -0
  14. ai_edge_torch/generative/layers/experimental/types.py +97 -0
  15. ai_edge_torch/generative/layers/kv_cache.py +2 -1
  16. ai_edge_torch/generative/layers/model_config.py +5 -1
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +11 -2
  18. ai_edge_torch/generative/utilities/bmm_4d.py +76 -0
  19. ai_edge_torch/generative/utilities/converter.py +18 -2
  20. ai_edge_torch/generative/utilities/model_builder.py +6 -1
  21. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -1
  22. ai_edge_torch/quantize/pt2e_quantizer_utils.py +22 -2
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/RECORD +28 -18
  26. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,269 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common building blocks for a GPU-specific Attention layer.
17
+
18
+ This is a temporary implemenation for the GPU. It is subject to change/removal
19
+ at any time.
20
+ """
21
+
22
+ from typing import Optional, Tuple, Union
23
+
24
+ from ai_edge_torch.generative.layers import builder
25
+ from ai_edge_torch.generative.layers import lora as lora_utils
26
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
27
+ from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
28
+ import ai_edge_torch.generative.layers.model_config as cfg
29
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
30
+ import torch
31
+ from torch import nn
32
+
33
+
34
+ class TransformerBlock(nn.Module):
35
+
36
+ def __init__(
37
+ self,
38
+ config: cfg.TransformerBlockConfig,
39
+ model_config: cfg.ModelConfig,
40
+ ) -> None:
41
+ """Initialize an instance of the TransformerBlock.
42
+
43
+ Args:
44
+ config (cfg.TransformerBlockConfig): the configuration object for this
45
+ transformer block.
46
+ model_config (cfg.ModelConfig): the configuration object for the model
47
+ this transformer block belongs to.
48
+ """
49
+ super().__init__()
50
+ self.pre_atten_norm = builder.build_norm(
51
+ model_config.embedding_dim,
52
+ config.pre_attention_norm_config,
53
+ )
54
+ self.atten_func = CausalSelfAttention(
55
+ model_config.batch_size,
56
+ model_config.embedding_dim,
57
+ config.attn_config,
58
+ model_config.enable_hlfb,
59
+ )
60
+ self.post_atten_norm = builder.build_norm(
61
+ model_config.embedding_dim,
62
+ config.post_attention_norm_config,
63
+ )
64
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
65
+ self.config = config
66
+
67
+ def forward(
68
+ self,
69
+ x: torch.Tensor,
70
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
71
+ mask: Optional[torch.Tensor] = None,
72
+ input_pos: Optional[torch.Tensor] = None,
73
+ kv_cache: kv_utils.KVCacheEntryBase = None,
74
+ lora: Optional[lora_utils.LoRAEntry] = None,
75
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
76
+ """Forward function of the TransformerBlock.
77
+
78
+ Args:
79
+ x (torch.Tensor): the input tensor.
80
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
81
+ mask (torch.Tensor): the optional mask tensor.
82
+ input_pos (torch.Tensor): the optional input position tensor.
83
+ kv_cache (KVCacheEntryBase): the optional kv cache entry.
84
+ lora (LoRAEntry): the optional lora entry.
85
+
86
+ Returns:
87
+ output activation from this transformer block, and updated kv cache (if
88
+ passed in).
89
+ """
90
+ kv = None
91
+ if self.config.parallel_residual:
92
+ x_norm = self.pre_atten_norm(x)
93
+ atten_func_out = self.atten_func(
94
+ x_norm, rope, mask, input_pos, kv_cache, lora
95
+ )
96
+ if kv_cache is None:
97
+ attn_out = atten_func_out
98
+ else:
99
+ attn_out, kv = atten_func_out
100
+ ff_out = self.ff(x_norm)
101
+ output = x + attn_out + ff_out
102
+ else:
103
+ x_norm = self.pre_atten_norm(x)
104
+ atten_func_out = self.atten_func(
105
+ x_norm, rope, mask, input_pos, kv_cache, lora
106
+ )
107
+ if kv_cache is None:
108
+ attn_out = atten_func_out
109
+ else:
110
+ attn_out, kv = atten_func_out
111
+ x = x + attn_out
112
+ x_norm = self.post_atten_norm(x)
113
+ output = x + self.ff(x_norm)
114
+
115
+ return output if kv is None else (output, kv)
116
+
117
+
118
+ class CausalSelfAttention(nn.Module):
119
+
120
+ def __init__(
121
+ self,
122
+ batch_size: int,
123
+ dim: int,
124
+ config: cfg.AttentionConfig,
125
+ enable_hlfb: bool,
126
+ ) -> None:
127
+ """Initialize an instance of CausalSelfAttention.
128
+
129
+ Args:
130
+ batch_size (int): batch size of the input tensor.
131
+ dim (int): causal attention's input/output dimmension.
132
+ config (cfg.AttentionConfig): attention specific configurations.
133
+ enable_hlfb (bool): whether hlfb is enabled or not.
134
+ """
135
+ super().__init__()
136
+ self.kv_cache = None
137
+ self.batch_size = batch_size
138
+ qkv_shape = (
139
+ config.num_heads + 2 * config.num_query_groups
140
+ ) * config.head_dim
141
+ output_shape = config.num_heads * config.head_dim
142
+ # Key, query, value projections for all heads.
143
+ self.qkv_projection = nn.Linear(dim, qkv_shape, bias=config.qkv_use_bias)
144
+ self.output_projection = nn.Linear(
145
+ output_shape, dim, bias=config.output_proj_use_bias
146
+ )
147
+ self.query_norm = builder.build_norm(
148
+ config.head_dim, config.query_norm_config
149
+ )
150
+ self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
151
+ self.config = config
152
+ self.enable_hlfb = enable_hlfb
153
+ self.sdpa_func = sdpa.scaled_dot_product_attention
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
159
+ mask: Optional[torch.Tensor] = None,
160
+ input_pos: Optional[torch.Tensor] = None,
161
+ kv_cache: Optional[kv_utils.KVCacheEntryBase] = None,
162
+ lora: Optional[lora_utils.LoRAEntry] = None,
163
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
164
+ """Forward function of the CausalSelfAttention layer, which can support
165
+
166
+ MQA, GQA and MHA.
167
+
168
+ Args:
169
+ x (torch.Tensor): the input tensor.
170
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
171
+ mask (torch.Tensor): the optional mask tensor.
172
+ input_pos (torch.Tensor): the optional input position tensor.
173
+ kv_cache (KVCacheEntryBase): the KV cache entry corresponding to this
174
+ module.
175
+ lora (LoRAEntry): the optional lora entry.
176
+
177
+ Returns:
178
+ output activation from this self attention layer, and the updated
179
+ KV Cach Entry (if passed in).
180
+ """
181
+ # Batch size, sequence length, embedding dimensionality.
182
+ B, T, E = x.size()
183
+ assert B == self.batch_size, (
184
+ "batch size of input tensor must match with the batch size specified in"
185
+ " the model configuration."
186
+ )
187
+
188
+ qkv = self.qkv_projection(x)
189
+
190
+ # Assemble into a number of query groups to support MHA, MQA and GQA.
191
+ q_per_kv = self.config.num_heads // self.config.num_query_groups
192
+ # Each group has >=1 queries, 1 key, and 1 value.
193
+ if self.config.qkv_transpose_before_split:
194
+ qkv = qkv.view(B, T, -1, self.config.head_dim)
195
+ q, k, v = qkv.split(
196
+ (
197
+ q_per_kv * self.config.num_query_groups,
198
+ self.config.num_query_groups,
199
+ self.config.num_query_groups,
200
+ ),
201
+ dim=-2,
202
+ )
203
+ else:
204
+ qkv = qkv.view(B, T, self.config.num_query_groups, -1)
205
+ q, k, v = qkv.split(
206
+ (
207
+ q_per_kv * self.config.head_dim,
208
+ self.config.head_dim,
209
+ self.config.head_dim,
210
+ ),
211
+ dim=-1,
212
+ )
213
+
214
+ if lora is not None:
215
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
216
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
217
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
218
+
219
+ q = self.query_norm(q)
220
+ k = self.key_norm(k)
221
+
222
+ q = q.reshape(B, T, -1, self.config.head_dim)
223
+ k = k.reshape(B, T, -1, self.config.head_dim)
224
+ v = v.reshape(B, T, -1, self.config.head_dim)
225
+
226
+ if rope is not None:
227
+ # Compute rotary positional embedding for query and key.
228
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
229
+ cos, sin = rope
230
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
231
+
232
+ # Transpose k/v to specific layout for GPU implementation.
233
+ b, _, n, h = q.shape
234
+ g = n // self.config.num_query_groups
235
+ # btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
236
+ q = q.permute(0, 2, 1, 3).reshape(
237
+ 1, b * self.config.num_query_groups, g * T, h
238
+ )
239
+
240
+ k = k.permute(0, 2, 1, 3).reshape(
241
+ 1, -1, T, self.config.head_dim
242
+ ) # 1, bk, s, h
243
+ v = v.permute(0, 2, 3, 1).reshape(
244
+ 1, -1, self.config.head_dim, T
245
+ ) # 1, bk, h, s
246
+
247
+ if kv_cache is not None:
248
+ kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
249
+ k, v = kv_cache.k_cache, kv_cache.v_cache
250
+
251
+ sdpa_out = self.sdpa_func(
252
+ kv_cache,
253
+ q,
254
+ k,
255
+ v,
256
+ self.config.head_dim,
257
+ mask=mask,
258
+ softcap=self.config.logit_softcap,
259
+ ) # 1, bk, gt, h
260
+ sdpa_out = (
261
+ sdpa_out.reshape(B, -1, T, h).permute(0, 2, 1, 3).reshape(B, T, -1)
262
+ )
263
+
264
+ # Compute the output projection.
265
+ y = self.output_projection(sdpa_out)
266
+ if lora is not None:
267
+ y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
268
+
269
+ return y if kv_cache is None else (y, kv_cache)
@@ -0,0 +1,314 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions for KV Cache.
17
+
18
+ This is an experimental implementation and is subject to change at any time.
19
+ """
20
+
21
+ import dataclasses
22
+ from typing import List, Tuple
23
+
24
+ from ai_edge_torch import hlfb
25
+ from ai_edge_torch.generative.layers import model_config
26
+ from ai_edge_torch.generative.layers.experimental import types as types
27
+ from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.utils._pytree as pytree
31
+
32
+ BATCH_SIZE = 1
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class KVCacheEntryBase:
37
+ """A single cache entry that includes K and V caches.
38
+
39
+ The chaches are built based on the provided config with the shape of
40
+ (batch_size=1, kv_cache_max, num_query_groups, head_dim).
41
+ """
42
+
43
+ k_cache: torch.Tensor
44
+ v_cache: torch.Tensor
45
+
46
+ @classmethod
47
+ def _from_model_config(
48
+ cls,
49
+ kv_cache_max: int,
50
+ config: model_config.AttentionConfig,
51
+ k_shape: Tuple,
52
+ v_shape: Tuple,
53
+ dtype: torch.dtype = torch.float32,
54
+ device: torch.device = None,
55
+ ) -> "KVCacheEntryBase":
56
+ """Build an instance of the class based on model config."""
57
+ k = torch.zeros(k_shape, dtype=dtype, device=device)
58
+ v = torch.zeros(v_shape, dtype=dtype, device=device)
59
+ obj = cls(k_cache=k, v_cache=v)
60
+ return obj
61
+
62
+ @classmethod
63
+ def from_model_config(
64
+ cls,
65
+ kv_cache_max: int,
66
+ config: model_config.AttentionConfig,
67
+ dtype: torch.dtype = torch.float32,
68
+ device: torch.device = None,
69
+ ) -> "KVCacheEntryBase":
70
+ """Build an instance of the class based on model config."""
71
+ shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
72
+ return cls._from_model_config(
73
+ kv_cache_max, config, shape, shape, dtype, device
74
+ )
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class KVCacheEntryBTNH(KVCacheEntryBase):
79
+ k_type = types.BTNH()
80
+ v_type = types.BTNH()
81
+
82
+
83
+ @dataclasses.dataclass
84
+ class KVCacheEntryTransposed(KVCacheEntryBase):
85
+
86
+ k_type = types.BNTH()
87
+ v_type = types.BNHT()
88
+
89
+ @classmethod
90
+ def from_model_config(
91
+ cls,
92
+ kv_cache_max: int,
93
+ config: model_config.AttentionConfig,
94
+ dtype: torch.dtype = torch.float32,
95
+ device: torch.device = None,
96
+ ) -> "KVCacheEntryBase":
97
+ """Build an instance of the class based on model config."""
98
+ num_kv_heads = config.num_query_groups
99
+ k_shape = (
100
+ 1,
101
+ BATCH_SIZE * num_kv_heads,
102
+ kv_cache_max,
103
+ config.head_dim,
104
+ ) # 1, bk, s, h
105
+ v_shape = (
106
+ 1,
107
+ BATCH_SIZE * num_kv_heads,
108
+ config.head_dim,
109
+ kv_cache_max,
110
+ ) # 1, bk, h, s
111
+ return cls._from_model_config(
112
+ kv_cache_max, config, k_shape, v_shape, dtype, device
113
+ )
114
+
115
+
116
+ @dataclasses.dataclass
117
+ class KVCacheBase:
118
+ """A utility class for holding KV cache entries per layer."""
119
+
120
+ caches: Tuple[KVCacheEntryBase, ...]
121
+
122
+ @classmethod
123
+ def _from_model_config(
124
+ cls,
125
+ kv_entry_cls,
126
+ config: model_config.ModelConfig,
127
+ dtype: torch.dtype = torch.float32,
128
+ device: torch.device = None,
129
+ ) -> "KVCacheBase":
130
+ caches = [
131
+ kv_entry_cls.from_model_config(
132
+ config.kv_cache_max,
133
+ config.block_config(idx).attn_config,
134
+ dtype,
135
+ device,
136
+ )
137
+ for idx in range(config.num_layers)
138
+ ]
139
+ obj = cls(caches=tuple(caches))
140
+ return obj
141
+
142
+ @classmethod
143
+ def from_model_config(
144
+ cls,
145
+ config: model_config.ModelConfig,
146
+ dtype: torch.dtype = torch.float32,
147
+ device: torch.device = None,
148
+ ) -> "KVCacheBase":
149
+ """Build an instance of the class based on model config.
150
+
151
+ Args:
152
+ config (ModelConfig): Model config used for building the cache.
153
+ dtype (torch.dtype, optional): The data type of the cache tensor.
154
+ Defaults to torch.float32.
155
+ device (torch.device, optional): The device placement of the cache
156
+ tensors. Defaults to None.
157
+
158
+ Returns:
159
+ KVCacheBase: The created cache object.
160
+ """
161
+ return cls._from_model_config(
162
+ KVCacheEntryBase, config=config, dtype=dtype, device=device
163
+ )
164
+
165
+ def flatten(self) -> List[torch.Tensor]:
166
+ """Flatten the cache entries into a list of tensors with order k_i, v_i."""
167
+ flattened, _ = _flatten_kvc(self)
168
+ return flattened
169
+
170
+
171
+ @dataclasses.dataclass
172
+ class KVCacheBTNH(KVCacheBase):
173
+
174
+ @classmethod
175
+ def from_model_config(
176
+ cls,
177
+ config: model_config.ModelConfig,
178
+ dtype: torch.dtype = torch.float32,
179
+ device: torch.device = None,
180
+ ) -> "KVCacheBTNH":
181
+ return cls._from_model_config(
182
+ KVCacheEntryBTNH, config=config, dtype=dtype, device=device
183
+ )
184
+
185
+
186
+ @dataclasses.dataclass
187
+ class KVCacheTransposed(KVCacheBase):
188
+
189
+ @classmethod
190
+ def from_model_config(
191
+ cls,
192
+ config: model_config.ModelConfig,
193
+ dtype: torch.dtype = torch.float32,
194
+ device: torch.device = None,
195
+ ) -> "KVCacheBTNH":
196
+ return cls._from_model_config(
197
+ KVCacheEntryTransposed, config=config, dtype=dtype, device=device
198
+ )
199
+
200
+
201
+ def _flatten_kvc(kvc: KVCacheBase) -> Tuple[List[str], List[str]]:
202
+ flattened = []
203
+ flat_names = []
204
+ none_names = []
205
+ for i, kv_entry in enumerate(kvc.caches):
206
+ flattened.append(kv_entry.k_cache)
207
+ flat_names.append(f"k_{i}")
208
+ flattened.append(kv_entry.v_cache)
209
+ flat_names.append(f"v_{i}")
210
+ return flattened, [flat_names, none_names]
211
+
212
+
213
+ def _flatten_kvc_with_keys(kvc: KVCacheBase) -> Tuple[List, List]:
214
+ flattened, (flat_names, none_names) = _flatten_kvc(kvc)
215
+ return [
216
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
217
+ ], flat_names
218
+
219
+
220
+ def _unflatten_kvc(
221
+ values: List[torch.Tensor], context: Tuple[List, List]
222
+ ) -> KVCacheBase:
223
+ assert len(values) % 2 == 0, "Found odd number of K and V entries."
224
+ num_layers = len(values) // 2
225
+ flat_names = context[0]
226
+ kv_entries = []
227
+ for i in range(num_layers):
228
+ k_cache_idx = flat_names.index(f"k_{i}")
229
+ v_cache_idx = flat_names.index(f"v_{i}")
230
+ kv_entries.append(
231
+ KVCacheEntryBase(
232
+ k_cache=values[k_cache_idx], v_cache=values[v_cache_idx]
233
+ )
234
+ )
235
+ obj = KVCacheBase(tuple(kv_entries))
236
+ return obj
237
+
238
+
239
+ pytree.register_pytree_node(
240
+ KVCacheTransposed,
241
+ _flatten_kvc,
242
+ _unflatten_kvc,
243
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
244
+ serialized_type_name="",
245
+ )
246
+
247
+ pytree.register_pytree_node(
248
+ KVCacheBase,
249
+ _flatten_kvc,
250
+ _unflatten_kvc,
251
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
252
+ serialized_type_name="",
253
+ )
254
+
255
+
256
+ def update(
257
+ cache: KVCacheEntryBase,
258
+ input_pos: torch.Tensor,
259
+ k_slice: torch.Tensor,
260
+ v_slice: torch.Tensor,
261
+ use_dus: bool = True,
262
+ ) -> KVCacheEntryBase:
263
+ """Out of place update of Cache buffer.
264
+
265
+ Args:
266
+ cache (KVCacheEntryBase): The original cache buffer.
267
+ input_pos (torch.Tensor): The update slice positions.
268
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
269
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
270
+
271
+ Returns:
272
+ KVCacheEntryBase: The updated KVCacheBase entry based on the passed
273
+ inputs.
274
+ """
275
+ update_kv_cache = _update_kv_impl
276
+ return update_kv_cache(cache, input_pos, k_slice, v_slice)
277
+
278
+
279
+ def _get_slice_indices(
280
+ positions: torch.Tensor, cache_dim: int, ts_idx: int
281
+ ) -> torch.Tensor:
282
+ """Returns the slice indices."""
283
+ positions = positions.float()[0].reshape(
284
+ 1,
285
+ )
286
+
287
+ zeros = torch.zeros((1,), dtype=torch.float32)
288
+ indices = []
289
+ for i in range(cache_dim):
290
+ if i == ts_idx:
291
+ indices.append(positions)
292
+ else:
293
+ indices.append(zeros)
294
+ slice_indices = torch.cat(indices, dim=0)
295
+ slice_indices = slice_indices.int()
296
+ return slice_indices
297
+
298
+
299
+ def _update_kv_impl(
300
+ cache: KVCacheEntryTransposed,
301
+ input_pos: torch.Tensor,
302
+ k_slice: torch.Tensor,
303
+ v_slice: torch.Tensor,
304
+ ) -> KVCacheEntryTransposed:
305
+ """Update the cache buffer with High Level Function Boundary annotation."""
306
+ cache_dim = 4
307
+ k_ts_idx = 2
308
+ v_ts_idx = 3
309
+ positions = input_pos.clone()
310
+ k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
311
+ v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
312
+ k = dynamic_update_slice(cache.k_cache, k_slice, [x for x in k_slice_indices])
313
+ v = dynamic_update_slice(cache.v_cache, v_slice, [x for x in v_slice_indices])
314
+ return KVCacheEntryTransposed(k, v)
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Implements scaled dot product attention. This is experimental and
16
+ # GPU-specific code.
17
+
18
+ import math
19
+ from typing import Optional
20
+
21
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers.experimental import types
23
+ from ai_edge_torch.generative.utilities import bmm_4d as bmm_lib
24
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
+ from multipledispatch import dispatch
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+
30
+ def scaled_dot_product_attention(
31
+ kv: kv_utils.KVCacheBase,
32
+ query: torch.Tensor,
33
+ key: torch.Tensor,
34
+ value: torch.Tensor,
35
+ head_size: int,
36
+ mask: Optional[torch.Tensor] = None,
37
+ scale: Optional[float] = None,
38
+ softcap: Optional[float] = None,
39
+ ):
40
+ if hasattr(kv, "k_type") and hasattr(kv, "v_type"):
41
+ return _sdpa(
42
+ kv.k_type,
43
+ kv.v_type,
44
+ query=query,
45
+ key=key,
46
+ value=value,
47
+ head_size=head_size,
48
+ mask=mask,
49
+ scale=scale,
50
+ softcap=softcap,
51
+ )
52
+ raise ValueError(
53
+ f"SDPA for K type {type(kv.caches[0].k_type)} and V type"
54
+ f" {type(kv.caches[0].v_type)} not supported."
55
+ )
56
+
57
+
58
+ @dispatch(types.BNTH, types.BNHT)
59
+ def _sdpa(k_type, v_type, *args, **kwargs):
60
+ query = kwargs["query"]
61
+ key = kwargs["key"]
62
+ value = kwargs["value"]
63
+ head_size = kwargs["head_size"]
64
+ mask = kwargs.get("mask", None)
65
+ scale = kwargs.get("scale", None)
66
+ softcap = kwargs.get("softcap", None)
67
+
68
+ if scale is None:
69
+ scale = 1.0 / math.sqrt(head_size)
70
+
71
+ query = query * scale
72
+
73
+ assert mask is not None, "Mask should not be None!"
74
+ t = mask.shape[2]
75
+
76
+ logits = bmm_lib.bmm_4d(query, key)
77
+
78
+ _, bk, gt, s = logits.shape
79
+ g = gt // t
80
+ logits = logits.reshape((bk, g, t, s))
81
+ if softcap is not None:
82
+ logits = torch.tanh(logits / softcap)
83
+ logits = logits * softcap
84
+
85
+ padded_logits = logits + mask
86
+ padded_logits = padded_logits.reshape(1, bk, gt, s)
87
+ probs = F.softmax(padded_logits, dim=-1).type_as(key)
88
+
89
+ encoded = bmm_lib.bmm_4d(probs, value)
90
+
91
+ return encoded # 1, bk, gt, h
92
+
93
+
94
+ @dispatch(object, object)
95
+ def _sdpa(k_type, v_type, *args, **kwargs):
96
+
97
+ raise ValueError(f"No implementations for k={k_type} and v={v_type}")