x-transformers 2.11.15__tar.gz → 2.11.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.15 → x_transformers-2.11.17}/PKG-INFO +1 -1
  2. {x_transformers-2.11.15 → x_transformers-2.11.17}/pyproject.toml +1 -1
  3. {x_transformers-2.11.15 → x_transformers-2.11.17}/tests/test_x_transformers.py +6 -2
  4. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/free_transformer.py +28 -4
  5. {x_transformers-2.11.15 → x_transformers-2.11.17}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.11.15 → x_transformers-2.11.17}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.11.15 → x_transformers-2.11.17}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.11.15 → x_transformers-2.11.17}/.gitignore +0 -0
  9. {x_transformers-2.11.15 → x_transformers-2.11.17}/LICENSE +0 -0
  10. {x_transformers-2.11.15 → x_transformers-2.11.17}/README.md +0 -0
  11. {x_transformers-2.11.15 → x_transformers-2.11.17}/data/README.md +0 -0
  12. {x_transformers-2.11.15 → x_transformers-2.11.17}/data/enwik8.gz +0 -0
  13. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/all-attention.png +0 -0
  14. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/deepnorm.png +0 -0
  17. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/fcm.png +0 -0
  23. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/ffglu.png +0 -0
  24. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/flash-attention.png +0 -0
  25. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/gate_values.png +0 -0
  26. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/gating.png +0 -0
  27. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/macaron-1.png +0 -0
  29. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/macaron-2.png +0 -0
  30. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/normformer.png +0 -0
  32. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/pia.png +0 -0
  33. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/resi_dual.png +0 -0
  35. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/residual_attn.png +0 -0
  36. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/rezero.png +0 -0
  37. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/rotary.png +0 -0
  38. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/sandwich.png +0 -0
  40. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/scalenorm.png +0 -0
  42. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/talking-heads.png +0 -0
  43. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/topk-attention.png +0 -0
  44. {x_transformers-2.11.15 → x_transformers-2.11.17}/images/xval.png +0 -0
  45. {x_transformers-2.11.15 → x_transformers-2.11.17}/train_belief_state.py +0 -0
  46. {x_transformers-2.11.15 → x_transformers-2.11.17}/train_copy.py +0 -0
  47. {x_transformers-2.11.15 → x_transformers-2.11.17}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.11.15 → x_transformers-2.11.17}/train_enwik8.py +0 -0
  49. {x_transformers-2.11.15 → x_transformers-2.11.17}/train_free.py +0 -0
  50. {x_transformers-2.11.15 → x_transformers-2.11.17}/train_gpt_vae.py +0 -0
  51. {x_transformers-2.11.15 → x_transformers-2.11.17}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.11.15 → x_transformers-2.11.17}/train_parity.py +0 -0
  53. {x_transformers-2.11.15 → x_transformers-2.11.17}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/gpt_vae.py +0 -0
  62. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/multi_input.py +0 -0
  63. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/neo_mlp.py +0 -0
  64. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/nonautoregressive_wrapper.py +0 -0
  65. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/up_wrapper.py +0 -0
  66. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/x_transformers.py +0 -0
  67. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.15 → x_transformers-2.11.17}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.15
