x-transformers 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -841,6 +841,15 @@ class SimpleRMSNorm(Module):
841
841
  def forward(self, x):
842
842
  return F.normalize(x, dim = -1) * self.scale
843
843
 
844
+ class MultiheadRMSNorm(Module):
845
+ def __init__(self, dim, heads):
846
+ super().__init__()
847
+ self.rmsnorm = SimpleRMSNorm(dim)
848
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
849
+
850
+ def forward(self, x):
851
+ return self.rmsnorm(x) * (self.gamma + 1.)
852
+
844
853
  # residual and residual gates
845
854
 
846
855
  class Residual(Module):
@@ -1195,6 +1204,7 @@ class Attention(Module):
1195
1204
  hybrid_module: Module | None = None,
1196
1205
  hybrid_mask_kwarg: str | None = None,
1197
1206
  hybrid_fold_axial_dim: int | None = None,
1207
+ hybrid_learned_mix = False,
1198
1208
  one_kv_head = False,
1199
1209
  kv_heads = None,
1200
1210
  value_dim_head = None,
@@ -1431,12 +1441,22 @@ class Attention(Module):
1431
1441
 
1432
1442
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1433
1443
 
1444
+ hybrid_mix = None
1445
+ hybrid_norms = None
1434
1446
  hybrid_module = maybe(deepcopy)(hybrid_module)
1435
1447
 
1436
1448
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1437
1449
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1450
+ hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
1451
+
1452
+ hybrid_norms = ModuleList([
1453
+ MultiheadRMSNorm(dim_head, heads = heads),
1454
+ MultiheadRMSNorm(dim_head, heads = heads)
1455
+ ])
1438
1456
 
1439
1457
  self.hybrid_module = hybrid_module
1458
+ self.hybrid_norms = hybrid_norms
1459
+ self.hybrid_mix = hybrid_mix
1440
1460
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1441
1461
 
1442
1462
  # output dimension by default same as input, but can be overridden
@@ -1729,11 +1749,9 @@ class Attention(Module):
1729
1749
  head_gate = self.to_v_head_gate(x)
1730
1750
  out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
1731
1751
 
1732
- # merge heads
1752
+ # if exists hybrid module, must do a normalization
1733
1753
 
1734
- out = self.merge_heads(out)
1735
-
1736
- # hybrid module
1754
+ # hybrid module
1737
1755
 
1738
1756
  if exists(self.hybrid_module):
1739
1757
 
@@ -1751,7 +1769,27 @@ class Attention(Module):
1751
1769
  # handle hybrid out
1752
1770
 
1753
1771
  (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1754
- out = 0.5 * (out + hybrid_out)
1772
+
1773
+ # handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
1774
+
1775
+ if hybrid_out.ndim == 3:
1776
+ hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
1777
+
1778
+ out_norm, hybrid_out_norm = self.hybrid_norms
1779
+
1780
+ out = out_norm(out)
1781
+ hybrid_out = hybrid_out_norm(hybrid_out)
1782
+
1783
+ if exists(self.hybrid_mix):
1784
+ mix = self.hybrid_mix(x)
1785
+ mix = rearrange(mix, 'b n h -> b h n 1')
1786
+ out = out.lerp(hybrid_out, mix.sigmoid())
1787
+ else:
1788
+ out = 0.5 * (out + hybrid_out)
1789
+
1790
+ # merge heads
1791
+
1792
+ out = self.merge_heads(out)
1755
1793
 
1756
1794
  # alphafold2 styled gating of the values
1757
1795