x-transformers 1.30.7__py3-none-any.whl → 1.30.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -308,28 +308,52 @@ class CoPE(Module):
308
308
  """
309
309
  Appendix B of https://arxiv.org/abs/2405.18719
310
310
  """
311
- def __init__ (self, dim, max_pos):
311
+ def __init__ (
312
+ self,
313
+ dim,
314
+ max_pos,
315
+ soft_onehot = True,
316
+ reverse = True
317
+ ):
312
318
  super () . __init__ ()
319
+ self.reverse = reverse
313
320
  self.max_pos = max_pos
314
321
  self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
315
322
 
316
- def forward(self, query, attn_logits):
323
+ self.soft_onehot = soft_onehot
324
+
325
+ if soft_onehot:
326
+ self.register_buffer('positions', torch.arange(max_pos))
327
+
328
+ def forward(self, query, attn_logits, temp = 5e-2):
317
329
  # compute positions
318
330
 
319
331
  gates = attn_logits.sigmoid()
320
- pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
321
- pos = pos.clamp(max = self.max_pos - 1)
322
332
 
323
- # interpolate from integer positions
333
+ if self.reverse:
334
+ pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
335
+ else:
336
+ pos = gates.cumsum(dim = -1)
337
+
338
+ pos = pos.clamp(max = self.max_pos - 1)
324
339
 
325
- pos_ceil = pos.ceil().long()
326
- pos_floor = pos.floor().long()
327
340
  logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
328
- logits_ceil = logits_int.gather(-1, pos_ceil)
329
- logits_floor = logits_int.gather(-1, pos_floor)
330
341
 
331
- w = pos - pos_floor
332
- return logits_ceil * w + logits_floor * (1 - w)
342
+ if self.soft_onehot:
343
+ diff_pos = (pos[..., None] - self.positions).abs()
344
+ soft_onehot_pos = F.softmax(-diff_pos / temp, dim = -1)
345
+ cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
346
+ else:
347
+ # interpolate from integer positions
348
+ pos_ceil = pos.ceil().long()
349
+ pos_floor = pos.floor().long()
350
+ logits_ceil = logits_int.gather(-1, pos_ceil)
351
+ logits_floor = logits_int.gather(-1, pos_floor)
352
+
353
+ w = pos - pos_floor
354
+ cope_pos_emb = logits_ceil * w + logits_floor * (1 - w)
355
+
356
+ return cope_pos_emb
333
357
 
334
358
  class DynamicPositionBias(Module):
335
359
  def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.7
3
+ Version: 1.30.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=r9F_LLp5bQyAlue3bBTRwoRx02noTCh4ICF8oWCw1wE,67657
7
+ x_transformers/x_transformers.py,sha256=5cof7yvOAfFviLh-luafmhtTJDemCPoy9rHHYjWxLu4,68338
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.7.dist-info/METADATA,sha256=_TniCg2s6tlimpfzpWeMCsCMOjsoYwUObBiXFdY-JhA,661
12
- x_transformers-1.30.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.7.dist-info/RECORD,,
10
+ x_transformers-1.30.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.8.dist-info/METADATA,sha256=2L0SfGhrbLMjpRKLwTp1_YH1Amu3g_j1nEuWIuGNqrQ,661
12
+ x_transformers-1.30.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.8.dist-info/RECORD,,