x-transformers 2.6.6__tar.gz → 2.6.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.6.6 → x_transformers-2.6.7}/PKG-INFO +1 -1
- {x_transformers-2.6.6 → x_transformers-2.6.7}/pyproject.toml +1 -1
- {x_transformers-2.6.6 → x_transformers-2.6.7}/tests/test_x_transformers.py +35 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/attend.py +1 -1
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/x_transformers.py +30 -2
- {x_transformers-2.6.6 → x_transformers-2.6.7}/.github/FUNDING.yml +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/.gitignore +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/LICENSE +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/README.md +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/data/README.md +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/data/enwik8.gz +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/all-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/attention-on-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/deepnorm.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/fcm.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/ffglu.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/flash-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/gate_values.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/gating.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/macaron-1.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/macaron-2.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/memory-transformer.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/normformer.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/pia.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/resi_dual.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/residual_attn.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/rezero.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/rotary.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/sandwich-2.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/sandwich.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/sandwich_norm.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/scalenorm.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/talking-heads.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/topk-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/images/xval.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/train_belief_state.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/train_copy.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/train_enwik8.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/train_length_extrapolate.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/train_parity.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/__init__.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/continuous.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/dpo.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.6.7}/x_transformers/xval.py +0 -0
@@ -1252,3 +1252,38 @@ def test_learned_head_attn_sink():
|
|
1252
1252
|
seq = torch.randint(0, 20000, (3, 1024))
|
1253
1253
|
|
1254
1254
|
logits = model(seq)
|
1255
|
+
|
1256
|
+
def test_accept_layer_intermediates():
|
1257
|
+
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper
|
1258
|
+
|
1259
|
+
vlm = TransformerWrapper(
|
1260
|
+
num_tokens = 20000,
|
1261
|
+
max_seq_len = 1024,
|
1262
|
+
attn_layers = Decoder(
|
1263
|
+
dim = 512,
|
1264
|
+
depth = 3,
|
1265
|
+
heads = 4,
|
1266
|
+
)
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
seq = torch.randint(0, 20000, (3, 1024))
|
1270
|
+
mask = torch.randint(0, 2, (3, 1024)).bool()
|
1271
|
+
|
1272
|
+
_, intermediates = vlm(seq, return_intermediates = True)
|
1273
|
+
|
1274
|
+
action_model = Decoder(
|
1275
|
+
dim = 512,
|
1276
|
+
depth = 6,
|
1277
|
+
heads = 8,
|
1278
|
+
)
|
1279
|
+
|
1280
|
+
seq = torch.randn(3, 32, 512)
|
1281
|
+
|
1282
|
+
embeds = action_model(
|
1283
|
+
seq,
|
1284
|
+
self_attn_additional_kv = intermediates,
|
1285
|
+
detach_additional_kv = True,
|
1286
|
+
additional_kv_mask = mask
|
1287
|
+
)
|
1288
|
+
|
1289
|
+
assert embeds.shape == (3, 32, 512)
|
@@ -23,7 +23,7 @@ class Intermediates:
|
|
23
23
|
pre_softmax_attn: Tensor | None = None
|
24
24
|
post_softmax_attn: Tensor | None = None
|
25
25
|
values: Tensor | None = None
|
26
|
-
cached_kv:
|
26
|
+
cached_kv: tuple[Tensor, Tensor] | None = None
|
27
27
|
layer_type: str | None = None
|
28
28
|
hybrid_hidden: Tensor | None = None
|
29
29
|
|
@@ -10,7 +10,7 @@ import torch
|
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
|
13
|
-
from torch.utils._pytree import tree_flatten, tree_unflatten
|
13
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
16
16
|
from functools import partial, wraps
|
@@ -81,6 +81,9 @@ def cast_tuple(val, depth = 1):
|
|
81
81
|
def divisible_by(num, den):
|
82
82
|
return (num % den) == 0
|
83
83
|
|
84
|
+
def detach_all(obj):
|
85
|
+
return tree_map(lambda t: t.detach() if is_tensor(t) and t.requires_grad else t, obj)
|
86
|
+
|
84
87
|
def maybe(fn = None):
|
85
88
|
if not exists(fn):
|
86
89
|
fn = identity
|
@@ -157,6 +160,19 @@ def or_reduce(masks):
|
|
157
160
|
head = head | rest
|
158
161
|
return head
|
159
162
|
|
163
|
+
# cache helpers
|
164
|
+
|
165
|
+
def get_cached_kvs(
|
166
|
+
cache: LayerIntermediates
|
167
|
+
) -> list[tuple[Tensor, Tensor]]:
|
168
|
+
|
169
|
+
cached_kvs = []
|
170
|
+
|
171
|
+
for attn_intermediate in cache.attn_intermediates:
|
172
|
+
cached_kvs.append(attn_intermediate.cached_kv)
|
173
|
+
|
174
|
+
return cached_kvs
|
175
|
+
|
160
176
|
# entropy
|
161
177
|
|
162
178
|
def calc_entropy(
|
@@ -2441,8 +2457,13 @@ class AttentionLayers(Module):
|
|
2441
2457
|
context_pos = None,
|
2442
2458
|
attn_bias = None,
|
2443
2459
|
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
|
2444
|
-
self_attn_additional_kv:
|
2460
|
+
self_attn_additional_kv: (
|
2461
|
+
LayerIntermediates |
|
2462
|
+
list[tuple[Tensor, Tensor]]
|
2463
|
+
| None
|
2464
|
+
) = None,
|
2445
2465
|
additional_kv_mask = None,
|
2466
|
+
detach_additional_kv = False,
|
2446
2467
|
route_additional_kv_to_top = True,
|
2447
2468
|
condition = None,
|
2448
2469
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
@@ -2590,6 +2611,13 @@ class AttentionLayers(Module):
|
|
2590
2611
|
# additional self attn key / values - say coming from vlm
|
2591
2612
|
|
2592
2613
|
if exists(self_attn_additional_kv) and route_additional_kv_to_top:
|
2614
|
+
|
2615
|
+
if isinstance(self_attn_additional_kv, LayerIntermediates):
|
2616
|
+
self_attn_additional_kv = get_cached_kvs(self_attn_additional_kv)
|
2617
|
+
|
2618
|
+
if detach_additional_kv:
|
2619
|
+
self_attn_additional_kv = detach_all(self_attn_additional_kv)
|
2620
|
+
|
2593
2621
|
num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
|
2594
2622
|
|
2595
2623
|
self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|