x-transformers 1.38.3__tar.gz → 1.39.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.38.3/x_transformers.egg-info → x_transformers-1.39.0}/PKG-INFO +1 -1
  2. {x_transformers-1.38.3 → x_transformers-1.39.0}/README.md +2 -1
  3. {x_transformers-1.38.3 → x_transformers-1.39.0}/setup.py +1 -1
  4. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/attend.py +11 -9
  5. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/x_transformers.py +4 -2
  6. {x_transformers-1.38.3 → x_transformers-1.39.0/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.38.3 → x_transformers-1.39.0}/LICENSE +0 -0
  8. {x_transformers-1.38.3 → x_transformers-1.39.0}/setup.cfg +0 -0
  9. {x_transformers-1.38.3 → x_transformers-1.39.0}/tests/test_x_transformers.py +0 -0
  10. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/__init__.py +0 -0
  11. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  16. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.38.3 → x_transformers-1.39.0}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.38.3
3
+ Version: 1.39.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -549,7 +549,8 @@ model = TransformerWrapper(
549
549
  dim = 512,
550
550
  depth = 6,
551
551
  heads = 8,
552
- attn_talking_heads = True # turn on information exchange between attention heads
552
+ attn_pre_talking_heads = True, # linear combination across pre-softmax attn logits across heads
553
+ attn_post_talking_heads = True # linear combination across post-softmax attn across heads
553
554
  )
554
555
  )
555
556
  ```
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.38.3',
6
+ version = '1.39.0',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -128,7 +128,8 @@ class Attend(Module):
128
128
  dropout = 0.,
129
129
  causal = False,
130
130
  heads = None,
131
- talking_heads = False,
131
+ pre_talking_heads = True,
132
+ post_talking_heads = True,
132
133
  sparse_topk = None,
133
134
  scale = None,
134
135
  qk_norm = False,
@@ -179,14 +180,15 @@ class Attend(Module):
179
180
 
180
181
  # talking heads
181
182
 
182
- assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
183
+ assert not (flash and (pre_talking_heads or post_talking_heads)), 'talking heads not compatible with flash attention'
183
184
 
184
- self.talking_heads = talking_heads
185
- if talking_heads:
186
- self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
187
- self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
185
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
186
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
188
187
 
188
+ if exists(self.pre_softmax_talking_heads):
189
189
  nn.init.dirac_(self.pre_softmax_talking_heads.weight)
190
+
191
+ if exists(self.post_softmax_talking_heads):
190
192
  nn.init.dirac_(self.post_softmax_talking_heads.weight)
191
193
 
192
194
  # selective attention
@@ -434,8 +436,8 @@ class Attend(Module):
434
436
 
435
437
  qk_similarities = sim.clone()
436
438
 
437
- if self.talking_heads:
438
- sim = self.pre_softmax_talking_heads(sim)
439
+ if exists(self.pre_softmax_talking_heads):
440
+ sim = sim + self.pre_softmax_talking_heads(sim)
439
441
 
440
442
  if exists(attn_bias):
441
443
  sim = sim + attn_bias
@@ -482,7 +484,7 @@ class Attend(Module):
482
484
 
483
485
  attn = self.attn_dropout(attn)
484
486
 
485
- if self.talking_heads:
487
+ if exists(self.post_softmax_talking_heads):
486
488
  attn = self.post_softmax_talking_heads(attn)
487
489
 
488
490
  out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
@@ -907,7 +907,8 @@ class Attention(Module):
907
907
  heads = 8,
908
908
  causal = False,
909
909
  flash = False,
910
- talking_heads = False,
910
+ pre_talking_heads = False,
911
+ post_talking_heads = False,
911
912
  head_scale = False,
912
913
  sparse_topk = None,
913
914
  num_mem_kv = 0,
@@ -1036,7 +1037,8 @@ class Attention(Module):
1036
1037
  self.attend = Attend(
1037
1038
  heads = heads,
1038
1039
  causal = causal,
1039
- talking_heads = talking_heads,
1040
+ pre_talking_heads = pre_talking_heads,
1041
+ post_talking_heads = post_talking_heads,
1040
1042
  dropout = dropout,
1041
1043
  sparse_topk = sparse_topk,
1042
1044
  qk_norm = qk_norm,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.38.3
3
+ Version: 1.39.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes