x-transformers 1.42.5__py3-none-any.whl → 1.42.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1055,7 +1055,12 @@ class Attention(Module):
1055
1055
  logit_softclamp_value = 50.,
1056
1056
  neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1057
1057
  neutreno_alpha = 0.4,
1058
- onnxable = False
1058
+ onnxable = False,
1059
+ attend_sdp_kwargs: dict = dict(
1060
+ enable_flash = True,
1061
+ enable_math = True,
1062
+ enable_mem_efficient = True
1063
+ )
1059
1064
  ):
1060
1065
  super().__init__()
1061
1066
  dim_kv = default(dim_context, dim)
@@ -1188,7 +1193,8 @@ class Attention(Module):
1188
1193
  softclamp_logits = softclamp_logits,
1189
1194
  logit_softclamp_value = logit_softclamp_value,
1190
1195
  cope = cope,
1191
- onnxable = onnxable
1196
+ onnxable = onnxable,
1197
+ sdp_kwargs = attend_sdp_kwargs
1192
1198
  )
1193
1199
 
1194
1200
  # head scaling
@@ -1240,6 +1246,7 @@ class Attention(Module):
1240
1246
  rel_pos = None,
1241
1247
  attn_bias = None,
1242
1248
  rotary_pos_emb = None,
1249
+ pos = None, # for custom alibi positions
1243
1250
  prev_attn = None,
1244
1251
  mem = None,
1245
1252
  mem_mask = None,
@@ -1386,7 +1393,14 @@ class Attention(Module):
1386
1393
 
1387
1394
  if exists(rel_pos):
1388
1395
  assert not exists(attn_bias)
1389
- attn_bias = rel_pos(i, j)
1396
+
1397
+ if exists(pos):
1398
+ assert isinstance(rel_pos, AlibiPositionalBias), 'only alibi allowed for custom positions at the moment'
1399
+ # allow for custom positions to be passed in
1400
+ attn_bias = rel_pos.forward_custom_pos(pos)
1401
+ else:
1402
+ attn_bias = rel_pos(i, j)
1403
+
1390
1404
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1391
1405
 
1392
1406
  # prepare data dependent alibi from forgetting transformers paper, if needed
@@ -1837,6 +1851,7 @@ class AttentionLayers(Module):
1837
1851
  cache_age = 1,
1838
1852
  return_hiddens = False,
1839
1853
  rotary_pos_emb = None,
1854
+ pos = None,
1840
1855
  attn_bias = None,
1841
1856
  condition = None,
1842
1857
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
@@ -1900,7 +1915,9 @@ class AttentionLayers(Module):
1900
1915
  maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1901
1916
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1902
1917
 
1903
- pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1918
+ if not exists(pos):
1919
+ pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1920
+
1904
1921
  rotary_pos_emb = self.rotary_pos_emb(pos)
1905
1922
 
1906
1923
  # assume cached key / values
@@ -2024,7 +2041,7 @@ class AttentionLayers(Module):
2024
2041
  # forward depending on layer type
2025
2042
 
2026
2043
  if layer_type == 'a':
2027
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2044
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2028
2045
  elif layer_type == 'c':
2029
2046
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
2030
2047
  elif layer_type == 'f':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.5
3
+ Version: 1.42.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=KIR7efx59xl0BVshU1e6RO0YKgz7zeYBXITDNYWJ4mQ,92506
9
+ x_transformers/x_transformers.py,sha256=6jXSMHViCU64gLMbxRJ6C8bgcLrPFbT-m-fhtusqq3g,93117
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.5.dist-info/METADATA,sha256=cLnay5nt6F6GKdghsaqHiZQsVmJ9dS5l-IDozZIs3ec,689
14
- x_transformers-1.42.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
- x_transformers-1.42.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.5.dist-info/RECORD,,
12
+ x_transformers-1.42.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.7.dist-info/METADATA,sha256=tM7s2gIMFH8hy_YZY84BhZ-yUoH6PTyjusK0dMOpTN8,689
14
+ x_transformers-1.42.7.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.7.dist-info/RECORD,,