x-transformers 1.42.6__py3-none-any.whl → 1.42.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +14 -3
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.7.dist-info}/METADATA +1 -1
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.7.dist-info}/RECORD +6 -6
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.6.dist-info → x_transformers-1.42.7.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1246,6 +1246,7 @@ class Attention(Module):
|
|
1246
1246
|
rel_pos = None,
|
1247
1247
|
attn_bias = None,
|
1248
1248
|
rotary_pos_emb = None,
|
1249
|
+
pos = None, # for custom alibi positions
|
1249
1250
|
prev_attn = None,
|
1250
1251
|
mem = None,
|
1251
1252
|
mem_mask = None,
|
@@ -1392,7 +1393,14 @@ class Attention(Module):
|
|
1392
1393
|
|
1393
1394
|
if exists(rel_pos):
|
1394
1395
|
assert not exists(attn_bias)
|
1395
|
-
|
1396
|
+
|
1397
|
+
if exists(pos):
|
1398
|
+
assert isinstance(rel_pos, AlibiPositionalBias), 'only alibi allowed for custom positions at the moment'
|
1399
|
+
# allow for custom positions to be passed in
|
1400
|
+
attn_bias = rel_pos.forward_custom_pos(pos)
|
1401
|
+
else:
|
1402
|
+
attn_bias = rel_pos(i, j)
|
1403
|
+
|
1396
1404
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
|
1397
1405
|
|
1398
1406
|
# prepare data dependent alibi from forgetting transformers paper, if needed
|
@@ -1843,6 +1851,7 @@ class AttentionLayers(Module):
|
|
1843
1851
|
cache_age = 1,
|
1844
1852
|
return_hiddens = False,
|
1845
1853
|
rotary_pos_emb = None,
|
1854
|
+
pos = None,
|
1846
1855
|
attn_bias = None,
|
1847
1856
|
condition = None,
|
1848
1857
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
@@ -1906,7 +1915,9 @@ class AttentionLayers(Module):
|
|
1906
1915
|
maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
|
1907
1916
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
1908
1917
|
|
1909
|
-
|
1918
|
+
if not exists(pos):
|
1919
|
+
pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
1920
|
+
|
1910
1921
|
rotary_pos_emb = self.rotary_pos_emb(pos)
|
1911
1922
|
|
1912
1923
|
# assume cached key / values
|
@@ -2030,7 +2041,7 @@ class AttentionLayers(Module):
|
|
2030
2041
|
# forward depending on layer type
|
2031
2042
|
|
2032
2043
|
if layer_type == 'a':
|
2033
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2044
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2034
2045
|
elif layer_type == 'c':
|
2035
2046
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
|
2036
2047
|
elif layer_type == 'f':
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=6jXSMHViCU64gLMbxRJ6C8bgcLrPFbT-m-fhtusqq3g,93117
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.7.dist-info/METADATA,sha256=tM7s2gIMFH8hy_YZY84BhZ-yUoH6PTyjusK0dMOpTN8,689
|
14
|
+
x_transformers-1.42.7.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
15
|
+
x_transformers-1.42.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|