x-transformers 1.26.2__py3-none-any.whl → 1.26.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -145,7 +145,7 @@ class AutoregressiveWrapper(Module):
145
145
  cache_kv = True,
146
146
  **kwargs
147
147
  ):
148
- max_seq_len, device = self.max_seq_len, prompts.device
148
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
149
149
 
150
150
  prompts, ps = pack([prompts], '* n')
151
151
 
@@ -230,23 +230,30 @@ class AutoregressiveWrapper(Module):
230
230
 
231
231
  # filter by top_k, top_p (nucleus), top_a, or custom
232
232
 
233
- filtered_logits = filter_logits_fn(logits, **filter_kwargs)
233
+ if greedy:
234
+ sample = logits.argmax(dim = -1, keepdim = True)
235
+ else:
236
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
237
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
238
+ sample = torch.multinomial(probs, 1)
234
239
 
235
- probs = F.softmax(filtered_logits / temperature, dim=-1)
236
-
237
- sample = torch.multinomial(probs, 1)
240
+ # concat sample
238
241
 
239
242
  out = torch.cat((out, sample), dim=-1)
240
243
 
241
- if exists(eos_token):
242
- is_eos_tokens = (out == eos_token)
244
+ if not exists(eos_token):
245
+ continue
246
+
247
+ is_eos_tokens = (out == eos_token)
248
+
249
+ if is_eos_tokens.any(dim = -1).all():
250
+ break
243
251
 
244
- if is_eos_tokens.any(dim = -1).all():
245
- # mask out everything after the eos tokens
246
- shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
247
- mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
248
- out = out.masked_fill(mask, self.pad_value)
249
- break
252
+ if exists(eos_token):
253
+ # mask out everything after the eos tokens
254
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
255
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
256
+ out = out.masked_fill(mask, self.pad_value)
250
257
 
251
258
  out = out[:, t:]
252
259
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.26.2
3
+ Version: 1.26.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,13 +1,13 @@
1
1
  x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
2
2
  x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
3
- x_transformers/autoregressive_wrapper.py,sha256=2mzOq_rl_vevgrxCDncBlVJOAJGS-XGm-iBKJqMjj_c,9041
3
+ x_transformers/autoregressive_wrapper.py,sha256=xqr-CnB98N89y7K0z2SLEy8If8pE_3MP7C9WFTwjefs,9230
4
4
  x_transformers/continuous.py,sha256=ixfgi2_zpGN03SX_STXFkNYEOAkgwVIxuS53QgDCx-g,6026
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
6
  x_transformers/x_transformers.py,sha256=gXq_IpWswCeM06r0VCxvZGTVndsxD69_uvKSJFC1544,61632
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
8
  x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
9
- x_transformers-1.26.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
- x_transformers-1.26.2.dist-info/METADATA,sha256=fWgPsKNzN3QgsdPNc-fgtU8UO6pEJTZa1FCFZ4hlbVY,661
11
- x_transformers-1.26.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
- x_transformers-1.26.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
- x_transformers-1.26.2.dist-info/RECORD,,
9
+ x_transformers-1.26.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
+ x_transformers-1.26.4.dist-info/METADATA,sha256=fX_dEQcxjW00biyG6pOz9J88HGiS9qVS05KP1Z94H44,661
11
+ x_transformers-1.26.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
+ x_transformers-1.26.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
+ x_transformers-1.26.4.dist-info/RECORD,,