x-transformers 1.30.9__py3-none-any.whl → 1.30.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -311,19 +311,25 @@ class CoPE(Module):
311
311
  def __init__ (
312
312
  self,
313
313
  dim,
314
+ heads,
314
315
  max_pos,
315
316
  soft_onehot = False,
317
+ talking_heads = False
316
318
  ):
317
319
  super () . __init__ ()
318
320
  self.max_pos = max_pos
319
321
  self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
320
322
 
323
+ self.maybe_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
321
324
  self.soft_onehot = soft_onehot
322
325
 
323
326
  if soft_onehot:
324
327
  self.register_buffer('positions', torch.arange(max_pos))
325
328
 
326
329
  def forward(self, query, attn_logits, temp = 5e-2):
330
+
331
+ attn_logits = self.maybe_talking_heads(attn_logits)
332
+
327
333
  # compute positions
328
334
 
329
335
  gates = attn_logits.sigmoid()
@@ -770,6 +776,7 @@ class Attention(Module):
770
776
  use_cope = False,
771
777
  cope_max_pos = 16,
772
778
  cope_soft_onehot_pos = False,
779
+ cope_talking_heads = False,
773
780
  logit_softclamp_value = None,
774
781
  onnxable = False
775
782
  ):
@@ -853,7 +860,13 @@ class Attention(Module):
853
860
  assert causal, 'CoPE was designed for causal attention'
854
861
  assert not flash, 'CoPE is not flash attention compatible'
855
862
 
856
- cope = CoPE(dim_head, cope_max_pos, soft_onehot = cope_soft_onehot_pos)
863
+ cope = CoPE(
864
+ dim = dim_head,
865
+ heads = heads,
866
+ max_pos = cope_max_pos,
867
+ talking_heads = cope_talking_heads,
868
+ soft_onehot = cope_soft_onehot_pos
869
+ )
857
870
 
858
871
  # attend class - includes core attention algorithm + talking heads
859
872
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.9
3
+ Version: 1.30.10
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=V_kR77GhWNu3TuUetu3xryWoFliwTJdNKRX2lVWnsRc,68274
7
+ x_transformers/x_transformers.py,sha256=qErIXwntVnZEmdO4MwTTsvxMc_vQPBU8zLLpEh535tE,68690
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.9.dist-info/METADATA,sha256=auoB6F1DCe054gY70tE1k048x6wQRtftk8-Pk6nJD-I,661
12
- x_transformers-1.30.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.9.dist-info/RECORD,,
10
+ x_transformers-1.30.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.10.dist-info/METADATA,sha256=Fo2EWkxK4rLMcwUqcyeSW9AljUJA6b_2hZ59W8axIO4,662
12
+ x_transformers-1.30.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.10.dist-info/RECORD,,