x-transformers 1.25.10__py3-none-any.whl → 1.25.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +7 -0
- x_transformers/x_transformers.py +9 -2
- {x_transformers-1.25.10.dist-info → x_transformers-1.25.12.dist-info}/METADATA +1 -1
- {x_transformers-1.25.10.dist-info → x_transformers-1.25.12.dist-info}/RECORD +7 -7
- {x_transformers-1.25.10.dist-info → x_transformers-1.25.12.dist-info}/WHEEL +1 -1
- {x_transformers-1.25.10.dist-info → x_transformers-1.25.12.dist-info}/LICENSE +0 -0
- {x_transformers-1.25.10.dist-info → x_transformers-1.25.12.dist-info}/top_level.txt +0 -0
x_transformers/continuous.py
CHANGED
@@ -85,6 +85,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
85
85
|
mems = None,
|
86
86
|
pos = None,
|
87
87
|
prepend_embeds = None,
|
88
|
+
prepend_mask = None,
|
88
89
|
**kwargs
|
89
90
|
):
|
90
91
|
batch = x.shape[0]
|
@@ -112,6 +113,12 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
112
113
|
|
113
114
|
x = torch.cat((prepend_embeds, x), dim = -2)
|
114
115
|
|
116
|
+
if exists(prepend_mask) or exists(mask):
|
117
|
+
mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
|
118
|
+
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
|
119
|
+
|
120
|
+
mask = torch.cat((prepend_mask, mask), dim = -1)
|
121
|
+
|
115
122
|
x = self.emb_dropout(x)
|
116
123
|
|
117
124
|
# attention layers
|
x_transformers/x_transformers.py
CHANGED
@@ -1545,6 +1545,7 @@ class TransformerWrapper(nn.Module):
|
|
1545
1545
|
mems = None,
|
1546
1546
|
pos = None,
|
1547
1547
|
prepend_embeds = None,
|
1548
|
+
prepend_mask = None,
|
1548
1549
|
sum_embeds = None,
|
1549
1550
|
return_attn_z_loss = False,
|
1550
1551
|
attn_z_loss_weight = 1e-4,
|
@@ -1578,6 +1579,12 @@ class TransformerWrapper(nn.Module):
|
|
1578
1579
|
|
1579
1580
|
x = torch.cat((prepend_embeds, x), dim = -2)
|
1580
1581
|
|
1582
|
+
if exists(prepend_mask) or exists(mask):
|
1583
|
+
mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
|
1584
|
+
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
|
1585
|
+
|
1586
|
+
mask = torch.cat((prepend_mask, mask), dim = -1)
|
1587
|
+
|
1581
1588
|
# whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
|
1582
1589
|
|
1583
1590
|
if emb_frac_gradient < 1:
|
@@ -1712,11 +1719,11 @@ class XTransformer(nn.Module):
|
|
1712
1719
|
|
1713
1720
|
def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
|
1714
1721
|
|
1722
|
+
enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
|
1723
|
+
|
1715
1724
|
if exists(src_prepend_embeds) and exists(mask):
|
1716
1725
|
mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
|
1717
1726
|
|
1718
|
-
enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
|
1719
|
-
|
1720
1727
|
if self.training and self.cross_attn_tokens_dropout > 0:
|
1721
1728
|
enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
|
1722
1729
|
|
@@ -1,13 +1,13 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
|
2
2
|
x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=2mzOq_rl_vevgrxCDncBlVJOAJGS-XGm-iBKJqMjj_c,9041
|
4
|
-
x_transformers/continuous.py,sha256=
|
4
|
+
x_transformers/continuous.py,sha256=s46BlvSHN7OL2Tya28dMomVF3xmpYxj5reerO6tUDoc,5933
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=5T5fOdveqe4vwqGRrNZqJ053pmtw-dJ12AV0nNaWLRc,60814
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
8
|
x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
|
9
|
-
x_transformers-1.25.
|
10
|
-
x_transformers-1.25.
|
11
|
-
x_transformers-1.25.
|
12
|
-
x_transformers-1.25.
|
13
|
-
x_transformers-1.25.
|
9
|
+
x_transformers-1.25.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
10
|
+
x_transformers-1.25.12.dist-info/METADATA,sha256=0GzTD3LTld07_sgfvdFJHgNQhiE5bnPkyEzc9G12Klc,662
|
11
|
+
x_transformers-1.25.12.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
12
|
+
x_transformers-1.25.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
13
|
+
x_transformers-1.25.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|