x-transformers 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +36 -4
- x_transformers-2.0.1.dist-info/METADATA +2419 -0
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.1.dist-info}/RECORD +5 -6
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.1.dist-info}/WHEEL +1 -2
- x_transformers-2.0.0.dist-info/METADATA +0 -30
- x_transformers-2.0.0.dist-info/top_level.txt +0 -1
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.1.dist-info/licenses}/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -841,6 +841,15 @@ class SimpleRMSNorm(Module):
|
|
841
841
|
def forward(self, x):
|
842
842
|
return F.normalize(x, dim = -1) * self.scale
|
843
843
|
|
844
|
+
class MultiheadRMSNorm(Module):
|
845
|
+
def __init__(self, dim, heads):
|
846
|
+
super().__init__()
|
847
|
+
self.rmsnorm = SimpleRMSNorm(dim)
|
848
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
849
|
+
|
850
|
+
def forward(self, x):
|
851
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
852
|
+
|
844
853
|
# residual and residual gates
|
845
854
|
|
846
855
|
class Residual(Module):
|
@@ -1431,12 +1440,22 @@ class Attention(Module):
|
|
1431
1440
|
|
1432
1441
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1433
1442
|
|
1443
|
+
hybrid_mix = None
|
1444
|
+
hybrid_norms = None
|
1434
1445
|
hybrid_module = maybe(deepcopy)(hybrid_module)
|
1435
1446
|
|
1436
1447
|
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1437
1448
|
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1449
|
+
hybrid_mix = LinearNoBias(dim, heads)
|
1450
|
+
|
1451
|
+
hybrid_norms = ModuleList([
|
1452
|
+
MultiheadRMSNorm(dim_head, heads = heads),
|
1453
|
+
MultiheadRMSNorm(dim_head, heads = heads)
|
1454
|
+
])
|
1438
1455
|
|
1439
1456
|
self.hybrid_module = hybrid_module
|
1457
|
+
self.hybrid_norms = hybrid_norms
|
1458
|
+
self.hybrid_mix = hybrid_mix
|
1440
1459
|
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1441
1460
|
|
1442
1461
|
# output dimension by default same as input, but can be overridden
|
@@ -1729,11 +1748,9 @@ class Attention(Module):
|
|
1729
1748
|
head_gate = self.to_v_head_gate(x)
|
1730
1749
|
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
|
1731
1750
|
|
1732
|
-
#
|
1733
|
-
|
1734
|
-
out = self.merge_heads(out)
|
1751
|
+
# if exists hybrid module, must do a normalization
|
1735
1752
|
|
1736
|
-
|
1753
|
+
# hybrid module
|
1737
1754
|
|
1738
1755
|
if exists(self.hybrid_module):
|
1739
1756
|
|
@@ -1751,8 +1768,23 @@ class Attention(Module):
|
|
1751
1768
|
# handle hybrid out
|
1752
1769
|
|
1753
1770
|
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1771
|
+
|
1772
|
+
# handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
|
1773
|
+
|
1774
|
+
if hybrid_out.ndim == 3:
|
1775
|
+
hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
|
1776
|
+
|
1777
|
+
out_norm, hybrid_out_norm = self.hybrid_norms
|
1778
|
+
|
1779
|
+
out = out_norm(out)
|
1780
|
+
hybrid_out = hybrid_out_norm(hybrid_out)
|
1781
|
+
|
1754
1782
|
out = 0.5 * (out + hybrid_out)
|
1755
1783
|
|
1784
|
+
# merge heads
|
1785
|
+
|
1786
|
+
out = self.merge_heads(out)
|
1787
|
+
|
1756
1788
|
# alphafold2 styled gating of the values
|
1757
1789
|
|
1758
1790
|
if exists(self.to_v_gate):
|