x-transformers 1.30.9__py3-none-any.whl → 1.30.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +14 -1
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.10.dist-info}/METADATA +1 -1
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.10.dist-info}/RECORD +6 -6
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.10.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.10.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.9.dist-info → x_transformers-1.30.10.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -311,19 +311,25 @@ class CoPE(Module):
|
|
311
311
|
def __init__ (
|
312
312
|
self,
|
313
313
|
dim,
|
314
|
+
heads,
|
314
315
|
max_pos,
|
315
316
|
soft_onehot = False,
|
317
|
+
talking_heads = False
|
316
318
|
):
|
317
319
|
super () . __init__ ()
|
318
320
|
self.max_pos = max_pos
|
319
321
|
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
320
322
|
|
323
|
+
self.maybe_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
|
321
324
|
self.soft_onehot = soft_onehot
|
322
325
|
|
323
326
|
if soft_onehot:
|
324
327
|
self.register_buffer('positions', torch.arange(max_pos))
|
325
328
|
|
326
329
|
def forward(self, query, attn_logits, temp = 5e-2):
|
330
|
+
|
331
|
+
attn_logits = self.maybe_talking_heads(attn_logits)
|
332
|
+
|
327
333
|
# compute positions
|
328
334
|
|
329
335
|
gates = attn_logits.sigmoid()
|
@@ -770,6 +776,7 @@ class Attention(Module):
|
|
770
776
|
use_cope = False,
|
771
777
|
cope_max_pos = 16,
|
772
778
|
cope_soft_onehot_pos = False,
|
779
|
+
cope_talking_heads = False,
|
773
780
|
logit_softclamp_value = None,
|
774
781
|
onnxable = False
|
775
782
|
):
|
@@ -853,7 +860,13 @@ class Attention(Module):
|
|
853
860
|
assert causal, 'CoPE was designed for causal attention'
|
854
861
|
assert not flash, 'CoPE is not flash attention compatible'
|
855
862
|
|
856
|
-
cope = CoPE(
|
863
|
+
cope = CoPE(
|
864
|
+
dim = dim_head,
|
865
|
+
heads = heads,
|
866
|
+
max_pos = cope_max_pos,
|
867
|
+
talking_heads = cope_talking_heads,
|
868
|
+
soft_onehot = cope_soft_onehot_pos
|
869
|
+
)
|
857
870
|
|
858
871
|
# attend class - includes core attention algorithm + talking heads
|
859
872
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=qErIXwntVnZEmdO4MwTTsvxMc_vQPBU8zLLpEh535tE,68690
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.10.dist-info/METADATA,sha256=Fo2EWkxK4rLMcwUqcyeSW9AljUJA6b_2hZ59W8axIO4,662
|
12
|
+
x_transformers-1.30.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|