x-transformers 1.44.0__py3-none-any.whl → 1.44.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -7,10 +7,11 @@ from random import random, randrange
7
7
  from packaging import version
8
8
 
9
9
  import torch
10
+ from torch.amp import autocast
10
11
  import torch.nn.functional as F
11
12
  from torch import nn, einsum, Tensor
13
+ from torch.utils._pytree import tree_flatten
12
14
  from torch.nn import Module, ModuleList, ModuleDict
13
- from torch.amp import autocast
14
15
 
15
16
  from functools import partial, wraps
16
17
  from collections import namedtuple
@@ -1138,6 +1139,7 @@ class Attention(Module):
1138
1139
  selective = False,
1139
1140
  custom_attn_fn: Callable | None = None,
1140
1141
  hybrid_module: Module | None = None,
1142
+ hybrid_mask_kwarg: str | None = None,
1141
1143
  one_kv_head = False,
1142
1144
  kv_heads = None,
1143
1145
  shared_kv = False,
@@ -1341,6 +1343,8 @@ class Attention(Module):
1341
1343
 
1342
1344
  self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1343
1345
 
1346
+ self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
+
1344
1348
  # output dimension by default same as input, but can be overridden
1345
1349
 
1346
1350
  dim_out = default(dim_out, dim)
@@ -1592,7 +1596,21 @@ class Attention(Module):
1592
1596
  # hybrid module
1593
1597
 
1594
1598
  if exists(self.hybrid_module):
1595
- hybrid_out, _ = self.hybrid_module(x)
1599
+
1600
+ # hybrid input
1601
+
1602
+ hybrid_forward_kwargs = dict()
1603
+
1604
+ if not self.causal and exists(self.hybrid_mask_kwarg):
1605
+ hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
1606
+
1607
+ # hybrid forward
1608
+
1609
+ hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
1610
+
1611
+ # handle hybrid out
1612
+
1613
+ (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1596
1614
  out = 0.5 * (out + hybrid_out)
1597
1615
 
1598
1616
  # alphafold2 styled gating of the values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.44.0
3
+ Version: 1.44.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=BI3RU3XFvwSNDZgoQBrFBSJ4SavJr38rOCCVgHZBTx0,101241
9
+ x_transformers/x_transformers.py,sha256=W8qNdv-1CctdU4zwN2rYwu2CrgVMT1WL_3lwTSy5cCg,101862
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.44.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.44.0.dist-info/METADATA,sha256=MNVwW_pDeKEIHRVEA1XOUNfzFmL706X7Npoh7xc3wIk,738
14
- x_transformers-1.44.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.44.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.44.0.dist-info/RECORD,,
12
+ x_transformers-1.44.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.2.dist-info/METADATA,sha256=yNNX1JyY8T3KF0Q6vO525au_90esaJZhw3quktk8n7g,738
14
+ x_transformers-1.44.2.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
+ x_transformers-1.44.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.7.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5