x-transformers 1.30.7__py3-none-any.whl → 1.30.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +35 -11
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.8.dist-info}/METADATA +1 -1
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.8.dist-info}/RECORD +6 -6
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.8.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.8.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.8.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -308,28 +308,52 @@ class CoPE(Module):
|
|
308
308
|
"""
|
309
309
|
Appendix B of https://arxiv.org/abs/2405.18719
|
310
310
|
"""
|
311
|
-
def __init__ (
|
311
|
+
def __init__ (
|
312
|
+
self,
|
313
|
+
dim,
|
314
|
+
max_pos,
|
315
|
+
soft_onehot = True,
|
316
|
+
reverse = True
|
317
|
+
):
|
312
318
|
super () . __init__ ()
|
319
|
+
self.reverse = reverse
|
313
320
|
self.max_pos = max_pos
|
314
321
|
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
315
322
|
|
316
|
-
|
323
|
+
self.soft_onehot = soft_onehot
|
324
|
+
|
325
|
+
if soft_onehot:
|
326
|
+
self.register_buffer('positions', torch.arange(max_pos))
|
327
|
+
|
328
|
+
def forward(self, query, attn_logits, temp = 5e-2):
|
317
329
|
# compute positions
|
318
330
|
|
319
331
|
gates = attn_logits.sigmoid()
|
320
|
-
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
321
|
-
pos = pos.clamp(max = self.max_pos - 1)
|
322
332
|
|
323
|
-
|
333
|
+
if self.reverse:
|
334
|
+
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
335
|
+
else:
|
336
|
+
pos = gates.cumsum(dim = -1)
|
337
|
+
|
338
|
+
pos = pos.clamp(max = self.max_pos - 1)
|
324
339
|
|
325
|
-
pos_ceil = pos.ceil().long()
|
326
|
-
pos_floor = pos.floor().long()
|
327
340
|
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
|
328
|
-
logits_ceil = logits_int.gather(-1, pos_ceil)
|
329
|
-
logits_floor = logits_int.gather(-1, pos_floor)
|
330
341
|
|
331
|
-
|
332
|
-
|
342
|
+
if self.soft_onehot:
|
343
|
+
diff_pos = (pos[..., None] - self.positions).abs()
|
344
|
+
soft_onehot_pos = F.softmax(-diff_pos / temp, dim = -1)
|
345
|
+
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
|
346
|
+
else:
|
347
|
+
# interpolate from integer positions
|
348
|
+
pos_ceil = pos.ceil().long()
|
349
|
+
pos_floor = pos.floor().long()
|
350
|
+
logits_ceil = logits_int.gather(-1, pos_ceil)
|
351
|
+
logits_floor = logits_int.gather(-1, pos_floor)
|
352
|
+
|
353
|
+
w = pos - pos_floor
|
354
|
+
cope_pos_emb = logits_ceil * w + logits_floor * (1 - w)
|
355
|
+
|
356
|
+
return cope_pos_emb
|
333
357
|
|
334
358
|
class DynamicPositionBias(Module):
|
335
359
|
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=5cof7yvOAfFviLh-luafmhtTJDemCPoy9rHHYjWxLu4,68338
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.8.dist-info/METADATA,sha256=2L0SfGhrbLMjpRKLwTp1_YH1Amu3g_j1nEuWIuGNqrQ,661
|
12
|
+
x_transformers-1.30.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|