x-transformers 1.30.10__py3-none-any.whl → 1.30.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +13 -5
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.11.dist-info}/METADATA +1 -1
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.11.dist-info}/RECORD +6 -6
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.11.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.11.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.11.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -314,21 +314,29 @@ class CoPE(Module):
|
|
314
314
|
heads,
|
315
315
|
max_pos,
|
316
316
|
soft_onehot = False,
|
317
|
-
talking_heads = False
|
317
|
+
talking_heads = False,
|
318
|
+
soft_onehot_temp = 5e-2
|
318
319
|
):
|
319
320
|
super () . __init__ ()
|
320
321
|
self.max_pos = max_pos
|
321
322
|
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
322
323
|
|
323
|
-
self.
|
324
|
+
self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
|
324
325
|
self.soft_onehot = soft_onehot
|
326
|
+
self.soft_onehot_temp = soft_onehot_temp
|
325
327
|
|
326
328
|
if soft_onehot:
|
327
329
|
self.register_buffer('positions', torch.arange(max_pos))
|
328
330
|
|
329
|
-
def forward(self, query, attn_logits
|
331
|
+
def forward(self, query, attn_logits):
|
332
|
+
|
333
|
+
if exists(self.talking_heads):
|
334
|
+
i, j = attn_logits.shape[-2:]
|
335
|
+
causal_mask = attn_logits.new_ones(i, j).triu_(j - i + 1).bool()
|
336
|
+
|
337
|
+
attn_logits = self.talking_heads(attn_logits)
|
330
338
|
|
331
|
-
|
339
|
+
attn_logits = attn_logits.masked_fill(causal_mask, -torch.finfo(attn_logits.dtype).max)
|
332
340
|
|
333
341
|
# compute positions
|
334
342
|
|
@@ -341,7 +349,7 @@ class CoPE(Module):
|
|
341
349
|
|
342
350
|
if self.soft_onehot:
|
343
351
|
diff_pos = (pos[..., None] - self.positions).abs()
|
344
|
-
soft_onehot_pos = F.softmax(-diff_pos /
|
352
|
+
soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
|
345
353
|
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
|
346
354
|
else:
|
347
355
|
# interpolate from integer positions
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=97WREYHfGljhrveGNUUSEWk2xbFvKdM52QXW7cnoBpk,69019
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.11.dist-info/METADATA,sha256=kpv8_mEb4DU4zL_oyLwzbWJ7z8WV7mf7c466Hehn-6c,662
|
12
|
+
x_transformers-1.30.11.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|