x-transformers 1.38.3__py3-none-any.whl → 1.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +11 -9
- x_transformers/x_transformers.py +4 -2
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.0.dist-info}/METADATA +1 -1
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.0.dist-info}/RECORD +7 -7
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.38.3.dist-info → x_transformers-1.39.0.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -128,7 +128,8 @@ class Attend(Module):
|
|
128
128
|
dropout = 0.,
|
129
129
|
causal = False,
|
130
130
|
heads = None,
|
131
|
-
|
131
|
+
pre_talking_heads = True,
|
132
|
+
post_talking_heads = True,
|
132
133
|
sparse_topk = None,
|
133
134
|
scale = None,
|
134
135
|
qk_norm = False,
|
@@ -179,14 +180,15 @@ class Attend(Module):
|
|
179
180
|
|
180
181
|
# talking heads
|
181
182
|
|
182
|
-
assert not (flash and
|
183
|
+
assert not (flash and (pre_talking_heads or post_talking_heads)), 'talking heads not compatible with flash attention'
|
183
184
|
|
184
|
-
self.
|
185
|
-
if
|
186
|
-
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
187
|
-
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
185
|
+
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
|
186
|
+
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
|
188
187
|
|
188
|
+
if exists(self.pre_softmax_talking_heads):
|
189
189
|
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
|
190
|
+
|
191
|
+
if exists(self.post_softmax_talking_heads):
|
190
192
|
nn.init.dirac_(self.post_softmax_talking_heads.weight)
|
191
193
|
|
192
194
|
# selective attention
|
@@ -434,8 +436,8 @@ class Attend(Module):
|
|
434
436
|
|
435
437
|
qk_similarities = sim.clone()
|
436
438
|
|
437
|
-
if self.
|
438
|
-
sim = self.pre_softmax_talking_heads(sim)
|
439
|
+
if exists(self.pre_softmax_talking_heads):
|
440
|
+
sim = sim + self.pre_softmax_talking_heads(sim)
|
439
441
|
|
440
442
|
if exists(attn_bias):
|
441
443
|
sim = sim + attn_bias
|
@@ -482,7 +484,7 @@ class Attend(Module):
|
|
482
484
|
|
483
485
|
attn = self.attn_dropout(attn)
|
484
486
|
|
485
|
-
if self.
|
487
|
+
if exists(self.post_softmax_talking_heads):
|
486
488
|
attn = self.post_softmax_talking_heads(attn)
|
487
489
|
|
488
490
|
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
x_transformers/x_transformers.py
CHANGED
@@ -907,7 +907,8 @@ class Attention(Module):
|
|
907
907
|
heads = 8,
|
908
908
|
causal = False,
|
909
909
|
flash = False,
|
910
|
-
|
910
|
+
pre_talking_heads = False,
|
911
|
+
post_talking_heads = False,
|
911
912
|
head_scale = False,
|
912
913
|
sparse_topk = None,
|
913
914
|
num_mem_kv = 0,
|
@@ -1036,7 +1037,8 @@ class Attention(Module):
|
|
1036
1037
|
self.attend = Attend(
|
1037
1038
|
heads = heads,
|
1038
1039
|
causal = causal,
|
1039
|
-
|
1040
|
+
pre_talking_heads = pre_talking_heads,
|
1041
|
+
post_talking_heads = post_talking_heads,
|
1040
1042
|
dropout = dropout,
|
1041
1043
|
sparse_topk = sparse_topk,
|
1042
1044
|
qk_norm = qk_norm,
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=sTeX7DmUt6I5FhHtgcTDOIvmD1CvJ1PmVjZ_-lYO-QA,15596
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=nfq-EOLx5HWf8tXlmVuDbkhqNfFnfqRMCEkALK3SFkA,84341
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.39.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.39.0.dist-info/METADATA,sha256=-ppCMMH6ZTsmwaMJB9q4b4Yvd-nU8v95l5SdzTY17OU,661
|
13
|
+
x_transformers-1.39.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.39.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.39.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|