x-transformers 1.44.0__py3-none-any.whl → 1.44.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +20 -2
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/METADATA +1 -1
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/RECORD +6 -6
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/WHEEL +1 -1
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/LICENSE +0 -0
- {x_transformers-1.44.0.dist-info → x_transformers-1.44.2.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -7,10 +7,11 @@ from random import random, randrange
|
|
7
7
|
from packaging import version
|
8
8
|
|
9
9
|
import torch
|
10
|
+
from torch.amp import autocast
|
10
11
|
import torch.nn.functional as F
|
11
12
|
from torch import nn, einsum, Tensor
|
13
|
+
from torch.utils._pytree import tree_flatten
|
12
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
13
|
-
from torch.amp import autocast
|
14
15
|
|
15
16
|
from functools import partial, wraps
|
16
17
|
from collections import namedtuple
|
@@ -1138,6 +1139,7 @@ class Attention(Module):
|
|
1138
1139
|
selective = False,
|
1139
1140
|
custom_attn_fn: Callable | None = None,
|
1140
1141
|
hybrid_module: Module | None = None,
|
1142
|
+
hybrid_mask_kwarg: str | None = None,
|
1141
1143
|
one_kv_head = False,
|
1142
1144
|
kv_heads = None,
|
1143
1145
|
shared_kv = False,
|
@@ -1341,6 +1343,8 @@ class Attention(Module):
|
|
1341
1343
|
|
1342
1344
|
self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
|
1343
1345
|
|
1346
|
+
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1347
|
+
|
1344
1348
|
# output dimension by default same as input, but can be overridden
|
1345
1349
|
|
1346
1350
|
dim_out = default(dim_out, dim)
|
@@ -1592,7 +1596,21 @@ class Attention(Module):
|
|
1592
1596
|
# hybrid module
|
1593
1597
|
|
1594
1598
|
if exists(self.hybrid_module):
|
1595
|
-
|
1599
|
+
|
1600
|
+
# hybrid input
|
1601
|
+
|
1602
|
+
hybrid_forward_kwargs = dict()
|
1603
|
+
|
1604
|
+
if not self.causal and exists(self.hybrid_mask_kwarg):
|
1605
|
+
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
|
1606
|
+
|
1607
|
+
# hybrid forward
|
1608
|
+
|
1609
|
+
hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
|
1610
|
+
|
1611
|
+
# handle hybrid out
|
1612
|
+
|
1613
|
+
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1596
1614
|
out = 0.5 * (out + hybrid_out)
|
1597
1615
|
|
1598
1616
|
# alphafold2 styled gating of the values
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=W8qNdv-1CctdU4zwN2rYwu2CrgVMT1WL_3lwTSy5cCg,101862
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.44.
|
13
|
-
x_transformers-1.44.
|
14
|
-
x_transformers-1.44.
|
15
|
-
x_transformers-1.44.
|
16
|
-
x_transformers-1.44.
|
12
|
+
x_transformers-1.44.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.2.dist-info/METADATA,sha256=yNNX1JyY8T3KF0Q6vO525au_90esaJZhw3quktk8n7g,738
|
14
|
+
x_transformers-1.44.2.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
|
15
|
+
x_transformers-1.44.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|