x-transformers 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +36 -4
- x_transformers-2.0.1.dist-info/METADATA +2419 -0
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.1.dist-info}/RECORD +5 -6
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.1.dist-info}/WHEEL +1 -2
- x_transformers-2.0.0.dist-info/METADATA +0 -30
- x_transformers-2.0.0.dist-info/top_level.txt +0 -1
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.1.dist-info/licenses}/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -841,6 +841,15 @@ class SimpleRMSNorm(Module):
|
|
841
841
|
def forward(self, x):
|
842
842
|
return F.normalize(x, dim = -1) * self.scale
|
843
843
|
|
844
|
+
class MultiheadRMSNorm(Module):
|
845
|
+
def __init__(self, dim, heads):
|
846
|
+
super().__init__()
|
847
|
+
self.rmsnorm = SimpleRMSNorm(dim)
|
848
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
849
|
+
|
850
|
+
def forward(self, x):
|
851
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
852
|
+
|
844
853
|
# residual and residual gates
|
845
854
|
|
846
855
|
class Residual(Module):
|
@@ -1431,12 +1440,22 @@ class Attention(Module):
|
|
1431
1440
|
|
1432
1441
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1433
1442
|
|
1443
|
+
hybrid_mix = None
|
1444
|
+
hybrid_norms = None
|
1434
1445
|
hybrid_module = maybe(deepcopy)(hybrid_module)
|
1435
1446
|
|
1436
1447
|
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1437
1448
|
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1449
|
+
hybrid_mix = LinearNoBias(dim, heads)
|
1450
|
+
|
1451
|
+
hybrid_norms = ModuleList([
|
1452
|
+
MultiheadRMSNorm(dim_head, heads = heads),
|
1453
|
+
MultiheadRMSNorm(dim_head, heads = heads)
|
1454
|
+
])
|
1438
1455
|
|
1439
1456
|
self.hybrid_module = hybrid_module
|
1457
|
+
self.hybrid_norms = hybrid_norms
|
1458
|
+
self.hybrid_mix = hybrid_mix
|
1440
1459
|
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1441
1460
|
|
1442
1461
|
# output dimension by default same as input, but can be overridden
|
@@ -1729,11 +1748,9 @@ class Attention(Module):
|
|
1729
1748
|
head_gate = self.to_v_head_gate(x)
|
1730
1749
|
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
|
1731
1750
|
|
1732
|
-
#
|
1733
|
-
|
1734
|
-
out = self.merge_heads(out)
|
1751
|
+
# if exists hybrid module, must do a normalization
|
1735
1752
|
|
1736
|
-
|
1753
|
+
# hybrid module
|
1737
1754
|
|
1738
1755
|
if exists(self.hybrid_module):
|
1739
1756
|
|
@@ -1751,8 +1768,23 @@ class Attention(Module):
|
|
1751
1768
|
# handle hybrid out
|
1752
1769
|
|
1753
1770
|
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1771
|
+
|
1772
|
+
# handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
|
1773
|
+
|
1774
|
+
if hybrid_out.ndim == 3:
|
1775
|
+
hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
|
1776
|
+
|
1777
|
+
out_norm, hybrid_out_norm = self.hybrid_norms
|
1778
|
+
|
1779
|
+
out = out_norm(out)
|
1780
|
+
hybrid_out = hybrid_out_norm(hybrid_out)
|
1781
|
+
|
1754
1782
|
out = 0.5 * (out + hybrid_out)
|
1755
1783
|
|
1784
|
+
# merge heads
|
1785
|
+
|
1786
|
+
out = self.merge_heads(out)
|
1787
|
+
|
1756
1788
|
# alphafold2 styled gating of the values
|
1757
1789
|
|
1758
1790
|
if exists(self.to_v_gate):
|