x-transformers 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +43 -5
- x_transformers-2.0.2.dist-info/METADATA +2420 -0
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.2.dist-info}/RECORD +5 -6
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.2.dist-info}/WHEEL +1 -2
- x_transformers-2.0.0.dist-info/METADATA +0 -30
- x_transformers-2.0.0.dist-info/top_level.txt +0 -1
- {x_transformers-2.0.0.dist-info → x_transformers-2.0.2.dist-info/licenses}/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -841,6 +841,15 @@ class SimpleRMSNorm(Module):
|
|
841
841
|
def forward(self, x):
|
842
842
|
return F.normalize(x, dim = -1) * self.scale
|
843
843
|
|
844
|
+
class MultiheadRMSNorm(Module):
|
845
|
+
def __init__(self, dim, heads):
|
846
|
+
super().__init__()
|
847
|
+
self.rmsnorm = SimpleRMSNorm(dim)
|
848
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
849
|
+
|
850
|
+
def forward(self, x):
|
851
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
852
|
+
|
844
853
|
# residual and residual gates
|
845
854
|
|
846
855
|
class Residual(Module):
|
@@ -1195,6 +1204,7 @@ class Attention(Module):
|
|
1195
1204
|
hybrid_module: Module | None = None,
|
1196
1205
|
hybrid_mask_kwarg: str | None = None,
|
1197
1206
|
hybrid_fold_axial_dim: int | None = None,
|
1207
|
+
hybrid_learned_mix = False,
|
1198
1208
|
one_kv_head = False,
|
1199
1209
|
kv_heads = None,
|
1200
1210
|
value_dim_head = None,
|
@@ -1431,12 +1441,22 @@ class Attention(Module):
|
|
1431
1441
|
|
1432
1442
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1433
1443
|
|
1444
|
+
hybrid_mix = None
|
1445
|
+
hybrid_norms = None
|
1434
1446
|
hybrid_module = maybe(deepcopy)(hybrid_module)
|
1435
1447
|
|
1436
1448
|
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1437
1449
|
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1450
|
+
hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
|
1451
|
+
|
1452
|
+
hybrid_norms = ModuleList([
|
1453
|
+
MultiheadRMSNorm(dim_head, heads = heads),
|
1454
|
+
MultiheadRMSNorm(dim_head, heads = heads)
|
1455
|
+
])
|
1438
1456
|
|
1439
1457
|
self.hybrid_module = hybrid_module
|
1458
|
+
self.hybrid_norms = hybrid_norms
|
1459
|
+
self.hybrid_mix = hybrid_mix
|
1440
1460
|
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1441
1461
|
|
1442
1462
|
# output dimension by default same as input, but can be overridden
|
@@ -1729,11 +1749,9 @@ class Attention(Module):
|
|
1729
1749
|
head_gate = self.to_v_head_gate(x)
|
1730
1750
|
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
|
1731
1751
|
|
1732
|
-
#
|
1752
|
+
# if exists hybrid module, must do a normalization
|
1733
1753
|
|
1734
|
-
|
1735
|
-
|
1736
|
-
# hybrid module
|
1754
|
+
# hybrid module
|
1737
1755
|
|
1738
1756
|
if exists(self.hybrid_module):
|
1739
1757
|
|
@@ -1751,7 +1769,27 @@ class Attention(Module):
|
|
1751
1769
|
# handle hybrid out
|
1752
1770
|
|
1753
1771
|
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1754
|
-
|
1772
|
+
|
1773
|
+
# handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
|
1774
|
+
|
1775
|
+
if hybrid_out.ndim == 3:
|
1776
|
+
hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
|
1777
|
+
|
1778
|
+
out_norm, hybrid_out_norm = self.hybrid_norms
|
1779
|
+
|
1780
|
+
out = out_norm(out)
|
1781
|
+
hybrid_out = hybrid_out_norm(hybrid_out)
|
1782
|
+
|
1783
|
+
if exists(self.hybrid_mix):
|
1784
|
+
mix = self.hybrid_mix(x)
|
1785
|
+
mix = rearrange(mix, 'b n h -> b h n 1')
|
1786
|
+
out = out.lerp(hybrid_out, mix.sigmoid())
|
1787
|
+
else:
|
1788
|
+
out = 0.5 * (out + hybrid_out)
|
1789
|
+
|
1790
|
+
# merge heads
|
1791
|
+
|
1792
|
+
out = self.merge_heads(out)
|
1755
1793
|
|
1756
1794
|
# alphafold2 styled gating of the values
|
1757
1795
|
|