x-transformers 1.39.0__py3-none-any.whl → 1.39.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -128,8 +128,9 @@ class Attend(Module):
128
128
  dropout = 0.,
129
129
  causal = False,
130
130
  heads = None,
131
- pre_talking_heads = True,
132
- post_talking_heads = True,
131
+ pre_talking_heads = False,
132
+ post_talking_heads = False,
133
+ pre_scale_post_talking_heads = False,
133
134
  sparse_topk = None,
134
135
  scale = None,
135
136
  qk_norm = False,
@@ -180,10 +181,11 @@ class Attend(Module):
180
181
 
181
182
  # talking heads
182
183
 
183
- assert not (flash and (pre_talking_heads or post_talking_heads)), 'talking heads not compatible with flash attention'
184
+ assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
184
185
 
185
186
  self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
186
187
  self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
188
+ self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
187
189
 
188
190
  if exists(self.pre_softmax_talking_heads):
189
191
  nn.init.dirac_(self.pre_softmax_talking_heads.weight)
@@ -191,6 +193,10 @@ class Attend(Module):
191
193
  if exists(self.post_softmax_talking_heads):
192
194
  nn.init.dirac_(self.post_softmax_talking_heads.weight)
193
195
 
196
+ if exists(self.pre_scale_post_talking_heads):
197
+ # an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
198
+ nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
199
+
194
200
  # selective attention
195
201
 
196
202
  assert not (flash and selective), 'selective attention cannot work on flash attention'
@@ -436,6 +442,9 @@ class Attend(Module):
436
442
 
437
443
  qk_similarities = sim.clone()
438
444
 
445
+ if exists(self.pre_scale_post_talking_heads):
446
+ pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
447
+
439
448
  if exists(self.pre_softmax_talking_heads):
440
449
  sim = sim + self.pre_softmax_talking_heads(sim)
441
450
 
@@ -487,6 +496,9 @@ class Attend(Module):
487
496
  if exists(self.post_softmax_talking_heads):
488
497
  attn = self.post_softmax_talking_heads(attn)
489
498
 
499
+ if exists(self.pre_scale_post_talking_heads):
500
+ attn = attn * pre_to_post_scale
501
+
490
502
  out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
491
503
 
492
504
  intermediates = Intermediates(
@@ -909,6 +909,7 @@ class Attention(Module):
909
909
  flash = False,
910
910
  pre_talking_heads = False,
911
911
  post_talking_heads = False,
912
+ pre_scale_post_talking_heads = False,
912
913
  head_scale = False,
913
914
  sparse_topk = None,
914
915
  num_mem_kv = 0,
@@ -1039,6 +1040,7 @@ class Attention(Module):
1039
1040
  causal = causal,
1040
1041
  pre_talking_heads = pre_talking_heads,
1041
1042
  post_talking_heads = post_talking_heads,
1043
+ pre_scale_post_talking_heads = pre_scale_post_talking_heads,
1042
1044
  dropout = dropout,
1043
1045
  sparse_topk = sparse_topk,
1044
1046
  qk_norm = qk_norm,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.39.0
3
+ Version: 1.39.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=sTeX7DmUt6I5FhHtgcTDOIvmD1CvJ1PmVjZ_-lYO-QA,15596
2
+ x_transformers/attend.py,sha256=JlWwEzigY5sl9ktDwDWWQD9np9uPRoj2eRo9XU6tJc0,16273
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=nfq-EOLx5HWf8tXlmVuDbkhqNfFnfqRMCEkALK3SFkA,84341
8
+ x_transformers/x_transformers.py,sha256=8ZQR6OLT4vusIjJXzrdSp12Fydmmpcc2t5cDE6SxPNc,84460
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.39.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.39.0.dist-info/METADATA,sha256=-ppCMMH6ZTsmwaMJB9q4b4Yvd-nU8v95l5SdzTY17OU,661
13
- x_transformers-1.39.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.39.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.39.0.dist-info/RECORD,,
11
+ x_transformers-1.39.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.39.1.dist-info/METADATA,sha256=KJfw4hIDzozyRlsnBaqbbfcFSXygbv1y2B_6cl9pu-4,661
13
+ x_transformers-1.39.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.39.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.39.1.dist-info/RECORD,,