x-transformers 1.30.7__py3-none-any.whl → 1.30.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +30 -11
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.9.dist-info}/METADATA +1 -1
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.9.dist-info}/RECORD +6 -6
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.9.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.9.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.7.dist-info → x_transformers-1.30.9.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -308,28 +308,46 @@ class CoPE(Module):
|
|
308
308
|
"""
|
309
309
|
Appendix B of https://arxiv.org/abs/2405.18719
|
310
310
|
"""
|
311
|
-
def __init__ (
|
311
|
+
def __init__ (
|
312
|
+
self,
|
313
|
+
dim,
|
314
|
+
max_pos,
|
315
|
+
soft_onehot = False,
|
316
|
+
):
|
312
317
|
super () . __init__ ()
|
313
318
|
self.max_pos = max_pos
|
314
319
|
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
315
320
|
|
316
|
-
|
321
|
+
self.soft_onehot = soft_onehot
|
322
|
+
|
323
|
+
if soft_onehot:
|
324
|
+
self.register_buffer('positions', torch.arange(max_pos))
|
325
|
+
|
326
|
+
def forward(self, query, attn_logits, temp = 5e-2):
|
317
327
|
# compute positions
|
318
328
|
|
319
329
|
gates = attn_logits.sigmoid()
|
330
|
+
|
320
331
|
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
321
332
|
pos = pos.clamp(max = self.max_pos - 1)
|
322
333
|
|
323
|
-
# interpolate from integer positions
|
324
|
-
|
325
|
-
pos_ceil = pos.ceil().long()
|
326
|
-
pos_floor = pos.floor().long()
|
327
334
|
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
|
328
|
-
logits_ceil = logits_int.gather(-1, pos_ceil)
|
329
|
-
logits_floor = logits_int.gather(-1, pos_floor)
|
330
335
|
|
331
|
-
|
332
|
-
|
336
|
+
if self.soft_onehot:
|
337
|
+
diff_pos = (pos[..., None] - self.positions).abs()
|
338
|
+
soft_onehot_pos = F.softmax(-diff_pos / temp, dim = -1)
|
339
|
+
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
|
340
|
+
else:
|
341
|
+
# interpolate from integer positions
|
342
|
+
pos_ceil = pos.ceil().long()
|
343
|
+
pos_floor = pos.floor().long()
|
344
|
+
logits_ceil = logits_int.gather(-1, pos_ceil)
|
345
|
+
logits_floor = logits_int.gather(-1, pos_floor)
|
346
|
+
|
347
|
+
w = pos - pos_floor
|
348
|
+
cope_pos_emb = logits_ceil * w + logits_floor * (1 - w)
|
349
|
+
|
350
|
+
return cope_pos_emb
|
333
351
|
|
334
352
|
class DynamicPositionBias(Module):
|
335
353
|
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
|
@@ -751,6 +769,7 @@ class Attention(Module):
|
|
751
769
|
rotary_embed_values = False,
|
752
770
|
use_cope = False,
|
753
771
|
cope_max_pos = 16,
|
772
|
+
cope_soft_onehot_pos = False,
|
754
773
|
logit_softclamp_value = None,
|
755
774
|
onnxable = False
|
756
775
|
):
|
@@ -834,7 +853,7 @@ class Attention(Module):
|
|
834
853
|
assert causal, 'CoPE was designed for causal attention'
|
835
854
|
assert not flash, 'CoPE is not flash attention compatible'
|
836
855
|
|
837
|
-
cope = CoPE(dim_head, cope_max_pos)
|
856
|
+
cope = CoPE(dim_head, cope_max_pos, soft_onehot = cope_soft_onehot_pos)
|
838
857
|
|
839
858
|
# attend class - includes core attention algorithm + talking heads
|
840
859
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=V_kR77GhWNu3TuUetu3xryWoFliwTJdNKRX2lVWnsRc,68274
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.9.dist-info/METADATA,sha256=auoB6F1DCe054gY70tE1k048x6wQRtftk8-Pk6nJD-I,661
|
12
|
+
x_transformers-1.30.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|