x-transformers 1.44.4__py3-none-any.whl → 1.44.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +9 -1
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/METADATA +11 -2
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/RECORD +6 -6
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/WHEEL +1 -1
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/LICENSE +0 -0
- {x_transformers-1.44.4.dist-info → x_transformers-1.44.6.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1736,6 +1736,7 @@ class AttentionLayers(Module):
|
|
1736
1736
|
unet_skips = False,
|
1737
1737
|
num_residual_streams = 1,
|
1738
1738
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1739
|
+
learned_reinject_input_gate = False,
|
1739
1740
|
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
|
1740
1741
|
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
|
1741
1742
|
rel_pos_kwargs: dict = dict(),
|
@@ -1993,6 +1994,7 @@ class AttentionLayers(Module):
|
|
1993
1994
|
|
1994
1995
|
self.reinject_input = reinject_input
|
1995
1996
|
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
1997
|
+
self.learned_reinject_input_gate = nn.Linear(dim, 1, bias = False) if learned_reinject_input_gate else None
|
1996
1998
|
|
1997
1999
|
# add the value from the first self attention block to all latter projected self attention values as a residual
|
1998
2000
|
|
@@ -2225,6 +2227,8 @@ class AttentionLayers(Module):
|
|
2225
2227
|
|
2226
2228
|
# derived input for reinjection if needed
|
2227
2229
|
|
2230
|
+
inp_inject = None
|
2231
|
+
|
2228
2232
|
if self.reinject_input:
|
2229
2233
|
assert not exists(in_attn_cond)
|
2230
2234
|
inp_inject = self.reinject_input_proj(x)
|
@@ -2233,6 +2237,10 @@ class AttentionLayers(Module):
|
|
2233
2237
|
# handle in-attention conditioning, which serves the same purpose of having the network learn the residual
|
2234
2238
|
inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
|
2235
2239
|
|
2240
|
+
if exists(inp_inject) and exists(self.learned_reinject_input_gate):
|
2241
|
+
inp_inject_gate = self.learned_reinject_input_gate(x).sigmoid()
|
2242
|
+
inp_inject = inp_inject * inp_inject_gate
|
2243
|
+
|
2236
2244
|
# store all hiddens for skips
|
2237
2245
|
|
2238
2246
|
skip_hiddens = []
|
@@ -2282,7 +2290,7 @@ class AttentionLayers(Module):
|
|
2282
2290
|
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
2283
2291
|
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
2284
2292
|
|
2285
|
-
if
|
2293
|
+
if exists(inp_inject):
|
2286
2294
|
x = x + inp_inject
|
2287
2295
|
|
2288
2296
|
if exists(pre_norm):
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.44.
|
3
|
+
Version: 1.44.6
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -19,3 +19,12 @@ Requires-Dist: einops>=0.8.0
|
|
19
19
|
Requires-Dist: loguru
|
20
20
|
Requires-Dist: packaging>=21.0
|
21
21
|
Requires-Dist: torch>=2.0
|
22
|
+
Dynamic: author
|
23
|
+
Dynamic: author-email
|
24
|
+
Dynamic: classifier
|
25
|
+
Dynamic: description-content-type
|
26
|
+
Dynamic: home-page
|
27
|
+
Dynamic: keywords
|
28
|
+
Dynamic: license
|
29
|
+
Dynamic: requires-dist
|
30
|
+
Dynamic: summary
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=PYrwLPEUaiWuPmDNV7nQQZChfMlPJbF9NULHl9Te3LQ,103494
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.44.
|
13
|
-
x_transformers-1.44.
|
14
|
-
x_transformers-1.44.
|
15
|
-
x_transformers-1.44.
|
16
|
-
x_transformers-1.44.
|
12
|
+
x_transformers-1.44.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.6.dist-info/METADATA,sha256=3_gOvzIcumtCNqhjGmlcAPMZ2FO6q4sVlhAV-_sybBA,924
|
14
|
+
x_transformers-1.44.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
15
|
+
x_transformers-1.44.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|