ai-edge-torch-nightly 0.5.0.dev20250423__py3-none-any.whl → 0.5.0.dev20250424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma3/decoder.py +1 -1
- ai_edge_torch/generative/layers/attention.py +4 -18
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +36 -40
- ai_edge_torch/generative/test/test_model_conversion.py +44 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250424.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250424.dist-info}/RECORD +10 -11
- ai_edge_torch/generative/layers/experimental/attention.py +0 -231
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250424.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250424.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250424.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,10 @@
|
|
17
17
|
|
18
18
|
from typing import List, Optional, Tuple
|
19
19
|
|
20
|
+
from ai_edge_torch.generative.layers import attention
|
20
21
|
from ai_edge_torch.generative.layers import builder
|
21
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
25
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
26
|
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.layers import builder
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
23
23
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
24
|
+
from ai_edge_torch.generative.layers import sdpa_with_kv_update
|
24
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
26
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
27
|
import torch
|
@@ -142,11 +143,6 @@ class CausalSelfAttention(nn.Module):
|
|
142
143
|
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
|
143
144
|
self.config = config
|
144
145
|
self.enable_hlfb = enable_hlfb
|
145
|
-
self.sdpa_func = (
|
146
|
-
sdpa.scaled_dot_product_attention_with_hlfb
|
147
|
-
if enable_hlfb
|
148
|
-
else sdpa.scaled_dot_product_attention
|
149
|
-
)
|
150
146
|
|
151
147
|
def forward(
|
152
148
|
self,
|
@@ -174,7 +170,7 @@ class CausalSelfAttention(nn.Module):
|
|
174
170
|
KV Cach Entry (if passed in).
|
175
171
|
"""
|
176
172
|
# Batch size, sequence length, embedding dimensionality.
|
177
|
-
B, T,
|
173
|
+
B, T, _ = x.size()
|
178
174
|
qkv = self.qkv_projection(x)
|
179
175
|
|
180
176
|
# Assemble into a number of query groups to support MHA, MQA and GQA.
|
@@ -218,19 +214,9 @@ class CausalSelfAttention(nn.Module):
|
|
218
214
|
cos, sin = rope
|
219
215
|
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
220
216
|
|
221
|
-
|
222
|
-
|
223
|
-
k, v = kv_cache.k_cache, kv_cache.v_cache
|
224
|
-
|
225
|
-
sdpa_out = self.sdpa_func(
|
226
|
-
q,
|
227
|
-
k,
|
228
|
-
v,
|
229
|
-
self.config.head_dim,
|
230
|
-
mask=mask,
|
231
|
-
softcap=self.config.logit_softcap,
|
217
|
+
sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
|
218
|
+
q, k, v, kv_cache, input_pos, mask, self.config, self.enable_hlfb
|
232
219
|
)
|
233
|
-
sdpa_out = sdpa_out.reshape(B, T, -1)
|
234
220
|
|
235
221
|
# Compute the output projection.
|
236
222
|
y = self.output_projection(sdpa_out)
|
@@ -12,16 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Common utility functions for data loading etc."""
|
17
|
+
|
17
18
|
from typing import Tuple
|
19
|
+
|
18
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
21
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
|
20
22
|
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
21
23
|
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
|
22
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
-
from ai_edge_torch.generative.utilities import types
|
24
|
-
from multipledispatch import dispatch
|
25
25
|
import torch
|
26
26
|
|
27
27
|
|
@@ -33,32 +33,27 @@ def sdpa_with_kv_update(
|
|
33
33
|
input_pos: torch.Tensor,
|
34
34
|
mask: torch.Tensor,
|
35
35
|
config: cfg.AttentionConfig,
|
36
|
+
enable_hlfb: bool,
|
36
37
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
kv
|
44
|
-
input_pos=input_pos,
|
45
|
-
mask=mask,
|
46
|
-
config=config,
|
38
|
+
"""Wrapper function for scaled dot product attention with KV cache update."""
|
39
|
+
if kv is not None and kv.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED:
|
40
|
+
return _sdpa_with_kv_update_transposed(
|
41
|
+
query, key, value, kv, input_pos, mask, config
|
42
|
+
)
|
43
|
+
return _sdpa_with_kv_update_default(
|
44
|
+
query, key, value, kv, input_pos, mask, config, enable_hlfb
|
47
45
|
)
|
48
46
|
|
49
47
|
|
50
|
-
|
51
|
-
|
52
|
-
|
48
|
+
def _sdpa_with_kv_update_transposed(
|
49
|
+
query: torch.Tensor,
|
50
|
+
key: torch.Tensor,
|
51
|
+
value: torch.Tensor,
|
52
|
+
kv: kv_utils.KVCacheEntry,
|
53
|
+
input_pos: torch.Tensor,
|
54
|
+
mask: torch.Tensor,
|
55
|
+
config: cfg.AttentionConfig,
|
53
56
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
54
|
-
query = kwargs["query"]
|
55
|
-
key = kwargs["key"]
|
56
|
-
value = kwargs["value"]
|
57
|
-
kv = kwargs["kv"]
|
58
|
-
input_pos = kwargs["input_pos"]
|
59
|
-
mask = kwargs["mask"]
|
60
|
-
config = kwargs["config"]
|
61
|
-
|
62
57
|
# Transpose k/v to specific layout for GPU implementation.
|
63
58
|
b, seq_len, n, h = query.shape
|
64
59
|
g = n // config.num_query_groups
|
@@ -74,9 +69,8 @@ def sdpa_with_kv_update_impl(
|
|
74
69
|
1, -1, config.head_dim, seq_len
|
75
70
|
) # 1, bk, h, s
|
76
71
|
|
77
|
-
|
78
|
-
|
79
|
-
key, value = kv.k_cache, kv.v_cache
|
72
|
+
kv = kv_utils_experimental.update(kv, input_pos, key, value)
|
73
|
+
key, value = kv.k_cache, kv.v_cache
|
80
74
|
|
81
75
|
sdpa_out = sdpa.scaled_dot_product_attention(
|
82
76
|
kv,
|
@@ -95,24 +89,26 @@ def sdpa_with_kv_update_impl(
|
|
95
89
|
return sdpa_out, kv
|
96
90
|
|
97
91
|
|
98
|
-
|
99
|
-
|
100
|
-
|
92
|
+
def _sdpa_with_kv_update_default(
|
93
|
+
query: torch.Tensor,
|
94
|
+
key: torch.Tensor,
|
95
|
+
value: torch.Tensor,
|
96
|
+
kv: kv_utils.KVCacheEntry,
|
97
|
+
input_pos: torch.Tensor,
|
98
|
+
mask: torch.Tensor,
|
99
|
+
config: cfg.AttentionConfig,
|
100
|
+
enable_hlfb: bool,
|
101
101
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
102
|
-
query = kwargs["query"]
|
103
|
-
key = kwargs["key"]
|
104
|
-
value = kwargs["value"]
|
105
|
-
kv = kwargs["kv"]
|
106
|
-
input_pos = kwargs["input_pos"]
|
107
|
-
mask = kwargs["mask"]
|
108
|
-
config = kwargs["config"]
|
109
|
-
|
110
102
|
b, seq_len, _, _ = query.shape
|
111
103
|
if kv is not None:
|
112
104
|
kv = kv_utils.update(kv, input_pos, key, value)
|
113
105
|
key, value = kv.k_cache, kv.v_cache
|
114
106
|
|
115
|
-
|
107
|
+
if enable_hlfb:
|
108
|
+
sdpa_func = sdpa_default.scaled_dot_product_attention_with_hlfb
|
109
|
+
else:
|
110
|
+
sdpa_func = sdpa_default.scaled_dot_product_attention
|
111
|
+
sdpa_out = sdpa_func(
|
116
112
|
query,
|
117
113
|
key,
|
118
114
|
value,
|
@@ -41,7 +41,7 @@ class TestModelConversion(googletest.TestCase):
|
|
41
41
|
)
|
42
42
|
)
|
43
43
|
|
44
|
-
def _get_params(self, enable_hlfb: bool):
|
44
|
+
def _get_params(self, enable_hlfb: bool, kv_layout: kv_cache.KVLayout):
|
45
45
|
"""Returns a model, edge model and the kwargs to use for testing."""
|
46
46
|
config = toy_model_with_kv_cache.get_model_config()
|
47
47
|
config.enable_hlfb = enable_hlfb
|
@@ -49,7 +49,7 @@ class TestModelConversion(googletest.TestCase):
|
|
49
49
|
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
50
50
|
[10], dtype=torch.int
|
51
51
|
)
|
52
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
52
|
+
kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
|
53
53
|
kwargs = {
|
54
54
|
"tokens": tokens,
|
55
55
|
"input_pos": input_pos,
|
@@ -65,8 +65,12 @@ class TestModelConversion(googletest.TestCase):
|
|
65
65
|
)
|
66
66
|
return pytorch_model, edge_model, kwargs
|
67
67
|
|
68
|
-
def _test_model_with_kv_cache(
|
69
|
-
|
68
|
+
def _test_model_with_kv_cache(
|
69
|
+
self,
|
70
|
+
enable_hlfb: bool = False,
|
71
|
+
kv_layout: kv_cache.KVLayout = kv_cache.KV_LAYOUT_DEFAULT,
|
72
|
+
):
|
73
|
+
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb, kv_layout)
|
70
74
|
|
71
75
|
self.assertTrue(
|
72
76
|
test_utils.compare_tflite_torch(
|
@@ -95,13 +99,22 @@ class TestModelConversion(googletest.TestCase):
|
|
95
99
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
96
100
|
self._test_model_with_kv_cache(enable_hlfb=True)
|
97
101
|
|
102
|
+
@googletest.skipIf(
|
103
|
+
ai_edge_torch.config.in_oss,
|
104
|
+
reason="tests with custom ops are not supported in oss",
|
105
|
+
)
|
106
|
+
def test_toy_model_with_kv_cache_transposed(self):
|
107
|
+
self._test_model_with_kv_cache(kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED)
|
108
|
+
|
98
109
|
@googletest.skipIf(
|
99
110
|
ai_edge_torch.config.in_oss,
|
100
111
|
reason="tests with custom ops are not supported in oss",
|
101
112
|
)
|
102
113
|
def test_toy_model_has_dus_op(self):
|
103
114
|
"""Tests that the model has the dynamic update slice op."""
|
104
|
-
_, edge_model, _ = self._get_params(
|
115
|
+
_, edge_model, _ = self._get_params(
|
116
|
+
enable_hlfb=True, kv_layout=kv_cache.KV_LAYOUT_DEFAULT
|
117
|
+
)
|
105
118
|
interpreter_ = interpreter.InterpreterWithCustomOps(
|
106
119
|
custom_op_registerers=["GenAIOpsRegisterer"],
|
107
120
|
model_content=edge_model.tflite_model(),
|
@@ -112,7 +125,14 @@ class TestModelConversion(googletest.TestCase):
|
|
112
125
|
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
113
126
|
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
|
114
127
|
|
115
|
-
def _test_multisig_model(
|
128
|
+
def _test_multisig_model(
|
129
|
+
self,
|
130
|
+
config,
|
131
|
+
pytorch_model,
|
132
|
+
atol,
|
133
|
+
rtol,
|
134
|
+
kv_layout=kv_cache.KV_LAYOUT_DEFAULT,
|
135
|
+
):
|
116
136
|
# prefill
|
117
137
|
seq_len = 10
|
118
138
|
prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
@@ -124,7 +144,7 @@ class TestModelConversion(googletest.TestCase):
|
|
124
144
|
decode_token = torch.tensor([[1]], dtype=torch.int)
|
125
145
|
decode_input_pos = torch.tensor([5], dtype=torch.int)
|
126
146
|
|
127
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
147
|
+
kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
|
128
148
|
|
129
149
|
edge_model = (
|
130
150
|
ai_edge_torch.signature(
|
@@ -160,7 +180,7 @@ class TestModelConversion(googletest.TestCase):
|
|
160
180
|
kv,
|
161
181
|
signature_name="prefill",
|
162
182
|
atol=atol,
|
163
|
-
rtol=
|
183
|
+
rtol=rtol,
|
164
184
|
)
|
165
185
|
)
|
166
186
|
|
@@ -173,7 +193,7 @@ class TestModelConversion(googletest.TestCase):
|
|
173
193
|
kv,
|
174
194
|
signature_name="decode",
|
175
195
|
atol=atol,
|
176
|
-
rtol=
|
196
|
+
rtol=rtol,
|
177
197
|
)
|
178
198
|
)
|
179
199
|
|
@@ -186,6 +206,21 @@ class TestModelConversion(googletest.TestCase):
|
|
186
206
|
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
187
207
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
188
208
|
|
209
|
+
@googletest.skipIf(
|
210
|
+
ai_edge_torch.config.in_oss,
|
211
|
+
reason="tests with custom ops are not supported in oss",
|
212
|
+
)
|
213
|
+
def test_tiny_llama_multisig_kv_layout_transposed(self):
|
214
|
+
config = tiny_llama.get_fake_model_config()
|
215
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
216
|
+
self._test_multisig_model(
|
217
|
+
config,
|
218
|
+
pytorch_model,
|
219
|
+
atol=1e-5,
|
220
|
+
rtol=1e-5,
|
221
|
+
kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED,
|
222
|
+
)
|
223
|
+
|
189
224
|
|
190
225
|
if __name__ == "__main__":
|
191
226
|
googletest.main()
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250424
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=Nixp49eAXZPPMWEWkqpm_M4Mi_WGPx-I8q2noKuh0hw,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=dOr3TUfF0UCvkmlUrMqKvgaN4jh3lJ9XFuO-sHaAmIw,5521
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -67,7 +67,7 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-
|
|
67
67
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
68
68
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
69
69
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
|
70
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
70
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
|
71
71
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
|
72
72
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
73
73
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
@@ -150,7 +150,7 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1t
|
|
150
150
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
|
151
151
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
152
152
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
153
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
153
|
+
ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWnQ8npEAfgcjMIkEY,12964
|
154
154
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
155
155
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
156
156
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
@@ -160,9 +160,8 @@ ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQ
|
|
160
160
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
161
161
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
162
162
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
163
|
-
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=
|
163
|
+
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=D4rATT2Ppa9Su7yuRHYnQPJ1dFvUDAyH1GrFnCed7p8,3810
|
164
164
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
165
|
-
ai_edge_torch/generative/layers/experimental/attention.py,sha256=XYbo1KlmiMEuwArye0Ul86jEsdxLr1RG-usRpidZiT8,8001
|
166
165
|
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
|
167
166
|
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFcIGOkaNb-vvQKjI-G9-bC2Z1W0O_qRyIZPlsLl72U,2797
|
168
167
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -181,7 +180,7 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvI
|
|
181
180
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
|
182
181
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
183
182
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
184
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
183
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jSNJ0Eex6VYCkGn3FXbCOOJ2S3-F_QuwJctu3VycjR4,7200
|
185
184
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3GyBn4k7BWVgyGzrbcL30Su3nxZYLtwkCs,14787
|
186
185
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
187
186
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
@@ -245,8 +244,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
245
244
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
246
245
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
247
246
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
-
ai_edge_torch_nightly-0.5.0.
|
252
|
-
ai_edge_torch_nightly-0.5.0.
|
247
|
+
ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
248
|
+
ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/METADATA,sha256=Gz8c2qvL6qiK7lrd001P55TXltKdycDvDaAq4d4Y-eQ,2051
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/RECORD,,
|
@@ -1,231 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Common building blocks for a GPU-specific Attention layer.
|
17
|
-
|
18
|
-
This is a temporary implemenation for the GPU. It is subject to change/removal
|
19
|
-
at any time.
|
20
|
-
"""
|
21
|
-
|
22
|
-
from typing import Optional, Tuple, Union
|
23
|
-
|
24
|
-
from ai_edge_torch.generative.layers import builder
|
25
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
|
-
from ai_edge_torch.generative.layers import lora as lora_utils
|
27
|
-
from ai_edge_torch.generative.layers import sdpa_with_kv_update
|
28
|
-
import ai_edge_torch.generative.layers.model_config as cfg
|
29
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
30
|
-
import torch
|
31
|
-
from torch import nn
|
32
|
-
|
33
|
-
|
34
|
-
class TransformerBlock(nn.Module):
|
35
|
-
|
36
|
-
def __init__(
|
37
|
-
self,
|
38
|
-
config: cfg.TransformerBlockConfig,
|
39
|
-
model_config: cfg.ModelConfig,
|
40
|
-
) -> None:
|
41
|
-
"""Initialize an instance of the TransformerBlock.
|
42
|
-
|
43
|
-
Args:
|
44
|
-
config (cfg.TransformerBlockConfig): the configuration object for this
|
45
|
-
transformer block.
|
46
|
-
model_config (cfg.ModelConfig): the configuration object for the model
|
47
|
-
this transformer block belongs to.
|
48
|
-
"""
|
49
|
-
super().__init__()
|
50
|
-
self.pre_atten_norm = builder.build_norm(
|
51
|
-
model_config.embedding_dim,
|
52
|
-
config.pre_attention_norm_config,
|
53
|
-
)
|
54
|
-
self.atten_func = CausalSelfAttention(
|
55
|
-
model_config.embedding_dim,
|
56
|
-
config.attn_config,
|
57
|
-
model_config.enable_hlfb,
|
58
|
-
)
|
59
|
-
self.post_atten_norm = builder.build_norm(
|
60
|
-
model_config.embedding_dim,
|
61
|
-
config.post_attention_norm_config,
|
62
|
-
)
|
63
|
-
self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
|
64
|
-
self.config = config
|
65
|
-
|
66
|
-
def forward(
|
67
|
-
self,
|
68
|
-
x: torch.Tensor,
|
69
|
-
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
70
|
-
mask: Optional[torch.Tensor] = None,
|
71
|
-
input_pos: Optional[torch.Tensor] = None,
|
72
|
-
kv_cache: kv_utils.KVCacheEntry = None,
|
73
|
-
lora: Optional[lora_utils.LoRAEntry] = None,
|
74
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
75
|
-
"""Forward function of the TransformerBlock.
|
76
|
-
|
77
|
-
Args:
|
78
|
-
x (torch.Tensor): the input tensor.
|
79
|
-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
80
|
-
mask (torch.Tensor): the optional mask tensor.
|
81
|
-
input_pos (torch.Tensor): the optional input position tensor.
|
82
|
-
kv_cache (KVCacheEntry): the optional kv cache entry.
|
83
|
-
lora (LoRAEntry): the optional lora entry.
|
84
|
-
|
85
|
-
Returns:
|
86
|
-
output activation from this transformer block, and updated kv cache (if
|
87
|
-
passed in).
|
88
|
-
"""
|
89
|
-
kv = None
|
90
|
-
if self.config.parallel_residual:
|
91
|
-
x_norm = self.pre_atten_norm(x)
|
92
|
-
atten_func_out = self.atten_func(
|
93
|
-
x_norm, rope, mask, input_pos, kv_cache, lora
|
94
|
-
)
|
95
|
-
if kv_cache is None:
|
96
|
-
attn_out = atten_func_out
|
97
|
-
else:
|
98
|
-
attn_out, kv = atten_func_out
|
99
|
-
ff_out = self.ff(x_norm)
|
100
|
-
output = x + attn_out + ff_out
|
101
|
-
else:
|
102
|
-
x_norm = self.pre_atten_norm(x)
|
103
|
-
atten_func_out = self.atten_func(
|
104
|
-
x_norm, rope, mask, input_pos, kv_cache, lora
|
105
|
-
)
|
106
|
-
if kv_cache is None:
|
107
|
-
attn_out = atten_func_out
|
108
|
-
else:
|
109
|
-
attn_out, kv = atten_func_out
|
110
|
-
x = x + attn_out
|
111
|
-
x_norm = self.post_atten_norm(x)
|
112
|
-
output = x + self.ff(x_norm)
|
113
|
-
|
114
|
-
return output if kv is None else (output, kv)
|
115
|
-
|
116
|
-
|
117
|
-
class CausalSelfAttention(nn.Module):
|
118
|
-
|
119
|
-
def __init__(
|
120
|
-
self,
|
121
|
-
dim: int,
|
122
|
-
config: cfg.AttentionConfig,
|
123
|
-
enable_hlfb: bool,
|
124
|
-
) -> None:
|
125
|
-
"""Initialize an instance of CausalSelfAttention.
|
126
|
-
|
127
|
-
Args:
|
128
|
-
dim (int): causal attention's input/output dimmension.
|
129
|
-
config (cfg.AttentionConfig): attention specific configurations.
|
130
|
-
enable_hlfb (bool): whether hlfb is enabled or not.
|
131
|
-
"""
|
132
|
-
super().__init__()
|
133
|
-
self.kv_cache = None
|
134
|
-
qkv_shape = (
|
135
|
-
config.num_heads + 2 * config.num_query_groups
|
136
|
-
) * config.head_dim
|
137
|
-
output_shape = config.num_heads * config.head_dim
|
138
|
-
# Key, query, value projections for all heads.
|
139
|
-
self.qkv_projection = nn.Linear(dim, qkv_shape, bias=config.qkv_use_bias)
|
140
|
-
self.output_projection = nn.Linear(
|
141
|
-
output_shape, dim, bias=config.output_proj_use_bias
|
142
|
-
)
|
143
|
-
self.query_norm = builder.build_norm(
|
144
|
-
config.head_dim, config.query_norm_config
|
145
|
-
)
|
146
|
-
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
|
147
|
-
self.config = config
|
148
|
-
self.enable_hlfb = enable_hlfb
|
149
|
-
|
150
|
-
def forward(
|
151
|
-
self,
|
152
|
-
x: torch.Tensor,
|
153
|
-
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
154
|
-
mask: Optional[torch.Tensor] = None,
|
155
|
-
input_pos: Optional[torch.Tensor] = None,
|
156
|
-
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
157
|
-
lora: Optional[lora_utils.LoRAEntry] = None,
|
158
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
159
|
-
"""Forward function of the CausalSelfAttention layer, which can support
|
160
|
-
|
161
|
-
MQA, GQA and MHA.
|
162
|
-
|
163
|
-
Args:
|
164
|
-
x (torch.Tensor): the input tensor.
|
165
|
-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
166
|
-
mask (torch.Tensor): the optional mask tensor.
|
167
|
-
input_pos (torch.Tensor): the optional input position tensor.
|
168
|
-
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
169
|
-
lora (LoRAEntry): the optional lora entry.
|
170
|
-
|
171
|
-
Returns:
|
172
|
-
output activation from this self attention layer, and the updated
|
173
|
-
KV Cach Entry (if passed in).
|
174
|
-
"""
|
175
|
-
# Batch size, sequence length, embedding dimensionality.
|
176
|
-
B, T, E = x.size()
|
177
|
-
|
178
|
-
qkv = self.qkv_projection(x)
|
179
|
-
|
180
|
-
# Assemble into a number of query groups to support MHA, MQA and GQA.
|
181
|
-
q_per_kv = self.config.num_heads // self.config.num_query_groups
|
182
|
-
# Each group has >=1 queries, 1 key, and 1 value.
|
183
|
-
if self.config.qkv_transpose_before_split:
|
184
|
-
qkv = qkv.view(B, T, -1, self.config.head_dim)
|
185
|
-
q, k, v = qkv.split(
|
186
|
-
(
|
187
|
-
q_per_kv * self.config.num_query_groups,
|
188
|
-
self.config.num_query_groups,
|
189
|
-
self.config.num_query_groups,
|
190
|
-
),
|
191
|
-
dim=-2,
|
192
|
-
)
|
193
|
-
else:
|
194
|
-
qkv = qkv.view(B, T, self.config.num_query_groups, -1)
|
195
|
-
q, k, v = qkv.split(
|
196
|
-
(
|
197
|
-
q_per_kv * self.config.head_dim,
|
198
|
-
self.config.head_dim,
|
199
|
-
self.config.head_dim,
|
200
|
-
),
|
201
|
-
dim=-1,
|
202
|
-
)
|
203
|
-
|
204
|
-
if lora is not None:
|
205
|
-
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
206
|
-
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
207
|
-
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
208
|
-
|
209
|
-
q = self.query_norm(q)
|
210
|
-
k = self.key_norm(k)
|
211
|
-
|
212
|
-
q = q.reshape(B, T, -1, self.config.head_dim)
|
213
|
-
k = k.reshape(B, T, -1, self.config.head_dim)
|
214
|
-
v = v.reshape(B, T, -1, self.config.head_dim)
|
215
|
-
|
216
|
-
if rope is not None:
|
217
|
-
# Compute rotary positional embedding for query and key.
|
218
|
-
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
219
|
-
cos, sin = rope
|
220
|
-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
221
|
-
|
222
|
-
sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
|
223
|
-
q, k, v, kv_cache, input_pos, mask, self.config
|
224
|
-
)
|
225
|
-
|
226
|
-
# Compute the output projection.
|
227
|
-
y = self.output_projection(sdpa_out)
|
228
|
-
if lora is not None:
|
229
|
-
y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
|
230
|
-
|
231
|
-
return y if kv_cache is None else (y, kv_cache)
|
File without changes
|
File without changes
|