x-transformers 1.40.4__py3-none-any.whl → 1.40.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +7 -1
- x_transformers/xval.py +2 -1
- {x_transformers-1.40.4.dist-info → x_transformers-1.40.6.dist-info}/METADATA +1 -1
- {x_transformers-1.40.4.dist-info → x_transformers-1.40.6.dist-info}/RECORD +7 -7
- {x_transformers-1.40.4.dist-info → x_transformers-1.40.6.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.4.dist-info → x_transformers-1.40.6.dist-info}/WHEEL +0 -0
- {x_transformers-1.40.4.dist-info → x_transformers-1.40.6.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1353,7 +1353,7 @@ class AttentionLayers(Module):
|
|
1353
1353
|
use_layerscale = False,
|
1354
1354
|
layerscale_init_value = 0.,
|
1355
1355
|
unet_skips = False,
|
1356
|
-
reinject_input =
|
1356
|
+
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1357
1357
|
**kwargs
|
1358
1358
|
):
|
1359
1359
|
super().__init__()
|
@@ -1673,6 +1673,7 @@ class AttentionLayers(Module):
|
|
1673
1673
|
rotary_pos_emb = None,
|
1674
1674
|
attn_bias = None,
|
1675
1675
|
condition = None,
|
1676
|
+
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
1676
1677
|
layers_execute_order: tuple[int, ...] | None = None
|
1677
1678
|
):
|
1678
1679
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
@@ -1775,8 +1776,13 @@ class AttentionLayers(Module):
|
|
1775
1776
|
# derived input for reinjection if needed
|
1776
1777
|
|
1777
1778
|
if self.reinject_input:
|
1779
|
+
assert not exists(in_attn_cond)
|
1778
1780
|
inp_inject = self.reinject_input_proj(x)
|
1779
1781
|
|
1782
|
+
elif exists(in_attn_cond):
|
1783
|
+
# handle in-attention conditioning, which serves the same purpose of having the network learn the residual
|
1784
|
+
inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
|
1785
|
+
|
1780
1786
|
# store all hiddens for skips
|
1781
1787
|
|
1782
1788
|
skip_hiddens = []
|
x_transformers/xval.py
CHANGED
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=FKHaJFQuMNiFMrjDF13OE3vk-iYf_qwogBNxVpiQSc4,85671
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
|
-
x_transformers/xval.py,sha256=
|
11
|
-
x_transformers-1.40.
|
12
|
-
x_transformers-1.40.
|
13
|
-
x_transformers-1.40.
|
14
|
-
x_transformers-1.40.
|
15
|
-
x_transformers-1.40.
|
10
|
+
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
+
x_transformers-1.40.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.6.dist-info/METADATA,sha256=HsoNUu71hkonsBhThVN46rakFnIAGOav3pHDpYnX9t8,661
|
13
|
+
x_transformers-1.40.6.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|