ai-edge-torch-nightly 0.7.0.dev20251012__py3-none-any.whl → 0.7.0.dev20251013__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/layers/attention.py +25 -2
- ai_edge_torch/generative/layers/attention_test.py +13 -1
- ai_edge_torch/generative/layers/attention_utils.py +62 -1
- ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +4 -0
- ai_edge_torch/generative/layers/normalization.py +8 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
- ai_edge_torch/generative/utilities/loader.py +2 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251012.dist-info → ai_edge_torch_nightly-0.7.0.dev20251013.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251012.dist-info → ai_edge_torch_nightly-0.7.0.dev20251013.dist-info}/RECORD +16 -16
- {ai_edge_torch_nightly-0.7.0.dev20251012.dist-info → ai_edge_torch_nightly-0.7.0.dev20251013.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20251012.dist-info → ai_edge_torch_nightly-0.7.0.dev20251013.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20251012.dist-info → ai_edge_torch_nightly-0.7.0.dev20251013.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@
|
|
18
18
|
import abc
|
19
19
|
from typing import Optional, Tuple, Union
|
20
20
|
|
21
|
+
from ai_edge_torch.generative.layers import attention_utils
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
22
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
24
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
@@ -240,13 +241,35 @@ class CausalSelfAttention(CausalSelfAttentionBase):
|
|
240
241
|
k = k.reshape(B, T, -1, self.config.head_dim)
|
241
242
|
v = v.reshape(B, T, -1, self.config.head_dim)
|
242
243
|
|
243
|
-
|
244
|
+
alibi_bias = None
|
245
|
+
if self.config.use_alibi:
|
246
|
+
k_size = T
|
247
|
+
if mask is not None:
|
248
|
+
k_size = mask.shape[-1]
|
249
|
+
elif input_pos is not None:
|
250
|
+
# If mask is not present, assume current sequence length is key length.
|
251
|
+
k_size = input_pos[-1].item() + 1
|
252
|
+
alibi_bias = attention_utils.build_alibi_bias(
|
253
|
+
n_heads=self.config.num_heads,
|
254
|
+
k_size=k_size,
|
255
|
+
dtype=x.dtype,
|
256
|
+
device=x.device,
|
257
|
+
)
|
258
|
+
elif rope is not None:
|
244
259
|
# Compute rotary positional embedding for query and key.
|
245
260
|
cos, sin = rope
|
246
261
|
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
247
262
|
|
248
263
|
sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
|
249
|
-
q,
|
264
|
+
q,
|
265
|
+
k,
|
266
|
+
v,
|
267
|
+
kv_cache,
|
268
|
+
input_pos,
|
269
|
+
mask,
|
270
|
+
self.config,
|
271
|
+
self.enable_hlfb,
|
272
|
+
alibi_bias=alibi_bias,
|
250
273
|
)
|
251
274
|
|
252
275
|
# Compute the output projection.
|
@@ -27,16 +27,27 @@ class AttentionTest(parameterized.TestCase):
|
|
27
27
|
dict(
|
28
28
|
testcase_name="local_causal_self_attention",
|
29
29
|
attn_type=cfg.AttentionType.LOCAL_SLIDING,
|
30
|
+
use_alibi=False,
|
30
31
|
expected_shape=(1, 10, 16),
|
31
32
|
),
|
32
33
|
dict(
|
33
34
|
testcase_name="global_causal_self_attention",
|
34
35
|
attn_type=cfg.AttentionType.GLOBAL,
|
36
|
+
use_alibi=False,
|
37
|
+
expected_shape=(1, 10, 16),
|
38
|
+
),
|
39
|
+
dict(
|
40
|
+
testcase_name="alibi_attention",
|
41
|
+
attn_type=cfg.AttentionType.GLOBAL,
|
42
|
+
use_alibi=True,
|
35
43
|
expected_shape=(1, 10, 16),
|
36
44
|
),
|
37
45
|
)
|
38
46
|
def test_causal_self_attention(
|
39
|
-
self,
|
47
|
+
self,
|
48
|
+
attn_type: cfg.AttentionType,
|
49
|
+
use_alibi: bool,
|
50
|
+
expected_shape: tuple[int, ...],
|
40
51
|
):
|
41
52
|
norm_config = cfg.NormalizationConfig(
|
42
53
|
type=cfg.NormalizationType.RMS_NORM,
|
@@ -56,6 +67,7 @@ class AttentionTest(parameterized.TestCase):
|
|
56
67
|
logit_softcap=None,
|
57
68
|
sliding_window_size=16,
|
58
69
|
attn_type=attn_type,
|
70
|
+
use_alibi=use_alibi,
|
59
71
|
)
|
60
72
|
self_atten = attention.CausalSelfAttention(
|
61
73
|
dim=16,
|
@@ -15,11 +15,72 @@
|
|
15
15
|
# Common utility functions used with attention module.
|
16
16
|
|
17
17
|
import math
|
18
|
-
from typing import Tuple
|
18
|
+
from typing import List, Tuple
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
|
23
|
+
def _get_alibi_slopes(n_heads: int) -> List[float]:
|
24
|
+
"""Returns slopes for ALiBi implementation.
|
25
|
+
|
26
|
+
The slopes are taken from the ALiBi paper
|
27
|
+
[https://arxiv.org/abs/2108.12409].
|
28
|
+
The slopes are later used to calculate the bias which is added to the
|
29
|
+
attention scores.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
n_heads (int): The number of attention heads.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def get_slopes_power_of_2(n):
|
36
|
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
37
|
+
return [start**i for i in range(1, n + 1)]
|
38
|
+
|
39
|
+
if math.log2(n_heads).is_integer():
|
40
|
+
return get_slopes_power_of_2(n_heads)
|
41
|
+
else:
|
42
|
+
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
|
43
|
+
return (
|
44
|
+
get_slopes_power_of_2(closest_power_of_2)
|
45
|
+
+ _get_alibi_slopes(2 * closest_power_of_2)[0::2][
|
46
|
+
: n_heads - closest_power_of_2
|
47
|
+
]
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
def build_alibi_bias(
|
52
|
+
n_heads: int,
|
53
|
+
k_size: int,
|
54
|
+
dtype: torch.dtype = torch.float32,
|
55
|
+
device: torch.device = None,
|
56
|
+
) -> torch.Tensor:
|
57
|
+
"""Builds ALiBi bias tensor based on key position.
|
58
|
+
|
59
|
+
The bias tensor is added to the attention scores before softmax.
|
60
|
+
Replicates HuggingFace Falcon implementation behavior where bias only depends
|
61
|
+
on key position j, not relative position j-i.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
n_heads (int): The number of attention heads.
|
65
|
+
k_size (int): The key size of the bias tensor.
|
66
|
+
dtype (torch.dtype, optional): Output tensor's data type. Defaults to
|
67
|
+
torch.float32.
|
68
|
+
device (torch.device, optional): Output tensor's data type. Defaults to
|
69
|
+
None in which case "cpu" is used.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
torch.Tensor: The ALiBi bias tensor of shape (1, n_heads, 1, k_size).
|
73
|
+
"""
|
74
|
+
if device is None:
|
75
|
+
device = torch.device('cpu')
|
76
|
+
slopes = torch.tensor(_get_alibi_slopes(n_heads), dtype=dtype, device=device)
|
77
|
+
k_pos = torch.arange(k_size, device=device)
|
78
|
+
# According to HF implementation, bias only depends on key position.
|
79
|
+
# slopes[h] * k_pos[j]
|
80
|
+
alibi_bias = slopes.unsqueeze(-1) * k_pos.unsqueeze(0) # Shape: H, K
|
81
|
+
return alibi_bias[None, :, None, :].to(dtype)
|
82
|
+
|
83
|
+
|
23
84
|
def build_rope_cache(
|
24
85
|
size: int,
|
25
86
|
dim: int,
|
@@ -21,6 +21,26 @@ from absl.testing import absltest as googletest
|
|
21
21
|
|
22
22
|
class AttentionUtilsTest(googletest.TestCase):
|
23
23
|
|
24
|
+
def test_get_alibi_slopes(self):
|
25
|
+
slopes = attention_utils._get_alibi_slopes(1)
|
26
|
+
self.assertSequenceAlmostEqual(slopes, [0.00390625], places=6)
|
27
|
+
slopes = attention_utils._get_alibi_slopes(2)
|
28
|
+
self.assertSequenceAlmostEqual(slopes, [0.0625, 0.00390625], places=6)
|
29
|
+
slopes = attention_utils._get_alibi_slopes(4)
|
30
|
+
self.assertSequenceAlmostEqual(
|
31
|
+
slopes, [0.25, 0.0625, 0.015625, 0.00390625], places=6
|
32
|
+
)
|
33
|
+
slopes = attention_utils._get_alibi_slopes(3)
|
34
|
+
self.assertSequenceAlmostEqual(slopes, [0.0625, 0.00390625, 0.25], places=6)
|
35
|
+
|
36
|
+
def test_build_alibi_bias(self):
|
37
|
+
bias = attention_utils.build_alibi_bias(n_heads=2, k_size=3)
|
38
|
+
self.assertEqual(bias.shape, (1, 2, 1, 3))
|
39
|
+
expected = torch.tensor(
|
40
|
+
[[[[0.0, 0.0625, 0.125]], [[0.0, 0.00390625, 0.0078125]]]]
|
41
|
+
)
|
42
|
+
torch.testing.assert_close(bias, expected)
|
43
|
+
|
24
44
|
def test_build_causal_mask_cache(self):
|
25
45
|
mask = attention_utils.build_causal_mask_cache(3)
|
26
46
|
self.assertEqual(mask.shape, (1, 1, 3, 3))
|
@@ -71,7 +71,7 @@ def build_norm(
|
|
71
71
|
Raises:
|
72
72
|
ValueError: If config's `layer_norm_type` is not supported.
|
73
73
|
"""
|
74
|
-
if config.type == cfg.NormalizationType.NONE:
|
74
|
+
if config is None or config.type == cfg.NormalizationType.NONE:
|
75
75
|
return lambda x: x
|
76
76
|
elif config.type == cfg.NormalizationType.RMS_NORM:
|
77
77
|
return normalization.RMSNorm(
|
@@ -84,7 +84,9 @@ def build_norm(
|
|
84
84
|
init_fn=init_fn,
|
85
85
|
)
|
86
86
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
87
|
-
return normalization.LayerNorm(
|
87
|
+
return normalization.LayerNorm(
|
88
|
+
dim, config.epsilon, config.use_bias, config.enable_hlfb
|
89
|
+
)
|
88
90
|
elif config.type == cfg.NormalizationType.GROUP_NORM:
|
89
91
|
return normalization.GroupNorm(
|
90
92
|
config.group_num, dim, config.epsilon, config.enable_hlfb
|
@@ -75,6 +75,8 @@ class NormalizationConfig:
|
|
75
75
|
scale_shift: float = 0.0
|
76
76
|
# Number of groups used in group normalization.
|
77
77
|
group_num: Optional[float] = None
|
78
|
+
# Whether to use bias in norm.
|
79
|
+
use_bias: bool = True
|
78
80
|
|
79
81
|
|
80
82
|
# Exprimental feature and may subject to change.
|
@@ -108,6 +110,8 @@ class AttentionConfig:
|
|
108
110
|
rotary_base: int = 10_000
|
109
111
|
# Percentage of Rotary Positional Embedding added Q and K projections.
|
110
112
|
rotary_percentage: Optional[float] = None
|
113
|
+
# Whether to use ALiBi positional encoding.
|
114
|
+
use_alibi: bool = False
|
111
115
|
# Whether to transpose the query groups of qkv bundled tensor before
|
112
116
|
# splitting into separated tensors.
|
113
117
|
qkv_transpose_before_split: bool = False
|
@@ -148,6 +148,7 @@ class LayerNorm(torch.nn.Module):
|
|
148
148
|
self,
|
149
149
|
dim: int,
|
150
150
|
eps: float = 1e-5,
|
151
|
+
use_bias: bool = True,
|
151
152
|
enable_hlfb: bool = False,
|
152
153
|
):
|
153
154
|
"""Initialize the LayerNorm layer.
|
@@ -156,6 +157,7 @@ class LayerNorm(torch.nn.Module):
|
|
156
157
|
dim (int): dimension of the input tensor.
|
157
158
|
eps (float): A small float value to ensure numerical stability (default:
|
158
159
|
1e-5).
|
160
|
+
use_bias (bool): Whether to use bias in LayerNorm.
|
159
161
|
enable_hlfb (bool): Whether to convert this normalization into a single
|
160
162
|
op.
|
161
163
|
"""
|
@@ -164,7 +166,11 @@ class LayerNorm(torch.nn.Module):
|
|
164
166
|
self.normalized_shape = (dim,)
|
165
167
|
self.eps = eps
|
166
168
|
self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
167
|
-
self.bias =
|
169
|
+
self.bias = (
|
170
|
+
torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
171
|
+
if use_bias
|
172
|
+
else None
|
173
|
+
)
|
168
174
|
|
169
175
|
def forward(self, x):
|
170
176
|
"""Running the forward pass of LayerNorm layer.
|
@@ -175,7 +181,7 @@ class LayerNorm(torch.nn.Module):
|
|
175
181
|
Returns:
|
176
182
|
torch.Tensor: output tensor after applying LayerNorm.
|
177
183
|
"""
|
178
|
-
if self.enable_hlfb:
|
184
|
+
if self.enable_hlfb and self.bias is not None:
|
179
185
|
return layer_norm_with_hlfb(
|
180
186
|
x, self.normalized_shape, self.weight, self.bias, self.eps
|
181
187
|
)
|
@@ -32,6 +32,7 @@ def scaled_dot_product_attention(
|
|
32
32
|
mask: Optional[torch.Tensor] = None,
|
33
33
|
scale: Optional[float] = None,
|
34
34
|
softcap: Optional[float] = None,
|
35
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
35
36
|
):
|
36
37
|
"""Scaled dot product attention.
|
37
38
|
|
@@ -41,14 +42,23 @@ def scaled_dot_product_attention(
|
|
41
42
|
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
42
43
|
head_size (int): head dimension.
|
43
44
|
mask (torch.Tensor): the optional mask tensor.
|
45
|
+
scale (float): the optional scale factor.
|
46
|
+
softcap (float): the optional softcap for the logits.
|
47
|
+
alibi_bias (torch.Tensor): optional alibi bias tensor.
|
44
48
|
|
45
49
|
Returns:
|
46
50
|
The output tensor of scaled_dot_product_attention.
|
47
51
|
"""
|
48
|
-
|
49
52
|
if scale is None:
|
50
53
|
scale = 1.0 / math.sqrt(head_size)
|
51
54
|
|
55
|
+
if alibi_bias is not None:
|
56
|
+
alibi_bias = alibi_bias * scale
|
57
|
+
if mask is None:
|
58
|
+
mask = alibi_bias
|
59
|
+
else:
|
60
|
+
mask = mask + alibi_bias
|
61
|
+
|
52
62
|
q = q.transpose(1, 2)
|
53
63
|
k = k.transpose(1, 2)
|
54
64
|
v = v.transpose(1, 2)
|
@@ -72,7 +82,8 @@ def scaled_dot_product_attention(
|
|
72
82
|
scores = scores / softcap
|
73
83
|
scores = torch.tanh(scores)
|
74
84
|
scores = scores * softcap
|
75
|
-
|
85
|
+
if mask is not None:
|
86
|
+
scores = scores + mask
|
76
87
|
out = F.softmax(scores.float(), dim=-1).type_as(q)
|
77
88
|
y = torch.matmul(out, v)
|
78
89
|
|
@@ -87,6 +98,7 @@ def scaled_dot_product_attention_with_hlfb(
|
|
87
98
|
mask: Optional[torch.Tensor] = None,
|
88
99
|
scale: Optional[float] = None,
|
89
100
|
softcap: Optional[float] = None,
|
101
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
90
102
|
):
|
91
103
|
"""Scaled dot product attention with high-level function boundary enabled.
|
92
104
|
|
@@ -96,14 +108,23 @@ def scaled_dot_product_attention_with_hlfb(
|
|
96
108
|
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
97
109
|
head_size (int): head dimension.
|
98
110
|
mask (torch.Tensor): the optional mask tensor.
|
111
|
+
scale (float): the optional scale factor.
|
112
|
+
softcap (float): the optional softcap for the logits.
|
113
|
+
alibi_bias (torch.Tensor): optional alibi bias tensor.
|
99
114
|
|
100
115
|
Returns:
|
101
116
|
The output tensor of scaled_dot_product_attention.
|
102
117
|
"""
|
103
|
-
|
104
118
|
if scale is None:
|
105
119
|
scale = 1.0 / math.sqrt(head_size)
|
106
120
|
|
121
|
+
if alibi_bias is not None:
|
122
|
+
alibi_bias = alibi_bias * scale
|
123
|
+
if mask is None:
|
124
|
+
mask = alibi_bias
|
125
|
+
else:
|
126
|
+
mask = mask + alibi_bias
|
127
|
+
|
107
128
|
attrs = {"scale": scale}
|
108
129
|
|
109
130
|
if softcap is not None:
|
@@ -137,7 +158,8 @@ def scaled_dot_product_attention_with_hlfb(
|
|
137
158
|
scores = scores / softcap
|
138
159
|
scores = torch.tanh(scores)
|
139
160
|
scores = scores * softcap
|
140
|
-
|
161
|
+
if mask is not None:
|
162
|
+
scores = scores + mask
|
141
163
|
out = F.softmax(scores.float(), dim=-1).type_as(q)
|
142
164
|
y = torch.matmul(out, v)
|
143
165
|
|
@@ -154,6 +176,7 @@ def scaled_dot_product_attention_transposed(
|
|
154
176
|
mask: Optional[torch.Tensor] = None,
|
155
177
|
scale: Optional[float] = None,
|
156
178
|
softcap: Optional[float] = None,
|
179
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
157
180
|
):
|
158
181
|
"""Scaled dot product attention with transposed key and value.
|
159
182
|
|
@@ -165,14 +188,21 @@ def scaled_dot_product_attention_transposed(
|
|
165
188
|
mask (torch.Tensor): the optional mask tensor.
|
166
189
|
scale (float): the optional scale factor.
|
167
190
|
softcap (float): the optional softcap for the logits.
|
191
|
+
alibi_bias (torch.Tensor): optional alibi bias tensor.
|
168
192
|
|
169
193
|
Returns:
|
170
194
|
The output tensor of scaled_dot_product_attention_transposed.
|
171
195
|
"""
|
172
|
-
|
173
196
|
if scale is None:
|
174
197
|
scale = 1.0 / math.sqrt(head_size)
|
175
198
|
|
199
|
+
if alibi_bias is not None:
|
200
|
+
alibi_bias = alibi_bias * scale
|
201
|
+
if mask is None:
|
202
|
+
mask = alibi_bias
|
203
|
+
else:
|
204
|
+
mask = mask + alibi_bias
|
205
|
+
|
176
206
|
query = query * scale
|
177
207
|
|
178
208
|
assert mask is not None, "Mask should not be None!"
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for data loading etc."""
|
17
17
|
|
18
|
-
from typing import Tuple
|
18
|
+
from typing import Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
21
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
@@ -32,14 +32,15 @@ def sdpa_with_kv_update(
|
|
32
32
|
mask: torch.Tensor,
|
33
33
|
config: cfg.AttentionConfig,
|
34
34
|
enable_hlfb: bool,
|
35
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
35
36
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
36
37
|
"""Wrapper function for scaled dot product attention with KV cache update."""
|
37
38
|
if kv is not None and kv.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED:
|
38
39
|
return _sdpa_with_kv_update_transposed(
|
39
|
-
query, key, value, kv, input_pos, mask, config
|
40
|
+
query, key, value, kv, input_pos, mask, config, alibi_bias
|
40
41
|
)
|
41
42
|
return _sdpa_with_kv_update_default(
|
42
|
-
query, key, value, kv, input_pos, mask, config, enable_hlfb
|
43
|
+
query, key, value, kv, input_pos, mask, config, enable_hlfb, alibi_bias
|
43
44
|
)
|
44
45
|
|
45
46
|
|
@@ -51,6 +52,7 @@ def _sdpa_with_kv_update_transposed(
|
|
51
52
|
input_pos: torch.Tensor,
|
52
53
|
mask: torch.Tensor,
|
53
54
|
config: cfg.AttentionConfig,
|
55
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
54
56
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
55
57
|
# Transpose k/v to specific layout for GPU implementation.
|
56
58
|
b, seq_len, n, h = query.shape
|
@@ -77,6 +79,7 @@ def _sdpa_with_kv_update_transposed(
|
|
77
79
|
config.head_dim,
|
78
80
|
mask=mask,
|
79
81
|
softcap=config.logit_softcap,
|
82
|
+
alibi_bias=alibi_bias,
|
80
83
|
) # 1, bk, gt, h
|
81
84
|
sdpa_out = (
|
82
85
|
sdpa_out.reshape(b, -1, seq_len, h)
|
@@ -95,6 +98,7 @@ def _sdpa_with_kv_update_default(
|
|
95
98
|
mask: torch.Tensor,
|
96
99
|
config: cfg.AttentionConfig,
|
97
100
|
enable_hlfb: bool,
|
101
|
+
alibi_bias: Optional[torch.Tensor] = None,
|
98
102
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
99
103
|
b, seq_len, _, _ = query.shape
|
100
104
|
if kv is not None:
|
@@ -112,6 +116,7 @@ def _sdpa_with_kv_update_default(
|
|
112
116
|
config.head_dim,
|
113
117
|
mask=mask,
|
114
118
|
softcap=config.logit_softcap,
|
119
|
+
alibi_bias=alibi_bias,
|
115
120
|
)
|
116
121
|
sdpa_out = sdpa_out.reshape(b, seq_len, -1)
|
117
122
|
return sdpa_out, kv
|
@@ -135,7 +135,8 @@ def load_pytorch_statedict(full_path: str):
|
|
135
135
|
|
136
136
|
tensors = {}
|
137
137
|
for file in files:
|
138
|
-
|
138
|
+
map_location = "cpu" if not torch.cuda.is_available() else None
|
139
|
+
this_file_tensors = torch.load(file, map_location=map_location)
|
139
140
|
for k in this_file_tensors:
|
140
141
|
assert k not in tensors
|
141
142
|
tensors.update(this_file_tensors)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.7.0.
|
3
|
+
Version: 0.7.0.dev20251013
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=YS7m8oFfpEVS_P7dwwhNyAA1nGOJd7lnZi3I8852GLo,806
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -169,24 +169,24 @@ ai_edge_torch/generative/examples/tiny_llama/verify_util.py,sha256=z6vPBXDWAL6gN
|
|
169
169
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
|
170
170
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
171
171
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
172
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
173
|
-
ai_edge_torch/generative/layers/attention_test.py,sha256=
|
174
|
-
ai_edge_torch/generative/layers/attention_utils.py,sha256=
|
175
|
-
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=
|
176
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
172
|
+
ai_edge_torch/generative/layers/attention.py,sha256=ZjU3vX-7gOq1KQb3xSZ1NT3xryOTXbYb_vkx_DlcizA,14524
|
173
|
+
ai_edge_torch/generative/layers/attention_test.py,sha256=ON9jQRY1r2kFpVq-Qkg6b13Ob95fd4PqHo1hic3RbOQ,5057
|
174
|
+
ai_edge_torch/generative/layers/attention_utils.py,sha256=3Ox1XjW_vaqz1-RuVG9RbzRKUqCberFW8P2BQcoNm7A,9659
|
175
|
+
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=IHIk39wqaPvxmkZtW27VD3_4xUpyFow_7mScf8OWdqU,3292
|
176
|
+
ai_edge_torch/generative/layers/builder.py,sha256=5QL59CbOOW_mk3mlPdcdirGcAxdLee5atbZlnu5Z3ts,5079
|
177
177
|
ai_edge_torch/generative/layers/einsum.py,sha256=LH4CNHr-pFfLUuCpwbYL3GpoAMgHJ4nLju3XCqA4VwM,1416
|
178
178
|
ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
|
179
179
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
|
180
180
|
ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1W4FGh7oC-9UGGyHdKS9tQKc,1880
|
181
181
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=A0IFXZ1HD2ZHOWRLfsDO4almgE0KQfjyBOdBFZIGnAs,10893
|
182
182
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
183
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
184
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
183
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=Pvoa766jIf1LWvRwEDMNce43C9NgPOvIpT30VUcnpqA,10390
|
184
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=syasVh3dRDVp2Nwhl0x7zucL-chTnCqWgeV1mb87DFY,7435
|
185
185
|
ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
|
186
186
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
187
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
187
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=1zhOsJpI4CTn78weOs0uRwkRxYu6wGfBvYVFpGFl0qQ,6681
|
188
188
|
ai_edge_torch/generative/layers/scaled_dot_product_attention_test.py,sha256=c6JBMQsq9XeMmR1XvGEIidNsoh-YIvichXo2LwVHgr4,3301
|
189
|
-
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=
|
189
|
+
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=fK_h9M-03ai5dV8ZyQzvB0y84IKlNg9h-4bt9F6bU0g,3833
|
190
190
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
191
191
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
192
192
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
@@ -211,7 +211,7 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
|
|
211
211
|
ai_edge_torch/generative/utilities/converter.py,sha256=d8pehTq6EzEdVR8ioL2b1ECGTR4G1K1fczc9amu_Oyk,23106
|
212
212
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5B15nYyqf96kjjYlHfPctUfsIdsBsh1f8rxKitJpwKQ,2384
|
213
213
|
ai_edge_torch/generative/utilities/litertlm_builder.py,sha256=0cNuaqhc7cQcAa4NRalUXyoPQUQC9O3-aHAJEDV1Mps,4265
|
214
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
214
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=QQeEu0cTC7gWnB7RkHonjWLdVGjMbDHd1lfYO_TcJyU,16047
|
215
215
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=xBvcTxihB9TN88UtQiXA9sAITQgf-pA77R-VZlLgUeU,6950
|
216
216
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
217
217
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
@@ -270,8 +270,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
270
270
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
271
271
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
272
272
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
273
|
-
ai_edge_torch_nightly-0.7.0.
|
274
|
-
ai_edge_torch_nightly-0.7.0.
|
275
|
-
ai_edge_torch_nightly-0.7.0.
|
276
|
-
ai_edge_torch_nightly-0.7.0.
|
277
|
-
ai_edge_torch_nightly-0.7.0.
|
273
|
+
ai_edge_torch_nightly-0.7.0.dev20251013.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
274
|
+
ai_edge_torch_nightly-0.7.0.dev20251013.dist-info/METADATA,sha256=N7flnuaI5R5i_3F7gRTTt5AM0wUzUgZLRvhpV3XpueQ,2074
|
275
|
+
ai_edge_torch_nightly-0.7.0.dev20251013.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
276
|
+
ai_edge_torch_nightly-0.7.0.dev20251013.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
277
|
+
ai_edge_torch_nightly-0.7.0.dev20251013.dist-info/RECORD,,
|
File without changes
|
File without changes
|