x-transformers 1.44.2__py3-none-any.whl → 1.44.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,7 +10,7 @@ import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
12
  from torch import nn, einsum, Tensor
13
- from torch.utils._pytree import tree_flatten
13
+ from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
16
16
  from functools import partial, wraps
@@ -966,6 +966,42 @@ class ShiftTokens(Module):
966
966
  x = torch.cat((*segments_to_shift, *rest), dim = -1)
967
967
  return self.fn(x, **kwargs)
968
968
 
969
+ class FoldAxially(Module):
970
+ def __init__(
971
+ self,
972
+ axial_dim,
973
+ fn: Module
974
+ ):
975
+ super().__init__()
976
+ self.fn = fn
977
+ self.axial_dim = axial_dim # will fold the sequence as rearrange("b (n axial_dim) ... -> (b axial_dim) n ...")
978
+
979
+ def forward(
980
+ self,
981
+ x,
982
+ **kwargs
983
+ ):
984
+ if self.axial_dim == 1:
985
+ return self.fn(x, **kwargs)
986
+
987
+ seq_len, axial_dim = x.shape[1], self.axial_dim
988
+
989
+ next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
990
+ x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)
991
+
992
+ x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
993
+
994
+ out = self.fn(x, **kwargs)
995
+
996
+ (out, *rest_out), tree_spec = tree_flatten(out)
997
+
998
+ out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)
999
+
1000
+ out = out[:, :seq_len]
1001
+ out = tree_unflatten((out, *rest_out), tree_spec)
1002
+
1003
+ return out
1004
+
969
1005
  # post branch operator
970
1006
 
971
1007
  class LayerScale(Module):
@@ -1140,6 +1176,7 @@ class Attention(Module):
1140
1176
  custom_attn_fn: Callable | None = None,
1141
1177
  hybrid_module: Module | None = None,
1142
1178
  hybrid_mask_kwarg: str | None = None,
1179
+ hybrid_fold_axial_dim: int | None = None,
1143
1180
  one_kv_head = False,
1144
1181
  kv_heads = None,
1145
1182
  shared_kv = False,
@@ -1341,8 +1378,12 @@ class Attention(Module):
1341
1378
 
1342
1379
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1343
1380
 
1344
- self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1381
+ hybrid_module = maybe(deepcopy)(hybrid_module)
1382
+
1383
+ if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1384
+ hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1345
1385
 
1386
+ self.hybrid_module = hybrid_module
1346
1387
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
1388
 
1348
1389
  # output dimension by default same as input, but can be overridden
@@ -2183,7 +2224,7 @@ class AttentionLayers(Module):
2183
2224
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2184
2225
 
2185
2226
  # derived input for reinjection if needed
2186
-
2227
+ inp_inject = None
2187
2228
  if self.reinject_input:
2188
2229
  assert not exists(in_attn_cond)
2189
2230
  inp_inject = self.reinject_input_proj(x)
@@ -2241,7 +2282,7 @@ class AttentionLayers(Module):
2241
2282
  post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
2242
2283
  post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
2243
2284
 
2244
- if self.reinject_input:
2285
+ if exists(inp_inject):
2245
2286
  x = x + inp_inject
2246
2287
 
2247
2288
  if exists(pre_norm):
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.2
3
+ Version: 1.44.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -19,3 +19,12 @@ Requires-Dist: einops>=0.8.0
19
19
  Requires-Dist: loguru
20
20
  Requires-Dist: packaging>=21.0
21
21
  Requires-Dist: torch>=2.0
22
+ Dynamic: author
23
+ Dynamic: author-email
24
+ Dynamic: classifier
25
+ Dynamic: description-content-type
26
+ Dynamic: home-page
27
+ Dynamic: keywords
28
+ Dynamic: license
29
+ Dynamic: requires-dist
30
+ Dynamic: summary
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=W8qNdv-1CctdU4zwN2rYwu2CrgVMT1WL_3lwTSy5cCg,101862
9
+ x_transformers/x_transformers.py,sha256=-B4xutlod09EwaA1tE24GgD-5ioi4HAXIwmI5N4MYio,103124
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.44.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.44.2.dist-info/METADATA,sha256=yNNX1JyY8T3KF0Q6vO525au_90esaJZhw3quktk8n7g,738
14
- x_transformers-1.44.2.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
- x_transformers-1.44.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.44.2.dist-info/RECORD,,
12
+ x_transformers-1.44.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.5.dist-info/METADATA,sha256=j_8nbdft6Qx1NtIjSd1NyERsrpu-W9L3-CuR0hug_S0,924
14
+ x_transformers-1.44.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
15
+ x_transformers-1.44.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5