x-transformers 2.6.6__tar.gz → 2.6.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.6.6 → x_transformers-2.6.7}/PKG-INFO +1 -1
  2. {x_transformers-2.6.6 → x_transformers-2.6.7}/pyproject.toml +1 -1
  3. {x_transformers-2.6.6 → x_transformers-2.6.7}/tests/test_x_transformers.py +35 -0
  4. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/attend.py +1 -1
  5. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/x_transformers.py +30 -2
  6. {x_transformers-2.6.6 → x_transformers-2.6.7}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.6.6 → x_transformers-2.6.7}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.6.6 → x_transformers-2.6.7}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.6.6 → x_transformers-2.6.7}/.gitignore +0 -0
  10. {x_transformers-2.6.6 → x_transformers-2.6.7}/LICENSE +0 -0
  11. {x_transformers-2.6.6 → x_transformers-2.6.7}/README.md +0 -0
  12. {x_transformers-2.6.6 → x_transformers-2.6.7}/data/README.md +0 -0
  13. {x_transformers-2.6.6 → x_transformers-2.6.7}/data/enwik8.gz +0 -0
  14. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/all-attention.png +0 -0
  15. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/deepnorm.png +0 -0
  18. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/fcm.png +0 -0
  24. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/ffglu.png +0 -0
  25. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/flash-attention.png +0 -0
  26. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/gate_values.png +0 -0
  27. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/gating.png +0 -0
  28. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/macaron-1.png +0 -0
  30. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/macaron-2.png +0 -0
  31. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/normformer.png +0 -0
  33. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/pia.png +0 -0
  34. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/resi_dual.png +0 -0
  36. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/residual_attn.png +0 -0
  37. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/rezero.png +0 -0
  38. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/rotary.png +0 -0
  39. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/sandwich.png +0 -0
  41. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/scalenorm.png +0 -0
  43. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/talking-heads.png +0 -0
  44. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/topk-attention.png +0 -0
  45. {x_transformers-2.6.6 → x_transformers-2.6.7}/images/xval.png +0 -0
  46. {x_transformers-2.6.6 → x_transformers-2.6.7}/train_belief_state.py +0 -0
  47. {x_transformers-2.6.6 → x_transformers-2.6.7}/train_copy.py +0 -0
  48. {x_transformers-2.6.6 → x_transformers-2.6.7}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.6.6 → x_transformers-2.6.7}/train_enwik8.py +0 -0
  50. {x_transformers-2.6.6 → x_transformers-2.6.7}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.6.6 → x_transformers-2.6.7}/train_parity.py +0 -0
  52. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/__init__.py +0 -0
  53. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.6
3
+ Version: 2.6.7
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.6.6"
3
+ version = "2.6.7"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1252,3 +1252,38 @@ def test_learned_head_attn_sink():
1252
1252
  seq = torch.randint(0, 20000, (3, 1024))
1253
1253
 
1254
1254
  logits = model(seq)
1255
+
1256
+ def test_accept_layer_intermediates():
1257
+ from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper
1258
+
1259
+ vlm = TransformerWrapper(
1260
+ num_tokens = 20000,
1261
+ max_seq_len = 1024,
1262
+ attn_layers = Decoder(
1263
+ dim = 512,
1264
+ depth = 3,
1265
+ heads = 4,
1266
+ )
1267
+ )
1268
+
1269
+ seq = torch.randint(0, 20000, (3, 1024))
1270
+ mask = torch.randint(0, 2, (3, 1024)).bool()
1271
+
1272
+ _, intermediates = vlm(seq, return_intermediates = True)
1273
+
1274
+ action_model = Decoder(
1275
+ dim = 512,
1276
+ depth = 6,
1277
+ heads = 8,
1278
+ )
1279
+
1280
+ seq = torch.randn(3, 32, 512)
1281
+
1282
+ embeds = action_model(
1283
+ seq,
1284
+ self_attn_additional_kv = intermediates,
1285
+ detach_additional_kv = True,
1286
+ additional_kv_mask = mask
1287
+ )
1288
+
1289
+ assert embeds.shape == (3, 32, 512)
@@ -23,7 +23,7 @@ class Intermediates:
23
23
  pre_softmax_attn: Tensor | None = None
24
24
  post_softmax_attn: Tensor | None = None
25
25
  values: Tensor | None = None
26
- cached_kv: Tuple[Tensor, Tensor] | None = None
26
+ cached_kv: tuple[Tensor, Tensor] | None = None
27
27
  layer_type: str | None = None
28
28
  hybrid_hidden: Tensor | None = None
29
29
 
@@ -10,7 +10,7 @@ import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
12
  from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
13
- from torch.utils._pytree import tree_flatten, tree_unflatten
13
+ from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
16
16
  from functools import partial, wraps
@@ -81,6 +81,9 @@ def cast_tuple(val, depth = 1):
81
81
  def divisible_by(num, den):
82
82
  return (num % den) == 0
83
83
 
84
+ def detach_all(obj):
85
+ return tree_map(lambda t: t.detach() if is_tensor(t) and t.requires_grad else t, obj)
86
+
84
87
  def maybe(fn = None):
85
88
  if not exists(fn):
86
89
  fn = identity
@@ -157,6 +160,19 @@ def or_reduce(masks):
157
160
  head = head | rest
158
161
  return head
159
162
 
163
+ # cache helpers
164
+
165
+ def get_cached_kvs(
166
+ cache: LayerIntermediates
167
+ ) -> list[tuple[Tensor, Tensor]]:
168
+
169
+ cached_kvs = []
170
+
171
+ for attn_intermediate in cache.attn_intermediates:
172
+ cached_kvs.append(attn_intermediate.cached_kv)
173
+
174
+ return cached_kvs
175
+
160
176
  # entropy
161
177
 
162
178
  def calc_entropy(
@@ -2441,8 +2457,13 @@ class AttentionLayers(Module):
2441
2457
  context_pos = None,
2442
2458
  attn_bias = None,
2443
2459
  deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2444
- self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2460
+ self_attn_additional_kv: (
2461
+ LayerIntermediates |
2462
+ list[tuple[Tensor, Tensor]]
2463
+ | None
2464
+ ) = None,
2445
2465
  additional_kv_mask = None,
2466
+ detach_additional_kv = False,
2446
2467
  route_additional_kv_to_top = True,
2447
2468
  condition = None,
2448
2469
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
@@ -2590,6 +2611,13 @@ class AttentionLayers(Module):
2590
2611
  # additional self attn key / values - say coming from vlm
2591
2612
 
2592
2613
  if exists(self_attn_additional_kv) and route_additional_kv_to_top:
2614
+
2615
+ if isinstance(self_attn_additional_kv, LayerIntermediates):
2616
+ self_attn_additional_kv = get_cached_kvs(self_attn_additional_kv)
2617
+
2618
+ if detach_additional_kv:
2619
+ self_attn_additional_kv = detach_all(self_attn_additional_kv)
2620
+
2593
2621
  num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
2594
2622
 
2595
2623
  self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
File without changes
File without changes