x-transformers 1.37.10__py3-none-any.whl → 1.38.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +38 -0
- x_transformers/x_transformers.py +2 -0
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.0.dist-info}/METADATA +1 -1
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.0.dist-info}/RECORD +7 -7
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.10.dist-info → x_transformers-1.38.0.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -65,6 +65,35 @@ def once(fn):
|
|
65
65
|
|
66
66
|
print_once = once(print)
|
67
67
|
|
68
|
+
# selective attention
|
69
|
+
# https://arxiv.org/abs/2410.02703 - section 3.3
|
70
|
+
# it is a parameter-less technique to allow each token to prevent itself from being attended to by future tokens
|
71
|
+
# if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
|
72
|
+
|
73
|
+
def selective_attn(
|
74
|
+
sim,
|
75
|
+
sim_head_gate = None,
|
76
|
+
no_mask_sos = True
|
77
|
+
):
|
78
|
+
i, j, device = *sim.shape[-2:], sim.device
|
79
|
+
sim_head_gate = default(sim_head_gate, sim[:, 0])
|
80
|
+
|
81
|
+
gate = F.relu(sim_head_gate) # only positive
|
82
|
+
|
83
|
+
if no_mask_sos:
|
84
|
+
gate[..., -i] = 0.
|
85
|
+
|
86
|
+
eye = torch.eye(j, device = device)
|
87
|
+
|
88
|
+
if j > i:
|
89
|
+
eye = F.pad(eye, (j - i, 0), value = 1.)
|
90
|
+
|
91
|
+
gate = (1. - eye) * gate
|
92
|
+
gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future
|
93
|
+
gate = gate.cumsum(dim = -2)
|
94
|
+
|
95
|
+
return sim - rearrange(gate, 'b i j -> b 1 i j')
|
96
|
+
|
68
97
|
# alternative distance functions
|
69
98
|
|
70
99
|
def qk_l2_dist_squared(q, k):
|
@@ -109,6 +138,7 @@ class Attend(Module):
|
|
109
138
|
softclamp_logits = False,
|
110
139
|
logit_softclamp_value = 50.,
|
111
140
|
add_zero_kv = False,
|
141
|
+
selective = False,
|
112
142
|
sigsoftmax = False,
|
113
143
|
cope = None,
|
114
144
|
onnxable = False,
|
@@ -155,6 +185,11 @@ class Attend(Module):
|
|
155
185
|
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
156
186
|
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
157
187
|
|
188
|
+
# selective attention
|
189
|
+
|
190
|
+
assert not (selective and not causal), 'selective attention is designed for autoregressive'
|
191
|
+
self.selective = selective
|
192
|
+
|
158
193
|
# sparse topk
|
159
194
|
|
160
195
|
assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
|
@@ -427,6 +462,9 @@ class Attend(Module):
|
|
427
462
|
if exists(self.cope):
|
428
463
|
sim = sim + self.cope(q, sim)
|
429
464
|
|
465
|
+
if self.selective:
|
466
|
+
sim = selective_attn(sim)
|
467
|
+
|
430
468
|
pre_softmax_attn = sim
|
431
469
|
|
432
470
|
if self.sigsoftmax:
|
x_transformers/x_transformers.py
CHANGED
@@ -925,6 +925,7 @@ class Attention(Module):
|
|
925
925
|
qk_norm_dim_scale = False,
|
926
926
|
l2_distance = False,
|
927
927
|
sigmoid = False,
|
928
|
+
selective = False,
|
928
929
|
custom_attn_fn: Callable | None = None,
|
929
930
|
one_kv_head = False,
|
930
931
|
kv_heads = None,
|
@@ -1042,6 +1043,7 @@ class Attention(Module):
|
|
1042
1043
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1043
1044
|
l2_distance = l2_distance,
|
1044
1045
|
sigmoid = sigmoid,
|
1046
|
+
selective = selective,
|
1045
1047
|
custom_attn_fn = custom_attn_fn,
|
1046
1048
|
add_zero_kv = add_zero_kv,
|
1047
1049
|
flash = flash,
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=tZOPsSgWwabIGgUlNdWrVZgcxOkF3sgiQQ7nrcOjl9I,15151
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=Dol6GMZOoHGOFdHwe21o2SbJp6b3YKCUHoIs_AjfvTo,83963
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.38.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.38.0.dist-info/METADATA,sha256=9KjkWIqzpGS3vs7MpWmfbFfPHPvdOJSE6hOhxQsu4XI,661
|
13
|
+
x_transformers-1.38.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.38.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.38.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|