x-transformers 1.44.2__py3-none-any.whl → 1.44.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +45 -4
- {x_transformers-1.44.2.dist-info → x_transformers-1.44.5.dist-info}/METADATA +11 -2
- {x_transformers-1.44.2.dist-info → x_transformers-1.44.5.dist-info}/RECORD +6 -6
- {x_transformers-1.44.2.dist-info → x_transformers-1.44.5.dist-info}/WHEEL +1 -1
- {x_transformers-1.44.2.dist-info → x_transformers-1.44.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.44.2.dist-info → x_transformers-1.44.5.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch import nn, einsum, Tensor
|
13
|
-
from torch.utils._pytree import tree_flatten
|
13
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
16
16
|
from functools import partial, wraps
|
@@ -966,6 +966,42 @@ class ShiftTokens(Module):
|
|
966
966
|
x = torch.cat((*segments_to_shift, *rest), dim = -1)
|
967
967
|
return self.fn(x, **kwargs)
|
968
968
|
|
969
|
+
class FoldAxially(Module):
|
970
|
+
def __init__(
|
971
|
+
self,
|
972
|
+
axial_dim,
|
973
|
+
fn: Module
|
974
|
+
):
|
975
|
+
super().__init__()
|
976
|
+
self.fn = fn
|
977
|
+
self.axial_dim = axial_dim # will fold the sequence as rearrange("b (n axial_dim) ... -> (b axial_dim) n ...")
|
978
|
+
|
979
|
+
def forward(
|
980
|
+
self,
|
981
|
+
x,
|
982
|
+
**kwargs
|
983
|
+
):
|
984
|
+
if self.axial_dim == 1:
|
985
|
+
return self.fn(x, **kwargs)
|
986
|
+
|
987
|
+
seq_len, axial_dim = x.shape[1], self.axial_dim
|
988
|
+
|
989
|
+
next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
|
990
|
+
x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)
|
991
|
+
|
992
|
+
x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
|
993
|
+
|
994
|
+
out = self.fn(x, **kwargs)
|
995
|
+
|
996
|
+
(out, *rest_out), tree_spec = tree_flatten(out)
|
997
|
+
|
998
|
+
out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)
|
999
|
+
|
1000
|
+
out = out[:, :seq_len]
|
1001
|
+
out = tree_unflatten((out, *rest_out), tree_spec)
|
1002
|
+
|
1003
|
+
return out
|
1004
|
+
|
969
1005
|
# post branch operator
|
970
1006
|
|
971
1007
|
class LayerScale(Module):
|
@@ -1140,6 +1176,7 @@ class Attention(Module):
|
|
1140
1176
|
custom_attn_fn: Callable | None = None,
|
1141
1177
|
hybrid_module: Module | None = None,
|
1142
1178
|
hybrid_mask_kwarg: str | None = None,
|
1179
|
+
hybrid_fold_axial_dim: int | None = None,
|
1143
1180
|
one_kv_head = False,
|
1144
1181
|
kv_heads = None,
|
1145
1182
|
shared_kv = False,
|
@@ -1341,8 +1378,12 @@ class Attention(Module):
|
|
1341
1378
|
|
1342
1379
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1343
1380
|
|
1344
|
-
|
1381
|
+
hybrid_module = maybe(deepcopy)(hybrid_module)
|
1382
|
+
|
1383
|
+
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1384
|
+
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1345
1385
|
|
1386
|
+
self.hybrid_module = hybrid_module
|
1346
1387
|
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1347
1388
|
|
1348
1389
|
# output dimension by default same as input, but can be overridden
|
@@ -2183,7 +2224,7 @@ class AttentionLayers(Module):
|
|
2183
2224
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
2184
2225
|
|
2185
2226
|
# derived input for reinjection if needed
|
2186
|
-
|
2227
|
+
inp_inject = None
|
2187
2228
|
if self.reinject_input:
|
2188
2229
|
assert not exists(in_attn_cond)
|
2189
2230
|
inp_inject = self.reinject_input_proj(x)
|
@@ -2241,7 +2282,7 @@ class AttentionLayers(Module):
|
|
2241
2282
|
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
2242
2283
|
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
2243
2284
|
|
2244
|
-
if
|
2285
|
+
if exists(inp_inject):
|
2245
2286
|
x = x + inp_inject
|
2246
2287
|
|
2247
2288
|
if exists(pre_norm):
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.44.
|
3
|
+
Version: 1.44.5
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -19,3 +19,12 @@ Requires-Dist: einops>=0.8.0
|
|
19
19
|
Requires-Dist: loguru
|
20
20
|
Requires-Dist: packaging>=21.0
|
21
21
|
Requires-Dist: torch>=2.0
|
22
|
+
Dynamic: author
|
23
|
+
Dynamic: author-email
|
24
|
+
Dynamic: classifier
|
25
|
+
Dynamic: description-content-type
|
26
|
+
Dynamic: home-page
|
27
|
+
Dynamic: keywords
|
28
|
+
Dynamic: license
|
29
|
+
Dynamic: requires-dist
|
30
|
+
Dynamic: summary
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256
|
9
|
+
x_transformers/x_transformers.py,sha256=-B4xutlod09EwaA1tE24GgD-5ioi4HAXIwmI5N4MYio,103124
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.44.
|
13
|
-
x_transformers-1.44.
|
14
|
-
x_transformers-1.44.
|
15
|
-
x_transformers-1.44.
|
16
|
-
x_transformers-1.44.
|
12
|
+
x_transformers-1.44.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.5.dist-info/METADATA,sha256=j_8nbdft6Qx1NtIjSd1NyERsrpu-W9L3-CuR0hug_S0,924
|
14
|
+
x_transformers-1.44.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
15
|
+
x_transformers-1.44.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|