ai-edge-torch-nightly 0.5.0.dev20250423__py3-none-any.whl → 0.5.0.dev20250425__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -3
  2. ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
  5. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +38 -4
  6. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
  7. ai_edge_torch/generative/examples/gemma3/decoder.py +1 -1
  8. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +37 -2
  9. ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
  10. ai_edge_torch/generative/layers/attention.py +4 -18
  11. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
  12. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +38 -44
  13. ai_edge_torch/generative/test/test_model_conversion.py +38 -33
  14. ai_edge_torch/generative/test/test_model_conversion_large.py +3 -75
  15. ai_edge_torch/generative/utilities/converter.py +5 -0
  16. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/RECORD +22 -25
  20. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
  21. ai_edge_torch/generative/layers/experimental/attention.py +0 -231
  22. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
  23. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/LICENSE +0 -0
  24. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/WHEEL +0 -0
  25. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/top_level.txt +0 -0
@@ -1,231 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Common building blocks for a GPU-specific Attention layer.
17
-
18
- This is a temporary implemenation for the GPU. It is subject to change/removal
19
- at any time.
20
- """
21
-
22
- from typing import Optional, Tuple, Union
23
-
24
- from ai_edge_torch.generative.layers import builder
25
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
- from ai_edge_torch.generative.layers import lora as lora_utils
27
- from ai_edge_torch.generative.layers import sdpa_with_kv_update
28
- import ai_edge_torch.generative.layers.model_config as cfg
29
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
30
- import torch
31
- from torch import nn
32
-
33
-
34
- class TransformerBlock(nn.Module):
35
-
36
- def __init__(
37
- self,
38
- config: cfg.TransformerBlockConfig,
39
- model_config: cfg.ModelConfig,
40
- ) -> None:
41
- """Initialize an instance of the TransformerBlock.
42
-
43
- Args:
44
- config (cfg.TransformerBlockConfig): the configuration object for this
45
- transformer block.
46
- model_config (cfg.ModelConfig): the configuration object for the model
47
- this transformer block belongs to.
48
- """
49
- super().__init__()
50
- self.pre_atten_norm = builder.build_norm(
51
- model_config.embedding_dim,
52
- config.pre_attention_norm_config,
53
- )
54
- self.atten_func = CausalSelfAttention(
55
- model_config.embedding_dim,
56
- config.attn_config,
57
- model_config.enable_hlfb,
58
- )
59
- self.post_atten_norm = builder.build_norm(
60
- model_config.embedding_dim,
61
- config.post_attention_norm_config,
62
- )
63
- self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
64
- self.config = config
65
-
66
- def forward(
67
- self,
68
- x: torch.Tensor,
69
- rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
70
- mask: Optional[torch.Tensor] = None,
71
- input_pos: Optional[torch.Tensor] = None,
72
- kv_cache: kv_utils.KVCacheEntry = None,
73
- lora: Optional[lora_utils.LoRAEntry] = None,
74
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
75
- """Forward function of the TransformerBlock.
76
-
77
- Args:
78
- x (torch.Tensor): the input tensor.
79
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
80
- mask (torch.Tensor): the optional mask tensor.
81
- input_pos (torch.Tensor): the optional input position tensor.
82
- kv_cache (KVCacheEntry): the optional kv cache entry.
83
- lora (LoRAEntry): the optional lora entry.
84
-
85
- Returns:
86
- output activation from this transformer block, and updated kv cache (if
87
- passed in).
88
- """
89
- kv = None
90
- if self.config.parallel_residual:
91
- x_norm = self.pre_atten_norm(x)
92
- atten_func_out = self.atten_func(
93
- x_norm, rope, mask, input_pos, kv_cache, lora
94
- )
95
- if kv_cache is None:
96
- attn_out = atten_func_out
97
- else:
98
- attn_out, kv = atten_func_out
99
- ff_out = self.ff(x_norm)
100
- output = x + attn_out + ff_out
101
- else:
102
- x_norm = self.pre_atten_norm(x)
103
- atten_func_out = self.atten_func(
104
- x_norm, rope, mask, input_pos, kv_cache, lora
105
- )
106
- if kv_cache is None:
107
- attn_out = atten_func_out
108
- else:
109
- attn_out, kv = atten_func_out
110
- x = x + attn_out
111
- x_norm = self.post_atten_norm(x)
112
- output = x + self.ff(x_norm)
113
-
114
- return output if kv is None else (output, kv)
115
-
116
-
117
- class CausalSelfAttention(nn.Module):
118
-
119
- def __init__(
120
- self,
121
- dim: int,
122
- config: cfg.AttentionConfig,
123
- enable_hlfb: bool,
124
- ) -> None:
125
- """Initialize an instance of CausalSelfAttention.
126
-
127
- Args:
128
- dim (int): causal attention's input/output dimmension.
129
- config (cfg.AttentionConfig): attention specific configurations.
130
- enable_hlfb (bool): whether hlfb is enabled or not.
131
- """
132
- super().__init__()
133
- self.kv_cache = None
134
- qkv_shape = (
135
- config.num_heads + 2 * config.num_query_groups
136
- ) * config.head_dim
137
- output_shape = config.num_heads * config.head_dim
138
- # Key, query, value projections for all heads.
139
- self.qkv_projection = nn.Linear(dim, qkv_shape, bias=config.qkv_use_bias)
140
- self.output_projection = nn.Linear(
141
- output_shape, dim, bias=config.output_proj_use_bias
142
- )
143
- self.query_norm = builder.build_norm(
144
- config.head_dim, config.query_norm_config
145
- )
146
- self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
147
- self.config = config
148
- self.enable_hlfb = enable_hlfb
149
-
150
- def forward(
151
- self,
152
- x: torch.Tensor,
153
- rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
154
- mask: Optional[torch.Tensor] = None,
155
- input_pos: Optional[torch.Tensor] = None,
156
- kv_cache: Optional[kv_utils.KVCacheEntry] = None,
157
- lora: Optional[lora_utils.LoRAEntry] = None,
158
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
159
- """Forward function of the CausalSelfAttention layer, which can support
160
-
161
- MQA, GQA and MHA.
162
-
163
- Args:
164
- x (torch.Tensor): the input tensor.
165
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
166
- mask (torch.Tensor): the optional mask tensor.
167
- input_pos (torch.Tensor): the optional input position tensor.
168
- kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
169
- lora (LoRAEntry): the optional lora entry.
170
-
171
- Returns:
172
- output activation from this self attention layer, and the updated
173
- KV Cach Entry (if passed in).
174
- """
175
- # Batch size, sequence length, embedding dimensionality.
176
- B, T, E = x.size()
177
-
178
- qkv = self.qkv_projection(x)
179
-
180
- # Assemble into a number of query groups to support MHA, MQA and GQA.
181
- q_per_kv = self.config.num_heads // self.config.num_query_groups
182
- # Each group has >=1 queries, 1 key, and 1 value.
183
- if self.config.qkv_transpose_before_split:
184
- qkv = qkv.view(B, T, -1, self.config.head_dim)
185
- q, k, v = qkv.split(
186
- (
187
- q_per_kv * self.config.num_query_groups,
188
- self.config.num_query_groups,
189
- self.config.num_query_groups,
190
- ),
191
- dim=-2,
192
- )
193
- else:
194
- qkv = qkv.view(B, T, self.config.num_query_groups, -1)
195
- q, k, v = qkv.split(
196
- (
197
- q_per_kv * self.config.head_dim,
198
- self.config.head_dim,
199
- self.config.head_dim,
200
- ),
201
- dim=-1,
202
- )
203
-
204
- if lora is not None:
205
- q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
206
- k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
207
- v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
208
-
209
- q = self.query_norm(q)
210
- k = self.key_norm(k)
211
-
212
- q = q.reshape(B, T, -1, self.config.head_dim)
213
- k = k.reshape(B, T, -1, self.config.head_dim)
214
- v = v.reshape(B, T, -1, self.config.head_dim)
215
-
216
- if rope is not None:
217
- # Compute rotary positional embedding for query and key.
218
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
219
- cos, sin = rope
220
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
221
-
222
- sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
223
- q, k, v, kv_cache, input_pos, mask, self.config
224
- )
225
-
226
- # Compute the output projection.
227
- y = self.output_projection(sdpa_out)
228
- if lora is not None:
229
- y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
230
-
231
- return y if kv_cache is None else (y, kv_cache)
@@ -1,93 +0,0 @@
1
- # Copyright 2025 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Implements scaled dot product attention. This is experimental and
16
- # GPU-specific code.
17
-
18
- import math
19
- from typing import Optional
20
-
21
- from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
22
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
- from ai_edge_torch.generative.utilities import types
24
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
- from multipledispatch import dispatch
26
- import torch
27
- import torch.nn.functional as F
28
-
29
-
30
- def scaled_dot_product_attention(
31
- kv: kv_utils.KVCacheEntry,
32
- query: torch.Tensor,
33
- key: torch.Tensor,
34
- value: torch.Tensor,
35
- head_size: int,
36
- mask: Optional[torch.Tensor] = None,
37
- scale: Optional[float] = None,
38
- softcap: Optional[float] = None,
39
- ):
40
- if hasattr(kv, "kv_layout"):
41
- return _sdpa(
42
- kv.kv_layout[0](), # key layout
43
- kv.kv_layout[1](), # value layout
44
- query=query,
45
- key=key,
46
- value=value,
47
- head_size=head_size,
48
- mask=mask,
49
- scale=scale,
50
- softcap=softcap,
51
- )
52
- raise ValueError("No kv_layout attribute found in kv.")
53
-
54
-
55
- @dispatch(types.BNTH, types.BNHT)
56
- def _sdpa(k_type, v_type, *args, **kwargs):
57
- query = kwargs["query"]
58
- key = kwargs["key"]
59
- value = kwargs["value"]
60
- head_size = kwargs["head_size"]
61
- mask = kwargs.get("mask", None)
62
- scale = kwargs.get("scale", None)
63
- softcap = kwargs.get("softcap", None)
64
-
65
- if scale is None:
66
- scale = 1.0 / math.sqrt(head_size)
67
-
68
- query = query * scale
69
-
70
- assert mask is not None, "Mask should not be None!"
71
- t = mask.shape[2]
72
-
73
- logits = bmm_lib.bmm_4d(query, key)
74
-
75
- _, bk, gt, s = logits.shape
76
- g = gt // t
77
- logits = logits.reshape((bk, g, t, s))
78
- if softcap is not None:
79
- logits = torch.tanh(logits / softcap)
80
- logits = logits * softcap
81
-
82
- padded_logits = logits + mask
83
- padded_logits = padded_logits.reshape(1, bk, gt, s)
84
- probs = F.softmax(padded_logits, dim=-1).type_as(key)
85
- encoded = bmm_lib.bmm_4d(probs, value)
86
-
87
- return encoded # 1, bk, gt, h
88
-
89
-
90
- @dispatch(object, object)
91
- def _sdpa(k_type, v_type, *args, **kwargs):
92
-
93
- raise ValueError(f"No implementations for k={k_type} and v={v_type}")