x-transformers 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -841,6 +841,15 @@ class SimpleRMSNorm(Module):
841
841
  def forward(self, x):
842
842
  return F.normalize(x, dim = -1) * self.scale
843
843
 
844
+ class MultiheadRMSNorm(Module):
845
+ def __init__(self, dim, heads):
846
+ super().__init__()
847
+ self.rmsnorm = SimpleRMSNorm(dim)
848
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
849
+
850
+ def forward(self, x):
851
+ return self.rmsnorm(x) * (self.gamma + 1.)
852
+
844
853
  # residual and residual gates
845
854
 
846
855
  class Residual(Module):
@@ -1195,6 +1204,7 @@ class Attention(Module):
1195
1204
  hybrid_module: Module | None = None,
1196
1205
  hybrid_mask_kwarg: str | None = None,
1197
1206
  hybrid_fold_axial_dim: int | None = None,
1207
+ hybrid_learned_mix = False,
1198
1208
  one_kv_head = False,
1199
1209
  kv_heads = None,
1200
1210
  value_dim_head = None,
@@ -1431,12 +1441,22 @@ class Attention(Module):
1431
1441
 
1432
1442
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1433
1443
 
1444
+ hybrid_mix = None
1445
+ hybrid_norms = None
1434
1446
  hybrid_module = maybe(deepcopy)(hybrid_module)
1435
1447
 
1436
1448
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1437
1449
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1450
+ hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
1451
+
1452
+ hybrid_norms = ModuleList([
1453
+ MultiheadRMSNorm(dim_head, heads = heads),
1454
+ MultiheadRMSNorm(dim_head, heads = heads)
1455
+ ])
1438
1456
 
1439
1457
  self.hybrid_module = hybrid_module
1458
+ self.hybrid_norms = hybrid_norms
1459
+ self.hybrid_mix = hybrid_mix
1440
1460
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1441
1461
 
1442
1462
  # output dimension by default same as input, but can be overridden
@@ -1729,11 +1749,9 @@ class Attention(Module):
1729
1749
  head_gate = self.to_v_head_gate(x)
1730
1750
  out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
1731
1751
 
1732
- # merge heads
1752
+ # if exists hybrid module, must do a normalization
1733
1753
 
1734
- out = self.merge_heads(out)
1735
-
1736
- # hybrid module
1754
+ # hybrid module
1737
1755
 
1738
1756
  if exists(self.hybrid_module):
1739
1757
 
@@ -1751,7 +1769,27 @@ class Attention(Module):
1751
1769
  # handle hybrid out
1752
1770
 
1753
1771
  (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1754
- out = 0.5 * (out + hybrid_out)
1772
+
1773
+ # handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
1774
+
1775
+ if hybrid_out.ndim == 3:
1776
+ hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
1777
+
1778
+ out_norm, hybrid_out_norm = self.hybrid_norms
1779
+
1780
+ out = out_norm(out)
1781
+ hybrid_out = hybrid_out_norm(hybrid_out)
1782
+
1783
+ if exists(self.hybrid_mix):
1784
+ mix = self.hybrid_mix(x)
1785
+ mix = rearrange(mix, 'b n h -> b h n 1')
1786
+ out = out.lerp(hybrid_out, mix.sigmoid())
1787
+ else:
1788
+ out = 0.5 * (out + hybrid_out)
1789
+
1790
+ # merge heads
1791
+
1792
+ out = self.merge_heads(out)
1755
1793
 
1756
1794
  # alphafold2 styled gating of the values
1757
1795