x-transformers 1.42.5__py3-none-any.whl → 1.42.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +22 -5
- {x_transformers-1.42.5.dist-info → x_transformers-1.42.7.dist-info}/METADATA +1 -1
- {x_transformers-1.42.5.dist-info → x_transformers-1.42.7.dist-info}/RECORD +6 -6
- {x_transformers-1.42.5.dist-info → x_transformers-1.42.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.5.dist-info → x_transformers-1.42.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.5.dist-info → x_transformers-1.42.7.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1055,7 +1055,12 @@ class Attention(Module):
|
|
1055
1055
|
logit_softclamp_value = 50.,
|
1056
1056
|
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
1057
1057
|
neutreno_alpha = 0.4,
|
1058
|
-
onnxable = False
|
1058
|
+
onnxable = False,
|
1059
|
+
attend_sdp_kwargs: dict = dict(
|
1060
|
+
enable_flash = True,
|
1061
|
+
enable_math = True,
|
1062
|
+
enable_mem_efficient = True
|
1063
|
+
)
|
1059
1064
|
):
|
1060
1065
|
super().__init__()
|
1061
1066
|
dim_kv = default(dim_context, dim)
|
@@ -1188,7 +1193,8 @@ class Attention(Module):
|
|
1188
1193
|
softclamp_logits = softclamp_logits,
|
1189
1194
|
logit_softclamp_value = logit_softclamp_value,
|
1190
1195
|
cope = cope,
|
1191
|
-
onnxable = onnxable
|
1196
|
+
onnxable = onnxable,
|
1197
|
+
sdp_kwargs = attend_sdp_kwargs
|
1192
1198
|
)
|
1193
1199
|
|
1194
1200
|
# head scaling
|
@@ -1240,6 +1246,7 @@ class Attention(Module):
|
|
1240
1246
|
rel_pos = None,
|
1241
1247
|
attn_bias = None,
|
1242
1248
|
rotary_pos_emb = None,
|
1249
|
+
pos = None, # for custom alibi positions
|
1243
1250
|
prev_attn = None,
|
1244
1251
|
mem = None,
|
1245
1252
|
mem_mask = None,
|
@@ -1386,7 +1393,14 @@ class Attention(Module):
|
|
1386
1393
|
|
1387
1394
|
if exists(rel_pos):
|
1388
1395
|
assert not exists(attn_bias)
|
1389
|
-
|
1396
|
+
|
1397
|
+
if exists(pos):
|
1398
|
+
assert isinstance(rel_pos, AlibiPositionalBias), 'only alibi allowed for custom positions at the moment'
|
1399
|
+
# allow for custom positions to be passed in
|
1400
|
+
attn_bias = rel_pos.forward_custom_pos(pos)
|
1401
|
+
else:
|
1402
|
+
attn_bias = rel_pos(i, j)
|
1403
|
+
|
1390
1404
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
|
1391
1405
|
|
1392
1406
|
# prepare data dependent alibi from forgetting transformers paper, if needed
|
@@ -1837,6 +1851,7 @@ class AttentionLayers(Module):
|
|
1837
1851
|
cache_age = 1,
|
1838
1852
|
return_hiddens = False,
|
1839
1853
|
rotary_pos_emb = None,
|
1854
|
+
pos = None,
|
1840
1855
|
attn_bias = None,
|
1841
1856
|
condition = None,
|
1842
1857
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
@@ -1900,7 +1915,9 @@ class AttentionLayers(Module):
|
|
1900
1915
|
maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
|
1901
1916
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
1902
1917
|
|
1903
|
-
|
1918
|
+
if not exists(pos):
|
1919
|
+
pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
1920
|
+
|
1904
1921
|
rotary_pos_emb = self.rotary_pos_emb(pos)
|
1905
1922
|
|
1906
1923
|
# assume cached key / values
|
@@ -2024,7 +2041,7 @@ class AttentionLayers(Module):
|
|
2024
2041
|
# forward depending on layer type
|
2025
2042
|
|
2026
2043
|
if layer_type == 'a':
|
2027
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2044
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2028
2045
|
elif layer_type == 'c':
|
2029
2046
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
|
2030
2047
|
elif layer_type == 'f':
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=6jXSMHViCU64gLMbxRJ6C8bgcLrPFbT-m-fhtusqq3g,93117
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.7.dist-info/METADATA,sha256=tM7s2gIMFH8hy_YZY84BhZ-yUoH6PTyjusK0dMOpTN8,689
|
14
|
+
x_transformers-1.42.7.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
15
|
+
x_transformers-1.42.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|