x-transformers 2.5.6__tar.gz → 2.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.5.6 → x_transformers-2.6.1}/PKG-INFO +1 -1
  2. {x_transformers-2.5.6 → x_transformers-2.6.1}/pyproject.toml +1 -1
  3. {x_transformers-2.5.6 → x_transformers-2.6.1}/tests/test_x_transformers.py +25 -0
  4. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/x_transformers.py +30 -2
  5. {x_transformers-2.5.6 → x_transformers-2.6.1}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.5.6 → x_transformers-2.6.1}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.5.6 → x_transformers-2.6.1}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.5.6 → x_transformers-2.6.1}/.gitignore +0 -0
  9. {x_transformers-2.5.6 → x_transformers-2.6.1}/LICENSE +0 -0
  10. {x_transformers-2.5.6 → x_transformers-2.6.1}/README.md +0 -0
  11. {x_transformers-2.5.6 → x_transformers-2.6.1}/data/README.md +0 -0
  12. {x_transformers-2.5.6 → x_transformers-2.6.1}/data/enwik8.gz +0 -0
  13. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/all-attention.png +0 -0
  14. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/deepnorm.png +0 -0
  17. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/fcm.png +0 -0
  23. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/ffglu.png +0 -0
  24. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/flash-attention.png +0 -0
  25. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/gate_values.png +0 -0
  26. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/gating.png +0 -0
  27. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/macaron-1.png +0 -0
  29. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/macaron-2.png +0 -0
  30. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/normformer.png +0 -0
  32. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/pia.png +0 -0
  33. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/resi_dual.png +0 -0
  35. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/residual_attn.png +0 -0
  36. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/rezero.png +0 -0
  37. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/rotary.png +0 -0
  38. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/sandwich.png +0 -0
  40. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/scalenorm.png +0 -0
  42. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/talking-heads.png +0 -0
  43. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/topk-attention.png +0 -0
  44. {x_transformers-2.5.6 → x_transformers-2.6.1}/images/xval.png +0 -0
  45. {x_transformers-2.5.6 → x_transformers-2.6.1}/train_belief_state.py +0 -0
  46. {x_transformers-2.5.6 → x_transformers-2.6.1}/train_copy.py +0 -0
  47. {x_transformers-2.5.6 → x_transformers-2.6.1}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.5.6 → x_transformers-2.6.1}/train_enwik8.py +0 -0
  49. {x_transformers-2.5.6 → x_transformers-2.6.1}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.5.6 → x_transformers-2.6.1}/train_parity.py +0 -0
  51. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.5.6
3
+ Version: 2.6.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.5.6"
3
+ version = "2.6.1"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1210,3 +1210,28 @@ def test_prompts_given_as_list_tensor():
1210
1210
  ], 256)
1211
1211
 
1212
1212
  assert sampled.shape == (4, 256)
1213
+
1214
+ def test_external_key_values():
1215
+ from x_transformers import AutoregressiveWrapper
1216
+
1217
+ model = TransformerWrapper(
1218
+ num_tokens = 20000,
1219
+ max_seq_len = 1024,
1220
+ attn_layers = Decoder(
1221
+ dim = 512,
1222
+ depth = 2,
1223
+ heads = 8,
1224
+ attn_dim_head = 16
1225
+ )
1226
+ )
1227
+
1228
+ seq = torch.randint(0, 20000, (3, 1024))
1229
+
1230
+ key_values = [
1231
+ (torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
1232
+ (torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
1233
+ ]
1234
+
1235
+ additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()
1236
+
1237
+ logits = model(seq, self_attn_additional_kv = key_values, additional_kv_mask = additional_kv_mask)
@@ -1617,7 +1617,9 @@ class Attention(Module):
1617
1617
  mem_mask = None,
1618
1618
  return_intermediates = False,
1619
1619
  cache: Intermediates | None = None,
1620
- value_residual = None
1620
+ value_residual = None,
1621
+ additional_key_values: tuple[Tensor, Tensor] | None = None,
1622
+ additional_key_value_mask = None,
1621
1623
  ):
1622
1624
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
1623
1625
 
@@ -1787,6 +1789,26 @@ class Attention(Module):
1787
1789
  if exists(input_mask):
1788
1790
  input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1789
1791
 
1792
+ # maybe append additional key / values
1793
+
1794
+ if exists(additional_key_values):
1795
+ seq_len = k.shape[-2]
1796
+
1797
+ added_k, added_v = additional_key_values
1798
+
1799
+ k = cat((added_k, k), dim = -2)
1800
+ v = cat((added_v, v), dim = -2)
1801
+
1802
+ if (exists(input_mask) or exists(additional_key_value_mask)):
1803
+
1804
+ if not exists(additional_key_value_mask):
1805
+ added_kv_len = added_k.shape[-2]
1806
+ input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1807
+ elif not exists(input_mask):
1808
+ input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
1809
+ else:
1810
+ input_mask = cat((additional_key_value_mask, input_mask), dim = -1)
1811
+
1790
1812
  # determine masking
1791
1813
 
1792
1814
  mask_value = max_neg_value(q)
@@ -2411,6 +2433,8 @@ class AttentionLayers(Module):
2411
2433
  context_pos = None,
2412
2434
  attn_bias = None,
2413
2435
  deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2436
+ self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2437
+ additional_kv_mask = None,
2414
2438
  condition = None,
2415
2439
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2416
2440
  layers_execute_order: tuple[int, ...] | None = None
@@ -2520,6 +2544,10 @@ class AttentionLayers(Module):
2520
2544
 
2521
2545
  iter_attn_cache = iter(attn_cache)
2522
2546
 
2547
+ # additional self attn key / values
2548
+
2549
+ iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2550
+
2523
2551
  # handle deep embeds if needed
2524
2552
 
2525
2553
  deep_embeds = []
@@ -2647,7 +2675,7 @@ class AttentionLayers(Module):
2647
2675
  # forward depending on layer type
2648
2676
 
2649
2677
  if layer_type == 'a':
2650
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2678
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2651
2679
  elif layer_type == 'c':
2652
2680
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2653
2681
  elif layer_type == 'f':
File without changes
File without changes