x-transformers 2.6.2__tar.gz → 2.6.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.6.2 → x_transformers-2.6.4}/PKG-INFO +11 -1
- {x_transformers-2.6.2 → x_transformers-2.6.4}/README.md +10 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/pyproject.toml +1 -1
- {x_transformers-2.6.2 → x_transformers-2.6.4}/tests/test_x_transformers.py +19 -2
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/attend.py +21 -4
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/x_transformers.py +9 -1
- {x_transformers-2.6.2 → x_transformers-2.6.4}/.github/FUNDING.yml +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/.gitignore +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/LICENSE +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/data/README.md +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/data/enwik8.gz +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/all-attention.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/attention-on-attention.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/deepnorm.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/fcm.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/ffglu.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/flash-attention.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/gate_values.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/gating.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/macaron-1.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/macaron-2.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/memory-transformer.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/normformer.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/pia.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/resi_dual.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/residual_attn.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/rezero.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/rotary.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/sandwich-2.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/sandwich.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/sandwich_norm.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/scalenorm.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/talking-heads.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/topk-attention.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/images/xval.png +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/train_belief_state.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/train_copy.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/train_enwik8.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/train_length_extrapolate.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/train_parity.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/__init__.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/continuous.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/dpo.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.2 → x_transformers-2.6.4}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.6.
|
3
|
+
Version: 2.6.4
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2507,4 +2507,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2507
2507
|
}
|
2508
2508
|
```
|
2509
2509
|
|
2510
|
+
```bibtex
|
2511
|
+
@misc{openai_gpt_oss,
|
2512
|
+
author = {OpenAI},
|
2513
|
+
title = {Introducing gpt-oss},
|
2514
|
+
howpublished = {https://openai.com/index/introducing-gpt-oss},
|
2515
|
+
month = {August},
|
2516
|
+
year = {2025}
|
2517
|
+
}
|
2518
|
+
```
|
2519
|
+
|
2510
2520
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -2459,4 +2459,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2459
2459
|
}
|
2460
2460
|
```
|
2461
2461
|
|
2462
|
+
```bibtex
|
2463
|
+
@misc{openai_gpt_oss,
|
2464
|
+
author = {OpenAI},
|
2465
|
+
title = {Introducing gpt-oss},
|
2466
|
+
howpublished = {https://openai.com/index/introducing-gpt-oss},
|
2467
|
+
month = {August},
|
2468
|
+
year = {2025}
|
2469
|
+
}
|
2470
|
+
```
|
2471
|
+
|
2462
2472
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -1228,10 +1228,27 @@ def test_external_key_values():
|
|
1228
1228
|
seq = torch.randint(0, 20000, (3, 1024))
|
1229
1229
|
|
1230
1230
|
key_values = [
|
1231
|
-
(torch.randn(3,
|
1232
|
-
(torch.randn(3,
|
1231
|
+
(torch.randn(3, 2, 32, 16), torch.randn(3, 2, 32, 16)),
|
1232
|
+
(torch.randn(3, 2, 32, 16), torch.randn(3, 2, 32, 16)),
|
1233
1233
|
]
|
1234
1234
|
|
1235
1235
|
additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()
|
1236
1236
|
|
1237
1237
|
logits = model(seq, self_attn_additional_kv = key_values, additional_kv_mask = additional_kv_mask)
|
1238
|
+
|
1239
|
+
def test_learned_head_attn_sink():
|
1240
|
+
|
1241
|
+
model = TransformerWrapper(
|
1242
|
+
num_tokens = 20000,
|
1243
|
+
max_seq_len = 1024,
|
1244
|
+
attn_layers = Decoder(
|
1245
|
+
dim = 512,
|
1246
|
+
depth = 12,
|
1247
|
+
heads = 8,
|
1248
|
+
attn_head_learned_sink = True
|
1249
|
+
)
|
1250
|
+
)
|
1251
|
+
|
1252
|
+
seq = torch.randint(0, 20000, (3, 1024))
|
1253
|
+
|
1254
|
+
logits = model(seq)
|
@@ -4,8 +4,8 @@ from functools import partial
|
|
4
4
|
from typing import Tuple, Callable
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from torch.nn import Module
|
8
|
-
from torch import nn, einsum, Tensor
|
7
|
+
from torch.nn import Module, Parameter
|
8
|
+
from torch import cat, nn, einsum, Tensor
|
9
9
|
import torch.nn.functional as F
|
10
10
|
|
11
11
|
from collections import namedtuple
|
@@ -176,6 +176,7 @@ class Attend(Module):
|
|
176
176
|
softclamp_logits = False,
|
177
177
|
logit_softclamp_value = 50.,
|
178
178
|
add_zero_kv = False,
|
179
|
+
head_learned_sink = False,
|
179
180
|
selective = False,
|
180
181
|
hard = False,
|
181
182
|
cope = None,
|
@@ -254,6 +255,13 @@ class Attend(Module):
|
|
254
255
|
|
255
256
|
self.add_zero_kv = add_zero_kv
|
256
257
|
|
258
|
+
# learned sink concatted pre-softmax, working solution from gpt-oss
|
259
|
+
|
260
|
+
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
|
261
|
+
|
262
|
+
self.head_learned_sink = head_learned_sink
|
263
|
+
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
|
264
|
+
|
257
265
|
# soft clamp attention logit value
|
258
266
|
|
259
267
|
if softclamp_logits:
|
@@ -315,10 +323,10 @@ class Attend(Module):
|
|
315
323
|
if self.l2_distance:
|
316
324
|
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
|
317
325
|
k = F.pad(k, (0, 1), value = -1.)
|
318
|
-
k =
|
326
|
+
k = cat((k, k_norm_sq), dim = -1)
|
319
327
|
|
320
328
|
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
|
321
|
-
q =
|
329
|
+
q = cat((2 * q, q_norm_sq), dim = -1)
|
322
330
|
q = F.pad(q, (0, 1), value = -1.)
|
323
331
|
|
324
332
|
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
|
@@ -509,6 +517,11 @@ class Attend(Module):
|
|
509
517
|
if self.selective:
|
510
518
|
sim = selective_attn(sim)
|
511
519
|
|
520
|
+
if self.head_learned_sink:
|
521
|
+
# add learned attention sink
|
522
|
+
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
|
523
|
+
sim = cat((attn_sink, sim), dim = -1)
|
524
|
+
|
512
525
|
pre_softmax_attn = sim
|
513
526
|
|
514
527
|
attn = self.attn_fn(sim)
|
@@ -517,6 +530,10 @@ class Attend(Module):
|
|
517
530
|
|
518
531
|
post_softmax_attn = attn
|
519
532
|
|
533
|
+
if self.head_learned_sink:
|
534
|
+
# remove attention sink
|
535
|
+
attn = attn[..., 1:]
|
536
|
+
|
520
537
|
attn = self.attn_dropout(attn)
|
521
538
|
|
522
539
|
if exists(self.post_softmax_talking_heads):
|
@@ -1319,6 +1319,7 @@ class Attention(Module):
|
|
1319
1319
|
value_dim_head = None,
|
1320
1320
|
dim_out = None,
|
1321
1321
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
1322
|
+
head_learned_sink = False,
|
1322
1323
|
rotate_num_heads = None,
|
1323
1324
|
data_dependent_alibi = False,
|
1324
1325
|
data_dependent_alibi_per_row = False,
|
@@ -1515,6 +1516,7 @@ class Attention(Module):
|
|
1515
1516
|
selective = selective,
|
1516
1517
|
custom_attn_fn = custom_attn_fn,
|
1517
1518
|
add_zero_kv = add_zero_kv,
|
1519
|
+
head_learned_sink = head_learned_sink,
|
1518
1520
|
flash = flash,
|
1519
1521
|
softclamp_logits = softclamp_logits,
|
1520
1522
|
logit_softclamp_value = logit_softclamp_value,
|
@@ -1795,6 +1797,13 @@ class Attention(Module):
|
|
1795
1797
|
seq_len = k.shape[-2]
|
1796
1798
|
|
1797
1799
|
added_k, added_v = additional_key_values
|
1800
|
+
added_kv_heads, added_kv_len = added_k.shape[1], added_k.shape[-2]
|
1801
|
+
|
1802
|
+
# take care of expanding to query heads if mismatch between key / value heads with the ones coming from vlm
|
1803
|
+
|
1804
|
+
if added_kv_heads != kv_h:
|
1805
|
+
assert divisible_by(h, added_kv_heads)
|
1806
|
+
k, v, added_k, added_v = tuple(repeat(t, 'b h ... -> b (r h) ...', r = h // t.shape[1]) for t in (k, v, added_k, added_v))
|
1798
1807
|
|
1799
1808
|
k = cat((added_k, k), dim = -2)
|
1800
1809
|
v = cat((added_v, v), dim = -2)
|
@@ -1802,7 +1811,6 @@ class Attention(Module):
|
|
1802
1811
|
if (exists(input_mask) or exists(additional_key_value_mask)):
|
1803
1812
|
|
1804
1813
|
if not exists(additional_key_value_mask):
|
1805
|
-
added_kv_len = added_k.shape[-2]
|
1806
1814
|
input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
|
1807
1815
|
elif not exists(input_mask):
|
1808
1816
|
input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|