x-transformers 1.38.3__py3-none-any.whl → 1.39.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -128,7 +128,9 @@ class Attend(Module):
128
128
  dropout = 0.,
129
129
  causal = False,
130
130
  heads = None,
131
- talking_heads = False,
131
+ pre_talking_heads = False,
132
+ post_talking_heads = False,
133
+ pre_scale_post_talking_heads = False,
132
134
  sparse_topk = None,
133
135
  scale = None,
134
136
  qk_norm = False,
@@ -179,16 +181,22 @@ class Attend(Module):
179
181
 
180
182
  # talking heads
181
183
 
182
- assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
184
+ assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
183
185
 
184
- self.talking_heads = talking_heads
185
- if talking_heads:
186
- self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
187
- self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
186
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
187
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
188
+ self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
188
189
 
190
+ if exists(self.pre_softmax_talking_heads):
189
191
  nn.init.dirac_(self.pre_softmax_talking_heads.weight)
192
+
193
+ if exists(self.post_softmax_talking_heads):
190
194
  nn.init.dirac_(self.post_softmax_talking_heads.weight)
191
195
 
196
+ if exists(self.pre_scale_post_talking_heads):
197
+ # an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
198
+ nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
199
+
192
200
  # selective attention
193
201
 
194
202
  assert not (flash and selective), 'selective attention cannot work on flash attention'
@@ -434,8 +442,11 @@ class Attend(Module):
434
442
 
435
443
  qk_similarities = sim.clone()
436
444
 
437
- if self.talking_heads:
438
- sim = self.pre_softmax_talking_heads(sim)
445
+ if exists(self.pre_scale_post_talking_heads):
446
+ pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
447
+
448
+ if exists(self.pre_softmax_talking_heads):
449
+ sim = sim + self.pre_softmax_talking_heads(sim)
439
450
 
440
451
  if exists(attn_bias):
441
452
  sim = sim + attn_bias
@@ -482,9 +493,12 @@ class Attend(Module):
482
493
 
483
494
  attn = self.attn_dropout(attn)
484
495
 
485
- if self.talking_heads:
496
+ if exists(self.post_softmax_talking_heads):
486
497
  attn = self.post_softmax_talking_heads(attn)
487
498
 
499
+ if exists(self.pre_scale_post_talking_heads):
500
+ attn = attn * pre_to_post_scale
501
+
488
502
  out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
489
503
 
490
504
  intermediates = Intermediates(
@@ -907,7 +907,9 @@ class Attention(Module):
907
907
  heads = 8,
908
908
  causal = False,
909
909
  flash = False,
910
- talking_heads = False,
910
+ pre_talking_heads = False,
911
+ post_talking_heads = False,
912
+ pre_scale_post_talking_heads = False,
911
913
  head_scale = False,
912
914
  sparse_topk = None,
913
915
  num_mem_kv = 0,
@@ -1036,7 +1038,9 @@ class Attention(Module):
1036
1038
  self.attend = Attend(
1037
1039
  heads = heads,
1038
1040
  causal = causal,
1039
- talking_heads = talking_heads,
1041
+ pre_talking_heads = pre_talking_heads,
1042
+ post_talking_heads = post_talking_heads,
1043
+ pre_scale_post_talking_heads = pre_scale_post_talking_heads,
1040
1044
  dropout = dropout,
1041
1045
  sparse_topk = sparse_topk,
1042
1046
  qk_norm = qk_norm,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.38.3
3
+ Version: 1.39.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=Z9cnbY3f9zCl1yUOMWVQT5_ee4keC0lo_NX4cd0rbKk,15393
2
+ x_transformers/attend.py,sha256=JlWwEzigY5sl9ktDwDWWQD9np9uPRoj2eRo9XU6tJc0,16273
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=P3o1R2DY2ic71FlJJ4ie4w_z-g3jIIrkBXcbllRoXHA,84240
8
+ x_transformers/x_transformers.py,sha256=8ZQR6OLT4vusIjJXzrdSp12Fydmmpcc2t5cDE6SxPNc,84460
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.38.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.38.3.dist-info/METADATA,sha256=XbmJ91NWqmOjO0xvyXh9uh82TIwVbU54L3eFmaFVvYs,661
13
- x_transformers-1.38.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.38.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.38.3.dist-info/RECORD,,
11
+ x_transformers-1.39.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.39.1.dist-info/METADATA,sha256=KJfw4hIDzozyRlsnBaqbbfcFSXygbv1y2B_6cl9pu-4,661
13
+ x_transformers-1.39.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.39.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.39.1.dist-info/RECORD,,