x-transformers 2.11.2__tar.gz → 2.11.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.2 → x_transformers-2.11.5}/PKG-INFO +1 -1
  2. {x_transformers-2.11.2 → x_transformers-2.11.5}/pyproject.toml +1 -1
  3. {x_transformers-2.11.2 → x_transformers-2.11.5}/tests/test_x_transformers.py +32 -6
  4. {x_transformers-2.11.2 → x_transformers-2.11.5}/train_free.py +1 -1
  5. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/free_transformer.py +47 -21
  6. {x_transformers-2.11.2 → x_transformers-2.11.5}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.11.2 → x_transformers-2.11.5}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.11.2 → x_transformers-2.11.5}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.11.2 → x_transformers-2.11.5}/.gitignore +0 -0
  10. {x_transformers-2.11.2 → x_transformers-2.11.5}/LICENSE +0 -0
  11. {x_transformers-2.11.2 → x_transformers-2.11.5}/README.md +0 -0
  12. {x_transformers-2.11.2 → x_transformers-2.11.5}/data/README.md +0 -0
  13. {x_transformers-2.11.2 → x_transformers-2.11.5}/data/enwik8.gz +0 -0
  14. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/all-attention.png +0 -0
  15. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/deepnorm.png +0 -0
  18. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/fcm.png +0 -0
  24. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/ffglu.png +0 -0
  25. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/flash-attention.png +0 -0
  26. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/gate_values.png +0 -0
  27. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/gating.png +0 -0
  28. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/macaron-1.png +0 -0
  30. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/macaron-2.png +0 -0
  31. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/normformer.png +0 -0
  33. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/pia.png +0 -0
  34. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/resi_dual.png +0 -0
  36. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/residual_attn.png +0 -0
  37. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/rezero.png +0 -0
  38. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/rotary.png +0 -0
  39. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/sandwich.png +0 -0
  41. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/scalenorm.png +0 -0
  43. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/talking-heads.png +0 -0
  44. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/topk-attention.png +0 -0
  45. {x_transformers-2.11.2 → x_transformers-2.11.5}/images/xval.png +0 -0
  46. {x_transformers-2.11.2 → x_transformers-2.11.5}/train_belief_state.py +0 -0
  47. {x_transformers-2.11.2 → x_transformers-2.11.5}/train_copy.py +0 -0
  48. {x_transformers-2.11.2 → x_transformers-2.11.5}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.11.2 → x_transformers-2.11.5}/train_enwik8.py +0 -0
  50. {x_transformers-2.11.2 → x_transformers-2.11.5}/train_gpt_vae.py +0 -0
  51. {x_transformers-2.11.2 → x_transformers-2.11.5}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.11.2 → x_transformers-2.11.5}/train_parity.py +0 -0
  53. {x_transformers-2.11.2 → x_transformers-2.11.5}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/gpt_vae.py +0 -0
  62. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/multi_input.py +0 -0
  63. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/neo_mlp.py +0 -0
  64. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
  65. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/up_wrapper.py +0 -0
  66. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/x_transformers.py +0 -0
  67. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.2 → x_transformers-2.11.5}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.2
