x-transformers 1.44.1__py3-none-any.whl → 1.44.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,7 +10,7 @@ import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
12
  from torch import nn, einsum, Tensor
13
- from torch.utils._pytree import tree_flatten
13
+ from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
16
16
  from functools import partial, wraps
@@ -966,6 +966,42 @@ class ShiftTokens(Module):
966
966
  x = torch.cat((*segments_to_shift, *rest), dim = -1)
967
967
  return self.fn(x, **kwargs)
968
968
 
969
+ class FoldAxially(Module):
970
+ def __init__(
971
+ self,
972
+ axial_dim,
973
+ fn: Module
974
+ ):
975
+ super().__init__()
976
+ self.fn = fn
977
+ self.axial_dim = axial_dim # will fold the sequence as rearrange("b (n axial_dim) ... -> (b axial_dim) n ...")
978
+
979
+ def forward(
980
+ self,
981
+ x,
982
+ **kwargs
983
+ ):
984
+ if self.axial_dim == 1:
985
+ return self.fn(x, **kwargs)
986
+
987
+ seq_len, axial_dim = x.shape[1], self.axial_dim
988
+
989
+ next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
990
+ x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)
991
+
992
+ x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
993
+
994
+ out = self.fn(x, **kwargs)
995
+
996
+ (out, *rest_out), tree_spec = tree_flatten(out)
997
+
998
+ out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)
999
+
1000
+ out = out[:, :seq_len]
1001
+ out = tree_unflatten((out, *rest_out), tree_spec)
1002
+
1003
+ return out
1004
+
969
1005
  # post branch operator
970
1006
 
971
1007
  class LayerScale(Module):
@@ -1140,6 +1176,7 @@ class Attention(Module):
1140
1176
  custom_attn_fn: Callable | None = None,
1141
1177
  hybrid_module: Module | None = None,
1142
1178
  hybrid_mask_kwarg: str | None = None,
1179
+ hybrid_fold_axial_dim: int | None = None,
1143
1180
  one_kv_head = False,
1144
1181
  kv_heads = None,
1145
1182
  shared_kv = False,
@@ -1341,8 +1378,12 @@ class Attention(Module):
1341
1378
 
1342
1379
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1343
1380
 
1344
- self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1381
+ hybrid_module = maybe(deepcopy)(hybrid_module)
1382
+
1383
+ if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1384
+ hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1345
1385
 
1386
+ self.hybrid_module = hybrid_module
1346
1387
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
1388
 
1348
1389
  # output dimension by default same as input, but can be overridden
@@ -1610,7 +1651,7 @@ class Attention(Module):
1610
1651
 
1611
1652
  # handle hybrid out
1612
1653
 
1613
- (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outs)
1654
+ (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1614
1655
  out = 0.5 * (out + hybrid_out)
1615
1656
 
1616
1657
  # alphafold2 styled gating of the values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.44.1
3
+ Version: 1.44.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=yjtB4kV4N9mzHdliIM9MjyA6SoMtvpzc2Z4iU6R9_Uc,101859
9
+ x_transformers/x_transformers.py,sha256=lc9mmhV-O9MesX7Di7P93KjMioRzM5zzZ0U9sVoDLqU,103100
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.44.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.44.1.dist-info/METADATA,sha256=Zw_Rscb4vNZxlKosWSHSQy4EsICF45U58K0hipxydpQ,738
14
- x_transformers-1.44.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
- x_transformers-1.44.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.44.1.dist-info/RECORD,,
12
+ x_transformers-1.44.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.4.dist-info/METADATA,sha256=09PGX7zKwq8DjFeEX3FmF3YmSpN1XU4fiyDisewXlDg,738
14
+ x_transformers-1.44.4.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
+ x_transformers-1.44.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.4.dist-info/RECORD,,