x-transformers 2.6.3__py3-none-any.whl → 2.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +22 -4
- x_transformers/x_transformers.py +2 -0
- {x_transformers-2.6.3.dist-info → x_transformers-2.6.5.dist-info}/METADATA +11 -1
- {x_transformers-2.6.3.dist-info → x_transformers-2.6.5.dist-info}/RECORD +6 -6
- {x_transformers-2.6.3.dist-info → x_transformers-2.6.5.dist-info}/WHEEL +0 -0
- {x_transformers-2.6.3.dist-info → x_transformers-2.6.5.dist-info}/licenses/LICENSE +0 -0
x_transformers/attend.py
CHANGED
@@ -4,8 +4,8 @@ from functools import partial
|
|
4
4
|
from typing import Tuple, Callable
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from torch.nn import Module
|
8
|
-
from torch import nn, einsum, Tensor
|
7
|
+
from torch.nn import Module, Parameter
|
8
|
+
from torch import cat, nn, einsum, Tensor
|
9
9
|
import torch.nn.functional as F
|
10
10
|
|
11
11
|
from collections import namedtuple
|
@@ -176,6 +176,7 @@ class Attend(Module):
|
|
176
176
|
softclamp_logits = False,
|
177
177
|
logit_softclamp_value = 50.,
|
178
178
|
add_zero_kv = False,
|
179
|
+
head_learned_sinks = 0,
|
179
180
|
selective = False,
|
180
181
|
hard = False,
|
181
182
|
cope = None,
|
@@ -254,6 +255,13 @@ class Attend(Module):
|
|
254
255
|
|
255
256
|
self.add_zero_kv = add_zero_kv
|
256
257
|
|
258
|
+
# learned sink concatted pre-softmax, working solution from gpt-oss
|
259
|
+
|
260
|
+
self.has_head_learned_sinks = head_learned_sinks > 0
|
261
|
+
assert not (self.has_head_learned_sinks and flash), f'not supported for flash attention yet'
|
262
|
+
|
263
|
+
self.head_attn_sinks = Parameter(torch.zeros(heads, head_learned_sinks)) if self.has_head_learned_sinks else None
|
264
|
+
|
257
265
|
# soft clamp attention logit value
|
258
266
|
|
259
267
|
if softclamp_logits:
|
@@ -315,10 +323,10 @@ class Attend(Module):
|
|
315
323
|
if self.l2_distance:
|
316
324
|
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
|
317
325
|
k = F.pad(k, (0, 1), value = -1.)
|
318
|
-
k =
|
326
|
+
k = cat((k, k_norm_sq), dim = -1)
|
319
327
|
|
320
328
|
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
|
321
|
-
q =
|
329
|
+
q = cat((2 * q, q_norm_sq), dim = -1)
|
322
330
|
q = F.pad(q, (0, 1), value = -1.)
|
323
331
|
|
324
332
|
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
|
@@ -509,6 +517,12 @@ class Attend(Module):
|
|
509
517
|
if self.selective:
|
510
518
|
sim = selective_attn(sim)
|
511
519
|
|
520
|
+
if self.has_head_learned_sinks:
|
521
|
+
# add learned attention sink
|
522
|
+
num_sinks = self.head_attn_sinks.shape[-1]
|
523
|
+
attn_sink = repeat(self.head_attn_sinks, 'h sinks -> b h i sinks', b = sim.shape[0], i = sim.shape[2])
|
524
|
+
sim = cat((attn_sink, sim), dim = -1)
|
525
|
+
|
512
526
|
pre_softmax_attn = sim
|
513
527
|
|
514
528
|
attn = self.attn_fn(sim)
|
@@ -517,6 +531,10 @@ class Attend(Module):
|
|
517
531
|
|
518
532
|
post_softmax_attn = attn
|
519
533
|
|
534
|
+
if self.has_head_learned_sinks:
|
535
|
+
# remove attention sink
|
536
|
+
attn = attn[..., num_sinks:]
|
537
|
+
|
520
538
|
attn = self.attn_dropout(attn)
|
521
539
|
|
522
540
|
if exists(self.post_softmax_talking_heads):
|
x_transformers/x_transformers.py
CHANGED
@@ -1319,6 +1319,7 @@ class Attention(Module):
|
|
1319
1319
|
value_dim_head = None,
|
1320
1320
|
dim_out = None,
|
1321
1321
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
1322
|
+
head_learned_sinks = 0,
|
1322
1323
|
rotate_num_heads = None,
|
1323
1324
|
data_dependent_alibi = False,
|
1324
1325
|
data_dependent_alibi_per_row = False,
|
@@ -1515,6 +1516,7 @@ class Attention(Module):
|
|
1515
1516
|
selective = selective,
|
1516
1517
|
custom_attn_fn = custom_attn_fn,
|
1517
1518
|
add_zero_kv = add_zero_kv,
|
1519
|
+
head_learned_sinks = head_learned_sinks,
|
1518
1520
|
flash = flash,
|
1519
1521
|
softclamp_logits = softclamp_logits,
|
1520
1522
|
logit_softclamp_value = logit_softclamp_value,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.6.
|
3
|
+
Version: 2.6.5
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2507,4 +2507,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2507
2507
|
}
|
2508
2508
|
```
|
2509
2509
|
|
2510
|
+
```bibtex
|
2511
|
+
@misc{openai_gpt_oss,
|
2512
|
+
author = {OpenAI},
|
2513
|
+
title = {Introducing gpt-oss},
|
2514
|
+
howpublished = {https://openai.com/index/introducing-gpt-oss},
|
2515
|
+
month = {August},
|
2516
|
+
year = {2025}
|
2517
|
+
}
|
2518
|
+
```
|
2519
|
+
|
2510
2520
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -1,5 +1,5 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=DX_qrDkz98Db0eNapbzciJbVp5dsWIFWdpv2LUfebJs,18223
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=BsGO9xfVYkvynqbU1__tu_S_cxl7gss0YwnkhIa2baY,18401
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=F_ZR9jysYmkbqKvsZmzXqOP3VznVeivXVOstAwKIdPU,122185
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.6.
|
16
|
-
x_transformers-2.6.
|
17
|
-
x_transformers-2.6.
|
18
|
-
x_transformers-2.6.
|
15
|
+
x_transformers-2.6.5.dist-info/METADATA,sha256=yMl0MlBbo7D9dOu_cBQz38iJQ3a6F8PlaPCo5RQXrSA,90445
|
16
|
+
x_transformers-2.6.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.6.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.6.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|