x-transformers 2.11.8__tar.gz → 2.11.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.8 → x_transformers-2.11.10}/PKG-INFO +1 -1
  2. {x_transformers-2.11.8 → x_transformers-2.11.10}/pyproject.toml +1 -1
  3. {x_transformers-2.11.8 → x_transformers-2.11.10}/tests/test_x_transformers.py +4 -2
  4. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/free_transformer.py +31 -6
  5. {x_transformers-2.11.8 → x_transformers-2.11.10}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.11.8 → x_transformers-2.11.10}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.11.8 → x_transformers-2.11.10}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.11.8 → x_transformers-2.11.10}/.gitignore +0 -0
  9. {x_transformers-2.11.8 → x_transformers-2.11.10}/LICENSE +0 -0
  10. {x_transformers-2.11.8 → x_transformers-2.11.10}/README.md +0 -0
  11. {x_transformers-2.11.8 → x_transformers-2.11.10}/data/README.md +0 -0
  12. {x_transformers-2.11.8 → x_transformers-2.11.10}/data/enwik8.gz +0 -0
  13. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/all-attention.png +0 -0
  14. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/deepnorm.png +0 -0
  17. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/fcm.png +0 -0
  23. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/ffglu.png +0 -0
  24. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/flash-attention.png +0 -0
  25. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/gate_values.png +0 -0
  26. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/gating.png +0 -0
  27. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/macaron-1.png +0 -0
  29. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/macaron-2.png +0 -0
  30. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/normformer.png +0 -0
  32. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/pia.png +0 -0
  33. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/resi_dual.png +0 -0
  35. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/residual_attn.png +0 -0
  36. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/rezero.png +0 -0
  37. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/rotary.png +0 -0
  38. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/sandwich.png +0 -0
  40. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/scalenorm.png +0 -0
  42. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/talking-heads.png +0 -0
  43. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/topk-attention.png +0 -0
  44. {x_transformers-2.11.8 → x_transformers-2.11.10}/images/xval.png +0 -0
  45. {x_transformers-2.11.8 → x_transformers-2.11.10}/train_belief_state.py +0 -0
  46. {x_transformers-2.11.8 → x_transformers-2.11.10}/train_copy.py +0 -0
  47. {x_transformers-2.11.8 → x_transformers-2.11.10}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.11.8 → x_transformers-2.11.10}/train_enwik8.py +0 -0
  49. {x_transformers-2.11.8 → x_transformers-2.11.10}/train_free.py +0 -0
  50. {x_transformers-2.11.8 → x_transformers-2.11.10}/train_gpt_vae.py +0 -0
  51. {x_transformers-2.11.8 → x_transformers-2.11.10}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.11.8 → x_transformers-2.11.10}/train_parity.py +0 -0
  53. {x_transformers-2.11.8 → x_transformers-2.11.10}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/gpt_vae.py +0 -0
  62. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/multi_input.py +0 -0
  63. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/neo_mlp.py +0 -0
  64. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/nonautoregressive_wrapper.py +0 -0
  65. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/up_wrapper.py +0 -0
  66. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/x_transformers.py +0 -0
  67. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.8
3
+ Version: 2.11.10
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.8"
3
+ version = "2.11.10"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1410,7 +1410,9 @@ def test_attn_negative_weights(
1410
1410
  logits = model(x)
1411
1411
 
1412
1412
  @param('per_token_latents', (False, True))
1413
+ @param('dec_head_depth', (0, 4))
1413
1414
  def test_free(
1415
+ dec_head_depth,
1414
1416
  per_token_latents
1415
1417
  ):
1416
1418
  from x_transformers.free_transformer import FreeTransformer
@@ -1420,9 +1422,9 @@ def test_free(
1420
1422
  max_seq_len = 1024,
1421
1423
  dim = 512,
1422
1424
  heads = 8,
1423
- dec_head_depth = 4,
1425
+ dec_head_depth = dec_head_depth,
1424
1426
  dec_tail_depth = 4,
1425
- enc_depth = 3,
1427
+ enc_depth = 2,
1426
1428
  kl_loss_weight = 1.,
1427
1429
  per_token_latents = per_token_latents,
1428
1430
  latent_bits = 8
@@ -197,7 +197,9 @@ class FreeTransformer(Module):
197
197
  pre_norm_has_final_norm = False,
198
198
  **kwargs,
199
199
  **dec_kwargs
200
- )
200
+ ) if dec_head_depth > 0 else None
201
+
202
+ assert dec_tail_depth > 0
201
203
 
202
204
  self.decoder_tail = Decoder(
203
205
  dim = dim,
@@ -266,7 +268,8 @@ class FreeTransformer(Module):
266
268
  seq_len,
267
269
  latents = None,
268
270
  filter_logits_fn = top_p,
269
- logit_filter_kwargs: dict = dict(thres = 0.9)
271
+ logit_filter_kwargs: dict = dict(thres = 0.9),
272
+ use_kv_cache = True
270
273
  ):
271
274
  prompts, inverse_pack = pack_with_inverse(prompts, '* n')
272
275
 
@@ -280,10 +283,16 @@ class FreeTransformer(Module):
280
283
  latents = tensor(latents, device = self.device)
281
284
 
282
285
  if latents.ndim == 1: # repeat latents
283
- latents = repeat(latents, 'd -> b d', b = batch)
286
+ latents = repeat(latents, 'd -> b 1 d', b = batch)
287
+ elif latents.ndim == 2:
288
+ latents = rearrange(latents, 'b d -> b 1 d')
284
289
 
285
290
  condition = self.from_latent_to_condition(latents)
286
291
 
292
+ # kv cache
293
+
294
+ head_cache = tail_cache = None
295
+
287
296
  # generated
288
297
 
289
298
  prompt_len = prompts.shape[-1]
@@ -294,9 +303,20 @@ class FreeTransformer(Module):
294
303
 
295
304
  for _ in range(max(0, seq_len - prompt_len)):
296
305
 
297
- head_embed = self.decoder_head(tokens)
306
+ # head, which may not exist
307
+
308
+ if exists(self.decoder_head):
309
+ head_embed, next_head_cache = self.decoder_head(tokens, cache = head_cache, return_hiddens = True)
310
+ else:
311
+ head_embed, next_head_cache = tokens, None
298
312
 
299
- tail_embed = self.decoder_tail(head_embed, self_attn_kv_residuals = condition)
313
+ # handle one token being given to the decoder tail when doing kv caching - rotary embedding needs to know the seq position offset
314
+
315
+ seq_pos_offset = head_cache.cache_length if exists(head_cache) else 0
316
+
317
+ # tail
318
+
319
+ tail_embed, next_tail_cache = self.decoder_tail(head_embed, cache = tail_cache, seq_pos_offset = seq_pos_offset, self_attn_kv_residuals = condition, return_hiddens = True)
300
320
 
301
321
  tail_embed = tail_embed[:, -1]
302
322
 
@@ -309,6 +329,10 @@ class FreeTransformer(Module):
309
329
  generated, _ = pack((generated, sampled), 'b *')
310
330
  tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
311
331
 
332
+ if use_kv_cache:
333
+ head_cache = next_head_cache
334
+ tail_cache = next_tail_cache
335
+
312
336
  return inverse_pack(generated)
313
337
 
314
338
  def forward(
@@ -326,7 +350,8 @@ class FreeTransformer(Module):
326
350
 
327
351
  # decoder head
328
352
 
329
- tokens = self.decoder_head(tokens)
353
+ if exists(self.decoder_head):
354
+ tokens = self.decoder_head(tokens)
330
355
 
331
356
  # get latent Z
332
357