x-transformers 1.40.3__py3-none-any.whl → 1.40.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +20 -0
- {x_transformers-1.40.3.dist-info → x_transformers-1.40.5.dist-info}/METADATA +1 -1
- {x_transformers-1.40.3.dist-info → x_transformers-1.40.5.dist-info}/RECORD +6 -6
- {x_transformers-1.40.3.dist-info → x_transformers-1.40.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.3.dist-info → x_transformers-1.40.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.40.3.dist-info → x_transformers-1.40.5.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1353,6 +1353,7 @@ class AttentionLayers(Module):
|
|
1353
1353
|
use_layerscale = False,
|
1354
1354
|
layerscale_init_value = 0.,
|
1355
1355
|
unet_skips = False,
|
1356
|
+
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1356
1357
|
**kwargs
|
1357
1358
|
):
|
1358
1359
|
super().__init__()
|
@@ -1582,6 +1583,11 @@ class AttentionLayers(Module):
|
|
1582
1583
|
|
1583
1584
|
self.skip_combines = ModuleList([])
|
1584
1585
|
|
1586
|
+
# whether there is reinjection of input at every layer
|
1587
|
+
|
1588
|
+
self.reinject_input = reinject_input
|
1589
|
+
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
1590
|
+
|
1585
1591
|
# iterate and construct layers
|
1586
1592
|
|
1587
1593
|
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
@@ -1667,6 +1673,7 @@ class AttentionLayers(Module):
|
|
1667
1673
|
rotary_pos_emb = None,
|
1668
1674
|
attn_bias = None,
|
1669
1675
|
condition = None,
|
1676
|
+
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
1670
1677
|
layers_execute_order: tuple[int, ...] | None = None
|
1671
1678
|
):
|
1672
1679
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
@@ -1766,6 +1773,16 @@ class AttentionLayers(Module):
|
|
1766
1773
|
|
1767
1774
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
1768
1775
|
|
1776
|
+
# derived input for reinjection if needed
|
1777
|
+
|
1778
|
+
if self.reinject_input:
|
1779
|
+
assert not exists(in_attn_cond)
|
1780
|
+
inp_inject = self.reinject_input_proj(x)
|
1781
|
+
|
1782
|
+
elif exists(in_attn_cond):
|
1783
|
+
# handle in-attention conditioning, which serves the same purpose of having the network learn the residual
|
1784
|
+
inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
|
1785
|
+
|
1769
1786
|
# store all hiddens for skips
|
1770
1787
|
|
1771
1788
|
skip_hiddens = []
|
@@ -1810,6 +1827,9 @@ class AttentionLayers(Module):
|
|
1810
1827
|
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
1811
1828
|
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
1812
1829
|
|
1830
|
+
if self.reinject_input:
|
1831
|
+
x = x + inp_inject
|
1832
|
+
|
1813
1833
|
if exists(pre_norm):
|
1814
1834
|
x = pre_norm(x)
|
1815
1835
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=FKHaJFQuMNiFMrjDF13OE3vk-iYf_qwogBNxVpiQSc4,85671
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.40.
|
12
|
-
x_transformers-1.40.
|
13
|
-
x_transformers-1.40.
|
14
|
-
x_transformers-1.40.
|
15
|
-
x_transformers-1.40.
|
11
|
+
x_transformers-1.40.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.5.dist-info/METADATA,sha256=WfbVonMAKfuqdCoXwi_AfnwsmCyx1310dqKoFnEWtiY,661
|
13
|
+
x_transformers-1.40.5.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|