x-transformers 2.11.12__tar.gz → 2.11.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.12 → x_transformers-2.11.14}/PKG-INFO +1 -1
  2. {x_transformers-2.11.12 → x_transformers-2.11.14}/pyproject.toml +1 -1
  3. {x_transformers-2.11.12 → x_transformers-2.11.14}/tests/test_x_transformers.py +5 -0
  4. {x_transformers-2.11.12 → x_transformers-2.11.14}/train_free.py +3 -4
  5. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/free_transformer.py +4 -0
  6. {x_transformers-2.11.12 → x_transformers-2.11.14}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.11.12 → x_transformers-2.11.14}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.11.12 → x_transformers-2.11.14}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.11.12 → x_transformers-2.11.14}/.gitignore +0 -0
  10. {x_transformers-2.11.12 → x_transformers-2.11.14}/LICENSE +0 -0
  11. {x_transformers-2.11.12 → x_transformers-2.11.14}/README.md +0 -0
  12. {x_transformers-2.11.12 → x_transformers-2.11.14}/data/README.md +0 -0
  13. {x_transformers-2.11.12 → x_transformers-2.11.14}/data/enwik8.gz +0 -0
  14. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/all-attention.png +0 -0
  15. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/deepnorm.png +0 -0
  18. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/fcm.png +0 -0
  24. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/ffglu.png +0 -0
  25. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/flash-attention.png +0 -0
  26. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/gate_values.png +0 -0
  27. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/gating.png +0 -0
  28. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/macaron-1.png +0 -0
  30. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/macaron-2.png +0 -0
  31. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/normformer.png +0 -0
  33. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/pia.png +0 -0
  34. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/resi_dual.png +0 -0
  36. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/residual_attn.png +0 -0
  37. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/rezero.png +0 -0
  38. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/rotary.png +0 -0
  39. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/sandwich.png +0 -0
  41. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/scalenorm.png +0 -0
  43. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/talking-heads.png +0 -0
  44. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/topk-attention.png +0 -0
  45. {x_transformers-2.11.12 → x_transformers-2.11.14}/images/xval.png +0 -0
  46. {x_transformers-2.11.12 → x_transformers-2.11.14}/train_belief_state.py +0 -0
  47. {x_transformers-2.11.12 → x_transformers-2.11.14}/train_copy.py +0 -0
  48. {x_transformers-2.11.12 → x_transformers-2.11.14}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.11.12 → x_transformers-2.11.14}/train_enwik8.py +0 -0
  50. {x_transformers-2.11.12 → x_transformers-2.11.14}/train_gpt_vae.py +0 -0
  51. {x_transformers-2.11.12 → x_transformers-2.11.14}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.11.12 → x_transformers-2.11.14}/train_parity.py +0 -0
  53. {x_transformers-2.11.12 → x_transformers-2.11.14}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/gpt_vae.py +0 -0
  62. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/multi_input.py +0 -0
  63. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/neo_mlp.py +0 -0
  64. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/nonautoregressive_wrapper.py +0 -0
  65. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/up_wrapper.py +0 -0
  66. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/x_transformers.py +0 -0
  67. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.12 → x_transformers-2.11.14}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.12
3
+ Version: 2.11.14
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.12"
3
+ version = "2.11.14"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1437,6 +1437,11 @@ def test_free(
1437
1437
 
1438
1438
  assert aux_loss.numel() == 1
1439
1439
 
1440
+ rand_indices = torch.randint(0, 2 ** 8, ())
1441
+ generated = model.generate(seq[:, :1], 32, latents = rand_indices)
1442
+
1443
+ assert generated.shape == (1, 32)
1444
+
1440
1445
  def test_kv_input_residual():
1441
1446
  attn = Decoder(
1442
1447
  dim = 256,
@@ -63,8 +63,7 @@ model = FreeTransformer(
63
63
  latent_bits = LATENT_BITS
64
64
  ).cuda()
65
65
 
66
- rand_index = torch.randint(0, 2 ** LATENT_BITS, ())
67
- latents = F.one_hot(rand_index, 2 ** LATENT_BITS).float().cuda()
66
+ one_hot_indices = torch.randint(0, 2 ** LATENT_BITS, ())
68
67
 
69
68
  # prepare enwik8 data
70
69
 
@@ -126,9 +125,9 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
126
125
  sample = model.generate(
127
126
  prompts = inp,
128
127
  seq_len = GENERATE_LENGTH,
129
- latents = latents
128
+ latents = one_hot_indices
130
129
  )
131
130
 
132
131
  output_str = decode_tokens(sample)
133
132
 
134
- print(f'\n\nlatent {rand_index.tolist()} - ', output_str)
133
+ print(f'\n\nlatent {one_hot_indices.tolist()} - ', output_str)
@@ -282,6 +282,10 @@ class FreeTransformer(Module):
282
282
  if not is_tensor(latents):
283
283
  latents = tensor(latents, device = self.device)
284
284
 
285
+ if latents.dtype in (torch.int, torch.long):
286
+ # if given as indices
287
+ latents = F.one_hot(latents, self.binary_mapper.num_codes).float()
288
+
285
289
  if latents.ndim == 1: # repeat latents
286
290
  latents = repeat(latents, 'd -> b 1 d', b = batch)
287
291
  elif latents.ndim == 2: