x-transformers 1.44.4__py3-none-any.whl → 1.44.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1736,6 +1736,7 @@ class AttentionLayers(Module):
1736
1736
  unet_skips = False,
1737
1737
  num_residual_streams = 1,
1738
1738
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1739
+ learned_reinject_input_gate = False,
1739
1740
  add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
1740
1741
  learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1741
1742
  rel_pos_kwargs: dict = dict(),
@@ -1993,6 +1994,7 @@ class AttentionLayers(Module):
1993
1994
 
1994
1995
  self.reinject_input = reinject_input
1995
1996
  self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
1997
+ self.learned_reinject_input_gate = nn.Linear(dim, 1, bias = False) if learned_reinject_input_gate else None
1996
1998
 
1997
1999
  # add the value from the first self attention block to all latter projected self attention values as a residual
1998
2000
 
@@ -2225,6 +2227,8 @@ class AttentionLayers(Module):
2225
2227
 
2226
2228
  # derived input for reinjection if needed
2227
2229
 
2230
+ inp_inject = None
2231
+
2228
2232
  if self.reinject_input:
2229
2233
  assert not exists(in_attn_cond)
2230
2234
  inp_inject = self.reinject_input_proj(x)
@@ -2233,6 +2237,10 @@ class AttentionLayers(Module):
2233
2237
  # handle in-attention conditioning, which serves the same purpose of having the network learn the residual
2234
2238
  inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
2235
2239
 
2240
+ if exists(inp_inject) and exists(self.learned_reinject_input_gate):
2241
+ inp_inject_gate = self.learned_reinject_input_gate(x).sigmoid()
2242
+ inp_inject = inp_inject * inp_inject_gate
2243
+
2236
2244
  # store all hiddens for skips
2237
2245
 
2238
2246
  skip_hiddens = []
@@ -2282,7 +2290,7 @@ class AttentionLayers(Module):
2282
2290
  post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
2283
2291
  post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
2284
2292
 
2285
- if self.reinject_input:
2293
+ if exists(inp_inject):
2286
2294
  x = x + inp_inject
2287
2295
 
2288
2296
  if exists(pre_norm):
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.4
3
+ Version: 1.44.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -19,3 +19,12 @@ Requires-Dist: einops>=0.8.0
19
19
  Requires-Dist: loguru
20
20
  Requires-Dist: packaging>=21.0
21
21
  Requires-Dist: torch>=2.0
22
+ Dynamic: author
23
+ Dynamic: author-email
24
+ Dynamic: classifier
25
+ Dynamic: description-content-type
26
+ Dynamic: home-page
27
+ Dynamic: keywords
28
+ Dynamic: license
29
+ Dynamic: requires-dist
30
+ Dynamic: summary
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=lc9mmhV-O9MesX7Di7P93KjMioRzM5zzZ0U9sVoDLqU,103100
9
+ x_transformers/x_transformers.py,sha256=PYrwLPEUaiWuPmDNV7nQQZChfMlPJbF9NULHl9Te3LQ,103494
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.44.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.44.4.dist-info/METADATA,sha256=09PGX7zKwq8DjFeEX3FmF3YmSpN1XU4fiyDisewXlDg,738
14
- x_transformers-1.44.4.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
- x_transformers-1.44.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.44.4.dist-info/RECORD,,
12
+ x_transformers-1.44.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.6.dist-info/METADATA,sha256=3_gOvzIcumtCNqhjGmlcAPMZ2FO6q4sVlhAV-_sybBA,924
14
+ x_transformers-1.44.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
15
+ x_transformers-1.44.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5