x-transformers 2.11.15__py3-none-any.whl → 2.11.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

@@ -225,8 +225,11 @@ class FreeTransformer(Module):
225
225
  self,
226
226
  decoder_head_embeds,
227
227
  mask = None,
228
- return_kl_loss = False
228
+ return_kl_loss = False,
229
+ per_token_latents = None
229
230
  ):
231
+ per_token_latents = default(per_token_latents, self.per_token_latents)
232
+
230
233
  batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
231
234
 
232
235
  query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
@@ -235,7 +238,7 @@ class FreeTransformer(Module):
235
238
 
236
239
  # handle the interesting per query token latents, as in the paper
237
240
 
238
- if self.per_token_latents:
241
+ if per_token_latents:
239
242
  query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
240
243
 
241
244
  rotary_pos = torch.arange(seq_len, device = device)
@@ -342,13 +345,13 @@ class FreeTransformer(Module):
342
345
  def forward(
343
346
  self,
344
347
  seq,
348
+ seq_for_latents = None,
345
349
  return_all_losses = False
346
350
  ):
347
351
  batch, device = seq.shape[0], seq.device
348
352
 
349
353
  seq, labels = seq[:, :-1], seq[:, 1:]
350
354
 
351
- encoder_mask = seq != self.pad_id
352
355
 
353
356
  tokens = self.token_emb(seq)
354
357
 
@@ -357,9 +360,25 @@ class FreeTransformer(Module):
357
360
  if exists(self.decoder_head):
358
361
  tokens = self.decoder_head(tokens)
359
362
 
363
+ # determine whether to use a separate sequence for encoding latents
364
+
365
+ if exists(seq_for_latents):
366
+ tokens_for_latents = self.token_emb(seq_for_latents)
367
+
368
+ if exists(self.decoder_head):
369
+ tokens_for_latents = self.decoder_head(tokens_for_latents)
370
+
371
+ encoder_mask = seq_for_latents != self.pad_id
372
+ per_token_latents = False
373
+ else:
374
+
375
+ tokens_for_latents = tokens
376
+ encoder_mask = seq != self.pad_id
377
+ per_token_latents = None
378
+
360
379
  # get latent Z
361
380
 
362
- latents, kl_loss = self.encode_to_latents(tokens, mask = encoder_mask, return_kl_loss = True)
381
+ latents, kl_loss = self.encode_to_latents(tokens_for_latents, mask = encoder_mask, per_token_latents = per_token_latents, return_kl_loss = True)
363
382
 
364
383
  condition = self.from_latent_to_condition(latents)
365
384
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.15
3
+ Version: 2.11.16
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -5,7 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
- x_transformers/free_transformer.py,sha256=a_sF_tx2RgKNsPCum22jlYam28OWEk3B0o1D4-Vo9Fw,10714
8
+ x_transformers/free_transformer.py,sha256=_hYYkaro3xei3MC3rwtuCWi9gSnciXyAT91_7SrA0nw,11396
9
9
  x_transformers/gpt_vae.py,sha256=1zyjwgfZr6CRDsh5VMCPSdoCPg-sdX5mXmZ_mn4VyYQ,6082
10
10
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
11
11
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
@@ -14,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
14
14
  x_transformers/x_transformers.py,sha256=5ctPu8tvlbUMrtW360e_LPnoGv6xcgQFsyWdbvLo6Tk,127002
15
15
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
16
16
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
17
- x_transformers-2.11.15.dist-info/METADATA,sha256=DSY5ug0mmywhOVxsCxjVkIzyWNY9ot4kmUxBFresdaE,96012
18
- x_transformers-2.11.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- x_transformers-2.11.15.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
- x_transformers-2.11.15.dist-info/RECORD,,
17
+ x_transformers-2.11.16.dist-info/METADATA,sha256=cvhm5LnIRCdqLuv25iSU4vj0a6Np9j2lv2O9W-V48-k,96012
18
+ x_transformers-2.11.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ x_transformers-2.11.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
+ x_transformers-2.11.16.dist-info/RECORD,,