x-transformers 1.30.7__py3-none-any.whl → 1.30.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -308,28 +308,46 @@ class CoPE(Module):
308
308
  """
309
309
  Appendix B of https://arxiv.org/abs/2405.18719
310
310
  """
311
- def __init__ (self, dim, max_pos):
311
+ def __init__ (
312
+ self,
313
+ dim,
314
+ max_pos,
315
+ soft_onehot = False,
316
+ ):
312
317
  super () . __init__ ()
313
318
  self.max_pos = max_pos
314
319
  self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
315
320
 
316
- def forward(self, query, attn_logits):
321
+ self.soft_onehot = soft_onehot
322
+
323
+ if soft_onehot:
324
+ self.register_buffer('positions', torch.arange(max_pos))
325
+
326
+ def forward(self, query, attn_logits, temp = 5e-2):
317
327
  # compute positions
318
328
 
319
329
  gates = attn_logits.sigmoid()
330
+
320
331
  pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
321
332
  pos = pos.clamp(max = self.max_pos - 1)
322
333
 
323
- # interpolate from integer positions
324
-
325
- pos_ceil = pos.ceil().long()
326
- pos_floor = pos.floor().long()
327
334
  logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
328
- logits_ceil = logits_int.gather(-1, pos_ceil)
329
- logits_floor = logits_int.gather(-1, pos_floor)
330
335
 
331
- w = pos - pos_floor
332
- return logits_ceil * w + logits_floor * (1 - w)
336
+ if self.soft_onehot:
337
+ diff_pos = (pos[..., None] - self.positions).abs()
338
+ soft_onehot_pos = F.softmax(-diff_pos / temp, dim = -1)
339
+ cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
340
+ else:
341
+ # interpolate from integer positions
342
+ pos_ceil = pos.ceil().long()
343
+ pos_floor = pos.floor().long()
344
+ logits_ceil = logits_int.gather(-1, pos_ceil)
345
+ logits_floor = logits_int.gather(-1, pos_floor)
346
+
347
+ w = pos - pos_floor
348
+ cope_pos_emb = logits_ceil * w + logits_floor * (1 - w)
349
+
350
+ return cope_pos_emb
333
351
 
334
352
  class DynamicPositionBias(Module):
335
353
  def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
@@ -751,6 +769,7 @@ class Attention(Module):
751
769
  rotary_embed_values = False,
752
770
  use_cope = False,
753
771
  cope_max_pos = 16,
772
+ cope_soft_onehot_pos = False,
754
773
  logit_softclamp_value = None,
755
774
  onnxable = False
756
775
  ):
@@ -834,7 +853,7 @@ class Attention(Module):
834
853
  assert causal, 'CoPE was designed for causal attention'
835
854
  assert not flash, 'CoPE is not flash attention compatible'
836
855
 
837
- cope = CoPE(dim_head, cope_max_pos)
856
+ cope = CoPE(dim_head, cope_max_pos, soft_onehot = cope_soft_onehot_pos)
838
857
 
839
858
  # attend class - includes core attention algorithm + talking heads
840
859
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.7
3
+ Version: 1.30.9
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=r9F_LLp5bQyAlue3bBTRwoRx02noTCh4ICF8oWCw1wE,67657
7
+ x_transformers/x_transformers.py,sha256=V_kR77GhWNu3TuUetu3xryWoFliwTJdNKRX2lVWnsRc,68274
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.7.dist-info/METADATA,sha256=_TniCg2s6tlimpfzpWeMCsCMOjsoYwUObBiXFdY-JhA,661
12
- x_transformers-1.30.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.7.dist-info/RECORD,,
10
+ x_transformers-1.30.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.9.dist-info/METADATA,sha256=auoB6F1DCe054gY70tE1k048x6wQRtftk8-Pk6nJD-I,661
12
+ x_transformers-1.30.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.9.dist-info/RECORD,,