x-transformers 2.6.3__py3-none-any.whl → 2.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +21 -4
- x_transformers/x_transformers.py +2 -0
- {x_transformers-2.6.3.dist-info → x_transformers-2.6.4.dist-info}/METADATA +11 -1
- {x_transformers-2.6.3.dist-info → x_transformers-2.6.4.dist-info}/RECORD +6 -6
- {x_transformers-2.6.3.dist-info → x_transformers-2.6.4.dist-info}/WHEEL +0 -0
- {x_transformers-2.6.3.dist-info → x_transformers-2.6.4.dist-info}/licenses/LICENSE +0 -0
x_transformers/attend.py
CHANGED
@@ -4,8 +4,8 @@ from functools import partial
|
|
4
4
|
from typing import Tuple, Callable
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from torch.nn import Module
|
8
|
-
from torch import nn, einsum, Tensor
|
7
|
+
from torch.nn import Module, Parameter
|
8
|
+
from torch import cat, nn, einsum, Tensor
|
9
9
|
import torch.nn.functional as F
|
10
10
|
|
11
11
|
from collections import namedtuple
|
@@ -176,6 +176,7 @@ class Attend(Module):
|
|
176
176
|
softclamp_logits = False,
|
177
177
|
logit_softclamp_value = 50.,
|
178
178
|
add_zero_kv = False,
|
179
|
+
head_learned_sink = False,
|
179
180
|
selective = False,
|
180
181
|
hard = False,
|
181
182
|
cope = None,
|
@@ -254,6 +255,13 @@ class Attend(Module):
|
|
254
255
|
|
255
256
|
self.add_zero_kv = add_zero_kv
|
256
257
|
|
258
|
+
# learned sink concatted pre-softmax, working solution from gpt-oss
|
259
|
+
|
260
|
+
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
|
261
|
+
|
262
|
+
self.head_learned_sink = head_learned_sink
|
263
|
+
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
|
264
|
+
|
257
265
|
# soft clamp attention logit value
|
258
266
|
|
259
267
|
if softclamp_logits:
|
@@ -315,10 +323,10 @@ class Attend(Module):
|
|
315
323
|
if self.l2_distance:
|
316
324
|
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
|
317
325
|
k = F.pad(k, (0, 1), value = -1.)
|
318
|
-
k =
|
326
|
+
k = cat((k, k_norm_sq), dim = -1)
|
319
327
|
|
320
328
|
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
|
321
|
-
q =
|
329
|
+
q = cat((2 * q, q_norm_sq), dim = -1)
|
322
330
|
q = F.pad(q, (0, 1), value = -1.)
|
323
331
|
|
324
332
|
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
|
@@ -509,6 +517,11 @@ class Attend(Module):
|
|
509
517
|
if self.selective:
|
510
518
|
sim = selective_attn(sim)
|
511
519
|
|
520
|
+
if self.head_learned_sink:
|
521
|
+
# add learned attention sink
|
522
|
+
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
|
523
|
+
sim = cat((attn_sink, sim), dim = -1)
|
524
|
+
|
512
525
|
pre_softmax_attn = sim
|
513
526
|
|
514
527
|
attn = self.attn_fn(sim)
|
@@ -517,6 +530,10 @@ class Attend(Module):
|
|
517
530
|
|
518
531
|
post_softmax_attn = attn
|
519
532
|
|
533
|
+
if self.head_learned_sink:
|
534
|
+
# remove attention sink
|
535
|
+
attn = attn[..., 1:]
|
536
|
+
|
520
537
|
attn = self.attn_dropout(attn)
|
521
538
|
|
522
539
|
if exists(self.post_softmax_talking_heads):
|
x_transformers/x_transformers.py
CHANGED
@@ -1319,6 +1319,7 @@ class Attention(Module):
|
|
1319
1319
|
value_dim_head = None,
|
1320
1320
|
dim_out = None,
|
1321
1321
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
1322
|
+
head_learned_sink = False,
|
1322
1323
|
rotate_num_heads = None,
|
1323
1324
|
data_dependent_alibi = False,
|
1324
1325
|
data_dependent_alibi_per_row = False,
|
@@ -1515,6 +1516,7 @@ class Attention(Module):
|
|
1515
1516
|
selective = selective,
|
1516
1517
|
custom_attn_fn = custom_attn_fn,
|
1517
1518
|
add_zero_kv = add_zero_kv,
|
1519
|
+
head_learned_sink = head_learned_sink,
|
1518
1520
|
flash = flash,
|
1519
1521
|
softclamp_logits = softclamp_logits,
|
1520
1522
|
logit_softclamp_value = logit_softclamp_value,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.6.
|
3
|
+
Version: 2.6.4
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2507,4 +2507,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2507
2507
|
}
|
2508
2508
|
```
|
2509
2509
|
|
2510
|
+
```bibtex
|
2511
|
+
@misc{openai_gpt_oss,
|
2512
|
+
author = {OpenAI},
|
2513
|
+
title = {Introducing gpt-oss},
|
2514
|
+
howpublished = {https://openai.com/index/introducing-gpt-oss},
|
2515
|
+
month = {August},
|
2516
|
+
year = {2025}
|
2517
|
+
}
|
2518
|
+
```
|
2519
|
+
|
2510
2520
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -1,5 +1,5 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=JJv6ypJbZIFmH1LQ49hFg6hD0Wf9Z7Im1AP2ekm9hVI,18091
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=BsGO9xfVYkvynqbU1__tu_S_cxl7gss0YwnkhIa2baY,18401
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=vjRMEMA12Js94YwLVeZksYMEoRgK6CSKT6TJViMPp7U,122186
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.6.
|
16
|
-
x_transformers-2.6.
|
17
|
-
x_transformers-2.6.
|
18
|
-
x_transformers-2.6.
|
15
|
+
x_transformers-2.6.4.dist-info/METADATA,sha256=pvv3zqf_syaWrLb2PbromP_T0H1o6ON7OGygS9dQN_M,90445
|
16
|
+
x_transformers-2.6.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.6.4.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.6.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|