x-transformers 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -841,6 +841,15 @@ class SimpleRMSNorm(Module):
841
841
  def forward(self, x):
842
842
  return F.normalize(x, dim = -1) * self.scale
843
843
 
844
+ class MultiheadRMSNorm(Module):
845
+ def __init__(self, dim, heads):
846
+ super().__init__()
847
+ self.rmsnorm = SimpleRMSNorm(dim)
848
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
849
+
850
+ def forward(self, x):
851
+ return self.rmsnorm(x) * (self.gamma + 1.)
852
+
844
853
  # residual and residual gates
845
854
 
846
855
  class Residual(Module):
@@ -1431,12 +1440,22 @@ class Attention(Module):
1431
1440
 
1432
1441
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1433
1442
 
1443
+ hybrid_mix = None
1444
+ hybrid_norms = None
1434
1445
  hybrid_module = maybe(deepcopy)(hybrid_module)
1435
1446
 
1436
1447
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1437
1448
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1449
+ hybrid_mix = LinearNoBias(dim, heads)
1450
+
1451
+ hybrid_norms = ModuleList([
1452
+ MultiheadRMSNorm(dim_head, heads = heads),
1453
+ MultiheadRMSNorm(dim_head, heads = heads)
1454
+ ])
1438
1455
 
1439
1456
  self.hybrid_module = hybrid_module
1457
+ self.hybrid_norms = hybrid_norms
1458
+ self.hybrid_mix = hybrid_mix
1440
1459
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1441
1460
 
1442
1461
  # output dimension by default same as input, but can be overridden
@@ -1729,11 +1748,9 @@ class Attention(Module):
1729
1748
  head_gate = self.to_v_head_gate(x)
1730
1749
  out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
1731
1750
 
1732
- # merge heads
1733
-
1734
- out = self.merge_heads(out)
1751
+ # if exists hybrid module, must do a normalization
1735
1752
 
1736
- # hybrid module
1753
+ # hybrid module
1737
1754
 
1738
1755
  if exists(self.hybrid_module):
1739
1756
 
@@ -1751,8 +1768,23 @@ class Attention(Module):
1751
1768
  # handle hybrid out
1752
1769
 
1753
1770
  (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1771
+
1772
+ # handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
1773
+
1774
+ if hybrid_out.ndim == 3:
1775
+ hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
1776
+
1777
+ out_norm, hybrid_out_norm = self.hybrid_norms
1778
+
1779
+ out = out_norm(out)
1780
+ hybrid_out = hybrid_out_norm(hybrid_out)
1781
+
1754
1782
  out = 0.5 * (out + hybrid_out)
1755
1783
 
1784
+ # merge heads
1785
+
1786
+ out = self.merge_heads(out)
1787
+
1756
1788
  # alphafold2 styled gating of the values
1757
1789
 
1758
1790
  if exists(self.to_v_gate):