x-transformers 2.5.6__tar.gz → 2.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.5.6 → x_transformers-2.6.0}/PKG-INFO +1 -1
  2. {x_transformers-2.5.6 → x_transformers-2.6.0}/pyproject.toml +1 -1
  3. {x_transformers-2.5.6 → x_transformers-2.6.0}/tests/test_x_transformers.py +23 -0
  4. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/x_transformers.py +21 -2
  5. {x_transformers-2.5.6 → x_transformers-2.6.0}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.5.6 → x_transformers-2.6.0}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.5.6 → x_transformers-2.6.0}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.5.6 → x_transformers-2.6.0}/.gitignore +0 -0
  9. {x_transformers-2.5.6 → x_transformers-2.6.0}/LICENSE +0 -0
  10. {x_transformers-2.5.6 → x_transformers-2.6.0}/README.md +0 -0
  11. {x_transformers-2.5.6 → x_transformers-2.6.0}/data/README.md +0 -0
  12. {x_transformers-2.5.6 → x_transformers-2.6.0}/data/enwik8.gz +0 -0
  13. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/all-attention.png +0 -0
  14. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/deepnorm.png +0 -0
  17. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/fcm.png +0 -0
  23. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/ffglu.png +0 -0
  24. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/flash-attention.png +0 -0
  25. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/gate_values.png +0 -0
  26. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/gating.png +0 -0
  27. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/macaron-1.png +0 -0
  29. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/macaron-2.png +0 -0
  30. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/normformer.png +0 -0
  32. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/pia.png +0 -0
  33. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/resi_dual.png +0 -0
  35. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/residual_attn.png +0 -0
  36. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/rezero.png +0 -0
  37. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/rotary.png +0 -0
  38. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/sandwich.png +0 -0
  40. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/scalenorm.png +0 -0
  42. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/talking-heads.png +0 -0
  43. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/topk-attention.png +0 -0
  44. {x_transformers-2.5.6 → x_transformers-2.6.0}/images/xval.png +0 -0
  45. {x_transformers-2.5.6 → x_transformers-2.6.0}/train_belief_state.py +0 -0
  46. {x_transformers-2.5.6 → x_transformers-2.6.0}/train_copy.py +0 -0
  47. {x_transformers-2.5.6 → x_transformers-2.6.0}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.5.6 → x_transformers-2.6.0}/train_enwik8.py +0 -0
  49. {x_transformers-2.5.6 → x_transformers-2.6.0}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.5.6 → x_transformers-2.6.0}/train_parity.py +0 -0
  51. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.5.6 → x_transformers-2.6.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.5.6
3
+ Version: 2.6.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.5.6"
3
+ version = "2.6.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1210,3 +1210,26 @@ def test_prompts_given_as_list_tensor():
1210
1210
  ], 256)
1211
1211
 
1212
1212
  assert sampled.shape == (4, 256)
1213
+
1214
+ def test_external_key_values():
1215
+ from x_transformers import AutoregressiveWrapper
1216
+
1217
+ model = TransformerWrapper(
1218
+ num_tokens = 20000,
1219
+ max_seq_len = 1024,
1220
+ attn_layers = Decoder(
1221
+ dim = 512,
1222
+ depth = 2,
1223
+ heads = 8,
1224
+ attn_dim_head = 16
1225
+ )
1226
+ )
1227
+
1228
+ seq = torch.randint(0, 20000, (3, 1024))
1229
+
1230
+ key_values = [
1231
+ (torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
1232
+ (torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
1233
+ ]
1234
+
1235
+ logits = model(seq, self_attn_additional_kv = key_values)
@@ -1617,7 +1617,8 @@ class Attention(Module):
1617
1617
  mem_mask = None,
1618
1618
  return_intermediates = False,
1619
1619
  cache: Intermediates | None = None,
1620
- value_residual = None
1620
+ value_residual = None,
1621
+ additional_key_values: tuple[Tensor, Tensor] | None = None
1621
1622
  ):
1622
1623
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
1623
1624
 
@@ -1787,6 +1788,19 @@ class Attention(Module):
1787
1788
  if exists(input_mask):
1788
1789
  input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1789
1790
 
1791
+ # maybe append additional key / values
1792
+
1793
+ if exists(additional_key_values):
1794
+
1795
+ added_k, added_v = additional_key_values
1796
+ added_kv_len = added_k.shape[-2]
1797
+
1798
+ k = cat((added_k, k), dim = -2)
1799
+ v = cat((added_v, v), dim = -2)
1800
+
1801
+ if exists(input_mask):
1802
+ input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1803
+
1790
1804
  # determine masking
1791
1805
 
1792
1806
  mask_value = max_neg_value(q)
@@ -2411,6 +2425,7 @@ class AttentionLayers(Module):
2411
2425
  context_pos = None,
2412
2426
  attn_bias = None,
2413
2427
  deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2428
+ self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2414
2429
  condition = None,
2415
2430
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2416
2431
  layers_execute_order: tuple[int, ...] | None = None
@@ -2520,6 +2535,10 @@ class AttentionLayers(Module):
2520
2535
 
2521
2536
  iter_attn_cache = iter(attn_cache)
2522
2537
 
2538
+ # additional self attn key / values
2539
+
2540
+ iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2541
+
2523
2542
  # handle deep embeds if needed
2524
2543
 
2525
2544
  deep_embeds = []
@@ -2647,7 +2666,7 @@ class AttentionLayers(Module):
2647
2666
  # forward depending on layer type
2648
2667
 
2649
2668
  if layer_type == 'a':
2650
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2669
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2651
2670
  elif layer_type == 'c':
2652
2671
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2653
2672
  elif layer_type == 'f':
File without changes
File without changes