x-transformers 1.30.8__py3-none-any.whl → 1.30.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -311,30 +311,30 @@ class CoPE(Module):
311
311
  def __init__ (
312
312
  self,
313
313
  dim,
314
+ heads,
314
315
  max_pos,
315
- soft_onehot = True,
316
- reverse = True
316
+ soft_onehot = False,
317
+ talking_heads = False
317
318
  ):
318
319
  super () . __init__ ()
319
- self.reverse = reverse
320
320
  self.max_pos = max_pos
321
321
  self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
322
322
 
323
+ self.maybe_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
323
324
  self.soft_onehot = soft_onehot
324
325
 
325
326
  if soft_onehot:
326
327
  self.register_buffer('positions', torch.arange(max_pos))
327
328
 
328
329
  def forward(self, query, attn_logits, temp = 5e-2):
330
+
331
+ attn_logits = self.maybe_talking_heads(attn_logits)
332
+
329
333
  # compute positions
330
334
 
331
335
  gates = attn_logits.sigmoid()
332
336
 
333
- if self.reverse:
334
- pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
335
- else:
336
- pos = gates.cumsum(dim = -1)
337
-
337
+ pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
338
338
  pos = pos.clamp(max = self.max_pos - 1)
339
339
 
340
340
  logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
@@ -775,6 +775,8 @@ class Attention(Module):
775
775
  rotary_embed_values = False,
776
776
  use_cope = False,
777
777
  cope_max_pos = 16,
778
+ cope_soft_onehot_pos = False,
779
+ cope_talking_heads = False,
778
780
  logit_softclamp_value = None,
779
781
  onnxable = False
780
782
  ):
@@ -858,7 +860,13 @@ class Attention(Module):
858
860
  assert causal, 'CoPE was designed for causal attention'
859
861
  assert not flash, 'CoPE is not flash attention compatible'
860
862
 
861
- cope = CoPE(dim_head, cope_max_pos)
863
+ cope = CoPE(
864
+ dim = dim_head,
865
+ heads = heads,
866
+ max_pos = cope_max_pos,
867
+ talking_heads = cope_talking_heads,
868
+ soft_onehot = cope_soft_onehot_pos
869
+ )
862
870
 
863
871
  # attend class - includes core attention algorithm + talking heads
864
872
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.8
3
+ Version: 1.30.10
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=5cof7yvOAfFviLh-luafmhtTJDemCPoy9rHHYjWxLu4,68338
7
+ x_transformers/x_transformers.py,sha256=qErIXwntVnZEmdO4MwTTsvxMc_vQPBU8zLLpEh535tE,68690
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.8.dist-info/METADATA,sha256=2L0SfGhrbLMjpRKLwTp1_YH1Amu3g_j1nEuWIuGNqrQ,661
12
- x_transformers-1.30.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.8.dist-info/RECORD,,
10
+ x_transformers-1.30.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.10.dist-info/METADATA,sha256=Fo2EWkxK4rLMcwUqcyeSW9AljUJA6b_2hZ59W8axIO4,662
12
+ x_transformers-1.30.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.10.dist-info/RECORD,,