x-transformers 2.11.9__tar.gz → 2.11.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.9 → x_transformers-2.11.11}/PKG-INFO +1 -1
  2. {x_transformers-2.11.9 → x_transformers-2.11.11}/pyproject.toml +1 -1
  3. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/free_transformer.py +29 -6
  4. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/x_transformers.py +1 -0
  5. {x_transformers-2.11.9 → x_transformers-2.11.11}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.11.9 → x_transformers-2.11.11}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.11.9 → x_transformers-2.11.11}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.11.9 → x_transformers-2.11.11}/.gitignore +0 -0
  9. {x_transformers-2.11.9 → x_transformers-2.11.11}/LICENSE +0 -0
  10. {x_transformers-2.11.9 → x_transformers-2.11.11}/README.md +0 -0
  11. {x_transformers-2.11.9 → x_transformers-2.11.11}/data/README.md +0 -0
  12. {x_transformers-2.11.9 → x_transformers-2.11.11}/data/enwik8.gz +0 -0
  13. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/all-attention.png +0 -0
  14. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/deepnorm.png +0 -0
  17. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/fcm.png +0 -0
  23. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/ffglu.png +0 -0
  24. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/flash-attention.png +0 -0
  25. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/gate_values.png +0 -0
  26. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/gating.png +0 -0
  27. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/macaron-1.png +0 -0
  29. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/macaron-2.png +0 -0
  30. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/normformer.png +0 -0
  32. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/pia.png +0 -0
  33. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/resi_dual.png +0 -0
  35. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/residual_attn.png +0 -0
  36. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/rezero.png +0 -0
  37. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/rotary.png +0 -0
  38. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/sandwich.png +0 -0
  40. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/scalenorm.png +0 -0
  42. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/talking-heads.png +0 -0
  43. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/topk-attention.png +0 -0
  44. {x_transformers-2.11.9 → x_transformers-2.11.11}/images/xval.png +0 -0
  45. {x_transformers-2.11.9 → x_transformers-2.11.11}/tests/test_x_transformers.py +0 -0
  46. {x_transformers-2.11.9 → x_transformers-2.11.11}/train_belief_state.py +0 -0
  47. {x_transformers-2.11.9 → x_transformers-2.11.11}/train_copy.py +0 -0
  48. {x_transformers-2.11.9 → x_transformers-2.11.11}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.11.9 → x_transformers-2.11.11}/train_enwik8.py +0 -0
  50. {x_transformers-2.11.9 → x_transformers-2.11.11}/train_free.py +0 -0
  51. {x_transformers-2.11.9 → x_transformers-2.11.11}/train_gpt_vae.py +0 -0
  52. {x_transformers-2.11.9 → x_transformers-2.11.11}/train_length_extrapolate.py +0 -0
  53. {x_transformers-2.11.9 → x_transformers-2.11.11}/train_parity.py +0 -0
  54. {x_transformers-2.11.9 → x_transformers-2.11.11}/train_with_muon.py +0 -0
  55. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/__init__.py +0 -0
  56. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/attend.py +0 -0
  57. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/autoregressive_wrapper.py +0 -0
  58. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/belief_state_wrapper.py +0 -0
  59. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/continuous.py +0 -0
  60. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/dpo.py +0 -0
  61. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/entropy_based_tokenizer.py +0 -0
  62. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/gpt_vae.py +0 -0
  63. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/multi_input.py +0 -0
  64. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/neo_mlp.py +0 -0
  65. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/nonautoregressive_wrapper.py +0 -0
  66. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/up_wrapper.py +0 -0
  67. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.9 → x_transformers-2.11.11}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.9
3
+ Version: 2.11.11
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.9"
3
+ version = "2.11.11"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -197,7 +197,7 @@ class FreeTransformer(Module):
197
197
  pre_norm_has_final_norm = False,
198
198
  **kwargs,
199
199
  **dec_kwargs
200
- ) if dec_head_depth > 0 else nn.Identity()
200
+ ) if dec_head_depth > 0 else None
201
201
 
202
202
  assert dec_tail_depth > 0
203
203
 
@@ -268,7 +268,8 @@ class FreeTransformer(Module):
268
268
  seq_len,
269
269
  latents = None,
270
270
  filter_logits_fn = top_p,
271
- logit_filter_kwargs: dict = dict(thres = 0.9)
271
+ logit_filter_kwargs: dict = dict(thres = 0.9),
272
+ use_kv_cache = True
272
273
  ):
273
274
  prompts, inverse_pack = pack_with_inverse(prompts, '* n')
274
275
 
@@ -282,10 +283,16 @@ class FreeTransformer(Module):
282
283
  latents = tensor(latents, device = self.device)
283
284
 
284
285
  if latents.ndim == 1: # repeat latents
285
- latents = repeat(latents, 'd -> b d', b = batch)
286
+ latents = repeat(latents, 'd -> b 1 d', b = batch)
287
+ elif latents.ndim == 2:
288
+ latents = rearrange(latents, 'b d -> b 1 d')
286
289
 
287
290
  condition = self.from_latent_to_condition(latents)
288
291
 
292
+ # kv cache
293
+
294
+ head_cache = tail_cache = None
295
+
289
296
  # generated
290
297
 
291
298
  prompt_len = prompts.shape[-1]
@@ -296,9 +303,20 @@ class FreeTransformer(Module):
296
303
 
297
304
  for _ in range(max(0, seq_len - prompt_len)):
298
305
 
299
- head_embed = self.decoder_head(tokens)
306
+ # head, which may not exist
307
+
308
+ if exists(self.decoder_head):
309
+ head_embed, next_head_cache = self.decoder_head(tokens, cache = head_cache, return_hiddens = True)
310
+ else:
311
+ head_embed, next_head_cache = tokens, None
312
+
313
+ # handle one token being given to the decoder tail when doing kv caching - rotary embedding needs to know the seq position offset
300
314
 
301
- tail_embed = self.decoder_tail(head_embed, self_attn_kv_residuals = condition)
315
+ seq_pos_offset = head_cache.cache_length if exists(head_cache) else 0
316
+
317
+ # tail
318
+
319
+ tail_embed, next_tail_cache = self.decoder_tail(head_embed, cache = tail_cache, seq_pos_offset = seq_pos_offset, self_attn_kv_residuals = condition, return_hiddens = True)
302
320
 
303
321
  tail_embed = tail_embed[:, -1]
304
322
 
@@ -311,6 +329,10 @@ class FreeTransformer(Module):
311
329
  generated, _ = pack((generated, sampled), 'b *')
312
330
  tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
313
331
 
332
+ if use_kv_cache:
333
+ head_cache = next_head_cache
334
+ tail_cache = next_tail_cache
335
+
314
336
  return inverse_pack(generated)
315
337
 
316
338
  def forward(
@@ -328,7 +350,8 @@ class FreeTransformer(Module):
328
350
 
329
351
  # decoder head
330
352
 
331
- tokens = self.decoder_head(tokens)
353
+ if exists(self.decoder_head):
354
+ tokens = self.decoder_head(tokens)
332
355
 
333
356
  # get latent Z
334
357
 
@@ -3438,6 +3438,7 @@ class TransformerWrapper(Module):
3438
3438
 
3439
3439
  kwargs = dict(
3440
3440
  **kwargs,
3441
+ pos = pos,
3441
3442
  seq_pos_offset = seq_pos_offset,
3442
3443
  seq_start_pos = seq_start_pos,
3443
3444
  input_not_include_cache = input_not_include_cache