x-transformers 2.8.0__tar.gz → 2.8.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {x_transformers-2.8.0 → x_transformers-2.8.1}/PKG-INFO +1 -1
  2. {x_transformers-2.8.0 → x_transformers-2.8.1}/pyproject.toml +1 -1
  3. {x_transformers-2.8.0 → x_transformers-2.8.1}/tests/test_x_transformers.py +20 -0
  4. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/gpt_vae.py +13 -5
  5. {x_transformers-2.8.0 → x_transformers-2.8.1}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.8.0 → x_transformers-2.8.1}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.8.0 → x_transformers-2.8.1}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.8.0 → x_transformers-2.8.1}/.gitignore +0 -0
  9. {x_transformers-2.8.0 → x_transformers-2.8.1}/LICENSE +0 -0
  10. {x_transformers-2.8.0 → x_transformers-2.8.1}/README.md +0 -0
  11. {x_transformers-2.8.0 → x_transformers-2.8.1}/data/README.md +0 -0
  12. {x_transformers-2.8.0 → x_transformers-2.8.1}/data/enwik8.gz +0 -0
  13. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/all-attention.png +0 -0
  14. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/deepnorm.png +0 -0
  17. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/fcm.png +0 -0
  23. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/ffglu.png +0 -0
  24. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/flash-attention.png +0 -0
  25. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/gate_values.png +0 -0
  26. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/gating.png +0 -0
  27. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/macaron-1.png +0 -0
  29. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/macaron-2.png +0 -0
  30. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/normformer.png +0 -0
  32. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/pia.png +0 -0
  33. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/resi_dual.png +0 -0
  35. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/residual_attn.png +0 -0
  36. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/rezero.png +0 -0
  37. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/rotary.png +0 -0
  38. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/sandwich.png +0 -0
  40. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/scalenorm.png +0 -0
  42. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/talking-heads.png +0 -0
  43. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/topk-attention.png +0 -0
  44. {x_transformers-2.8.0 → x_transformers-2.8.1}/images/xval.png +0 -0
  45. {x_transformers-2.8.0 → x_transformers-2.8.1}/train_belief_state.py +0 -0
  46. {x_transformers-2.8.0 → x_transformers-2.8.1}/train_copy.py +0 -0
  47. {x_transformers-2.8.0 → x_transformers-2.8.1}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.8.0 → x_transformers-2.8.1}/train_enwik8.py +0 -0
  49. {x_transformers-2.8.0 → x_transformers-2.8.1}/train_gpt_vae.py +0 -0
  50. {x_transformers-2.8.0 → x_transformers-2.8.1}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.8.0 → x_transformers-2.8.1}/train_parity.py +0 -0
  52. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/__init__.py +0 -0
  53. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/attend.py +0 -0
  54. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/autoregressive_wrapper.py +0 -0
  55. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/belief_state_wrapper.py +0 -0
  56. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/continuous.py +0 -0
  57. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/dpo.py +0 -0
  58. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/entropy_based_tokenizer.py +0 -0
  59. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/multi_input.py +0 -0
  60. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/neo_mlp.py +0 -0
  61. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
  62. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/up_wrapper.py +0 -0
  63. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/x_transformers.py +0 -0
  64. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  65. {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.8.0
3
+ Version: 2.8.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.8.0"
3
+ version = "2.8.1"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1340,3 +1340,23 @@ def test_qk_clip_attn_layers():
1340
1340
  out, intermediates = model(seq, return_intermediates = True)
1341
1341
 
1342
1342
  model.attn_qk_clip_(intermediates)
1343
+
1344
+ def test_vae():
1345
+ from x_transformers.gpt_vae import GPTVAE
1346
+
1347
+ model = GPTVAE(
1348
+ num_tokens = 256,
1349
+ max_seq_len = 1024,
1350
+ dim = 512,
1351
+ depth = 4,
1352
+ enc_depth = 2
1353
+ )
1354
+
1355
+ seq = torch.randint(0, 256, (1, 1024))
1356
+
1357
+ loss = model(seq)
1358
+ loss.backward()
1359
+
1360
+ style = torch.randint(0, 256, (1, 1024))
1361
+
1362
+ out = model.generate(seq[:, :512], 512, seq_for_latents = style)
@@ -68,10 +68,13 @@ class GPTVAE(Module):
68
68
 
69
69
  self.to_latent_mean_log_variance = nn.Sequential(
70
70
  nn.Linear(dim, dim_latent * 2),
71
- Rearrange('b (two d) -> two b 1 d', two = 2)
71
+ Rearrange('b (two d) -> two b d', two = 2)
72
72
  )
73
73
 
74
- self.from_latent_to_prepend_token = nn.Linear(dim_latent, dim)
74
+ self.from_latent_to_prepend_token = nn.Sequential(
75
+ nn.Linear(dim_latent, dim),
76
+ Rearrange('b d -> b 1 d')
77
+ )
75
78
 
76
79
  self.decoder = TransformerWrapper(
77
80
  num_tokens = num_tokens,
@@ -126,11 +129,19 @@ class GPTVAE(Module):
126
129
  prompts,
127
130
  seq_len,
128
131
  latents = None,
132
+ seq_for_latents = None,
129
133
  **generate_kwargs
130
134
  ):
131
135
  assert prompts.ndim in {1, 2}
132
136
  batch = prompts.shape[0] if prompts.ndim == 2 else 1
133
137
 
138
+ # if seq_for_latents passed in, derive latents from it
139
+
140
+ if exists(seq_for_latents):
141
+ assert not exists(latents), 'latents should not be passed in if given the seq from which to derive them'
142
+
143
+ latents = self.encode_to_latents(seq_for_latents)
144
+
134
145
  # prepend embeds
135
146
 
136
147
  prepend_embeds = None
@@ -143,9 +154,6 @@ class GPTVAE(Module):
143
154
 
144
155
  prepend_embeds = self.from_latent_to_prepend_token(latents)
145
156
 
146
- if exists(prepend_embeds):
147
- prepend_embeds = rearrange(prepend_embeds, 'b d -> b 1 d')
148
-
149
157
  # generated
150
158
 
151
159
  generated = self.ar_wrapped_decoder.generate(
File without changes
File without changes