x-transformers 1.44.5__tar.gz → 1.44.6__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.44.5/x_transformers.egg-info → x_transformers-1.44.6}/PKG-INFO +1 -1
  2. {x_transformers-1.44.5 → x_transformers-1.44.6}/setup.py +1 -1
  3. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/x_transformers.py +8 -0
  4. {x_transformers-1.44.5 → x_transformers-1.44.6/x_transformers.egg-info}/PKG-INFO +1 -1
  5. {x_transformers-1.44.5 → x_transformers-1.44.6}/LICENSE +0 -0
  6. {x_transformers-1.44.5 → x_transformers-1.44.6}/README.md +0 -0
  7. {x_transformers-1.44.5 → x_transformers-1.44.6}/setup.cfg +0 -0
  8. {x_transformers-1.44.5 → x_transformers-1.44.6}/tests/test_x_transformers.py +0 -0
  9. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.44.5 → x_transformers-1.44.6}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.5
3
+ Version: 1.44.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.44.5',
6
+ version = '1.44.6',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -1736,6 +1736,7 @@ class AttentionLayers(Module):
1736
1736
  unet_skips = False,
1737
1737
  num_residual_streams = 1,
1738
1738
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1739
+ learned_reinject_input_gate = False,
1739
1740
  add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
1740
1741
  learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1741
1742
  rel_pos_kwargs: dict = dict(),
@@ -1993,6 +1994,7 @@ class AttentionLayers(Module):
1993
1994
 
1994
1995
  self.reinject_input = reinject_input
1995
1996
  self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
1997
+ self.learned_reinject_input_gate = nn.Linear(dim, 1, bias = False) if learned_reinject_input_gate else None
1996
1998
 
1997
1999
  # add the value from the first self attention block to all latter projected self attention values as a residual
1998
2000
 
@@ -2224,7 +2226,9 @@ class AttentionLayers(Module):
2224
2226
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2225
2227
 
2226
2228
  # derived input for reinjection if needed
2229
+
2227
2230
  inp_inject = None
2231
+
2228
2232
  if self.reinject_input:
2229
2233
  assert not exists(in_attn_cond)
2230
2234
  inp_inject = self.reinject_input_proj(x)
@@ -2233,6 +2237,10 @@ class AttentionLayers(Module):
2233
2237
  # handle in-attention conditioning, which serves the same purpose of having the network learn the residual
2234
2238
  inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
2235
2239
 
2240
+ if exists(inp_inject) and exists(self.learned_reinject_input_gate):
2241
+ inp_inject_gate = self.learned_reinject_input_gate(x).sigmoid()
2242
+ inp_inject = inp_inject * inp_inject_gate
2243
+
2236
2244
  # store all hiddens for skips
2237
2245
 
2238
2246
  skip_hiddens = []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.5
3
+ Version: 1.44.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes