x-transformers 1.44.4__py3-none-any.whl → 1.44.6__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +9 -1
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/METADATA +11 -2
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/RECORD +6 -6
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/WHEEL +1 -1
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/LICENSE +0 -0
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1736,6 +1736,7 @@ class AttentionLayers(Module):
|
|
1736
1736
|
unet_skips = False,
|
1737
1737
|
num_residual_streams = 1,
|
1738
1738
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1739
|
+
learned_reinject_input_gate = False,
|
1739
1740
|
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
|
1740
1741
|
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
|
1741
1742
|
rel_pos_kwargs: dict = dict(),
|
@@ -1993,6 +1994,7 @@ class AttentionLayers(Module):
|
|
1993
1994
|
|
1994
1995
|
self.reinject_input = reinject_input
|
1995
1996
|
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
1997
|
+
self.learned_reinject_input_gate = nn.Linear(dim, 1, bias = False) if learned_reinject_input_gate else None
|
1996
1998
|
|
1997
1999
|
# add the value from the first self attention block to all latter projected self attention values as a residual
|
1998
2000
|
|
@@ -2225,6 +2227,8 @@ class AttentionLayers(Module):
|
|
2225
2227
|
|
2226
2228
|
# derived input for reinjection if needed
|
2227
2229
|
|
2230
|
+
inp_inject = None
|
2231
|
+
|
2228
2232
|
if self.reinject_input:
|
2229
2233
|
assert not exists(in_attn_cond)
|
2230
2234
|
inp_inject = self.reinject_input_proj(x)
|
@@ -2233,6 +2237,10 @@ class AttentionLayers(Module):
|
|
2233
2237
|
# handle in-attention conditioning, which serves the same purpose of having the network learn the residual
|
2234
2238
|
inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
|
2235
2239
|
|
2240
|
+
if exists(inp_inject) and exists(self.learned_reinject_input_gate):
|
2241
|
+
inp_inject_gate = self.learned_reinject_input_gate(x).sigmoid()
|
2242
|
+
inp_inject = inp_inject * inp_inject_gate
|
2243
|
+
|
2236
2244
|
# store all hiddens for skips
|
2237
2245
|
|
2238
2246
|
skip_hiddens = []
|
@@ -2282,7 +2290,7 @@ class AttentionLayers(Module):
|
|
2282
2290
|
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
2283
2291
|
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
2284
2292
|
|
2285
|
-
if
|
2293
|
+
if exists(inp_inject):
|
2286
2294
|
x = x + inp_inject
|
2287
2295
|
|
2288
2296
|
if exists(pre_norm):
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.44.
|
3
|
+
Version: 1.44.6
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -19,3 +19,12 @@ Requires-Dist: einops>=0.8.0
|
|
19
19
|
Requires-Dist: loguru
|
20
20
|
Requires-Dist: packaging>=21.0
|
21
21
|
Requires-Dist: torch>=2.0
|
22
|
+
Dynamic: author
|
23
|
+
Dynamic: author-email
|
24
|
+
Dynamic: classifier
|
25
|
+
Dynamic: description-content-type
|
26
|
+
Dynamic: home-page
|
27
|
+
Dynamic: keywords
|
28
|
+
Dynamic: license
|
29
|
+
Dynamic: requires-dist
|
30
|
+
Dynamic: summary
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=PYrwLPEUaiWuPmDNV7nQQZChfMlPJbF9NULHl9Te3LQ,103494
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.44.
|
13
|
-
x_transformers-1.44.
|
14
|
-
x_transformers-1.44.
|
15
|
-
x_transformers-1.44.
|
16
|
-
x_transformers-1.44.
|
12
|
+
x_transformers-1.44.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.6.dist-info/METADATA,sha256=3_gOvzIcumtCNqhjGmlcAPMZ2FO6q4sVlhAV-_sybBA,924
|
14
|
+
x_transformers-1.44.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
15
|
+
x_transformers-1.44.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|