x-transformers 1.38.3__py3-none-any.whl → 1.39.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +23 -9
- x_transformers/x_transformers.py +6 -2
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.1.dist-info}/METADATA +1 -1
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.1.dist-info}/RECORD +7 -7
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.1.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -128,7 +128,9 @@ class Attend(Module):
|
|
128
128
|
dropout = 0.,
|
129
129
|
causal = False,
|
130
130
|
heads = None,
|
131
|
-
|
131
|
+
pre_talking_heads = False,
|
132
|
+
post_talking_heads = False,
|
133
|
+
pre_scale_post_talking_heads = False,
|
132
134
|
sparse_topk = None,
|
133
135
|
scale = None,
|
134
136
|
qk_norm = False,
|
@@ -179,16 +181,22 @@ class Attend(Module):
|
|
179
181
|
|
180
182
|
# talking heads
|
181
183
|
|
182
|
-
assert not (flash and
|
184
|
+
assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
|
183
185
|
|
184
|
-
self.
|
185
|
-
if
|
186
|
-
|
187
|
-
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
186
|
+
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
|
187
|
+
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
|
188
|
+
self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
|
188
189
|
|
190
|
+
if exists(self.pre_softmax_talking_heads):
|
189
191
|
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
|
192
|
+
|
193
|
+
if exists(self.post_softmax_talking_heads):
|
190
194
|
nn.init.dirac_(self.post_softmax_talking_heads.weight)
|
191
195
|
|
196
|
+
if exists(self.pre_scale_post_talking_heads):
|
197
|
+
# an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
|
198
|
+
nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
|
199
|
+
|
192
200
|
# selective attention
|
193
201
|
|
194
202
|
assert not (flash and selective), 'selective attention cannot work on flash attention'
|
@@ -434,8 +442,11 @@ class Attend(Module):
|
|
434
442
|
|
435
443
|
qk_similarities = sim.clone()
|
436
444
|
|
437
|
-
if self.
|
438
|
-
|
445
|
+
if exists(self.pre_scale_post_talking_heads):
|
446
|
+
pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
|
447
|
+
|
448
|
+
if exists(self.pre_softmax_talking_heads):
|
449
|
+
sim = sim + self.pre_softmax_talking_heads(sim)
|
439
450
|
|
440
451
|
if exists(attn_bias):
|
441
452
|
sim = sim + attn_bias
|
@@ -482,9 +493,12 @@ class Attend(Module):
|
|
482
493
|
|
483
494
|
attn = self.attn_dropout(attn)
|
484
495
|
|
485
|
-
if self.
|
496
|
+
if exists(self.post_softmax_talking_heads):
|
486
497
|
attn = self.post_softmax_talking_heads(attn)
|
487
498
|
|
499
|
+
if exists(self.pre_scale_post_talking_heads):
|
500
|
+
attn = attn * pre_to_post_scale
|
501
|
+
|
488
502
|
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
489
503
|
|
490
504
|
intermediates = Intermediates(
|
x_transformers/x_transformers.py
CHANGED
@@ -907,7 +907,9 @@ class Attention(Module):
|
|
907
907
|
heads = 8,
|
908
908
|
causal = False,
|
909
909
|
flash = False,
|
910
|
-
|
910
|
+
pre_talking_heads = False,
|
911
|
+
post_talking_heads = False,
|
912
|
+
pre_scale_post_talking_heads = False,
|
911
913
|
head_scale = False,
|
912
914
|
sparse_topk = None,
|
913
915
|
num_mem_kv = 0,
|
@@ -1036,7 +1038,9 @@ class Attention(Module):
|
|
1036
1038
|
self.attend = Attend(
|
1037
1039
|
heads = heads,
|
1038
1040
|
causal = causal,
|
1039
|
-
|
1041
|
+
pre_talking_heads = pre_talking_heads,
|
1042
|
+
post_talking_heads = post_talking_heads,
|
1043
|
+
pre_scale_post_talking_heads = pre_scale_post_talking_heads,
|
1040
1044
|
dropout = dropout,
|
1041
1045
|
sparse_topk = sparse_topk,
|
1042
1046
|
qk_norm = qk_norm,
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=JlWwEzigY5sl9ktDwDWWQD9np9uPRoj2eRo9XU6tJc0,16273
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=8ZQR6OLT4vusIjJXzrdSp12Fydmmpcc2t5cDE6SxPNc,84460
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.39.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.39.1.dist-info/METADATA,sha256=KJfw4hIDzozyRlsnBaqbbfcFSXygbv1y2B_6cl9pu-4,661
|
13
|
+
x_transformers-1.39.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.39.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.39.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|