x-transformers 1.44.1__py3-none-any.whl → 1.44.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -10,7 +10,7 @@ import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
12
  from torch import nn, einsum, Tensor
13
- from torch.utils._pytree import tree_flatten
13
+ from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
16
16
  from functools import partial, wraps
@@ -966,6 +966,42 @@ class ShiftTokens(Module):
966
966
  x = torch.cat((*segments_to_shift, *rest), dim = -1)
967
967
  return self.fn(x, **kwargs)
968
968
 
969
+ class FoldAxially(Module):
970
+ def __init__(
971
+ self,
972
+ axial_dim,
973
+ fn: Module
974
+ ):
975
+ super().__init__()
976
+ self.fn = fn
977
+ self.axial_dim = axial_dim # will fold the sequence as rearrange("b (n axial_dim) ... -> (b axial_dim) n ...")
978
+
979
+ def forward(
980
+ self,
981
+ x,
982
+ **kwargs
983
+ ):
984
+ if self.axial_dim == 1:
985
+ return self.fn(x, **kwargs)
986
+
987
+ seq_len, axial_dim = x.shape[1], self.axial_dim
988
+
989
+ next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
990
+ x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)
991
+
992
+ x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
993
+
994
+ out = self.fn(x, **kwargs)
995
+
996
+ (out, *rest_out), tree_spec = tree_flatten(out)
997
+
998
+ out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)
999
+
1000
+ out = out[:, :seq_len]
1001
+ out = tree_unflatten((out, *rest_out), tree_spec)
1002
+
1003
+ return out
1004
+
969
1005
  # post branch operator
970
1006
 
971
1007
  class LayerScale(Module):
@@ -1140,6 +1176,7 @@ class Attention(Module):
1140
1176
  custom_attn_fn: Callable | None = None,
1141
1177
  hybrid_module: Module | None = None,
1142
1178
  hybrid_mask_kwarg: str | None = None,
1179
+ hybrid_fold_axial_dim: int | None = None,
1143
1180
  one_kv_head = False,
1144
1181
  kv_heads = None,
1145
1182
  shared_kv = False,
@@ -1341,8 +1378,12 @@ class Attention(Module):
1341
1378
 
1342
1379
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1343
1380
 
1344
- self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1381
+ hybrid_module = maybe(deepcopy)(hybrid_module)
1382
+
1383
+ if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1384
+ hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1345
1385
 
1386
+ self.hybrid_module = hybrid_module
1346
1387
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
1388
 
1348
1389
  # output dimension by default same as input, but can be overridden
@@ -1610,7 +1651,7 @@ class Attention(Module):
1610
1651
 
1611
1652
  # handle hybrid out
1612
1653
 
1613
- (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outs)
1654
+ (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1614
1655
  out = 0.5 * (out + hybrid_out)
1615
1656
 
1616
1657
  # alphafold2 styled gating of the values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.44.1
3
+ Version: 1.44.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=yjtB4kV4N9mzHdliIM9MjyA6SoMtvpzc2Z4iU6R9_Uc,101859
9
+ x_transformers/x_transformers.py,sha256=lc9mmhV-O9MesX7Di7P93KjMioRzM5zzZ0U9sVoDLqU,103100
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.44.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.44.1.dist-info/METADATA,sha256=Zw_Rscb4vNZxlKosWSHSQy4EsICF45U58K0hipxydpQ,738
14
- x_transformers-1.44.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
- x_transformers-1.44.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.44.1.dist-info/RECORD,,
12
+ x_transformers-1.44.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.4.dist-info/METADATA,sha256=09PGX7zKwq8DjFeEX3FmF3YmSpN1XU4fiyDisewXlDg,738
14
+ x_transformers-1.44.4.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
+ x_transformers-1.44.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.4.dist-info/RECORD,,