x-transformers 1.44.0__py3-none-any.whl → 1.44.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,10 +7,11 @@ from random import random, randrange
7
7
  from packaging import version
8
8
 
9
9
  import torch
10
+ from torch.amp import autocast
10
11
  import torch.nn.functional as F
11
12
  from torch import nn, einsum, Tensor
13
+ from torch.utils._pytree import tree_flatten
12
14
  from torch.nn import Module, ModuleList, ModuleDict
13
- from torch.amp import autocast
14
15
 
15
16
  from functools import partial, wraps
16
17
  from collections import namedtuple
@@ -1138,6 +1139,7 @@ class Attention(Module):
1138
1139
  selective = False,
1139
1140
  custom_attn_fn: Callable | None = None,
1140
1141
  hybrid_module: Module | None = None,
1142
+ hybrid_mask_kwarg: str | None = None,
1141
1143
  one_kv_head = False,
1142
1144
  kv_heads = None,
1143
1145
  shared_kv = False,
@@ -1341,6 +1343,8 @@ class Attention(Module):
1341
1343
 
1342
1344
  self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1343
1345
 
1346
+ self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
+
1344
1348
  # output dimension by default same as input, but can be overridden
1345
1349
 
1346
1350
  dim_out = default(dim_out, dim)
@@ -1592,7 +1596,21 @@ class Attention(Module):
1592
1596
  # hybrid module
1593
1597
 
1594
1598
  if exists(self.hybrid_module):
1595
- hybrid_out, _ = self.hybrid_module(x)
1599
+
1600
+ # hybrid input
1601
+
1602
+ hybrid_forward_kwargs = dict()
1603
+
1604
+ if not self.causal and exists(self.hybrid_mask_kwarg):
1605
+ hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
1606
+
1607
+ # hybrid forward
1608
+
1609
+ hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
1610
+
1611
+ # handle hybrid out
1612
+
1613
+ (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1596
1614
  out = 0.5 * (out + hybrid_out)
1597
1615
 
1598
1616
  # alphafold2 styled gating of the values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.44.0
3
+ Version: 1.44.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=BI3RU3XFvwSNDZgoQBrFBSJ4SavJr38rOCCVgHZBTx0,101241
9
+ x_transformers/x_transformers.py,sha256=W8qNdv-1CctdU4zwN2rYwu2CrgVMT1WL_3lwTSy5cCg,101862
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.44.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.44.0.dist-info/METADATA,sha256=MNVwW_pDeKEIHRVEA1XOUNfzFmL706X7Npoh7xc3wIk,738
14
- x_transformers-1.44.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.44.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.44.0.dist-info/RECORD,,
12
+ x_transformers-1.44.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.2.dist-info/METADATA,sha256=yNNX1JyY8T3KF0Q6vO525au_90esaJZhw3quktk8n7g,738
14
+ x_transformers-1.44.2.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
+ x_transformers-1.44.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.7.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5