x-transformers 1.38.3__py3-none-any.whl → 1.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -128,7 +128,8 @@ class Attend(Module):
128
128
  dropout = 0.,
129
129
  causal = False,
130
130
  heads = None,
131
- talking_heads = False,
131
+ pre_talking_heads = True,
132
+ post_talking_heads = True,
132
133
  sparse_topk = None,
133
134
  scale = None,
134
135
  qk_norm = False,
@@ -179,14 +180,15 @@ class Attend(Module):
179
180
 
180
181
  # talking heads
181
182
 
182
- assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
183
+ assert not (flash and (pre_talking_heads or post_talking_heads)), 'talking heads not compatible with flash attention'
183
184
 
184
- self.talking_heads = talking_heads
185
- if talking_heads:
186
- self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
187
- self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
185
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
186
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
188
187
 
188
+ if exists(self.pre_softmax_talking_heads):
189
189
  nn.init.dirac_(self.pre_softmax_talking_heads.weight)
190
+
191
+ if exists(self.post_softmax_talking_heads):
190
192
  nn.init.dirac_(self.post_softmax_talking_heads.weight)
191
193
 
192
194
  # selective attention
@@ -434,8 +436,8 @@ class Attend(Module):
434
436
 
435
437
  qk_similarities = sim.clone()
436
438
 
437
- if self.talking_heads:
438
- sim = self.pre_softmax_talking_heads(sim)
439
+ if exists(self.pre_softmax_talking_heads):
440
+ sim = sim + self.pre_softmax_talking_heads(sim)
439
441
 
440
442
  if exists(attn_bias):
441
443
  sim = sim + attn_bias
@@ -482,7 +484,7 @@ class Attend(Module):
482
484
 
483
485
  attn = self.attn_dropout(attn)
484
486
 
485
- if self.talking_heads:
487
+ if exists(self.post_softmax_talking_heads):
486
488
  attn = self.post_softmax_talking_heads(attn)
487
489
 
488
490
  out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
@@ -907,7 +907,8 @@ class Attention(Module):
907
907
  heads = 8,
908
908
  causal = False,
909
909
  flash = False,
910
- talking_heads = False,
910
+ pre_talking_heads = False,
911
+ post_talking_heads = False,
911
912
  head_scale = False,
912
913
  sparse_topk = None,
913
914
  num_mem_kv = 0,
@@ -1036,7 +1037,8 @@ class Attention(Module):
1036
1037
  self.attend = Attend(
1037
1038
  heads = heads,
1038
1039
  causal = causal,
1039
- talking_heads = talking_heads,
1040
+ pre_talking_heads = pre_talking_heads,
1041
+ post_talking_heads = post_talking_heads,
1040
1042
  dropout = dropout,
1041
1043
  sparse_topk = sparse_topk,
1042
1044
  qk_norm = qk_norm,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.38.3
3
+ Version: 1.39.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=Z9cnbY3f9zCl1yUOMWVQT5_ee4keC0lo_NX4cd0rbKk,15393
2
+ x_transformers/attend.py,sha256=sTeX7DmUt6I5FhHtgcTDOIvmD1CvJ1PmVjZ_-lYO-QA,15596
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=P3o1R2DY2ic71FlJJ4ie4w_z-g3jIIrkBXcbllRoXHA,84240
8
+ x_transformers/x_transformers.py,sha256=nfq-EOLx5HWf8tXlmVuDbkhqNfFnfqRMCEkALK3SFkA,84341
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.38.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.38.3.dist-info/METADATA,sha256=XbmJ91NWqmOjO0xvyXh9uh82TIwVbU54L3eFmaFVvYs,661
13
- x_transformers-1.38.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.38.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.38.3.dist-info/RECORD,,
11
+ x_transformers-1.39.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.39.0.dist-info/METADATA,sha256=-ppCMMH6ZTsmwaMJB9q4b4Yvd-nU8v95l5SdzTY17OU,661
13
+ x_transformers-1.39.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.39.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.39.0.dist-info/RECORD,,