x-transformers 1.42.24__tar.gz → 1.42.26__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.24/x_transformers.egg-info → x_transformers-1.42.26}/PKG-INFO +1 -1
  2. {x_transformers-1.42.24 → x_transformers-1.42.26}/README.md +2 -0
  3. {x_transformers-1.42.24 → x_transformers-1.42.26}/setup.py +1 -1
  4. {x_transformers-1.42.24 → x_transformers-1.42.26}/tests/test_x_transformers.py +33 -0
  5. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/x_transformers.py +33 -11
  6. {x_transformers-1.42.24 → x_transformers-1.42.26/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.42.24 → x_transformers-1.42.26}/LICENSE +0 -0
  8. {x_transformers-1.42.24 → x_transformers-1.42.26}/setup.cfg +0 -0
  9. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.42.24 → x_transformers-1.42.26}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.24
3
+ Version: 1.42.26
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -317,6 +317,8 @@ model = TransformerWrapper(
317
317
 
318
318
  Update: MetaAI researchers <a href="https://arxiv.org/abs/2309.16588">have found</a> that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to <a href="https://arxiv.org/abs/2306.12929">attend to nothing</a>).
319
319
 
320
+ Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper
321
+
320
322
  ### Transformers Without Tears
321
323
 
322
324
  <img src="./images/scalenorm.png"></img>
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.24',
6
+ version = '1.42.26',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -557,3 +557,36 @@ def test_laser():
557
557
  x = torch.randint(0, 20000, (2, 1024))
558
558
 
559
559
  model(x)
560
+
561
+ @pytest.mark.parametrize('self_attn_custom_pos', (True, False))
562
+ @pytest.mark.parametrize('cross_attn_rotary', (True, False))
563
+ def test_cross_attn_rotary(
564
+ self_attn_custom_pos: bool,
565
+ cross_attn_rotary: bool
566
+ ):
567
+
568
+ x = torch.randn((1, 64, 256))
569
+ mask = torch.ones((1, 64)).bool()
570
+ context = torch.randn((1, 128, 512))
571
+ context_mask = torch.ones((1, 128)).bool()
572
+
573
+ model = Encoder(
574
+ dim = 256,
575
+ depth = 4,
576
+ heads = 4,
577
+ rotary_pos_emb = True,
578
+ cross_attend = True,
579
+ cross_attn_dim_context = 512
580
+ )
581
+
582
+ pos = torch.arange(64) if self_attn_custom_pos else None
583
+ context_pos = torch.arange(128) if cross_attn_rotary else None
584
+
585
+ embed = model(
586
+ x = x,
587
+ mask = mask,
588
+ context = context,
589
+ pos = pos,
590
+ context_pos = context_pos,
591
+ context_mask = context_mask
592
+ )
@@ -51,8 +51,8 @@ def default(val, d):
51
51
  return val
52
52
  return d() if callable(d) else d
53
53
 
54
- def first(it):
55
- return it[0]
54
+ def first(it, default = None):
55
+ return it[0] if len(it) > 0 else default
56
56
 
57
57
  def is_empty(x):
58
58
  return len(x) == 0
@@ -1284,6 +1284,7 @@ class Attention(Module):
1284
1284
  rel_pos = None,
1285
1285
  attn_bias = None,
1286
1286
  rotary_pos_emb = None,
1287
+ context_rotary_pos_emb = None,
1287
1288
  pos = None, # for custom alibi positions
1288
1289
  prev_attn = None,
1289
1290
  mem = None,
@@ -1355,11 +1356,18 @@ class Attention(Module):
1355
1356
  q = q * self.qk_norm_q_scale
1356
1357
  k = k * self.qk_norm_k_scale
1357
1358
 
1358
- if exists(rotary_pos_emb) and not has_context:
1359
+ if exists(rotary_pos_emb):
1359
1360
  freqs, xpos_scale = rotary_pos_emb
1360
1361
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1361
1362
 
1362
1363
  q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1364
+
1365
+ if has_context:
1366
+ # override with `context_rotary_pos_emb` if provided
1367
+
1368
+ freqs, xpos_scale = context_rotary_pos_emb
1369
+ _, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1370
+
1363
1371
  k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1364
1372
 
1365
1373
  if self.rotary_embed_values:
@@ -1848,7 +1856,6 @@ class AttentionLayers(Module):
1848
1856
  layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
1849
1857
  is_first_self_attn = False
1850
1858
  elif layer_type == 'c':
1851
- cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn
1852
1859
  layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
1853
1860
  is_first_cross_attn = False
1854
1861
  elif layer_type == 'f':
@@ -1917,6 +1924,7 @@ class AttentionLayers(Module):
1917
1924
  return_hiddens = False,
1918
1925
  rotary_pos_emb = None,
1919
1926
  pos = None,
1927
+ context_pos = None,
1920
1928
  attn_bias = None,
1921
1929
  condition = None,
1922
1930
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
@@ -1976,14 +1984,28 @@ class AttentionLayers(Module):
1976
1984
 
1977
1985
  # rotary positions
1978
1986
 
1979
- if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
1980
- maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1981
- mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1987
+ cross_attn_rotary_pos_emb = dict()
1988
+
1989
+ if exists(self.rotary_pos_emb):
1990
+ if not exists(rotary_pos_emb):
1991
+ maybe_mem = first(mems, None) # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1992
+ mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1993
+
1994
+ if not exists(pos):
1995
+ pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1996
+
1997
+ rotary_pos_emb = self.rotary_pos_emb(pos)
1998
+
1999
+ # allow for rotary positions for context if provided
1982
2000
 
1983
- if not exists(pos):
1984
- pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
2001
+ if exists(context_pos):
2002
+ assert self.cross_attend
2003
+ context_rotary_pos_emb = self.rotary_pos_emb(context_pos)
1985
2004
 
1986
- rotary_pos_emb = self.rotary_pos_emb(pos)
2005
+ cross_attn_rotary_pos_emb.update(
2006
+ rotary_pos_emb = rotary_pos_emb,
2007
+ context_rotary_pos_emb = context_rotary_pos_emb
2008
+ )
1987
2009
 
1988
2010
  # assume cached key / values
1989
2011
 
@@ -2108,7 +2130,7 @@ class AttentionLayers(Module):
2108
2130
  if layer_type == 'a':
2109
2131
  out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2110
2132
  elif layer_type == 'c':
2111
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
2133
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2112
2134
  elif layer_type == 'f':
2113
2135
  out = block(x)
2114
2136
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.24
3
+ Version: 1.42.26
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang