x-transformers 2.11.15__tar.gz → 2.11.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.15 → x_transformers-2.11.16}/PKG-INFO +1 -1
  2. {x_transformers-2.11.15 → x_transformers-2.11.16}/pyproject.toml +1 -1
  3. {x_transformers-2.11.15 → x_transformers-2.11.16}/tests/test_x_transformers.py +6 -2
  4. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/free_transformer.py +23 -4
  5. {x_transformers-2.11.15 → x_transformers-2.11.16}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.11.15 → x_transformers-2.11.16}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.11.15 → x_transformers-2.11.16}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.11.15 → x_transformers-2.11.16}/.gitignore +0 -0
  9. {x_transformers-2.11.15 → x_transformers-2.11.16}/LICENSE +0 -0
  10. {x_transformers-2.11.15 → x_transformers-2.11.16}/README.md +0 -0
  11. {x_transformers-2.11.15 → x_transformers-2.11.16}/data/README.md +0 -0
  12. {x_transformers-2.11.15 → x_transformers-2.11.16}/data/enwik8.gz +0 -0
  13. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/all-attention.png +0 -0
  14. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/deepnorm.png +0 -0
  17. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/fcm.png +0 -0
  23. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/ffglu.png +0 -0
  24. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/flash-attention.png +0 -0
  25. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/gate_values.png +0 -0
  26. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/gating.png +0 -0
  27. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/macaron-1.png +0 -0
  29. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/macaron-2.png +0 -0
  30. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/normformer.png +0 -0
  32. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/pia.png +0 -0
  33. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/resi_dual.png +0 -0
  35. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/residual_attn.png +0 -0
  36. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/rezero.png +0 -0
  37. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/rotary.png +0 -0
  38. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/sandwich.png +0 -0
  40. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/scalenorm.png +0 -0
  42. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/talking-heads.png +0 -0
  43. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/topk-attention.png +0 -0
  44. {x_transformers-2.11.15 → x_transformers-2.11.16}/images/xval.png +0 -0
  45. {x_transformers-2.11.15 → x_transformers-2.11.16}/train_belief_state.py +0 -0
  46. {x_transformers-2.11.15 → x_transformers-2.11.16}/train_copy.py +0 -0
  47. {x_transformers-2.11.15 → x_transformers-2.11.16}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.11.15 → x_transformers-2.11.16}/train_enwik8.py +0 -0
  49. {x_transformers-2.11.15 → x_transformers-2.11.16}/train_free.py +0 -0
  50. {x_transformers-2.11.15 → x_transformers-2.11.16}/train_gpt_vae.py +0 -0
  51. {x_transformers-2.11.15 → x_transformers-2.11.16}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.11.15 → x_transformers-2.11.16}/train_parity.py +0 -0
  53. {x_transformers-2.11.15 → x_transformers-2.11.16}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/gpt_vae.py +0 -0
  62. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/multi_input.py +0 -0
  63. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/neo_mlp.py +0 -0
  64. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/nonautoregressive_wrapper.py +0 -0
  65. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/up_wrapper.py +0 -0
  66. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/x_transformers.py +0 -0
  67. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.15
3
+ Version: 2.11.16
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.15"
3
+ version = "2.11.16"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1411,9 +1411,11 @@ def test_attn_negative_weights(
1411
1411
 
1412
1412
  @param('per_token_latents', (False, True))
1413
1413
  @param('dec_head_depth', (0, 4))
1414
+ @param('separate_seq_for_latents', (False, True))
1414
1415
  def test_free(
1415
1416
  dec_head_depth,
1416
- per_token_latents
1417
+ per_token_latents,
1418
+ separate_seq_for_latents
1417
1419
  ):
1418
1420
  from x_transformers.free_transformer import FreeTransformer
1419
1421
 
@@ -1432,7 +1434,9 @@ def test_free(
1432
1434
 
1433
1435
  seq = torch.randint(0, 256, (1, 1024))
1434
1436
 
1435
- loss, (ar_loss, aux_loss) = model(seq, return_all_losses = True)
1437
+ separate_seq_for_latents = torch.randint(0, 256, (1, 32)) if separate_seq_for_latents else None
1438
+
1439
+ loss, (ar_loss, aux_loss) = model(seq, separate_seq_for_latents, return_all_losses = True)
1436
1440
  loss.backward()
1437
1441
 
1438
1442
  assert aux_loss.numel() == 1
@@ -225,8 +225,11 @@ class FreeTransformer(Module):
225
225
  self,
226
226
  decoder_head_embeds,
227
227
  mask = None,
228
- return_kl_loss = False
228
+ return_kl_loss = False,
229
+ per_token_latents = None
229
230
  ):
231
+ per_token_latents = default(per_token_latents, self.per_token_latents)
232
+
230
233
  batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
231
234
 
232
235
  query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
@@ -235,7 +238,7 @@ class FreeTransformer(Module):
235
238
 
236
239
  # handle the interesting per query token latents, as in the paper
237
240
 
238
- if self.per_token_latents:
241
+ if per_token_latents:
239
242
  query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
240
243
 
241
244
  rotary_pos = torch.arange(seq_len, device = device)
@@ -342,13 +345,13 @@ class FreeTransformer(Module):
342
345
  def forward(
343
346
  self,
344
347
  seq,
348
+ seq_for_latents = None,
345
349
  return_all_losses = False
346
350
  ):
347
351
  batch, device = seq.shape[0], seq.device
348
352
 
349
353
  seq, labels = seq[:, :-1], seq[:, 1:]
350
354
 
351
- encoder_mask = seq != self.pad_id
352
355
 
353
356
  tokens = self.token_emb(seq)
354
357
 
@@ -357,9 +360,25 @@ class FreeTransformer(Module):
357
360
  if exists(self.decoder_head):
358
361
  tokens = self.decoder_head(tokens)
359
362
 
363
+ # determine whether to use a separate sequence for encoding latents
364
+
365
+ if exists(seq_for_latents):
366
+ tokens_for_latents = self.token_emb(seq_for_latents)
367
+
368
+ if exists(self.decoder_head):
369
+ tokens_for_latents = self.decoder_head(tokens_for_latents)
370
+
371
+ encoder_mask = seq_for_latents != self.pad_id
372
+ per_token_latents = False
373
+ else:
374
+
375
+ tokens_for_latents = tokens
376
+ encoder_mask = seq != self.pad_id
377
+ per_token_latents = None
378
+
360
379
  # get latent Z
361
380
 
362
- latents, kl_loss = self.encode_to_latents(tokens, mask = encoder_mask, return_kl_loss = True)
381
+ latents, kl_loss = self.encode_to_latents(tokens_for_latents, mask = encoder_mask, per_token_latents = per_token_latents, return_kl_loss = True)
363
382
 
364
383
  condition = self.from_latent_to_condition(latents)
365
384