3
+ Version: 2.11.17
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.15"
3
+ version = "2.11.17"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1411,9 +1411,11 @@ def test_attn_negative_weights(
1411
1411
 
1412
1412
  @param('per_token_latents', (False, True))
1413
1413
  @param('dec_head_depth', (0, 4))
1414
+ @param('separate_seq_for_latents', (False, True))
1414
1415
  def test_free(
1415
1416
  dec_head_depth,
1416
- per_token_latents
1417
+ per_token_latents,
1418
+ separate_seq_for_latents
1417
1419
  ):
1418
1420
  from x_transformers.free_transformer import FreeTransformer
1419
1421
 
@@ -1432,7 +1434,9 @@ def test_free(
1432
1434
 
1433
1435
  seq = torch.randint(0, 256, (1, 1024))
1434
1436
 
1435
- loss, (ar_loss, aux_loss) = model(seq, return_all_losses = True)
1437
+ separate_seq_for_latents = torch.randint(0, 256, (1, 32)) if separate_seq_for_latents else None
1438
+
1439
+ loss, (ar_loss, aux_loss) = model(seq, separate_seq_for_latents, return_all_losses = True)
1436
1440
  loss.backward()
1437
1441
 
1438
1442
  assert aux_loss.numel() == 1
@@ -149,6 +149,7 @@ class FreeTransformer(Module):
149
149
  enc_kwargs: dict = dict(),
150
150
  dec_kwargs: dict = dict(),
151
151
  kl_loss_weight = 1.,
152
+ latent_dropout_prob = 0.,
152
153
  pad_id = -1,
153
154
  **kwargs
154
155
  ):
@@ -187,6 +188,8 @@ class FreeTransformer(Module):
187
188
 
188
189
  self.from_latent_to_condition = nn.Linear(self.binary_mapper.num_codes, dim, bias = False)
189
190
 
191
+ self.latent_dropout = nn.Dropout(latent_dropout_prob)
192
+
190
193
  self.decoder_head = Decoder(
191
194
  dim = dim,
192
195
  depth = dec_head_depth,
@@ -225,8 +228,11 @@ class FreeTransformer(Module):
225
228
  self,
226
229
  decoder_head_embeds,
227
230
  mask = None,
228
- return_kl_loss = False
231
+ return_kl_loss = False,
232
+ per_token_latents = None
229
233
  ):
234
+ per_token_latents = default(per_token_latents, self.per_token_latents)
235
+
230
236
  batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
231
237
 
232
238
  query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
@@ -235,7 +241,7 @@ class FreeTransformer(Module):
235
241
 
236
242
  # handle the interesting per query token latents, as in the paper
237
243
 
238
- if self.per_token_latents:
244
+ if per_token_latents:
239
245
  query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
240
246
 
241
247
  rotary_pos = torch.arange(seq_len, device = device)
@@ -342,13 +348,13 @@ class FreeTransformer(Module):
342
348
  def forward(
343
349
  self,
344
350
  seq,
351
+ seq_for_latents = None,
345
352
  return_all_losses = False
346
353
  ):
347
354
  batch, device = seq.shape[0], seq.device
348
355
 
349
356
  seq, labels = seq[:, :-1], seq[:, 1:]
350
357
 
351
- encoder_mask = seq != self.pad_id
352
358
 
353
359
  tokens = self.token_emb(seq)
354
360
 
@@ -357,9 +363,27 @@ class FreeTransformer(Module):
357
363
  if exists(self.decoder_head):
358
364
  tokens = self.decoder_head(tokens)
359
365
 
366
+ # determine whether to use a separate sequence for encoding latents
367
+
368
+ if exists(seq_for_latents):
369
+ tokens_for_latents = self.token_emb(seq_for_latents)
370
+
371
+ if exists(self.decoder_head):
372
+ tokens_for_latents = self.decoder_head(tokens_for_latents)
373
+
374
+ encoder_mask = seq_for_latents != self.pad_id
375
+ per_token_latents = False
376
+ else:
377
+
378
+ tokens_for_latents = tokens
379
+ encoder_mask = seq != self.pad_id
380
+ per_token_latents = None
381
+
360
382
  # get latent Z
361
383
 
362
- latents, kl_loss = self.encode_to_latents(tokens, mask = encoder_mask, return_kl_loss = True)
384
+ latents, kl_loss = self.encode_to_latents(tokens_for_latents, mask = encoder_mask, per_token_latents = per_token_latents, return_kl_loss = True)
385
+
386
+ latents = self.latent_dropout(latents)
363
387
 
364
388
  condition = self.from_latent_to_condition(latents)
365
389