x-transformers 1.42.12__py3-none-any.whl → 1.42.15__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +24 -5
- {x_transformers-1.42.12.dist-info → x_transformers-1.42.15.dist-info}/METADATA +5 -6
- {x_transformers-1.42.12.dist-info → x_transformers-1.42.15.dist-info}/RECORD +6 -6
- {x_transformers-1.42.12.dist-info → x_transformers-1.42.15.dist-info}/WHEEL +1 -1
- {x_transformers-1.42.12.dist-info → x_transformers-1.42.15.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.12.dist-info → x_transformers-1.42.15.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1072,6 +1072,7 @@ class Attention(Module):
|
|
1072
1072
|
logit_softclamp_value = 50.,
|
1073
1073
|
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
1074
1074
|
neutreno_alpha = 0.4,
|
1075
|
+
learned_value_residual_mix = False,
|
1075
1076
|
onnxable = False,
|
1076
1077
|
attend_sdp_kwargs: dict = dict(
|
1077
1078
|
enable_flash = True,
|
@@ -1231,6 +1232,14 @@ class Attention(Module):
|
|
1231
1232
|
self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
1232
1233
|
self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
1233
1234
|
|
1235
|
+
# maybe learned value residual mixer per token
|
1236
|
+
|
1237
|
+
self.to_value_residual_mix = nn.Sequential(
|
1238
|
+
nn.Linear(dim, 1),
|
1239
|
+
nn.Sigmoid(),
|
1240
|
+
Rearrange('b n 1 -> b 1 n 1')
|
1241
|
+
) if learned_value_residual_mix else always(0.5)
|
1242
|
+
|
1234
1243
|
# attention on attention
|
1235
1244
|
|
1236
1245
|
self.attn_on_attn = on_attn
|
@@ -1303,7 +1312,8 @@ class Attention(Module):
|
|
1303
1312
|
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
|
1304
1313
|
else:
|
1305
1314
|
# https://arxiv.org/abs/2410.17897v1
|
1306
|
-
|
1315
|
+
value_residual_mix = self.to_value_residual_mix(q_input)
|
1316
|
+
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
|
1307
1317
|
|
1308
1318
|
# take care of caching
|
1309
1319
|
|
@@ -1541,8 +1551,9 @@ class AttentionLayers(Module):
|
|
1541
1551
|
use_layerscale = False,
|
1542
1552
|
layerscale_init_value = 0.,
|
1543
1553
|
unet_skips = False,
|
1544
|
-
reinject_input = False,
|
1545
|
-
add_value_residual = False,
|
1554
|
+
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1555
|
+
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
|
1556
|
+
learned_value_residual_mix = False, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
|
1546
1557
|
rel_pos_kwargs: dict = dict(),
|
1547
1558
|
**kwargs
|
1548
1559
|
):
|
@@ -1786,6 +1797,10 @@ class AttentionLayers(Module):
|
|
1786
1797
|
|
1787
1798
|
self.add_value_residual = add_value_residual
|
1788
1799
|
|
1800
|
+
is_first_self_attn = True
|
1801
|
+
is_first_cross_attn = True
|
1802
|
+
learned_value_residual_mix &= add_value_residual
|
1803
|
+
|
1789
1804
|
# iterate and construct layers
|
1790
1805
|
|
1791
1806
|
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
@@ -1801,9 +1816,13 @@ class AttentionLayers(Module):
|
|
1801
1816
|
# attention, cross attention, feedforward
|
1802
1817
|
|
1803
1818
|
if layer_type == 'a':
|
1804
|
-
|
1819
|
+
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
|
1820
|
+
layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
|
1821
|
+
is_first_self_attn = False
|
1805
1822
|
elif layer_type == 'c':
|
1806
|
-
|
1823
|
+
cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn
|
1824
|
+
layer = Attention(dim, heads = heads, learned_value_residual_mix = cross_attn_learned_value_residual, **{**attn_kwargs, **cross_attn_kwargs})
|
1825
|
+
is_first_cross_attn = False
|
1807
1826
|
elif layer_type == 'f':
|
1808
1827
|
layer = FeedForward(dim, **ff_kwargs)
|
1809
1828
|
layer = layer if not macaron else Scale(0.5, layer)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.42.
|
3
|
+
Version: 1.42.15
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -14,8 +14,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.6
|
15
15
|
Description-Content-Type: text/markdown
|
16
16
|
License-File: LICENSE
|
17
|
-
Requires-Dist: torch
|
18
|
-
Requires-Dist: einx
|
19
|
-
Requires-Dist: einops
|
20
|
-
Requires-Dist: packaging
|
21
|
-
|
17
|
+
Requires-Dist: torch>=2.0
|
18
|
+
Requires-Dist: einx>=0.3.0
|
19
|
+
Requires-Dist: einops>=0.8.0
|
20
|
+
Requires-Dist: packaging>=21.0
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256
|
9
|
+
x_transformers/x_transformers.py,sha256=-gi7UiCRdp-5y34cUJEMk7uFSi-I7khXxON1gErAKbY,95125
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.15.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.15.dist-info/METADATA,sha256=zqzIQ3mdFjs4WV7IgTu4YYEmFM-6GKWast8twY4__Tg,717
|
14
|
+
x_transformers-1.42.15.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.42.15.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|