x-transformers 1.30.9__py3-none-any.whl → 1.30.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +24 -3
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.11.dist-info}/METADATA +1 -1
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.11.dist-info}/RECORD +6 -6
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.11.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.11.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.11.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -311,19 +311,33 @@ class CoPE(Module):
|
|
311
311
|
def __init__ (
|
312
312
|
self,
|
313
313
|
dim,
|
314
|
+
heads,
|
314
315
|
max_pos,
|
315
316
|
soft_onehot = False,
|
317
|
+
talking_heads = False,
|
318
|
+
soft_onehot_temp = 5e-2
|
316
319
|
):
|
317
320
|
super () . __init__ ()
|
318
321
|
self.max_pos = max_pos
|
319
322
|
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
320
323
|
|
324
|
+
self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
|
321
325
|
self.soft_onehot = soft_onehot
|
326
|
+
self.soft_onehot_temp = soft_onehot_temp
|
322
327
|
|
323
328
|
if soft_onehot:
|
324
329
|
self.register_buffer('positions', torch.arange(max_pos))
|
325
330
|
|
326
|
-
def forward(self, query, attn_logits
|
331
|
+
def forward(self, query, attn_logits):
|
332
|
+
|
333
|
+
if exists(self.talking_heads):
|
334
|
+
i, j = attn_logits.shape[-2:]
|
335
|
+
causal_mask = attn_logits.new_ones(i, j).triu_(j - i + 1).bool()
|
336
|
+
|
337
|
+
attn_logits = self.talking_heads(attn_logits)
|
338
|
+
|
339
|
+
attn_logits = attn_logits.masked_fill(causal_mask, -torch.finfo(attn_logits.dtype).max)
|
340
|
+
|
327
341
|
# compute positions
|
328
342
|
|
329
343
|
gates = attn_logits.sigmoid()
|
@@ -335,7 +349,7 @@ class CoPE(Module):
|
|
335
349
|
|
336
350
|
if self.soft_onehot:
|
337
351
|
diff_pos = (pos[..., None] - self.positions).abs()
|
338
|
-
soft_onehot_pos = F.softmax(-diff_pos /
|
352
|
+
soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
|
339
353
|
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
|
340
354
|
else:
|
341
355
|
# interpolate from integer positions
|
@@ -770,6 +784,7 @@ class Attention(Module):
|
|
770
784
|
use_cope = False,
|
771
785
|
cope_max_pos = 16,
|
772
786
|
cope_soft_onehot_pos = False,
|
787
|
+
cope_talking_heads = False,
|
773
788
|
logit_softclamp_value = None,
|
774
789
|
onnxable = False
|
775
790
|
):
|
@@ -853,7 +868,13 @@ class Attention(Module):
|
|
853
868
|
assert causal, 'CoPE was designed for causal attention'
|
854
869
|
assert not flash, 'CoPE is not flash attention compatible'
|
855
870
|
|
856
|
-
cope = CoPE(
|
871
|
+
cope = CoPE(
|
872
|
+
dim = dim_head,
|
873
|
+
heads = heads,
|
874
|
+
max_pos = cope_max_pos,
|
875
|
+
talking_heads = cope_talking_heads,
|
876
|
+
soft_onehot = cope_soft_onehot_pos
|
877
|
+
)
|
857
878
|
|
858
879
|
# attend class - includes core attention algorithm + talking heads
|
859
880
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=97WREYHfGljhrveGNUUSEWk2xbFvKdM52QXW7cnoBpk,69019
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.11.dist-info/METADATA,sha256=kpv8_mEb4DU4zL_oyLwzbWJ7z8WV7mf7c466Hehn-6c,662
|
12
|
+
x_transformers-1.30.11.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|