x-transformers 2.11.15__py3-none-any.whl → 2.11.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

@@ -149,6 +149,7 @@ class FreeTransformer(Module):
149
149
  enc_kwargs: dict = dict(),
150
150
  dec_kwargs: dict = dict(),
151
151
  kl_loss_weight = 1.,
152
+ latent_dropout_prob = 0.,
152
153
  pad_id = -1,
153
154
  **kwargs
154
155
  ):
@@ -187,6 +188,8 @@ class FreeTransformer(Module):
187
188
 
188
189
  self.from_latent_to_condition = nn.Linear(self.binary_mapper.num_codes, dim, bias = False)
189
190
 
191
+ self.latent_dropout = nn.Dropout(latent_dropout_prob)
192
+
190
193
  self.decoder_head = Decoder(
191
194
  dim = dim,
192
195
  depth = dec_head_depth,
@@ -225,8 +228,11 @@ class FreeTransformer(Module):
225
228
  self,
226
229
  decoder_head_embeds,
227
230
  mask = None,
228
- return_kl_loss = False
231
+ return_kl_loss = False,
232
+ per_token_latents = None
229
233
  ):
234
+ per_token_latents = default(per_token_latents, self.per_token_latents)
235
+
230
236
  batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
231
237
 
232
238
  query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
@@ -235,7 +241,7 @@ class FreeTransformer(Module):
235
241
 
236
242
  # handle the interesting per query token latents, as in the paper
237
243
 
238
- if self.per_token_latents:
244
+ if per_token_latents:
239
245
  query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
240
246
 
241
247
  rotary_pos = torch.arange(seq_len, device = device)
@@ -342,13 +348,13 @@ class FreeTransformer(Module):
342
348
  def forward(
343
349
  self,
344
350
  seq,
351
+ seq_for_latents = None,
345
352
  return_all_losses = False
346
353
  ):
347
354
  batch, device = seq.shape[0], seq.device
348
355
 
349
356
  seq, labels = seq[:, :-1], seq[:, 1:]
350
357
 
351
- encoder_mask = seq != self.pad_id
352
358
 
353
359
  tokens = self.token_emb(seq)
354
360
 
@@ -357,9 +363,27 @@ class FreeTransformer(Module):
357
363
  if exists(self.decoder_head):
358
364
  tokens = self.decoder_head(tokens)
359
365
 
366
+ # determine whether to use a separate sequence for encoding latents
367
+
368
+ if exists(seq_for_latents):
369
+ tokens_for_latents = self.token_emb(seq_for_latents)
370
+
371
+ if exists(self.decoder_head):
372
+ tokens_for_latents = self.decoder_head(tokens_for_latents)
373
+
374
+ encoder_mask = seq_for_latents != self.pad_id
375
+ per_token_latents = False
376
+ else:
377
+
378
+ tokens_for_latents = tokens
379
+ encoder_mask = seq != self.pad_id
380
+ per_token_latents = None
381
+
360
382
  # get latent Z
361
383
 
362
- latents, kl_loss = self.encode_to_latents(tokens, mask = encoder_mask, return_kl_loss = True)
384
+ latents, kl_loss = self.encode_to_latents(tokens_for_latents, mask = encoder_mask, per_token_latents = per_token_latents, return_kl_loss = True)
385
+
386
+ latents = self.latent_dropout(latents)
363
387
 
364
388
  condition = self.from_latent_to_condition(latents)
365
389
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.15
3
+ Version: 2.11.17
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -5,7 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
- x_transformers/free_transformer.py,sha256=a_sF_tx2RgKNsPCum22jlYam28OWEk3B0o1D4-Vo9Fw,10714
8
+ x_transformers/free_transformer.py,sha256=F0H_rfb_8_nO4oRbaVDLdfOa8EP4YcUNCOaI2rhkLV0,11541
9
9
  x_transformers/gpt_vae.py,sha256=1zyjwgfZr6CRDsh5VMCPSdoCPg-sdX5mXmZ_mn4VyYQ,6082
10
10
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
11
11
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
@@ -14,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
14
14
  x_transformers/x_transformers.py,sha256=5ctPu8tvlbUMrtW360e_LPnoGv6xcgQFsyWdbvLo6Tk,127002
15
15
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
16
16
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
17
- x_transformers-2.11.15.dist-info/METADATA,sha256=DSY5ug0mmywhOVxsCxjVkIzyWNY9ot4kmUxBFresdaE,96012
18
- x_transformers-2.11.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- x_transformers-2.11.15.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
- x_transformers-2.11.15.dist-info/RECORD,,
17
+ x_transformers-2.11.17.dist-info/METADATA,sha256=9gqVZAutVIzE5Xs5ulYv8fZ97-M2vsCacbrWhJmkXm0,96012
18
+ x_transformers-2.11.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ x_transformers-2.11.17.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
+ x_transformers-2.11.17.dist-info/RECORD,,