x-transformers 1.39.0__py3-none-any.whl → 1.39.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +15 -3
- x_transformers/x_transformers.py +2 -0
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.1.dist-info}/METADATA +1 -1
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.1.dist-info}/RECORD +7 -7
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.1.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -128,8 +128,9 @@ class Attend(Module):
|
|
128
128
|
dropout = 0.,
|
129
129
|
causal = False,
|
130
130
|
heads = None,
|
131
|
-
pre_talking_heads =
|
132
|
-
post_talking_heads =
|
131
|
+
pre_talking_heads = False,
|
132
|
+
post_talking_heads = False,
|
133
|
+
pre_scale_post_talking_heads = False,
|
133
134
|
sparse_topk = None,
|
134
135
|
scale = None,
|
135
136
|
qk_norm = False,
|
@@ -180,10 +181,11 @@ class Attend(Module):
|
|
180
181
|
|
181
182
|
# talking heads
|
182
183
|
|
183
|
-
assert not (flash and (pre_talking_heads or post_talking_heads)), 'talking heads not compatible with flash attention'
|
184
|
+
assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
|
184
185
|
|
185
186
|
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
|
186
187
|
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
|
188
|
+
self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
|
187
189
|
|
188
190
|
if exists(self.pre_softmax_talking_heads):
|
189
191
|
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
|
@@ -191,6 +193,10 @@ class Attend(Module):
|
|
191
193
|
if exists(self.post_softmax_talking_heads):
|
192
194
|
nn.init.dirac_(self.post_softmax_talking_heads.weight)
|
193
195
|
|
196
|
+
if exists(self.pre_scale_post_talking_heads):
|
197
|
+
# an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
|
198
|
+
nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
|
199
|
+
|
194
200
|
# selective attention
|
195
201
|
|
196
202
|
assert not (flash and selective), 'selective attention cannot work on flash attention'
|
@@ -436,6 +442,9 @@ class Attend(Module):
|
|
436
442
|
|
437
443
|
qk_similarities = sim.clone()
|
438
444
|
|
445
|
+
if exists(self.pre_scale_post_talking_heads):
|
446
|
+
pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
|
447
|
+
|
439
448
|
if exists(self.pre_softmax_talking_heads):
|
440
449
|
sim = sim + self.pre_softmax_talking_heads(sim)
|
441
450
|
|
@@ -487,6 +496,9 @@ class Attend(Module):
|
|
487
496
|
if exists(self.post_softmax_talking_heads):
|
488
497
|
attn = self.post_softmax_talking_heads(attn)
|
489
498
|
|
499
|
+
if exists(self.pre_scale_post_talking_heads):
|
500
|
+
attn = attn * pre_to_post_scale
|
501
|
+
|
490
502
|
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
491
503
|
|
492
504
|
intermediates = Intermediates(
|
x_transformers/x_transformers.py
CHANGED
@@ -909,6 +909,7 @@ class Attention(Module):
|
|
909
909
|
flash = False,
|
910
910
|
pre_talking_heads = False,
|
911
911
|
post_talking_heads = False,
|
912
|
+
pre_scale_post_talking_heads = False,
|
912
913
|
head_scale = False,
|
913
914
|
sparse_topk = None,
|
914
915
|
num_mem_kv = 0,
|
@@ -1039,6 +1040,7 @@ class Attention(Module):
|
|
1039
1040
|
causal = causal,
|
1040
1041
|
pre_talking_heads = pre_talking_heads,
|
1041
1042
|
post_talking_heads = post_talking_heads,
|
1043
|
+
pre_scale_post_talking_heads = pre_scale_post_talking_heads,
|
1042
1044
|
dropout = dropout,
|
1043
1045
|
sparse_topk = sparse_topk,
|
1044
1046
|
qk_norm = qk_norm,
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=JlWwEzigY5sl9ktDwDWWQD9np9uPRoj2eRo9XU6tJc0,16273
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=8ZQR6OLT4vusIjJXzrdSp12Fydmmpcc2t5cDE6SxPNc,84460
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.39.
|
12
|
-
x_transformers-1.39.
|
13
|
-
x_transformers-1.39.
|
14
|
-
x_transformers-1.39.
|
15
|
-
x_transformers-1.39.
|
11
|
+
x_transformers-1.39.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.39.1.dist-info/METADATA,sha256=KJfw4hIDzozyRlsnBaqbbfcFSXygbv1y2B_6cl9pu-4,661
|
13
|
+
x_transformers-1.39.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.39.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.39.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|