x-transformers 1.44.2__tar.gz → 1.44.5__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.44.2/x_transformers.egg-info → x_transformers-1.44.5}/PKG-INFO +11 -2
- {x_transformers-1.44.2 → x_transformers-1.44.5}/setup.py +1 -1
- {x_transformers-1.44.2 → x_transformers-1.44.5}/tests/test_x_transformers.py +4 -1
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/x_transformers.py +45 -4
- {x_transformers-1.44.2 → x_transformers-1.44.5/x_transformers.egg-info}/PKG-INFO +11 -2
- {x_transformers-1.44.2 → x_transformers-1.44.5}/LICENSE +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/README.md +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/setup.cfg +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/__init__.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/attend.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/continuous.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/dpo.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/xval.py +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.44.
|
3
|
+
Version: 1.44.5
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -19,3 +19,12 @@ Requires-Dist: einops>=0.8.0
|
|
19
19
|
Requires-Dist: loguru
|
20
20
|
Requires-Dist: packaging>=21.0
|
21
21
|
Requires-Dist: torch>=2.0
|
22
|
+
Dynamic: author
|
23
|
+
Dynamic: author-email
|
24
|
+
Dynamic: classifier
|
25
|
+
Dynamic: description-content-type
|
26
|
+
Dynamic: home-page
|
27
|
+
Dynamic: keywords
|
28
|
+
Dynamic: license
|
29
|
+
Dynamic: requires-dist
|
30
|
+
Dynamic: summary
|
@@ -614,7 +614,8 @@ def test_hyper_connections(tanh):
|
|
614
614
|
|
615
615
|
model(x)
|
616
616
|
|
617
|
-
|
617
|
+
@pytest.mark.parametrize('hybrid_axial_dim', (1, 4))
|
618
|
+
def test_hybrid(hybrid_axial_dim):
|
618
619
|
from torch.nn import GRU
|
619
620
|
|
620
621
|
dec = TransformerWrapper(
|
@@ -625,6 +626,7 @@ def test_hybrid():
|
|
625
626
|
depth = 6,
|
626
627
|
heads = 8,
|
627
628
|
attn_dim_head = 64,
|
629
|
+
attn_hybrid_fold_axial_dim = hybrid_axial_dim,
|
628
630
|
attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
|
629
631
|
)
|
630
632
|
)
|
@@ -641,6 +643,7 @@ def test_hybrid():
|
|
641
643
|
depth = 6,
|
642
644
|
heads = 8,
|
643
645
|
attn_dim_head = 64,
|
646
|
+
attn_hybrid_fold_axial_dim = hybrid_axial_dim,
|
644
647
|
attn_hybrid_module = GRU(128, 64 * 4, batch_first = True, bidirectional = True)
|
645
648
|
)
|
646
649
|
)
|
@@ -10,7 +10,7 @@ import torch
|
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch import nn, einsum, Tensor
|
13
|
-
from torch.utils._pytree import tree_flatten
|
13
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
16
16
|
from functools import partial, wraps
|
@@ -966,6 +966,42 @@ class ShiftTokens(Module):
|
|
966
966
|
x = torch.cat((*segments_to_shift, *rest), dim = -1)
|
967
967
|
return self.fn(x, **kwargs)
|
968
968
|
|
969
|
+
class FoldAxially(Module):
|
970
|
+
def __init__(
|
971
|
+
self,
|
972
|
+
axial_dim,
|
973
|
+
fn: Module
|
974
|
+
):
|
975
|
+
super().__init__()
|
976
|
+
self.fn = fn
|
977
|
+
self.axial_dim = axial_dim # will fold the sequence as rearrange("b (n axial_dim) ... -> (b axial_dim) n ...")
|
978
|
+
|
979
|
+
def forward(
|
980
|
+
self,
|
981
|
+
x,
|
982
|
+
**kwargs
|
983
|
+
):
|
984
|
+
if self.axial_dim == 1:
|
985
|
+
return self.fn(x, **kwargs)
|
986
|
+
|
987
|
+
seq_len, axial_dim = x.shape[1], self.axial_dim
|
988
|
+
|
989
|
+
next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
|
990
|
+
x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)
|
991
|
+
|
992
|
+
x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
|
993
|
+
|
994
|
+
out = self.fn(x, **kwargs)
|
995
|
+
|
996
|
+
(out, *rest_out), tree_spec = tree_flatten(out)
|
997
|
+
|
998
|
+
out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)
|
999
|
+
|
1000
|
+
out = out[:, :seq_len]
|
1001
|
+
out = tree_unflatten((out, *rest_out), tree_spec)
|
1002
|
+
|
1003
|
+
return out
|
1004
|
+
|
969
1005
|
# post branch operator
|
970
1006
|
|
971
1007
|
class LayerScale(Module):
|
@@ -1140,6 +1176,7 @@ class Attention(Module):
|
|
1140
1176
|
custom_attn_fn: Callable | None = None,
|
1141
1177
|
hybrid_module: Module | None = None,
|
1142
1178
|
hybrid_mask_kwarg: str | None = None,
|
1179
|
+
hybrid_fold_axial_dim: int | None = None,
|
1143
1180
|
one_kv_head = False,
|
1144
1181
|
kv_heads = None,
|
1145
1182
|
shared_kv = False,
|
@@ -1341,8 +1378,12 @@ class Attention(Module):
|
|
1341
1378
|
|
1342
1379
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1343
1380
|
|
1344
|
-
|
1381
|
+
hybrid_module = maybe(deepcopy)(hybrid_module)
|
1382
|
+
|
1383
|
+
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1384
|
+
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1345
1385
|
|
1386
|
+
self.hybrid_module = hybrid_module
|
1346
1387
|
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1347
1388
|
|
1348
1389
|
# output dimension by default same as input, but can be overridden
|
@@ -2183,7 +2224,7 @@ class AttentionLayers(Module):
|
|
2183
2224
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
2184
2225
|
|
2185
2226
|
# derived input for reinjection if needed
|
2186
|
-
|
2227
|
+
inp_inject = None
|
2187
2228
|
if self.reinject_input:
|
2188
2229
|
assert not exists(in_attn_cond)
|
2189
2230
|
inp_inject = self.reinject_input_proj(x)
|
@@ -2241,7 +2282,7 @@ class AttentionLayers(Module):
|
|
2241
2282
|
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
2242
2283
|
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
2243
2284
|
|
2244
|
-
if
|
2285
|
+
if exists(inp_inject):
|
2245
2286
|
x = x + inp_inject
|
2246
2287
|
|
2247
2288
|
if exists(pre_norm):
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.44.
|
3
|
+
Version: 1.44.5
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -19,3 +19,12 @@ Requires-Dist: einops>=0.8.0
|
|
19
19
|
Requires-Dist: loguru
|
20
20
|
Requires-Dist: packaging>=21.0
|
21
21
|
Requires-Dist: torch>=2.0
|
22
|
+
Dynamic: author
|
23
|
+
Dynamic: author-email
|
24
|
+
Dynamic: classifier
|
25
|
+
Dynamic: description-content-type
|
26
|
+
Dynamic: home-page
|
27
|
+
Dynamic: keywords
|
28
|
+
Dynamic: license
|
29
|
+
Dynamic: requires-dist
|
30
|
+
Dynamic: summary
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|