x-transformers 1.23.1__py3-none-any.whl → 1.23.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -650,6 +650,7 @@ class Attention(nn.Module):
650
650
  num_mem_kv = 0,
651
651
  dropout = 0.,
652
652
  on_attn = False,
653
+ gate_value_heads = False,
653
654
  gate_values = False,
654
655
  zero_init_output = False,
655
656
  max_attend_past = None,
@@ -703,7 +704,14 @@ class Attention(nn.Module):
703
704
  if gate_values:
704
705
  self.to_v_gate = nn.Linear(dim, out_dim)
705
706
  nn.init.constant_(self.to_v_gate.weight, 0)
706
- nn.init.constant_(self.to_v_gate.bias, 1)
707
+ nn.init.constant_(self.to_v_gate.bias, 10)
708
+
709
+ # add per head gating of the output values, from 'Attend to nothing' paper
710
+ self.to_v_head_gate = None
711
+ if gate_value_heads:
712
+ self.to_v_head_gate = nn.Linear(dim, heads)
713
+ nn.init.constant_(self.to_v_head_gate.weight, 0)
714
+ nn.init.constant_(self.to_v_head_gate.bias, 10)
707
715
 
708
716
  # cosine sim attention
709
717
  self.qk_norm = qk_norm
@@ -905,6 +913,12 @@ class Attention(nn.Module):
905
913
  if head_scale:
906
914
  out = out * self.head_scale_params
907
915
 
916
+ # per head gating, from https://arxiv.org/abs/2306.12929
917
+
918
+ if exists(self.to_v_head_gate):
919
+ head_gate = self.to_v_head_gate(x)
920
+ out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
921
+
908
922
  # merge heads
909
923
 
910
924
  out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1608,6 +1622,7 @@ class ContinuousTransformerWrapper(nn.Module):
1608
1622
  dim_out = None,
1609
1623
  emb_dim = None,
1610
1624
  max_mem_len = 0,
1625
+ num_memory_tokens = None,
1611
1626
  post_emb_norm = False,
1612
1627
  emb_dropout = 0.,
1613
1628
  use_abs_pos_emb = True,
@@ -1632,10 +1647,21 @@ class ContinuousTransformerWrapper(nn.Module):
1632
1647
  self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
1633
1648
  self.emb_dropout = nn.Dropout(emb_dropout)
1634
1649
 
1635
- self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1650
+ # memory tokens
1651
+
1652
+ num_memory_tokens = default(num_memory_tokens, 0)
1653
+ self.has_memory_tokens = num_memory_tokens > 0
1654
+
1655
+ if num_memory_tokens > 0:
1656
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1657
+
1658
+ # attention layers
1636
1659
 
1637
1660
  self.attn_layers = attn_layers
1638
1661
 
1662
+ # project in and out
1663
+
1664
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1639
1665
  self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1640
1666
 
1641
1667
  def forward(
@@ -1651,11 +1677,19 @@ class ContinuousTransformerWrapper(nn.Module):
1651
1677
  prepend_embeds = None,
1652
1678
  **kwargs
1653
1679
  ):
1680
+ batch = x.shape[0]
1681
+
1654
1682
  x = self.project_in(x)
1655
1683
  x = x + self.pos_emb(x, pos = pos)
1656
1684
 
1657
1685
  x = self.post_emb_norm(x)
1658
1686
 
1687
+ # memory tokens
1688
+
1689
+ if self.has_memory_tokens:
1690
+ m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
1691
+ x, mem_ps = pack([m, x], 'b * d')
1692
+
1659
1693
  # whether to append embeds, as in PaLI, for image embeddings
1660
1694
 
1661
1695
  if exists(prepend_embeds):
@@ -1666,8 +1700,15 @@ class ContinuousTransformerWrapper(nn.Module):
1666
1700
 
1667
1701
  x = self.emb_dropout(x)
1668
1702
 
1703
+ # attention layers
1704
+
1669
1705
  x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
1670
1706
 
1707
+ # splice out memory tokens
1708
+
1709
+ if self.has_memory_tokens:
1710
+ m, x = unpack(x, mem_ps, 'b * d')
1711
+
1671
1712
  out = self.project_out(x) if not return_embeddings else x
1672
1713
 
1673
1714
  if return_intermediates:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.23.1
3
+ Version: 1.23.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,1126
3
3
  x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=ikf99q_g1_v_wObZed972s2hHrbiDpAq_qGJDmNcVZc,60573
6
+ x_transformers/x_transformers.py,sha256=goudsIa79mfyJtzuI0GqTSdGQ5CXG1ga5Is9h3UBC5Y,61861
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers-1.23.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.23.1.dist-info/METADATA,sha256=91Gu0qU9ztioZ2_oVeOrLeVkP--n6ngneDxEOMUHJe8,661
10
- x_transformers-1.23.1.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.23.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.23.1.dist-info/RECORD,,
8
+ x_transformers-1.23.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.23.3.dist-info/METADATA,sha256=SXNDjqYSGkklnbXVRg8S52VxDR6VVO62KvRH60abY_k,661
10
+ x_transformers-1.23.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.23.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.23.3.dist-info/RECORD,,