x-transformers 1.44.1__py3-none-any.whl → 1.44.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +44 -3
- {x_transformers-1.44.1.dist-info → x_transformers-1.44.4.dist-info}/METADATA +1 -1
- {x_transformers-1.44.1.dist-info → x_transformers-1.44.4.dist-info}/RECORD +6 -6
- {x_transformers-1.44.1.dist-info → x_transformers-1.44.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.44.1.dist-info → x_transformers-1.44.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.44.1.dist-info → x_transformers-1.44.4.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch import nn, einsum, Tensor
|
13
|
-
from torch.utils._pytree import tree_flatten
|
13
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
16
16
|
from functools import partial, wraps
|
@@ -966,6 +966,42 @@ class ShiftTokens(Module):
|
|
966
966
|
x = torch.cat((*segments_to_shift, *rest), dim = -1)
|
967
967
|
return self.fn(x, **kwargs)
|
968
968
|
|
969
|
+
class FoldAxially(Module):
|
970
|
+
def __init__(
|
971
|
+
self,
|
972
|
+
axial_dim,
|
973
|
+
fn: Module
|
974
|
+
):
|
975
|
+
super().__init__()
|
976
|
+
self.fn = fn
|
977
|
+
self.axial_dim = axial_dim # will fold the sequence as rearrange("b (n axial_dim) ... -> (b axial_dim) n ...")
|
978
|
+
|
979
|
+
def forward(
|
980
|
+
self,
|
981
|
+
x,
|
982
|
+
**kwargs
|
983
|
+
):
|
984
|
+
if self.axial_dim == 1:
|
985
|
+
return self.fn(x, **kwargs)
|
986
|
+
|
987
|
+
seq_len, axial_dim = x.shape[1], self.axial_dim
|
988
|
+
|
989
|
+
next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
|
990
|
+
x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)
|
991
|
+
|
992
|
+
x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
|
993
|
+
|
994
|
+
out = self.fn(x, **kwargs)
|
995
|
+
|
996
|
+
(out, *rest_out), tree_spec = tree_flatten(out)
|
997
|
+
|
998
|
+
out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)
|
999
|
+
|
1000
|
+
out = out[:, :seq_len]
|
1001
|
+
out = tree_unflatten((out, *rest_out), tree_spec)
|
1002
|
+
|
1003
|
+
return out
|
1004
|
+
|
969
1005
|
# post branch operator
|
970
1006
|
|
971
1007
|
class LayerScale(Module):
|
@@ -1140,6 +1176,7 @@ class Attention(Module):
|
|
1140
1176
|
custom_attn_fn: Callable | None = None,
|
1141
1177
|
hybrid_module: Module | None = None,
|
1142
1178
|
hybrid_mask_kwarg: str | None = None,
|
1179
|
+
hybrid_fold_axial_dim: int | None = None,
|
1143
1180
|
one_kv_head = False,
|
1144
1181
|
kv_heads = None,
|
1145
1182
|
shared_kv = False,
|
@@ -1341,8 +1378,12 @@ class Attention(Module):
|
|
1341
1378
|
|
1342
1379
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1343
1380
|
|
1344
|
-
|
1381
|
+
hybrid_module = maybe(deepcopy)(hybrid_module)
|
1382
|
+
|
1383
|
+
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1384
|
+
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1345
1385
|
|
1386
|
+
self.hybrid_module = hybrid_module
|
1346
1387
|
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1347
1388
|
|
1348
1389
|
# output dimension by default same as input, but can be overridden
|
@@ -1610,7 +1651,7 @@ class Attention(Module):
|
|
1610
1651
|
|
1611
1652
|
# handle hybrid out
|
1612
1653
|
|
1613
|
-
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(
|
1654
|
+
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1614
1655
|
out = 0.5 * (out + hybrid_out)
|
1615
1656
|
|
1616
1657
|
# alphafold2 styled gating of the values
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=lc9mmhV-O9MesX7Di7P93KjMioRzM5zzZ0U9sVoDLqU,103100
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.44.
|
13
|
-
x_transformers-1.44.
|
14
|
-
x_transformers-1.44.
|
15
|
-
x_transformers-1.44.
|
16
|
-
x_transformers-1.44.
|
12
|
+
x_transformers-1.44.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.4.dist-info/METADATA,sha256=09PGX7zKwq8DjFeEX3FmF3YmSpN1XU4fiyDisewXlDg,738
|
14
|
+
x_transformers-1.44.4.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
|
15
|
+
x_transformers-1.44.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|