x-transformers 1.40.3__tar.gz → 1.40.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.40.3/x_transformers.egg-info → x_transformers-1.40.4}/PKG-INFO +1 -1
  2. {x_transformers-1.40.3 → x_transformers-1.40.4}/README.md +11 -0
  3. {x_transformers-1.40.3 → x_transformers-1.40.4}/setup.py +1 -1
  4. {x_transformers-1.40.3 → x_transformers-1.40.4}/tests/test_x_transformers.py +18 -0
  5. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/x_transformers.py +14 -0
  6. {x_transformers-1.40.3 → x_transformers-1.40.4/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.40.3 → x_transformers-1.40.4}/LICENSE +0 -0
  8. {x_transformers-1.40.3 → x_transformers-1.40.4}/setup.cfg +0 -0
  9. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/nonautoregressive_wrapper.py +0 -0
  16. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.40.3 → x_transformers-1.40.4}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.3
3
+ Version: 1.40.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2287,4 +2287,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2287
2287
  }
2288
2288
  ```
2289
2289
 
2290
+ ```bibtex
2291
+ @article{Bai2019DeepEM,
2292
+ title = {Deep Equilibrium Models},
2293
+ author = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
2294
+ journal = {ArXiv},
2295
+ year = {2019},
2296
+ volume = {abs/1909.01377},
2297
+ url = {https://api.semanticscholar.org/CorpusID:202539738}
2298
+ }
2299
+ ```
2300
+
2290
2301
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.40.3',
6
+ version = '1.40.4',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -298,3 +298,21 @@ def test_l2_distance(attn_one_kv_head):
298
298
  x = torch.randint(0, 256, (1, 1024))
299
299
 
300
300
  model(x)
301
+
302
+ def test_reinject_input():
303
+
304
+ model = TransformerWrapper(
305
+ num_tokens = 20000,
306
+ max_seq_len = 1024,
307
+ recycling = True,
308
+ attn_layers = Decoder(
309
+ dim = 512,
310
+ depth = 12,
311
+ heads = 8,
312
+ reinject_input = True
313
+ )
314
+ )
315
+
316
+ x = torch.randint(0, 256, (1, 12))
317
+
318
+ model(x) # (1, 1024, 20000)
@@ -1353,6 +1353,7 @@ class AttentionLayers(Module):
1353
1353
  use_layerscale = False,
1354
1354
  layerscale_init_value = 0.,
1355
1355
  unet_skips = False,
1356
+ reinject_input = True, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1356
1357
  **kwargs
1357
1358
  ):
1358
1359
  super().__init__()
@@ -1582,6 +1583,11 @@ class AttentionLayers(Module):
1582
1583
 
1583
1584
  self.skip_combines = ModuleList([])
1584
1585
 
1586
+ # whether there is reinjection of input at every layer
1587
+
1588
+ self.reinject_input = reinject_input
1589
+ self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
1590
+
1585
1591
  # iterate and construct layers
1586
1592
 
1587
1593
  for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
@@ -1766,6 +1772,11 @@ class AttentionLayers(Module):
1766
1772
 
1767
1773
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
1768
1774
 
1775
+ # derived input for reinjection if needed
1776
+
1777
+ if self.reinject_input:
1778
+ inp_inject = self.reinject_input_proj(x)
1779
+
1769
1780
  # store all hiddens for skips
1770
1781
 
1771
1782
  skip_hiddens = []
@@ -1810,6 +1821,9 @@ class AttentionLayers(Module):
1810
1821
  post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
1811
1822
  post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
1812
1823
 
1824
+ if self.reinject_input:
1825
+ x = x + inp_inject
1826
+
1813
1827
  if exists(pre_norm):
1814
1828
  x = pre_norm(x)
1815
1829
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.3
3
+ Version: 1.40.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes