x-transformers 1.44.2__tar.gz → 1.44.5__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.44.2/x_transformers.egg-info → x_transformers-1.44.5}/PKG-INFO +11 -2
  2. {x_transformers-1.44.2 → x_transformers-1.44.5}/setup.py +1 -1
  3. {x_transformers-1.44.2 → x_transformers-1.44.5}/tests/test_x_transformers.py +4 -1
  4. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/x_transformers.py +45 -4
  5. {x_transformers-1.44.2 → x_transformers-1.44.5/x_transformers.egg-info}/PKG-INFO +11 -2
  6. {x_transformers-1.44.2 → x_transformers-1.44.5}/LICENSE +0 -0
  7. {x_transformers-1.44.2 → x_transformers-1.44.5}/README.md +0 -0
  8. {x_transformers-1.44.2 → x_transformers-1.44.5}/setup.cfg +0 -0
  9. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.44.2 → x_transformers-1.44.5}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.2
3
+ Version: 1.44.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -19,3 +19,12 @@ Requires-Dist: einops>=0.8.0
19
19
  Requires-Dist: loguru
20
20
  Requires-Dist: packaging>=21.0
21
21
  Requires-Dist: torch>=2.0
22
+ Dynamic: author
23
+ Dynamic: author-email
24
+ Dynamic: classifier
25
+ Dynamic: description-content-type
26
+ Dynamic: home-page
27
+ Dynamic: keywords
28
+ Dynamic: license
29
+ Dynamic: requires-dist
30
+ Dynamic: summary
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.44.2',
6
+ version = '1.44.5',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -614,7 +614,8 @@ def test_hyper_connections(tanh):
614
614
 
615
615
  model(x)
616
616
 
617
- def test_hybrid():
617
+ @pytest.mark.parametrize('hybrid_axial_dim', (1, 4))
618
+ def test_hybrid(hybrid_axial_dim):
618
619
  from torch.nn import GRU
619
620
 
620
621
  dec = TransformerWrapper(
@@ -625,6 +626,7 @@ def test_hybrid():
625
626
  depth = 6,
626
627
  heads = 8,
627
628
  attn_dim_head = 64,
629
+ attn_hybrid_fold_axial_dim = hybrid_axial_dim,
628
630
  attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
629
631
  )
630
632
  )
@@ -641,6 +643,7 @@ def test_hybrid():
641
643
  depth = 6,
642
644
  heads = 8,
643
645
  attn_dim_head = 64,
646
+ attn_hybrid_fold_axial_dim = hybrid_axial_dim,
644
647
  attn_hybrid_module = GRU(128, 64 * 4, batch_first = True, bidirectional = True)
645
648
  )
646
649
  )
@@ -10,7 +10,7 @@ import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
12
  from torch import nn, einsum, Tensor
13
- from torch.utils._pytree import tree_flatten
13
+ from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
16
16
  from functools import partial, wraps
@@ -966,6 +966,42 @@ class ShiftTokens(Module):
966
966
  x = torch.cat((*segments_to_shift, *rest), dim = -1)
967
967
  return self.fn(x, **kwargs)
968
968
 
969
+ class FoldAxially(Module):
970
+ def __init__(
971
+ self,
972
+ axial_dim,
973
+ fn: Module
974
+ ):
975
+ super().__init__()
976
+ self.fn = fn
977
+ self.axial_dim = axial_dim # will fold the sequence as rearrange("b (n axial_dim) ... -> (b axial_dim) n ...")
978
+
979
+ def forward(
980
+ self,
981
+ x,
982
+ **kwargs
983
+ ):
984
+ if self.axial_dim == 1:
985
+ return self.fn(x, **kwargs)
986
+
987
+ seq_len, axial_dim = x.shape[1], self.axial_dim
988
+
989
+ next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
990
+ x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)
991
+
992
+ x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
993
+
994
+ out = self.fn(x, **kwargs)
995
+
996
+ (out, *rest_out), tree_spec = tree_flatten(out)
997
+
998
+ out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)
999
+
1000
+ out = out[:, :seq_len]
1001
+ out = tree_unflatten((out, *rest_out), tree_spec)
1002
+
1003
+ return out
1004
+
969
1005
  # post branch operator
970
1006
 
971
1007
  class LayerScale(Module):
@@ -1140,6 +1176,7 @@ class Attention(Module):
1140
1176
  custom_attn_fn: Callable | None = None,
1141
1177
  hybrid_module: Module | None = None,
1142
1178
  hybrid_mask_kwarg: str | None = None,
1179
+ hybrid_fold_axial_dim: int | None = None,
1143
1180
  one_kv_head = False,
1144
1181
  kv_heads = None,
1145
1182
  shared_kv = False,
@@ -1341,8 +1378,12 @@ class Attention(Module):
1341
1378
 
1342
1379
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1343
1380
 
1344
- self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1381
+ hybrid_module = maybe(deepcopy)(hybrid_module)
1382
+
1383
+ if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1384
+ hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1345
1385
 
1386
+ self.hybrid_module = hybrid_module
1346
1387
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
1388
 
1348
1389
  # output dimension by default same as input, but can be overridden
@@ -2183,7 +2224,7 @@ class AttentionLayers(Module):
2183
2224
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2184
2225
 
2185
2226
  # derived input for reinjection if needed
2186
-
2227
+ inp_inject = None
2187
2228
  if self.reinject_input:
2188
2229
  assert not exists(in_attn_cond)
2189
2230
  inp_inject = self.reinject_input_proj(x)
@@ -2241,7 +2282,7 @@ class AttentionLayers(Module):
2241
2282
  post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
2242
2283
  post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
2243
2284
 
2244
- if self.reinject_input:
2285
+ if exists(inp_inject):
2245
2286
  x = x + inp_inject
2246
2287
 
2247
2288
  if exists(pre_norm):
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.2
3
+ Version: 1.44.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -19,3 +19,12 @@ Requires-Dist: einops>=0.8.0
19
19
  Requires-Dist: loguru
20
20
  Requires-Dist: packaging>=21.0
21
21
  Requires-Dist: torch>=2.0
22
+ Dynamic: author
23
+ Dynamic: author-email
24
+ Dynamic: classifier
25
+ Dynamic: description-content-type
26
+ Dynamic: home-page
27
+ Dynamic: keywords
28
+ Dynamic: license
29
+ Dynamic: requires-dist
30
+ Dynamic: summary
File without changes