x-transformers 2.6.0__tar.gz → 2.6.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.6.0 → x_transformers-2.6.2}/PKG-INFO +1 -1
  2. {x_transformers-2.6.0 → x_transformers-2.6.2}/pyproject.toml +1 -1
  3. {x_transformers-2.6.0 → x_transformers-2.6.2}/tests/test_x_transformers.py +4 -2
  4. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/x_transformers.py +25 -9
  5. {x_transformers-2.6.0 → x_transformers-2.6.2}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.6.0 → x_transformers-2.6.2}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.6.0 → x_transformers-2.6.2}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.6.0 → x_transformers-2.6.2}/.gitignore +0 -0
  9. {x_transformers-2.6.0 → x_transformers-2.6.2}/LICENSE +0 -0
  10. {x_transformers-2.6.0 → x_transformers-2.6.2}/README.md +0 -0
  11. {x_transformers-2.6.0 → x_transformers-2.6.2}/data/README.md +0 -0
  12. {x_transformers-2.6.0 → x_transformers-2.6.2}/data/enwik8.gz +0 -0
  13. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/all-attention.png +0 -0
  14. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/deepnorm.png +0 -0
  17. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/fcm.png +0 -0
  23. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/ffglu.png +0 -0
  24. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/flash-attention.png +0 -0
  25. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/gate_values.png +0 -0
  26. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/gating.png +0 -0
  27. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/macaron-1.png +0 -0
  29. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/macaron-2.png +0 -0
  30. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/normformer.png +0 -0
  32. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/pia.png +0 -0
  33. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/resi_dual.png +0 -0
  35. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/residual_attn.png +0 -0
  36. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/rezero.png +0 -0
  37. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/rotary.png +0 -0
  38. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/sandwich.png +0 -0
  40. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/scalenorm.png +0 -0
  42. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/talking-heads.png +0 -0
  43. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/topk-attention.png +0 -0
  44. {x_transformers-2.6.0 → x_transformers-2.6.2}/images/xval.png +0 -0
  45. {x_transformers-2.6.0 → x_transformers-2.6.2}/train_belief_state.py +0 -0
  46. {x_transformers-2.6.0 → x_transformers-2.6.2}/train_copy.py +0 -0
  47. {x_transformers-2.6.0 → x_transformers-2.6.2}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.6.0 → x_transformers-2.6.2}/train_enwik8.py +0 -0
  49. {x_transformers-2.6.0 → x_transformers-2.6.2}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.6.0 → x_transformers-2.6.2}/train_parity.py +0 -0
  51. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.6.0 → x_transformers-2.6.2}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.0
3
+ Version: 2.6.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.6.0"
3
+ version = "2.6.2"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1219,7 +1219,7 @@ def test_external_key_values():
1219
1219
  max_seq_len = 1024,
1220
1220
  attn_layers = Decoder(
1221
1221
  dim = 512,
1222
- depth = 2,
1222
+ depth = 3,
1223
1223
  heads = 8,
1224
1224
  attn_dim_head = 16
1225
1225
  )
@@ -1232,4 +1232,6 @@ def test_external_key_values():
1232
1232
  (torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
1233
1233
  ]
1234
1234
 
1235
- logits = model(seq, self_attn_additional_kv = key_values)
1235
+ additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()
1236
+
1237
+ logits = model(seq, self_attn_additional_kv = key_values, additional_kv_mask = additional_kv_mask)
@@ -1618,7 +1618,8 @@ class Attention(Module):
1618
1618
  return_intermediates = False,
1619
1619
  cache: Intermediates | None = None,
1620
1620
  value_residual = None,
1621
- additional_key_values: tuple[Tensor, Tensor] | None = None
1621
+ additional_key_values: tuple[Tensor, Tensor] | None = None,
1622
+ additional_key_value_mask = None,
1622
1623
  ):
1623
1624
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
1624
1625
 
@@ -1791,15 +1792,22 @@ class Attention(Module):
1791
1792
  # maybe append additional key / values
1792
1793
 
1793
1794
  if exists(additional_key_values):
1795
+ seq_len = k.shape[-2]
1794
1796
 
1795
1797
  added_k, added_v = additional_key_values
1796
- added_kv_len = added_k.shape[-2]
1797
1798
 
1798
1799
  k = cat((added_k, k), dim = -2)
1799
1800
  v = cat((added_v, v), dim = -2)
1800
1801
 
1801
- if exists(input_mask):
1802
- input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1802
+ if (exists(input_mask) or exists(additional_key_value_mask)):
1803
+
1804
+ if not exists(additional_key_value_mask):
1805
+ added_kv_len = added_k.shape[-2]
1806
+ input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1807
+ elif not exists(input_mask):
1808
+ input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
1809
+ else:
1810
+ input_mask = cat((additional_key_value_mask, input_mask), dim = -1)
1803
1811
 
1804
1812
  # determine masking
1805
1813
 
@@ -2426,6 +2434,8 @@ class AttentionLayers(Module):
2426
2434
  attn_bias = None,
2427
2435
  deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2428
2436
  self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2437
+ additional_kv_mask = None,
2438
+ route_additional_kv_to_top = True,
2429
2439
  condition = None,
2430
2440
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2431
2441
  layers_execute_order: tuple[int, ...] | None = None
@@ -2535,10 +2545,6 @@ class AttentionLayers(Module):
2535
2545
 
2536
2546
  iter_attn_cache = iter(attn_cache)
2537
2547
 
2538
- # additional self attn key / values
2539
-
2540
- iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2541
-
2542
2548
  # handle deep embeds if needed
2543
2549
 
2544
2550
  deep_embeds = []
@@ -2573,6 +2579,16 @@ class AttentionLayers(Module):
2573
2579
  layers_execute_order = default(layers_execute_order, self.layers_execute_order)
2574
2580
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2575
2581
 
2582
+ # additional self attn key / values - say coming from vlm
2583
+
2584
+ if exists(self_attn_additional_kv) and route_additional_kv_to_top:
2585
+ num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
2586
+
2587
+ self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
2588
+ self_attn_additional_kv = [None] * (num_self_attns - len(self_attn_additional_kv)) + self_attn_additional_kv
2589
+
2590
+ iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2591
+
2576
2592
  # derived input for reinjection if needed
2577
2593
 
2578
2594
  inp_inject = None
@@ -2666,7 +2682,7 @@ class AttentionLayers(Module):
2666
2682
  # forward depending on layer type
2667
2683
 
2668
2684
  if layer_type == 'a':
2669
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2685
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2670
2686
  elif layer_type == 'c':
2671
2687
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2672
2688
  elif layer_type == 'f':
File without changes
File without changes