x-transformers 1.40.3__py3-none-any.whl → 1.40.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1353,6 +1353,7 @@ class AttentionLayers(Module):
1353
1353
  use_layerscale = False,
1354
1354
  layerscale_init_value = 0.,
1355
1355
  unet_skips = False,
1356
+ reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1356
1357
  **kwargs
1357
1358
  ):
1358
1359
  super().__init__()
@@ -1582,6 +1583,11 @@ class AttentionLayers(Module):
1582
1583
 
1583
1584
  self.skip_combines = ModuleList([])
1584
1585
 
1586
+ # whether there is reinjection of input at every layer
1587
+
1588
+ self.reinject_input = reinject_input
1589
+ self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
1590
+
1585
1591
  # iterate and construct layers
1586
1592
 
1587
1593
  for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
@@ -1667,6 +1673,7 @@ class AttentionLayers(Module):
1667
1673
  rotary_pos_emb = None,
1668
1674
  attn_bias = None,
1669
1675
  condition = None,
1676
+ in_attn_cond = None, # https://arxiv.org/abs/2105.04090
1670
1677
  layers_execute_order: tuple[int, ...] | None = None
1671
1678
  ):
1672
1679
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
@@ -1766,6 +1773,16 @@ class AttentionLayers(Module):
1766
1773
 
1767
1774
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
1768
1775
 
1776
+ # derived input for reinjection if needed
1777
+
1778
+ if self.reinject_input:
1779
+ assert not exists(in_attn_cond)
1780
+ inp_inject = self.reinject_input_proj(x)
1781
+
1782
+ elif exists(in_attn_cond):
1783
+ # handle in-attention conditioning, which serves the same purpose of having the network learn the residual
1784
+ inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
1785
+
1769
1786
  # store all hiddens for skips
1770
1787
 
1771
1788
  skip_hiddens = []
@@ -1810,6 +1827,9 @@ class AttentionLayers(Module):
1810
1827
  post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
1811
1828
  post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
1812
1829
 
1830
+ if self.reinject_input:
1831
+ x = x + inp_inject
1832
+
1813
1833
  if exists(pre_norm):
1814
1834
  x = pre_norm(x)
1815
1835
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.3
3
+ Version: 1.40.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=17wNbMh5Mt6cyrAvkBToIozmFt-p9ZBhQCqqlnDyHPI,84676
8
+ x_transformers/x_transformers.py,sha256=FKHaJFQuMNiFMrjDF13OE3vk-iYf_qwogBNxVpiQSc4,85671
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.40.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.3.dist-info/METADATA,sha256=2Qb9WF8pa8bmqMI4VdO9f9bBEmKMcUatlbcU2SDJGwM,661
13
- x_transformers-1.40.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.3.dist-info/RECORD,,
11
+ x_transformers-1.40.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.5.dist-info/METADATA,sha256=WfbVonMAKfuqdCoXwi_AfnwsmCyx1310dqKoFnEWtiY,661
13
+ x_transformers-1.40.5.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.5.dist-info/RECORD,,