x-transformers 1.44.0__py3-none-any.whl → 1.44.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +20 -2
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/METADATA +1 -1
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/RECORD +6 -6
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/WHEEL +1 -1
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/LICENSE +0 -0
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -7,10 +7,11 @@ from random import random, randrange
|
|
7
7
|
from packaging import version
|
8
8
|
|
9
9
|
import torch
|
10
|
+
from torch.amp import autocast
|
10
11
|
import torch.nn.functional as F
|
11
12
|
from torch import nn, einsum, Tensor
|
13
|
+
from torch.utils._pytree import tree_flatten
|
12
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
13
|
-
from torch.amp import autocast
|
14
15
|
|
15
16
|
from functools import partial, wraps
|
16
17
|
from collections import namedtuple
|
@@ -1138,6 +1139,7 @@ class Attention(Module):
|
|
1138
1139
|
selective = False,
|
1139
1140
|
custom_attn_fn: Callable | None = None,
|
1140
1141
|
hybrid_module: Module | None = None,
|
1142
|
+
hybrid_mask_kwarg: str | None = None,
|
1141
1143
|
one_kv_head = False,
|
1142
1144
|
kv_heads = None,
|
1143
1145
|
shared_kv = False,
|
@@ -1341,6 +1343,8 @@ class Attention(Module):
|
|
1341
1343
|
|
1342
1344
|
self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
|
1343
1345
|
|
1346
|
+
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1347
|
+
|
1344
1348
|
# output dimension by default same as input, but can be overridden
|
1345
1349
|
|
1346
1350
|
dim_out = default(dim_out, dim)
|
@@ -1592,7 +1596,21 @@ class Attention(Module):
|
|
1592
1596
|
# hybrid module
|
1593
1597
|
|
1594
1598
|
if exists(self.hybrid_module):
|
1595
|
-
|
1599
|
+
|
1600
|
+
# hybrid input
|
1601
|
+
|
1602
|
+
hybrid_forward_kwargs = dict()
|
1603
|
+
|
1604
|
+
if not self.causal and exists(self.hybrid_mask_kwarg):
|
1605
|
+
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
|
1606
|
+
|
1607
|
+
# hybrid forward
|
1608
|
+
|
1609
|
+
hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
|
1610
|
+
|
1611
|
+
# handle hybrid out
|
1612
|
+
|
1613
|
+
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1596
1614
|
out = 0.5 * (out + hybrid_out)
|
1597
1615
|
|
1598
1616
|
# alphafold2 styled gating of the values
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=W8qNdv-1CctdU4zwN2rYwu2CrgVMT1WL_3lwTSy5cCg,101862
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.44.
|
13
|
-
x_transformers-1.44.
|
14
|
-
x_transformers-1.44.
|
15
|
-
x_transformers-1.44.
|
16
|
-
x_transformers-1.44.
|
12
|
+
x_transformers-1.44.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.2.dist-info/METADATA,sha256=yNNX1JyY8T3KF0Q6vO525au_90esaJZhw3quktk8n7g,738
|
14
|
+
x_transformers-1.44.2.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
|
15
|
+
x_transformers-1.44.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|