x-transformers 1.30.9__py3-none-any.whl → 1.30.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -311,19 +311,33 @@ class CoPE(Module):
311
311
  def __init__ (
312
312
  self,
313
313
  dim,
314
+ heads,
314
315
  max_pos,
315
316
  soft_onehot = False,
317
+ talking_heads = False,
318
+ soft_onehot_temp = 5e-2
316
319
  ):
317
320
  super () . __init__ ()
318
321
  self.max_pos = max_pos
319
322
  self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
320
323
 
324
+ self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
321
325
  self.soft_onehot = soft_onehot
326
+ self.soft_onehot_temp = soft_onehot_temp
322
327
 
323
328
  if soft_onehot:
324
329
  self.register_buffer('positions', torch.arange(max_pos))
325
330
 
326
- def forward(self, query, attn_logits, temp = 5e-2):
331
+ def forward(self, query, attn_logits):
332
+
333
+ if exists(self.talking_heads):
334
+ i, j = attn_logits.shape[-2:]
335
+ causal_mask = attn_logits.new_ones(i, j).triu_(j - i + 1).bool()
336
+
337
+ attn_logits = self.talking_heads(attn_logits)
338
+
339
+ attn_logits = attn_logits.masked_fill(causal_mask, -torch.finfo(attn_logits.dtype).max)
340
+
327
341
  # compute positions
328
342
 
329
343
  gates = attn_logits.sigmoid()
@@ -335,7 +349,7 @@ class CoPE(Module):
335
349
 
336
350
  if self.soft_onehot:
337
351
  diff_pos = (pos[..., None] - self.positions).abs()
338
- soft_onehot_pos = F.softmax(-diff_pos / temp, dim = -1)
352
+ soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
339
353
  cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
340
354
  else:
341
355
  # interpolate from integer positions
@@ -770,6 +784,7 @@ class Attention(Module):
770
784
  use_cope = False,
771
785
  cope_max_pos = 16,
772
786
  cope_soft_onehot_pos = False,
787
+ cope_talking_heads = False,
773
788
  logit_softclamp_value = None,
774
789
  onnxable = False
775
790
  ):
@@ -853,7 +868,13 @@ class Attention(Module):
853
868
  assert causal, 'CoPE was designed for causal attention'
854
869
  assert not flash, 'CoPE is not flash attention compatible'
855
870
 
856
- cope = CoPE(dim_head, cope_max_pos, soft_onehot = cope_soft_onehot_pos)
871
+ cope = CoPE(
872
+ dim = dim_head,
873
+ heads = heads,
874
+ max_pos = cope_max_pos,
875
+ talking_heads = cope_talking_heads,
876
+ soft_onehot = cope_soft_onehot_pos
877
+ )
857
878
 
858
879
  # attend class - includes core attention algorithm + talking heads
859
880
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.9
3
+ Version: 1.30.11
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=V_kR77GhWNu3TuUetu3xryWoFliwTJdNKRX2lVWnsRc,68274
7
+ x_transformers/x_transformers.py,sha256=97WREYHfGljhrveGNUUSEWk2xbFvKdM52QXW7cnoBpk,69019
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.9.dist-info/METADATA,sha256=auoB6F1DCe054gY70tE1k048x6wQRtftk8-Pk6nJD-I,661
12
- x_transformers-1.30.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.9.dist-info/RECORD,,
10
+ x_transformers-1.30.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.11.dist-info/METADATA,sha256=kpv8_mEb4DU4zL_oyLwzbWJ7z8WV7mf7c466Hehn-6c,662
12
+ x_transformers-1.30.11.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.11.dist-info/RECORD,,