x-transformers 1.37.10__py3-none-any.whl → 1.38.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +39 -0
- x_transformers/x_transformers.py +2 -0
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.1.dist-info}/METADATA +1 -1
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.1.dist-info}/RECORD +7 -7
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.1.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -65,6 +65,35 @@ def once(fn):
|
|
65
65
|
|
66
66
|
print_once = once(print)
|
67
67
|
|
68
|
+
# selective attention
|
69
|
+
# https://arxiv.org/abs/2410.02703 - section 3.3
|
70
|
+
# it is a parameter-less technique to allow each token to prevent itself from being attended to by future tokens
|
71
|
+
# if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
|
72
|
+
|
73
|
+
def selective_attn(
|
74
|
+
sim,
|
75
|
+
sim_head_gate = None,
|
76
|
+
no_mask_sos = True
|
77
|
+
):
|
78
|
+
i, j, device = *sim.shape[-2:], sim.device
|
79
|
+
sim_head_gate = default(sim_head_gate, sim[:, 0])
|
80
|
+
|
81
|
+
gate = F.relu(sim_head_gate) # only positive
|
82
|
+
|
83
|
+
if no_mask_sos:
|
84
|
+
gate[..., -i] = 0.
|
85
|
+
|
86
|
+
eye = torch.eye(j, device = device)
|
87
|
+
|
88
|
+
if j > i:
|
89
|
+
eye = F.pad(eye, (j - i, 0), value = 1.)
|
90
|
+
|
91
|
+
gate = (1. - eye) * gate
|
92
|
+
gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future
|
93
|
+
gate = gate.cumsum(dim = -2)
|
94
|
+
|
95
|
+
return sim - rearrange(gate, 'b i j -> b 1 i j')
|
96
|
+
|
68
97
|
# alternative distance functions
|
69
98
|
|
70
99
|
def qk_l2_dist_squared(q, k):
|
@@ -109,6 +138,7 @@ class Attend(Module):
|
|
109
138
|
softclamp_logits = False,
|
110
139
|
logit_softclamp_value = 50.,
|
111
140
|
add_zero_kv = False,
|
141
|
+
selective = False,
|
112
142
|
sigsoftmax = False,
|
113
143
|
cope = None,
|
114
144
|
onnxable = False,
|
@@ -155,6 +185,12 @@ class Attend(Module):
|
|
155
185
|
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
156
186
|
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
157
187
|
|
188
|
+
# selective attention
|
189
|
+
|
190
|
+
assert not (flash and selective), 'selective attention cannot work on flash attention'
|
191
|
+
assert not (selective and not causal), 'selective attention is designed for autoregressive'
|
192
|
+
self.selective = selective
|
193
|
+
|
158
194
|
# sparse topk
|
159
195
|
|
160
196
|
assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
|
@@ -427,6 +463,9 @@ class Attend(Module):
|
|
427
463
|
if exists(self.cope):
|
428
464
|
sim = sim + self.cope(q, sim)
|
429
465
|
|
466
|
+
if self.selective:
|
467
|
+
sim = selective_attn(sim)
|
468
|
+
|
430
469
|
pre_softmax_attn = sim
|
431
470
|
|
432
471
|
if self.sigsoftmax:
|
x_transformers/x_transformers.py
CHANGED
@@ -925,6 +925,7 @@ class Attention(Module):
|
|
925
925
|
qk_norm_dim_scale = False,
|
926
926
|
l2_distance = False,
|
927
927
|
sigmoid = False,
|
928
|
+
selective = False,
|
928
929
|
custom_attn_fn: Callable | None = None,
|
929
930
|
one_kv_head = False,
|
930
931
|
kv_heads = None,
|
@@ -1042,6 +1043,7 @@ class Attention(Module):
|
|
1042
1043
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1043
1044
|
l2_distance = l2_distance,
|
1044
1045
|
sigmoid = sigmoid,
|
1046
|
+
selective = selective,
|
1045
1047
|
custom_attn_fn = custom_attn_fn,
|
1046
1048
|
add_zero_kv = add_zero_kv,
|
1047
1049
|
flash = flash,
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=Ysh8fFQ-BW0aKl7Ls3-gBJQ0Y-QGahMCuvTRJZ-C1rM,15246
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=Dol6GMZOoHGOFdHwe21o2SbJp6b3YKCUHoIs_AjfvTo,83963
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.38.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.38.1.dist-info/METADATA,sha256=mdClNuSZCRyILMjlRi0WUuRFRnRoRSDsXZEF5A4NbpA,661
|
13
|
+
x_transformers-1.38.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.38.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.38.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|