3
+ Version: 2.11.5
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.2"
3
+ version = "2.11.5"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -586,12 +586,12 @@ def test_cross_attn_rotary(
586
586
  context_pos = torch.arange(128) if cross_attn_rotary else None
587
587
 
588
588
  embed = model(
589
- x = x,
590
- mask = mask,
591
- context = context,
592
- pos = pos,
593
- context_pos = context_pos,
594
- context_mask = context_mask
589
+ x = x,
590
+ mask = mask,
591
+ context = context,
592
+ pos = pos,
593
+ context_pos = context_pos,
594
+ context_mask = context_mask
595
595
  )
596
596
 
597
597
  @param('tanh', (True, False))
@@ -1408,3 +1408,29 @@ def test_attn_negative_weights(
1408
1408
  x = torch.randint(0, 256, (1, 10))
1409
1409
 
1410
1410
  logits = model(x)
1411
+
1412
+ @param('per_token_latents', (False, True))
1413
+ def test_free(
1414
+ per_token_latents
1415
+ ):
1416
+ from x_transformers.free_transformer import FreeTransformer
1417
+
1418
+ model = FreeTransformer(
1419
+ num_tokens = 256,
1420
+ max_seq_len = 1024,
1421
+ dim = 512,
1422
+ heads = 8,
1423
+ dec_head_depth = 4,
1424
+ dec_tail_depth = 4,
1425
+ enc_depth = 3,
1426
+ kl_loss_weight = 1.,
1427
+ per_token_latents = per_token_latents,
1428
+ latent_bits = 8
1429
+ )
1430
+
1431
+ seq = torch.randint(0, 256, (1, 1024))
1432
+
1433
+ loss, (ar_loss, aux_loss) = model(seq, return_all_losses = True)
1434
+ loss.backward()
1435
+
1436
+ assert aux_loss.numel() == 1
@@ -54,11 +54,11 @@ model = FreeTransformer(
54
54
  max_seq_len = SEQ_LEN,
55
55
  dim = 512,
56
56
  heads = 8,
57
- rotary_pos_emb = True,
58
57
  dec_head_depth = 4,
59
58
  dec_tail_depth = 4,
60
59
  enc_depth = 3,
61
60
  kl_loss_weight = 1.,
61
+ per_token_latents = True,
62
62
  kl_loss_threshold = NAT,
63
63
  latent_bits = LATENT_BITS
64
64
  ).cuda()
@@ -128,19 +128,19 @@ class FreeTransformer(Module):
128
128
  dim,
129
129
  dec_head_depth,
130
130
  dec_tail_depth,
131
- enc_depth,
132
131
  max_seq_len,
132
+ enc_depth = 1,
133
133
  dim_latent = None,
134
134
  attn_dim_head = 64,
135
135
  heads = 8,
136
136
  latent_bits = 16,
137
+ per_token_latents = True, # they use a latent per token in the sequence, instead of one for entire sequence, iiuc
137
138
  kl_loss_threshold = NAT,
138
139
  binary_mapper_kwargs: dict = dict(),
139
140
  enc_kwargs: dict = dict(),
140
141
  dec_kwargs: dict = dict(),
141
142
  kl_loss_weight = 1.,
142
143
  pad_id = -1,
143
- encoder: Module | None = None,
144
144
  **kwargs
145
145
  ):
146
146
  super().__init__()
@@ -150,39 +150,40 @@ class FreeTransformer(Module):
150
150
 
151
151
  self.token_unembed = nn.Linear(dim, num_tokens, bias = False)
152
152
 
153
- if not exists(encoder):
154
- encoder = Encoder(
155
- dim = dim,
156
- depth = enc_depth,
157
- attn_dim_head = attn_dim_head,
158
- heads = heads,
159
- **kwargs,
160
- **enc_kwargs
161
- )
153
+ self.query_token_for_latents = nn.Parameter(torch.randn(dim) * 1e-2)
162
154
 
163
- self.encoder = encoder
155
+ self.per_token_latents = per_token_latents
164
156
 
165
- self.to_latent_bit_logits = nn.Sequential(
166
- Reduce('b n d -> b d', 'mean'),
167
- nn.Linear(dim, latent_bits, bias = False),
157
+ self.encoder = Encoder(
158
+ dim = dim,
159
+ depth = enc_depth,
160
+ attn_dim_head = attn_dim_head,
161
+ heads = heads,
162
+ only_cross = True,
163
+ cross_attend = True,
164
+ use_rmsnorm = True,
165
+ rotary_pos_emb = True,
166
+ **kwargs,
167
+ **enc_kwargs
168
168
  )
169
169
 
170
+ self.to_latent_bit_logits = nn.Linear(dim, latent_bits, bias = False)
171
+
170
172
  self.binary_mapper = BinaryMapper(
171
173
  latent_bits,
172
174
  kl_loss_threshold,
173
175
  **binary_mapper_kwargs
174
176
  )
175
177
 
176
- self.from_latent_to_condition = nn.Sequential(
177
- nn.Linear(2 ** latent_bits, dim, bias = False),
178
- Rearrange('b d -> b 1 d')
179
- )
178
+ self.from_latent_to_condition = nn.Linear(self.binary_mapper.num_codes, dim, bias = False)
180
179
 
181
180
  self.decoder_head = Decoder(
182
181
  dim = dim,
183
182
  depth = dec_head_depth,
184
183
  attn_dim_head = attn_dim_head,
185
184
  heads = heads,
185
+ rotary_pos_emb = True,
186
+ use_rmsnorm = True,
186
187
  pre_norm_has_final_norm = False,
187
188
  **kwargs,
188
189
  **dec_kwargs
@@ -193,6 +194,8 @@ class FreeTransformer(Module):
193
194
  depth = dec_tail_depth,
194
195
  attn_dim_head = attn_dim_head,
195
196
  heads = heads,
197
+ rotary_pos_emb = True,
198
+ use_rmsnorm = True,
196
199
  pre_norm_has_final_norm = True,
197
200
  **kwargs,
198
201
  **dec_kwargs
@@ -208,11 +211,34 @@ class FreeTransformer(Module):
208
211
 
209
212
  def encode_to_latents(
210
213
  self,
211
- seq,
214
+ decoder_head_embeds,
212
215
  mask = None,
213
216
  return_kl_loss = False
214
217
  ):
215
- pooled = self.encoder(seq, mask = mask)
218
+ batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
219
+
220
+ query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
221
+
222
+ encoder_kwargs = dict()
223
+
224
+ # handle the interesting per query token latents, as in the paper
225
+
226
+ if self.per_token_latents:
227
+ query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
228
+
229
+ rotary_pos = torch.arange(seq_len, device = device)
230
+
231
+ encoder_kwargs.update(
232
+ pos = rotary_pos,
233
+ context_pos = rotary_pos
234
+ )
235
+
236
+ pooled = self.encoder(
237
+ query_tokens,
238
+ context = decoder_head_embeds,
239
+ context_mask = mask,
240
+ **encoder_kwargs
241
+ )
216
242
 
217
243
  bit_logits = self.to_latent_bit_logits(pooled)
218
244
 
File without changes