x-transformers 1.26.1__py3-none-any.whl → 1.26.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -238,15 +238,19 @@ class AutoregressiveWrapper(Module):
238
238
 
239
239
  out = torch.cat((out, sample), dim=-1)
240
240
 
241
- if exists(eos_token):
242
- is_eos_tokens = (out == eos_token)
243
-
244
- if is_eos_tokens.any(dim = -1).all():
245
- # mask out everything after the eos tokens
246
- shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
247
- mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
248
- out = out.masked_fill(mask, self.pad_value)
249
- break
241
+ if not exists(eos_token):
242
+ continue
243
+
244
+ is_eos_tokens = (out == eos_token)
245
+
246
+ if is_eos_tokens.any(dim = -1).all():
247
+ break
248
+
249
+ if exists(eos_token):
250
+ # mask out everything after the eos tokens
251
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
252
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
253
+ out = out.masked_fill(mask, self.pad_value)
250
254
 
251
255
  out = out[:, t:]
252
256
 
@@ -1559,7 +1559,7 @@ class TransformerWrapper(nn.Module):
1559
1559
  pos = None,
1560
1560
  prepend_embeds = None,
1561
1561
  prepend_mask = None,
1562
- embed_ids: Dict[str, Tensor] = None,
1562
+ embed_ids: Dict[str, Tensor] = dict(),
1563
1563
  sum_embeds = None,
1564
1564
  return_attn_z_loss = False,
1565
1565
  attn_z_loss_weight = 1e-4,
@@ -1578,7 +1578,9 @@ class TransformerWrapper(nn.Module):
1578
1578
 
1579
1579
  # add additional embeddings
1580
1580
 
1581
- if exists(self.embeds) and exists(embed_ids):
1581
+ if exists(self.embeds):
1582
+ assert len(embed_ids) == len(self.embeds)
1583
+
1582
1584
  for name, embed_id in embed_ids.items():
1583
1585
  embed_key = f'{name}_embed'
1584
1586
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.26.1
3
+ Version: 1.26.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,13 +1,13 @@
1
1
  x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
2
2
  x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
3
- x_transformers/autoregressive_wrapper.py,sha256=2mzOq_rl_vevgrxCDncBlVJOAJGS-XGm-iBKJqMjj_c,9041
3
+ x_transformers/autoregressive_wrapper.py,sha256=6O4fz0keP2EBaRssUda7I5ZFxY-LFCzM-m--NtQp9rw,9058
4
4
  x_transformers/continuous.py,sha256=ixfgi2_zpGN03SX_STXFkNYEOAkgwVIxuS53QgDCx-g,6026
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=k-_pU9lICUYnCumwimPv3VaaxjpKOeFKDdKEWgnvslc,61597
6
+ x_transformers/x_transformers.py,sha256=gXq_IpWswCeM06r0VCxvZGTVndsxD69_uvKSJFC1544,61632
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
8
  x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
9
- x_transformers-1.26.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
- x_transformers-1.26.1.dist-info/METADATA,sha256=mJsZmH-mfNGzd7eGYglFH6GlNzGUTeugQHLcH8oDecU,661
11
- x_transformers-1.26.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
- x_transformers-1.26.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
- x_transformers-1.26.1.dist-info/RECORD,,
9
+ x_transformers-1.26.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
+ x_transformers-1.26.3.dist-info/METADATA,sha256=K6TZ-TdM0z8jBSMvLMJjBI7nVTtp7rKHW-5wddYYHYQ,661
11
+ x_transformers-1.26.3.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
+ x_transformers-1.26.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
+ x_transformers-1.26.3.dist-info/RECORD,,