x-transformers 2.0.1__py3-none-any.whl → 2.0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +8 -2
- {x_transformers-2.0.1.dist-info → x_transformers-2.0.2.dist-info}/METADATA +2 -1
- {x_transformers-2.0.1.dist-info → x_transformers-2.0.2.dist-info}/RECORD +5 -5
- {x_transformers-2.0.1.dist-info → x_transformers-2.0.2.dist-info}/WHEEL +0 -0
- {x_transformers-2.0.1.dist-info → x_transformers-2.0.2.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1204,6 +1204,7 @@ class Attention(Module):
|
|
1204
1204
|
hybrid_module: Module | None = None,
|
1205
1205
|
hybrid_mask_kwarg: str | None = None,
|
1206
1206
|
hybrid_fold_axial_dim: int | None = None,
|
1207
|
+
hybrid_learned_mix = False,
|
1207
1208
|
one_kv_head = False,
|
1208
1209
|
kv_heads = None,
|
1209
1210
|
value_dim_head = None,
|
@@ -1446,7 +1447,7 @@ class Attention(Module):
|
|
1446
1447
|
|
1447
1448
|
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1448
1449
|
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1449
|
-
hybrid_mix = LinearNoBias(dim, heads)
|
1450
|
+
hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
|
1450
1451
|
|
1451
1452
|
hybrid_norms = ModuleList([
|
1452
1453
|
MultiheadRMSNorm(dim_head, heads = heads),
|
@@ -1779,7 +1780,12 @@ class Attention(Module):
|
|
1779
1780
|
out = out_norm(out)
|
1780
1781
|
hybrid_out = hybrid_out_norm(hybrid_out)
|
1781
1782
|
|
1782
|
-
|
1783
|
+
if exists(self.hybrid_mix):
|
1784
|
+
mix = self.hybrid_mix(x)
|
1785
|
+
mix = rearrange(mix, 'b n h -> b h n 1')
|
1786
|
+
out = out.lerp(hybrid_out, mix.sigmoid())
|
1787
|
+
else:
|
1788
|
+
out = 0.5 * (out + hybrid_out)
|
1783
1789
|
|
1784
1790
|
# merge heads
|
1785
1791
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.2
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -40,6 +40,7 @@ Requires-Dist: loguru
|
|
40
40
|
Requires-Dist: packaging>=21.0
|
41
41
|
Requires-Dist: torch>=2.0
|
42
42
|
Provides-Extra: examples
|
43
|
+
Requires-Dist: lion-pytorch; extra == 'examples'
|
43
44
|
Requires-Dist: torchvision; extra == 'examples'
|
44
45
|
Requires-Dist: tqdm; extra == 'examples'
|
45
46
|
Provides-Extra: test
|
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=1s8KCSfHXMN9TKLFdS-RzzCskBDkh4CuBk2_XRb6IXk,107537
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-2.0.
|
13
|
-
x_transformers-2.0.
|
14
|
-
x_transformers-2.0.
|
15
|
-
x_transformers-2.0.
|
12
|
+
x_transformers-2.0.2.dist-info/METADATA,sha256=tNdI3H2S4HnnGK1hPY3l94FoXH3SB9vGAb55pcah6Yw,86506
|
13
|
+
x_transformers-2.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
14
|
+
x_transformers-2.0.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
15
|
+
x_transformers-2.0.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|