x-transformers 1.42.6__py3-none-any.whl → 1.42.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1246,6 +1246,7 @@ class Attention(Module):
1246
1246
  rel_pos = None,
1247
1247
  attn_bias = None,
1248
1248
  rotary_pos_emb = None,
1249
+ pos = None, # for custom alibi positions
1249
1250
  prev_attn = None,
1250
1251
  mem = None,
1251
1252
  mem_mask = None,
@@ -1392,7 +1393,14 @@ class Attention(Module):
1392
1393
 
1393
1394
  if exists(rel_pos):
1394
1395
  assert not exists(attn_bias)
1395
- attn_bias = rel_pos(i, j)
1396
+
1397
+ if exists(pos):
1398
+ assert isinstance(rel_pos, AlibiPositionalBias), 'only alibi allowed for custom positions at the moment'
1399
+ # allow for custom positions to be passed in
1400
+ attn_bias = rel_pos.forward_custom_pos(pos)
1401
+ else:
1402
+ attn_bias = rel_pos(i, j)
1403
+
1396
1404
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1397
1405
 
1398
1406
  # prepare data dependent alibi from forgetting transformers paper, if needed
@@ -1843,6 +1851,7 @@ class AttentionLayers(Module):
1843
1851
  cache_age = 1,
1844
1852
  return_hiddens = False,
1845
1853
  rotary_pos_emb = None,
1854
+ pos = None,
1846
1855
  attn_bias = None,
1847
1856
  condition = None,
1848
1857
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
@@ -1906,7 +1915,9 @@ class AttentionLayers(Module):
1906
1915
  maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1907
1916
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1908
1917
 
1909
- pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1918
+ if not exists(pos):
1919
+ pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1920
+
1910
1921
  rotary_pos_emb = self.rotary_pos_emb(pos)
1911
1922
 
1912
1923
  # assume cached key / values
@@ -2030,7 +2041,7 @@ class AttentionLayers(Module):
2030
2041
  # forward depending on layer type
2031
2042
 
2032
2043
  if layer_type == 'a':
2033
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2044
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2034
2045
  elif layer_type == 'c':
2035
2046
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
2036
2047
  elif layer_type == 'f':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.6
3
+ Version: 1.42.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=cPsSl1s14_c9fMdn9cZwe6Eg3aDbcRyCTsoXUJusWUg,92706
9
+ x_transformers/x_transformers.py,sha256=6jXSMHViCU64gLMbxRJ6C8bgcLrPFbT-m-fhtusqq3g,93117
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.6.dist-info/METADATA,sha256=OANeMK9I504gC7iErAdYMTGBUEl6FOcEwm97o4OyC1k,689
14
- x_transformers-1.42.6.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
- x_transformers-1.42.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.6.dist-info/RECORD,,
12
+ x_transformers-1.42.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.7.dist-info/METADATA,sha256=tM7s2gIMFH8hy_YZY84BhZ-yUoH6PTyjusK0dMOpTN8,689
14
+ x_transformers-1.42.7.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.7.dist-info/RECORD,,