x-transformers 1.42.25__tar.gz → 1.42.27__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.25/x_transformers.egg-info → x_transformers-1.42.27}/PKG-INFO +1 -1
  2. {x_transformers-1.42.25 → x_transformers-1.42.27}/setup.py +1 -1
  3. {x_transformers-1.42.25 → x_transformers-1.42.27}/tests/test_x_transformers.py +6 -2
  4. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/x_transformers.py +4 -5
  5. {x_transformers-1.42.25 → x_transformers-1.42.27/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.42.25 → x_transformers-1.42.27}/LICENSE +0 -0
  7. {x_transformers-1.42.25 → x_transformers-1.42.27}/README.md +0 -0
  8. {x_transformers-1.42.25 → x_transformers-1.42.27}/setup.cfg +0 -0
  9. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.42.25 → x_transformers-1.42.27}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.25
3
+ Version: 1.42.27
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.25',
6
+ version = '1.42.27',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -558,8 +558,10 @@ def test_laser():
558
558
 
559
559
  model(x)
560
560
 
561
+ @pytest.mark.parametrize('self_attn_custom_pos', (True, False))
561
562
  @pytest.mark.parametrize('cross_attn_rotary', (True, False))
562
563
  def test_cross_attn_rotary(
564
+ self_attn_custom_pos: bool,
563
565
  cross_attn_rotary: bool
564
566
  ):
565
567
 
@@ -577,12 +579,14 @@ def test_cross_attn_rotary(
577
579
  cross_attn_dim_context = 512
578
580
  )
579
581
 
580
- context_pos = torch.arange(128)
582
+ pos = torch.arange(64) if self_attn_custom_pos else None
583
+ context_pos = torch.arange(128) if cross_attn_rotary else None
581
584
 
582
585
  embed = model(
583
586
  x = x,
584
587
  mask = mask,
585
588
  context = context,
586
- context_pos = context_pos if cross_attn_rotary else None,
589
+ pos = pos,
590
+ context_pos = context_pos,
587
591
  context_mask = context_mask
588
592
  )
@@ -51,8 +51,8 @@ def default(val, d):
51
51
  return val
52
52
  return d() if callable(d) else d
53
53
 
54
- def first(it):
55
- return it[0]
54
+ def first(it, default = None):
55
+ return it[0] if len(it) > 0 else default
56
56
 
57
57
  def is_empty(x):
58
58
  return len(x) == 0
@@ -1077,7 +1077,7 @@ class Attention(Module):
1077
1077
  logit_softclamp_value = 50.,
1078
1078
  neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1079
1079
  neutreno_alpha = 0.4,
1080
- learned_value_residual_mix = False,
1080
+ learned_value_residual_mix = True,
1081
1081
  laser = False, # https://arxiv.org/abs/2411.03493v1
1082
1082
  laser_softclamp_value = 15.,
1083
1083
  onnxable = False,
@@ -1357,7 +1357,6 @@ class Attention(Module):
1357
1357
  k = k * self.qk_norm_k_scale
1358
1358
 
1359
1359
  if exists(rotary_pos_emb):
1360
-
1361
1360
  freqs, xpos_scale = rotary_pos_emb
1362
1361
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1363
1362
 
@@ -1989,7 +1988,7 @@ class AttentionLayers(Module):
1989
1988
 
1990
1989
  if exists(self.rotary_pos_emb):
1991
1990
  if not exists(rotary_pos_emb):
1992
- maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1991
+ maybe_mem = first(mems, None) # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1993
1992
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1994
1993
 
1995
1994
  if not exists(pos):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.25
3
+ Version: 1.42.27
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang