x-transformers 1.23.2__py3-none-any.whl → 1.23.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1622,6 +1622,7 @@ class ContinuousTransformerWrapper(nn.Module):
1622
1622
  dim_out = None,
1623
1623
  emb_dim = None,
1624
1624
  max_mem_len = 0,
1625
+ num_memory_tokens = None,
1625
1626
  post_emb_norm = False,
1626
1627
  emb_dropout = 0.,
1627
1628
  use_abs_pos_emb = True,
@@ -1646,10 +1647,21 @@ class ContinuousTransformerWrapper(nn.Module):
1646
1647
  self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
1647
1648
  self.emb_dropout = nn.Dropout(emb_dropout)
1648
1649
 
1649
- self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1650
+ # memory tokens
1651
+
1652
+ num_memory_tokens = default(num_memory_tokens, 0)
1653
+ self.has_memory_tokens = num_memory_tokens > 0
1654
+
1655
+ if num_memory_tokens > 0:
1656
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1657
+
1658
+ # attention layers
1650
1659
 
1651
1660
  self.attn_layers = attn_layers
1652
1661
 
1662
+ # project in and out
1663
+
1664
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1653
1665
  self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1654
1666
 
1655
1667
  def forward(
@@ -1665,11 +1677,19 @@ class ContinuousTransformerWrapper(nn.Module):
1665
1677
  prepend_embeds = None,
1666
1678
  **kwargs
1667
1679
  ):
1680
+ batch = x.shape[0]
1681
+
1668
1682
  x = self.project_in(x)
1669
1683
  x = x + self.pos_emb(x, pos = pos)
1670
1684
 
1671
1685
  x = self.post_emb_norm(x)
1672
1686
 
1687
+ # memory tokens
1688
+
1689
+ if self.has_memory_tokens:
1690
+ m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
1691
+ x, mem_ps = pack([m, x], 'b * d')
1692
+
1673
1693
  # whether to append embeds, as in PaLI, for image embeddings
1674
1694
 
1675
1695
  if exists(prepend_embeds):
@@ -1680,8 +1700,15 @@ class ContinuousTransformerWrapper(nn.Module):
1680
1700
 
1681
1701
  x = self.emb_dropout(x)
1682
1702
 
1703
+ # attention layers
1704
+
1683
1705
  x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
1684
1706
 
1707
+ # splice out memory tokens
1708
+
1709
+ if self.has_memory_tokens:
1710
+ m, x = unpack(x, mem_ps, 'b * d')
1711
+
1685
1712
  out = self.project_out(x) if not return_embeddings else x
1686
1713
 
1687
1714
  if return_intermediates:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.23.2
3
+ Version: 1.23.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,1126
3
3
  x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=KQ9mU_jE27whl6yQI67grF0S8Xhd3GndnM6Yd0-q-lw,61162
6
+ x_transformers/x_transformers.py,sha256=goudsIa79mfyJtzuI0GqTSdGQ5CXG1ga5Is9h3UBC5Y,61861
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers-1.23.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.23.2.dist-info/METADATA,sha256=8h0sbx8-4yNTOJuAZLbe5HQ16hsmZI1M_mT-rMIIMJc,661
10
- x_transformers-1.23.2.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.23.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.23.2.dist-info/RECORD,,
8
+ x_transformers-1.23.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.23.3.dist-info/METADATA,sha256=SXNDjqYSGkklnbXVRg8S52VxDR6VVO62KvRH60abY_k,661
10
+ x_transformers-1.23.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.23.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.23.3.dist-info/RECORD,,