x-transformers 2.11.11__tar.gz → 2.11.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.11 → x_transformers-2.11.14}/PKG-INFO +1 -1
  2. {x_transformers-2.11.11 → x_transformers-2.11.14}/pyproject.toml +1 -1
  3. {x_transformers-2.11.11 → x_transformers-2.11.14}/tests/test_x_transformers.py +5 -0
  4. {x_transformers-2.11.11 → x_transformers-2.11.14}/train_free.py +3 -4
  5. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/free_transformer.py +4 -0
  6. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/x_transformers.py +4 -1
  7. {x_transformers-2.11.11 → x_transformers-2.11.14}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.11.11 → x_transformers-2.11.14}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.11.11 → x_transformers-2.11.14}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.11.11 → x_transformers-2.11.14}/.gitignore +0 -0
  11. {x_transformers-2.11.11 → x_transformers-2.11.14}/LICENSE +0 -0
  12. {x_transformers-2.11.11 → x_transformers-2.11.14}/README.md +0 -0
  13. {x_transformers-2.11.11 → x_transformers-2.11.14}/data/README.md +0 -0
  14. {x_transformers-2.11.11 → x_transformers-2.11.14}/data/enwik8.gz +0 -0
  15. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/all-attention.png +0 -0
  16. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/attention-on-attention.png +0 -0
  17. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/cosine-sim-attention.png +0 -0
  18. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/deepnorm.png +0 -0
  19. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/dynamic-pos-bias-linear.png +0 -0
  20. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/dynamic-pos-bias-log.png +0 -0
  21. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  22. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/dynamic-pos-bias.png +0 -0
  23. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/enhanced-recurrence.png +0 -0
  24. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/fcm.png +0 -0
  25. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/ffglu.png +0 -0
  26. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/flash-attention.png +0 -0
  27. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/gate_values.png +0 -0
  28. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/gating.png +0 -0
  29. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/length-extrapolation-scale.png +0 -0
  30. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/macaron-1.png +0 -0
  31. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/macaron-2.png +0 -0
  32. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/memory-transformer.png +0 -0
  33. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/normformer.png +0 -0
  34. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/pia.png +0 -0
  35. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/qknorm-analysis.png +0 -0
  36. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/resi_dual.png +0 -0
  37. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/residual_attn.png +0 -0
  38. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/rezero.png +0 -0
  39. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/rotary.png +0 -0
  40. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/sandwich-2.png +0 -0
  41. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/sandwich.png +0 -0
  42. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/sandwich_norm.png +0 -0
  43. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/scalenorm.png +0 -0
  44. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/talking-heads.png +0 -0
  45. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/topk-attention.png +0 -0
  46. {x_transformers-2.11.11 → x_transformers-2.11.14}/images/xval.png +0 -0
  47. {x_transformers-2.11.11 → x_transformers-2.11.14}/train_belief_state.py +0 -0
  48. {x_transformers-2.11.11 → x_transformers-2.11.14}/train_copy.py +0 -0
  49. {x_transformers-2.11.11 → x_transformers-2.11.14}/train_entropy_tokenizer.py +0 -0
  50. {x_transformers-2.11.11 → x_transformers-2.11.14}/train_enwik8.py +0 -0
  51. {x_transformers-2.11.11 → x_transformers-2.11.14}/train_gpt_vae.py +0 -0
  52. {x_transformers-2.11.11 → x_transformers-2.11.14}/train_length_extrapolate.py +0 -0
  53. {x_transformers-2.11.11 → x_transformers-2.11.14}/train_parity.py +0 -0
  54. {x_transformers-2.11.11 → x_transformers-2.11.14}/train_with_muon.py +0 -0
  55. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/__init__.py +0 -0
  56. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/attend.py +0 -0
  57. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/autoregressive_wrapper.py +0 -0
  58. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/belief_state_wrapper.py +0 -0
  59. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/continuous.py +0 -0
  60. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/dpo.py +0 -0
  61. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/entropy_based_tokenizer.py +0 -0
  62. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/gpt_vae.py +0 -0
  63. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/multi_input.py +0 -0
  64. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/neo_mlp.py +0 -0
  65. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/nonautoregressive_wrapper.py +0 -0
  66. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/up_wrapper.py +0 -0
  67. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.11 → x_transformers-2.11.14}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.11
3
+ Version: 2.11.14
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.11"
3
+ version = "2.11.14"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1437,6 +1437,11 @@ def test_free(
1437
1437
 
1438
1438
  assert aux_loss.numel() == 1
1439
1439
 
1440
+ rand_indices = torch.randint(0, 2 ** 8, ())
1441
+ generated = model.generate(seq[:, :1], 32, latents = rand_indices)
1442
+
1443
+ assert generated.shape == (1, 32)
1444
+
1440
1445
  def test_kv_input_residual():
1441
1446
  attn = Decoder(
1442
1447
  dim = 256,
@@ -63,8 +63,7 @@ model = FreeTransformer(
63
63
  latent_bits = LATENT_BITS
64
64
  ).cuda()
65
65
 
66
- rand_index = torch.randint(0, 2 ** LATENT_BITS, ())
67
- latents = F.one_hot(rand_index, 2 ** LATENT_BITS).float().cuda()
66
+ one_hot_indices = torch.randint(0, 2 ** LATENT_BITS, ())
68
67
 
69
68
  # prepare enwik8 data
70
69
 
@@ -126,9 +125,9 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
126
125
  sample = model.generate(
127
126
  prompts = inp,
128
127
  seq_len = GENERATE_LENGTH,
129
- latents = latents
128
+ latents = one_hot_indices
130
129
  )
131
130
 
132
131
  output_str = decode_tokens(sample)
133
132
 
134
- print(f'\n\nlatent {rand_index.tolist()} - ', output_str)
133
+ print(f'\n\nlatent {one_hot_indices.tolist()} - ', output_str)
@@ -282,6 +282,10 @@ class FreeTransformer(Module):
282
282
  if not is_tensor(latents):
283
283
  latents = tensor(latents, device = self.device)
284
284
 
285
+ if latents.dtype in (torch.int, torch.long):
286
+ # if given as indices
287
+ latents = F.one_hot(latents, self.binary_mapper.num_codes).float()
288
+
285
289
  if latents.ndim == 1: # repeat latents
286
290
  latents = repeat(latents, 'd -> b 1 d', b = batch)
287
291
  elif latents.ndim == 2:
@@ -740,11 +740,14 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
740
740
  rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
741
741
 
742
742
  freqs = freqs[:, -seq_len:, :]
743
- scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
743
+ scale = scale[:, -seq_len:, :] if is_tensor(scale) else scale
744
744
 
745
745
  if t.ndim == 4 and freqs.ndim == 3:
746
746
  freqs = rearrange(freqs, 'b n d -> b 1 n d')
747
747
 
748
+ if is_tensor(scale):
749
+ scale = rearrange(scale, 'b n d -> b 1 n d')
750
+
748
751
  # partial rotary embeddings, Wang et al. GPT-J
749
752
  t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
750
753
  t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)