x-transformers 2.6.5__tar.gz → 2.6.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.6.5 → x_transformers-2.6.7}/PKG-INFO +1 -1
- {x_transformers-2.6.5 → x_transformers-2.6.7}/pyproject.toml +1 -1
- {x_transformers-2.6.5 → x_transformers-2.6.7}/tests/test_x_transformers.py +36 -1
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/attend.py +9 -10
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/x_transformers.py +32 -4
- {x_transformers-2.6.5 → x_transformers-2.6.7}/.github/FUNDING.yml +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/.gitignore +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/LICENSE +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/README.md +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/data/README.md +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/data/enwik8.gz +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/all-attention.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/attention-on-attention.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/deepnorm.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/fcm.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/ffglu.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/flash-attention.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/gate_values.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/gating.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/macaron-1.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/macaron-2.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/memory-transformer.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/normformer.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/pia.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/resi_dual.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/residual_attn.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/rezero.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/rotary.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/sandwich-2.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/sandwich.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/sandwich_norm.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/scalenorm.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/talking-heads.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/topk-attention.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/images/xval.png +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/train_belief_state.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/train_copy.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/train_enwik8.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/train_length_extrapolate.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/train_parity.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/__init__.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/continuous.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/dpo.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.5 → x_transformers-2.6.7}/x_transformers/xval.py +0 -0
@@ -1245,10 +1245,45 @@ def test_learned_head_attn_sink():
|
|
1245
1245
|
dim = 512,
|
1246
1246
|
depth = 12,
|
1247
1247
|
heads = 8,
|
1248
|
-
|
1248
|
+
attn_head_learned_sink = True
|
1249
1249
|
)
|
1250
1250
|
)
|
1251
1251
|
|
1252
1252
|
seq = torch.randint(0, 20000, (3, 1024))
|
1253
1253
|
|
1254
1254
|
logits = model(seq)
|
1255
|
+
|
1256
|
+
def test_accept_layer_intermediates():
|
1257
|
+
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper
|
1258
|
+
|
1259
|
+
vlm = TransformerWrapper(
|
1260
|
+
num_tokens = 20000,
|
1261
|
+
max_seq_len = 1024,
|
1262
|
+
attn_layers = Decoder(
|
1263
|
+
dim = 512,
|
1264
|
+
depth = 3,
|
1265
|
+
heads = 4,
|
1266
|
+
)
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
seq = torch.randint(0, 20000, (3, 1024))
|
1270
|
+
mask = torch.randint(0, 2, (3, 1024)).bool()
|
1271
|
+
|
1272
|
+
_, intermediates = vlm(seq, return_intermediates = True)
|
1273
|
+
|
1274
|
+
action_model = Decoder(
|
1275
|
+
dim = 512,
|
1276
|
+
depth = 6,
|
1277
|
+
heads = 8,
|
1278
|
+
)
|
1279
|
+
|
1280
|
+
seq = torch.randn(3, 32, 512)
|
1281
|
+
|
1282
|
+
embeds = action_model(
|
1283
|
+
seq,
|
1284
|
+
self_attn_additional_kv = intermediates,
|
1285
|
+
detach_additional_kv = True,
|
1286
|
+
additional_kv_mask = mask
|
1287
|
+
)
|
1288
|
+
|
1289
|
+
assert embeds.shape == (3, 32, 512)
|
@@ -23,7 +23,7 @@ class Intermediates:
|
|
23
23
|
pre_softmax_attn: Tensor | None = None
|
24
24
|
post_softmax_attn: Tensor | None = None
|
25
25
|
values: Tensor | None = None
|
26
|
-
cached_kv:
|
26
|
+
cached_kv: tuple[Tensor, Tensor] | None = None
|
27
27
|
layer_type: str | None = None
|
28
28
|
hybrid_hidden: Tensor | None = None
|
29
29
|
|
@@ -176,7 +176,7 @@ class Attend(Module):
|
|
176
176
|
softclamp_logits = False,
|
177
177
|
logit_softclamp_value = 50.,
|
178
178
|
add_zero_kv = False,
|
179
|
-
|
179
|
+
head_learned_sink = False,
|
180
180
|
selective = False,
|
181
181
|
hard = False,
|
182
182
|
cope = None,
|
@@ -257,10 +257,10 @@ class Attend(Module):
|
|
257
257
|
|
258
258
|
# learned sink concatted pre-softmax, working solution from gpt-oss
|
259
259
|
|
260
|
-
|
261
|
-
assert not (self.has_head_learned_sinks and flash), f'not supported for flash attention yet'
|
260
|
+
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
|
262
261
|
|
263
|
-
self.
|
262
|
+
self.head_learned_sink = head_learned_sink
|
263
|
+
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
|
264
264
|
|
265
265
|
# soft clamp attention logit value
|
266
266
|
|
@@ -517,10 +517,9 @@ class Attend(Module):
|
|
517
517
|
if self.selective:
|
518
518
|
sim = selective_attn(sim)
|
519
519
|
|
520
|
-
if self.
|
520
|
+
if self.head_learned_sink:
|
521
521
|
# add learned attention sink
|
522
|
-
|
523
|
-
attn_sink = repeat(self.head_attn_sinks, 'h sinks -> b h i sinks', b = sim.shape[0], i = sim.shape[2])
|
522
|
+
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
|
524
523
|
sim = cat((attn_sink, sim), dim = -1)
|
525
524
|
|
526
525
|
pre_softmax_attn = sim
|
@@ -531,9 +530,9 @@ class Attend(Module):
|
|
531
530
|
|
532
531
|
post_softmax_attn = attn
|
533
532
|
|
534
|
-
if self.
|
533
|
+
if self.head_learned_sink:
|
535
534
|
# remove attention sink
|
536
|
-
attn = attn[...,
|
535
|
+
attn = attn[..., 1:]
|
537
536
|
|
538
537
|
attn = self.attn_dropout(attn)
|
539
538
|
|
@@ -10,7 +10,7 @@ import torch
|
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
|
13
|
-
from torch.utils._pytree import tree_flatten, tree_unflatten
|
13
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
16
16
|
from functools import partial, wraps
|
@@ -81,6 +81,9 @@ def cast_tuple(val, depth = 1):
|
|
81
81
|
def divisible_by(num, den):
|
82
82
|
return (num % den) == 0
|
83
83
|
|
84
|
+
def detach_all(obj):
|
85
|
+
return tree_map(lambda t: t.detach() if is_tensor(t) and t.requires_grad else t, obj)
|
86
|
+
|
84
87
|
def maybe(fn = None):
|
85
88
|
if not exists(fn):
|
86
89
|
fn = identity
|
@@ -157,6 +160,19 @@ def or_reduce(masks):
|
|
157
160
|
head = head | rest
|
158
161
|
return head
|
159
162
|
|
163
|
+
# cache helpers
|
164
|
+
|
165
|
+
def get_cached_kvs(
|
166
|
+
cache: LayerIntermediates
|
167
|
+
) -> list[tuple[Tensor, Tensor]]:
|
168
|
+
|
169
|
+
cached_kvs = []
|
170
|
+
|
171
|
+
for attn_intermediate in cache.attn_intermediates:
|
172
|
+
cached_kvs.append(attn_intermediate.cached_kv)
|
173
|
+
|
174
|
+
return cached_kvs
|
175
|
+
|
160
176
|
# entropy
|
161
177
|
|
162
178
|
def calc_entropy(
|
@@ -1319,7 +1335,7 @@ class Attention(Module):
|
|
1319
1335
|
value_dim_head = None,
|
1320
1336
|
dim_out = None,
|
1321
1337
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
1322
|
-
|
1338
|
+
head_learned_sink = False,
|
1323
1339
|
rotate_num_heads = None,
|
1324
1340
|
data_dependent_alibi = False,
|
1325
1341
|
data_dependent_alibi_per_row = False,
|
@@ -1516,7 +1532,7 @@ class Attention(Module):
|
|
1516
1532
|
selective = selective,
|
1517
1533
|
custom_attn_fn = custom_attn_fn,
|
1518
1534
|
add_zero_kv = add_zero_kv,
|
1519
|
-
|
1535
|
+
head_learned_sink = head_learned_sink,
|
1520
1536
|
flash = flash,
|
1521
1537
|
softclamp_logits = softclamp_logits,
|
1522
1538
|
logit_softclamp_value = logit_softclamp_value,
|
@@ -2441,8 +2457,13 @@ class AttentionLayers(Module):
|
|
2441
2457
|
context_pos = None,
|
2442
2458
|
attn_bias = None,
|
2443
2459
|
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
|
2444
|
-
self_attn_additional_kv:
|
2460
|
+
self_attn_additional_kv: (
|
2461
|
+
LayerIntermediates |
|
2462
|
+
list[tuple[Tensor, Tensor]]
|
2463
|
+
| None
|
2464
|
+
) = None,
|
2445
2465
|
additional_kv_mask = None,
|
2466
|
+
detach_additional_kv = False,
|
2446
2467
|
route_additional_kv_to_top = True,
|
2447
2468
|
condition = None,
|
2448
2469
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
@@ -2590,6 +2611,13 @@ class AttentionLayers(Module):
|
|
2590
2611
|
# additional self attn key / values - say coming from vlm
|
2591
2612
|
|
2592
2613
|
if exists(self_attn_additional_kv) and route_additional_kv_to_top:
|
2614
|
+
|
2615
|
+
if isinstance(self_attn_additional_kv, LayerIntermediates):
|
2616
|
+
self_attn_additional_kv = get_cached_kvs(self_attn_additional_kv)
|
2617
|
+
|
2618
|
+
if detach_additional_kv:
|
2619
|
+
self_attn_additional_kv = detach_all(self_attn_additional_kv)
|
2620
|
+
|
2593
2621
|
num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
|
2594
2622
|
|
2595
2623
|
self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|