x-transformers 1.30.8__py3-none-any.whl → 1.30.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +4 -9
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.9.dist-info}/METADATA +1 -1
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.9.dist-info}/RECORD +6 -6
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.9.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.9.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.9.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -312,11 +312,9 @@ class CoPE(Module):
|
|
312
312
|
self,
|
313
313
|
dim,
|
314
314
|
max_pos,
|
315
|
-
soft_onehot =
|
316
|
-
reverse = True
|
315
|
+
soft_onehot = False,
|
317
316
|
):
|
318
317
|
super () . __init__ ()
|
319
|
-
self.reverse = reverse
|
320
318
|
self.max_pos = max_pos
|
321
319
|
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
322
320
|
|
@@ -330,11 +328,7 @@ class CoPE(Module):
|
|
330
328
|
|
331
329
|
gates = attn_logits.sigmoid()
|
332
330
|
|
333
|
-
|
334
|
-
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
335
|
-
else:
|
336
|
-
pos = gates.cumsum(dim = -1)
|
337
|
-
|
331
|
+
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
338
332
|
pos = pos.clamp(max = self.max_pos - 1)
|
339
333
|
|
340
334
|
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
|
@@ -775,6 +769,7 @@ class Attention(Module):
|
|
775
769
|
rotary_embed_values = False,
|
776
770
|
use_cope = False,
|
777
771
|
cope_max_pos = 16,
|
772
|
+
cope_soft_onehot_pos = False,
|
778
773
|
logit_softclamp_value = None,
|
779
774
|
onnxable = False
|
780
775
|
):
|
@@ -858,7 +853,7 @@ class Attention(Module):
|
|
858
853
|
assert causal, 'CoPE was designed for causal attention'
|
859
854
|
assert not flash, 'CoPE is not flash attention compatible'
|
860
855
|
|
861
|
-
cope = CoPE(dim_head, cope_max_pos)
|
856
|
+
cope = CoPE(dim_head, cope_max_pos, soft_onehot = cope_soft_onehot_pos)
|
862
857
|
|
863
858
|
# attend class - includes core attention algorithm + talking heads
|
864
859
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=V_kR77GhWNu3TuUetu3xryWoFliwTJdNKRX2lVWnsRc,68274
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.9.dist-info/METADATA,sha256=auoB6F1DCe054gY70tE1k048x6wQRtftk8-Pk6nJD-I,661
|
12
|
+
x_transformers-1.30.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|