x-transformers 1.40.3__tar.gz → 1.40.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.40.3/x_transformers.egg-info → x_transformers-1.40.5}/PKG-INFO +1 -1
  2. {x_transformers-1.40.3 → x_transformers-1.40.5}/README.md +23 -0
  3. {x_transformers-1.40.3 → x_transformers-1.40.5}/setup.py +1 -1
  4. {x_transformers-1.40.3 → x_transformers-1.40.5}/tests/test_x_transformers.py +18 -0
  5. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/x_transformers.py +20 -0
  6. {x_transformers-1.40.3 → x_transformers-1.40.5/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.40.3 → x_transformers-1.40.5}/LICENSE +0 -0
  8. {x_transformers-1.40.3 → x_transformers-1.40.5}/setup.cfg +0 -0
  9. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
  16. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.40.3 → x_transformers-1.40.5}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.3
3
+ Version: 1.40.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2287,4 +2287,27 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2287
2287
  }
2288
2288
  ```
2289
2289
 
2290
+ ```bibtex
2291
+ @article{Bai2019DeepEM,
2292
+ title = {Deep Equilibrium Models},
2293
+ author = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
2294
+ journal = {ArXiv},
2295
+ year = {2019},
2296
+ volume = {abs/1909.01377},
2297
+ url = {https://api.semanticscholar.org/CorpusID:202539738}
2298
+ }
2299
+ ```
2300
+
2301
+ ```bibtex
2302
+ @article{Wu2021MuseMorphoseFA,
2303
+ title = {MuseMorphose: Full-Song and Fine-Grained Piano Music Style Transfer With One Transformer VAE},
2304
+ author = {Shih-Lun Wu and Yi-Hsuan Yang},
2305
+ journal = {IEEE/ACM Transactions on Audio, Speech, and Language Processing},
2306
+ year = {2021},
2307
+ volume = {31},
2308
+ pages = {1953-1967},
2309
+ url = {https://api.semanticscholar.org/CorpusID:234338162}
2310
+ }
2311
+ ```
2312
+
2290
2313
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.40.3',
6
+ version = '1.40.5',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -298,3 +298,21 @@ def test_l2_distance(attn_one_kv_head):
298
298
  x = torch.randint(0, 256, (1, 1024))
299
299
 
300
300
  model(x)
301
+
302
+ def test_reinject_input():
303
+
304
+ model = TransformerWrapper(
305
+ num_tokens = 20000,
306
+ max_seq_len = 1024,
307
+ recycling = True,
308
+ attn_layers = Decoder(
309
+ dim = 512,
310
+ depth = 12,
311
+ heads = 8,
312
+ reinject_input = True
313
+ )
314
+ )
315
+
316
+ x = torch.randint(0, 256, (1, 12))
317
+
318
+ model(x) # (1, 1024, 20000)
@@ -1353,6 +1353,7 @@ class AttentionLayers(Module):
1353
1353
  use_layerscale = False,
1354
1354
  layerscale_init_value = 0.,
1355
1355
  unet_skips = False,
1356
+ reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1356
1357
  **kwargs
1357
1358
  ):
1358
1359
  super().__init__()
@@ -1582,6 +1583,11 @@ class AttentionLayers(Module):
1582
1583
 
1583
1584
  self.skip_combines = ModuleList([])
1584
1585
 
1586
+ # whether there is reinjection of input at every layer
1587
+
1588
+ self.reinject_input = reinject_input
1589
+ self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
1590
+
1585
1591
  # iterate and construct layers
1586
1592
 
1587
1593
  for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
@@ -1667,6 +1673,7 @@ class AttentionLayers(Module):
1667
1673
  rotary_pos_emb = None,
1668
1674
  attn_bias = None,
1669
1675
  condition = None,
1676
+ in_attn_cond = None, # https://arxiv.org/abs/2105.04090
1670
1677
  layers_execute_order: tuple[int, ...] | None = None
1671
1678
  ):
1672
1679
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
@@ -1766,6 +1773,16 @@ class AttentionLayers(Module):
1766
1773
 
1767
1774
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
1768
1775
 
1776
+ # derived input for reinjection if needed
1777
+
1778
+ if self.reinject_input:
1779
+ assert not exists(in_attn_cond)
1780
+ inp_inject = self.reinject_input_proj(x)
1781
+
1782
+ elif exists(in_attn_cond):
1783
+ # handle in-attention conditioning, which serves the same purpose of having the network learn the residual
1784
+ inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
1785
+
1769
1786
  # store all hiddens for skips
1770
1787
 
1771
1788
  skip_hiddens = []
@@ -1810,6 +1827,9 @@ class AttentionLayers(Module):
1810
1827
  post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
1811
1828
  post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
1812
1829
 
1830
+ if self.reinject_input:
1831
+ x = x + inp_inject
1832
+
1813
1833
  if exists(pre_norm):
1814
1834
  x = pre_norm(x)
1815
1835
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.3
3
+ Version: 1.40.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes