x-transformers 1.23.2__py3-none-any.whl → 1.23.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +28 -1
- {x_transformers-1.23.2.dist-info → x_transformers-1.23.3.dist-info}/METADATA +1 -1
- {x_transformers-1.23.2.dist-info → x_transformers-1.23.3.dist-info}/RECORD +6 -6
- {x_transformers-1.23.2.dist-info → x_transformers-1.23.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.23.2.dist-info → x_transformers-1.23.3.dist-info}/WHEEL +0 -0
- {x_transformers-1.23.2.dist-info → x_transformers-1.23.3.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1622,6 +1622,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1622
1622
|
dim_out = None,
|
1623
1623
|
emb_dim = None,
|
1624
1624
|
max_mem_len = 0,
|
1625
|
+
num_memory_tokens = None,
|
1625
1626
|
post_emb_norm = False,
|
1626
1627
|
emb_dropout = 0.,
|
1627
1628
|
use_abs_pos_emb = True,
|
@@ -1646,10 +1647,21 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1646
1647
|
self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
|
1647
1648
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
1648
1649
|
|
1649
|
-
|
1650
|
+
# memory tokens
|
1651
|
+
|
1652
|
+
num_memory_tokens = default(num_memory_tokens, 0)
|
1653
|
+
self.has_memory_tokens = num_memory_tokens > 0
|
1654
|
+
|
1655
|
+
if num_memory_tokens > 0:
|
1656
|
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
1657
|
+
|
1658
|
+
# attention layers
|
1650
1659
|
|
1651
1660
|
self.attn_layers = attn_layers
|
1652
1661
|
|
1662
|
+
# project in and out
|
1663
|
+
|
1664
|
+
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
1653
1665
|
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
1654
1666
|
|
1655
1667
|
def forward(
|
@@ -1665,11 +1677,19 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1665
1677
|
prepend_embeds = None,
|
1666
1678
|
**kwargs
|
1667
1679
|
):
|
1680
|
+
batch = x.shape[0]
|
1681
|
+
|
1668
1682
|
x = self.project_in(x)
|
1669
1683
|
x = x + self.pos_emb(x, pos = pos)
|
1670
1684
|
|
1671
1685
|
x = self.post_emb_norm(x)
|
1672
1686
|
|
1687
|
+
# memory tokens
|
1688
|
+
|
1689
|
+
if self.has_memory_tokens:
|
1690
|
+
m = repeat(self.memory_tokens, 'm d -> b m d', b = batch)
|
1691
|
+
x, mem_ps = pack([m, x], 'b * d')
|
1692
|
+
|
1673
1693
|
# whether to append embeds, as in PaLI, for image embeddings
|
1674
1694
|
|
1675
1695
|
if exists(prepend_embeds):
|
@@ -1680,8 +1700,15 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
1680
1700
|
|
1681
1701
|
x = self.emb_dropout(x)
|
1682
1702
|
|
1703
|
+
# attention layers
|
1704
|
+
|
1683
1705
|
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
|
1684
1706
|
|
1707
|
+
# splice out memory tokens
|
1708
|
+
|
1709
|
+
if self.has_memory_tokens:
|
1710
|
+
m, x = unpack(x, mem_ps, 'b * d')
|
1711
|
+
|
1685
1712
|
out = self.project_out(x) if not return_embeddings else x
|
1686
1713
|
|
1687
1714
|
if return_intermediates:
|
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=hZcz_iijzbEqbXp2_BPEVL-1LoHXmYaHE6e6Oy-7hFE,1126
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=f2u0usjUfAlXwgTz87O8J8XjGTbsbrx2XEP6K2beSNI,8944
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=goudsIa79mfyJtzuI0GqTSdGQ5CXG1ga5Is9h3UBC5Y,61861
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
-
x_transformers-1.23.
|
9
|
-
x_transformers-1.23.
|
10
|
-
x_transformers-1.23.
|
11
|
-
x_transformers-1.23.
|
12
|
-
x_transformers-1.23.
|
8
|
+
x_transformers-1.23.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.23.3.dist-info/METADATA,sha256=SXNDjqYSGkklnbXVRg8S52VxDR6VVO62KvRH60abY_k,661
|
10
|
+
x_transformers-1.23.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.23.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.23.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|