x-transformers 1.25.10__py3-none-any.whl → 1.25.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -85,6 +85,7 @@ class ContinuousTransformerWrapper(nn.Module):
85
85
  mems = None,
86
86
  pos = None,
87
87
  prepend_embeds = None,
88
+ prepend_mask = None,
88
89
  **kwargs
89
90
  ):
90
91
  batch = x.shape[0]
@@ -112,6 +113,12 @@ class ContinuousTransformerWrapper(nn.Module):
112
113
 
113
114
  x = torch.cat((prepend_embeds, x), dim = -2)
114
115
 
116
+ if exists(prepend_mask) or exists(mask):
117
+ mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
118
+ prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
119
+
120
+ mask = torch.cat((prepend_mask, mask), dim = -1)
121
+
115
122
  x = self.emb_dropout(x)
116
123
 
117
124
  # attention layers
@@ -1545,6 +1545,7 @@ class TransformerWrapper(nn.Module):
1545
1545
  mems = None,
1546
1546
  pos = None,
1547
1547
  prepend_embeds = None,
1548
+ prepend_mask = None,
1548
1549
  sum_embeds = None,
1549
1550
  return_attn_z_loss = False,
1550
1551
  attn_z_loss_weight = 1e-4,
@@ -1578,6 +1579,12 @@ class TransformerWrapper(nn.Module):
1578
1579
 
1579
1580
  x = torch.cat((prepend_embeds, x), dim = -2)
1580
1581
 
1582
+ if exists(prepend_mask) or exists(mask):
1583
+ mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
1584
+ prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
1585
+
1586
+ mask = torch.cat((prepend_mask, mask), dim = -1)
1587
+
1581
1588
  # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
1582
1589
 
1583
1590
  if emb_frac_gradient < 1:
@@ -1712,11 +1719,11 @@ class XTransformer(nn.Module):
1712
1719
 
1713
1720
  def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
1714
1721
 
1722
+ enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
1723
+
1715
1724
  if exists(src_prepend_embeds) and exists(mask):
1716
1725
  mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
1717
1726
 
1718
- enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
1719
-
1720
1727
  if self.training and self.cross_attn_tokens_dropout > 0:
1721
1728
  enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
1722
1729
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.25.10
3
+ Version: 1.25.12
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,13 +1,13 @@
1
1
  x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
2
2
  x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
3
3
  x_transformers/autoregressive_wrapper.py,sha256=2mzOq_rl_vevgrxCDncBlVJOAJGS-XGm-iBKJqMjj_c,9041
4
- x_transformers/continuous.py,sha256=7zo4lnYyIkIYvs_a_NCj86DUA_ZccU5ndjq-13UnEqg,5554
4
+ x_transformers/continuous.py,sha256=s46BlvSHN7OL2Tya28dMomVF3xmpYxj5reerO6tUDoc,5933
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=SuGWhp9P-RDT3jnquILbPVeTVBUbXHB22QMFxh8YvxU,60435
6
+ x_transformers/x_transformers.py,sha256=5T5fOdveqe4vwqGRrNZqJ053pmtw-dJ12AV0nNaWLRc,60814
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
8
  x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
9
- x_transformers-1.25.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
- x_transformers-1.25.10.dist-info/METADATA,sha256=tIFJS0oIe-1oRS2-LfC1kd32-bFXL4ZrlvLP9cM1DNc,662
11
- x_transformers-1.25.10.dist-info/WHEEL,sha256=Xo9-1PvkuimrydujYJAjF7pCkriuXBpUPEjma1nZyJ0,92
12
- x_transformers-1.25.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
- x_transformers-1.25.10.dist-info/RECORD,,
9
+ x_transformers-1.25.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
+ x_transformers-1.25.12.dist-info/METADATA,sha256=0GzTD3LTld07_sgfvdFJHgNQhiE5bnPkyEzc9G12Klc,662
11
+ x_transformers-1.25.12.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
+ x_transformers-1.25.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
+ x_transformers-1.25.12.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.3)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5