x-transformers 1.40.3__py3-none-any.whl → 1.40.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1353,6 +1353,7 @@ class AttentionLayers(Module):
1353
1353
  use_layerscale = False,
1354
1354
  layerscale_init_value = 0.,
1355
1355
  unet_skips = False,
1356
+ reinject_input = True, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1356
1357
  **kwargs
1357
1358
  ):
1358
1359
  super().__init__()
@@ -1582,6 +1583,11 @@ class AttentionLayers(Module):
1582
1583
 
1583
1584
  self.skip_combines = ModuleList([])
1584
1585
 
1586
+ # whether there is reinjection of input at every layer
1587
+
1588
+ self.reinject_input = reinject_input
1589
+ self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
1590
+
1585
1591
  # iterate and construct layers
1586
1592
 
1587
1593
  for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
@@ -1766,6 +1772,11 @@ class AttentionLayers(Module):
1766
1772
 
1767
1773
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
1768
1774
 
1775
+ # derived input for reinjection if needed
1776
+
1777
+ if self.reinject_input:
1778
+ inp_inject = self.reinject_input_proj(x)
1779
+
1769
1780
  # store all hiddens for skips
1770
1781
 
1771
1782
  skip_hiddens = []
@@ -1810,6 +1821,9 @@ class AttentionLayers(Module):
1810
1821
  post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
1811
1822
  post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
1812
1823
 
1824
+ if self.reinject_input:
1825
+ x = x + inp_inject
1826
+
1813
1827
  if exists(pre_norm):
1814
1828
  x = pre_norm(x)
1815
1829
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.3
3
+ Version: 1.40.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=17wNbMh5Mt6cyrAvkBToIozmFt-p9ZBhQCqqlnDyHPI,84676
8
+ x_transformers/x_transformers.py,sha256=g7cGYUcq344QZejqdB_PPHHz1O_zqmdplxgQA07hqks,85298
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.40.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.3.dist-info/METADATA,sha256=2Qb9WF8pa8bmqMI4VdO9f9bBEmKMcUatlbcU2SDJGwM,661
13
- x_transformers-1.40.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.3.dist-info/RECORD,,
11
+ x_transformers-1.40.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.4.dist-info/METADATA,sha256=K7tcPd4ZY7qO773XErYQbik-XDEfLbFQ0lFYSfjAFvY,661
13
+ x_transformers-1.40.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.4.dist-info/RECORD,,