x-transformers 1.30.8__py3-none-any.whl → 1.30.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +17 -9
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.10.dist-info}/METADATA +1 -1
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.10.dist-info}/RECORD +6 -6
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.10.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.10.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.8.dist-info → x_transformers-1.30.10.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -311,30 +311,30 @@ class CoPE(Module):
|
|
311
311
|
def __init__ (
|
312
312
|
self,
|
313
313
|
dim,
|
314
|
+
heads,
|
314
315
|
max_pos,
|
315
|
-
soft_onehot =
|
316
|
-
|
316
|
+
soft_onehot = False,
|
317
|
+
talking_heads = False
|
317
318
|
):
|
318
319
|
super () . __init__ ()
|
319
|
-
self.reverse = reverse
|
320
320
|
self.max_pos = max_pos
|
321
321
|
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
322
322
|
|
323
|
+
self.maybe_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
|
323
324
|
self.soft_onehot = soft_onehot
|
324
325
|
|
325
326
|
if soft_onehot:
|
326
327
|
self.register_buffer('positions', torch.arange(max_pos))
|
327
328
|
|
328
329
|
def forward(self, query, attn_logits, temp = 5e-2):
|
330
|
+
|
331
|
+
attn_logits = self.maybe_talking_heads(attn_logits)
|
332
|
+
|
329
333
|
# compute positions
|
330
334
|
|
331
335
|
gates = attn_logits.sigmoid()
|
332
336
|
|
333
|
-
|
334
|
-
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
335
|
-
else:
|
336
|
-
pos = gates.cumsum(dim = -1)
|
337
|
-
|
337
|
+
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
338
338
|
pos = pos.clamp(max = self.max_pos - 1)
|
339
339
|
|
340
340
|
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
|
@@ -775,6 +775,8 @@ class Attention(Module):
|
|
775
775
|
rotary_embed_values = False,
|
776
776
|
use_cope = False,
|
777
777
|
cope_max_pos = 16,
|
778
|
+
cope_soft_onehot_pos = False,
|
779
|
+
cope_talking_heads = False,
|
778
780
|
logit_softclamp_value = None,
|
779
781
|
onnxable = False
|
780
782
|
):
|
@@ -858,7 +860,13 @@ class Attention(Module):
|
|
858
860
|
assert causal, 'CoPE was designed for causal attention'
|
859
861
|
assert not flash, 'CoPE is not flash attention compatible'
|
860
862
|
|
861
|
-
cope = CoPE(
|
863
|
+
cope = CoPE(
|
864
|
+
dim = dim_head,
|
865
|
+
heads = heads,
|
866
|
+
max_pos = cope_max_pos,
|
867
|
+
talking_heads = cope_talking_heads,
|
868
|
+
soft_onehot = cope_soft_onehot_pos
|
869
|
+
)
|
862
870
|
|
863
871
|
# attend class - includes core attention algorithm + talking heads
|
864
872
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=qErIXwntVnZEmdO4MwTTsvxMc_vQPBU8zLLpEh535tE,68690
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.10.dist-info/METADATA,sha256=Fo2EWkxK4rLMcwUqcyeSW9AljUJA6b_2hZ59W8axIO4,662
|
12
|
+
x_transformers-1.30.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|