x-transformers 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -841,6 +841,15 @@ class SimpleRMSNorm(Module):
841
841
  def forward(self, x):
842
842
  return F.normalize(x, dim = -1) * self.scale
843
843
 
844
+ class MultiheadRMSNorm(Module):
845
+ def __init__(self, dim, heads):
846
+ super().__init__()
847
+ self.rmsnorm = SimpleRMSNorm(dim)
848
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
849
+
850
+ def forward(self, x):
851
+ return self.rmsnorm(x) * (self.gamma + 1.)
852
+
844
853
  # residual and residual gates
845
854
 
846
855
  class Residual(Module):
@@ -1431,12 +1440,22 @@ class Attention(Module):
1431
1440
 
1432
1441
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1433
1442
 
1443
+ hybrid_mix = None
1444
+ hybrid_norms = None
1434
1445
  hybrid_module = maybe(deepcopy)(hybrid_module)
1435
1446
 
1436
1447
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1437
1448
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1449
+ hybrid_mix = LinearNoBias(dim, heads)
1450
+
1451
+ hybrid_norms = ModuleList([
1452
+ MultiheadRMSNorm(dim_head, heads = heads),
1453
+ MultiheadRMSNorm(dim_head, heads = heads)
1454
+ ])
1438
1455
 
1439
1456
  self.hybrid_module = hybrid_module
1457
+ self.hybrid_norms = hybrid_norms
1458
+ self.hybrid_mix = hybrid_mix
1440
1459
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1441
1460
 
1442
1461
  # output dimension by default same as input, but can be overridden
@@ -1729,11 +1748,9 @@ class Attention(Module):
1729
1748
  head_gate = self.to_v_head_gate(x)
1730
1749
  out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
1731
1750
 
1732
- # merge heads
1733
-
1734
- out = self.merge_heads(out)
1751
+ # if exists hybrid module, must do a normalization
1735
1752
 
1736
- # hybrid module
1753
+ # hybrid module
1737
1754
 
1738
1755
  if exists(self.hybrid_module):
1739
1756
 
@@ -1751,8 +1768,23 @@ class Attention(Module):
1751
1768
  # handle hybrid out
1752
1769
 
1753
1770
  (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1771
+
1772
+ # handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
1773
+
1774
+ if hybrid_out.ndim == 3:
1775
+ hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
1776
+
1777
+ out_norm, hybrid_out_norm = self.hybrid_norms
1778
+
1779
+ out = out_norm(out)
1780
+ hybrid_out = hybrid_out_norm(hybrid_out)
1781
+
1754
1782
  out = 0.5 * (out + hybrid_out)
1755
1783
 
1784
+ # merge heads
1785
+
1786
+ out = self.merge_heads(out)
1787
+
1756
1788
  # alphafold2 styled gating of the values
1757
1789
 
1758
1790
  if exists(self.to_v_gate):