x-transformers 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +43 -5
- x_transformers-2.0.2.dist-info/METADATA +2420 -0
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.2.dist-info}/RECORD +5 -6
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.2.dist-info}/WHEEL +1 -2
- x_transformers-2.0.0.dist-info/METADATA +0 -30
- x_transformers-2.0.0.dist-info/top_level.txt +0 -1
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.2.dist-info/licenses}/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -841,6 +841,15 @@ class SimpleRMSNorm(Module):
|
|
841
841
|
def forward(self, x):
|
842
842
|
return F.normalize(x, dim = -1) * self.scale
|
843
843
|
|
844
|
+
class MultiheadRMSNorm(Module):
|
845
|
+
def __init__(self, dim, heads):
|
846
|
+
super().__init__()
|
847
|
+
self.rmsnorm = SimpleRMSNorm(dim)
|
848
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
849
|
+
|
850
|
+
def forward(self, x):
|
851
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
852
|
+
|
844
853
|
# residual and residual gates
|
845
854
|
|
846
855
|
class Residual(Module):
|
@@ -1195,6 +1204,7 @@ class Attention(Module):
|
|
1195
1204
|
hybrid_module: Module | None = None,
|
1196
1205
|
hybrid_mask_kwarg: str | None = None,
|
1197
1206
|
hybrid_fold_axial_dim: int | None = None,
|
1207
|
+
hybrid_learned_mix = False,
|
1198
1208
|
one_kv_head = False,
|
1199
1209
|
kv_heads = None,
|
1200
1210
|
value_dim_head = None,
|
@@ -1431,12 +1441,22 @@ class Attention(Module):
|
|
1431
1441
|
|
1432
1442
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1433
1443
|
|
1444
|
+
hybrid_mix = None
|
1445
|
+
hybrid_norms = None
|
1434
1446
|
hybrid_module = maybe(deepcopy)(hybrid_module)
|
1435
1447
|
|
1436
1448
|
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1437
1449
|
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1450
|
+
hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
|
1451
|
+
|
1452
|
+
hybrid_norms = ModuleList([
|
1453
|
+
MultiheadRMSNorm(dim_head, heads = heads),
|
1454
|
+
MultiheadRMSNorm(dim_head, heads = heads)
|
1455
|
+
])
|
1438
1456
|
|
1439
1457
|
self.hybrid_module = hybrid_module
|
1458
|
+
self.hybrid_norms = hybrid_norms
|
1459
|
+
self.hybrid_mix = hybrid_mix
|
1440
1460
|
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1441
1461
|
|
1442
1462
|
# output dimension by default same as input, but can be overridden
|
@@ -1729,11 +1749,9 @@ class Attention(Module):
|
|
1729
1749
|
head_gate = self.to_v_head_gate(x)
|
1730
1750
|
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
|
1731
1751
|
|
1732
|
-
#
|
1752
|
+
# if exists hybrid module, must do a normalization
|
1733
1753
|
|
1734
|
-
|
1735
|
-
|
1736
|
-
# hybrid module
|
1754
|
+
# hybrid module
|
1737
1755
|
|
1738
1756
|
if exists(self.hybrid_module):
|
1739
1757
|
|
@@ -1751,7 +1769,27 @@ class Attention(Module):
|
|
1751
1769
|
# handle hybrid out
|
1752
1770
|
|
1753
1771
|
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1754
|
-
|
1772
|
+
|
1773
|
+
# handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
|
1774
|
+
|
1775
|
+
if hybrid_out.ndim == 3:
|
1776
|
+
hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
|
1777
|
+
|
1778
|
+
out_norm, hybrid_out_norm = self.hybrid_norms
|
1779
|
+
|
1780
|
+
out = out_norm(out)
|
1781
|
+
hybrid_out = hybrid_out_norm(hybrid_out)
|
1782
|
+
|
1783
|
+
if exists(self.hybrid_mix):
|
1784
|
+
mix = self.hybrid_mix(x)
|
1785
|
+
mix = rearrange(mix, 'b n h -> b h n 1')
|
1786
|
+
out = out.lerp(hybrid_out, mix.sigmoid())
|
1787
|
+
else:
|
1788
|
+
out = 0.5 * (out + hybrid_out)
|
1789
|
+
|
1790
|
+
# merge heads
|
1791
|
+
|
1792
|
+
out = self.merge_heads(out)
|
1755
1793
|
|
1756
1794
|
# alphafold2 styled gating of the values
|
1757
1795
|
|