x-transformers 1.37.9__py3-none-any.whl → 1.37.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +10 -7
- x_transformers/x_transformers.py +2 -0
- {x_transformers-1.37.9.dist-info → x_transformers-1.37.10.dist-info}/METADATA +1 -1
- {x_transformers-1.37.9.dist-info → x_transformers-1.37.10.dist-info}/RECORD +7 -7
- {x_transformers-1.37.9.dist-info → x_transformers-1.37.10.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.9.dist-info → x_transformers-1.37.10.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.9.dist-info → x_transformers-1.37.10.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from functools import partial
|
4
|
-
from typing import Tuple
|
4
|
+
from typing import Tuple, Callable
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
@@ -104,6 +104,7 @@ class Attend(Module):
|
|
104
104
|
qk_norm = False,
|
105
105
|
l2_distance = False,
|
106
106
|
sigmoid = False,
|
107
|
+
custom_attn_fn: Callable | None = None,
|
107
108
|
flash = False,
|
108
109
|
softclamp_logits = False,
|
109
110
|
logit_softclamp_value = 50.,
|
@@ -132,7 +133,9 @@ class Attend(Module):
|
|
132
133
|
|
133
134
|
self.sigmoid = sigmoid
|
134
135
|
|
135
|
-
if
|
136
|
+
if exists(custom_attn_fn):
|
137
|
+
self.attn_fn = custom_attn_fn
|
138
|
+
elif not sigmoid:
|
136
139
|
softmax_fn = partial(F.softmax, dim = -1)
|
137
140
|
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
138
141
|
else:
|
@@ -404,11 +407,6 @@ class Attend(Module):
|
|
404
407
|
|
405
408
|
mask_value = -torch.finfo(sim.dtype).max
|
406
409
|
|
407
|
-
if exists(self.sparse_topk) and self.sparse_topk < j:
|
408
|
-
top_values, _ = sim.topk(self.sparse_topk, dim = -1)
|
409
|
-
sparse_topk_mask = sim < top_values[..., -1:]
|
410
|
-
mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
|
411
|
-
|
412
410
|
if exists(mask):
|
413
411
|
sim = sim.masked_fill(~mask, mask_value)
|
414
412
|
|
@@ -416,6 +414,11 @@ class Attend(Module):
|
|
416
414
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
417
415
|
sim = sim.masked_fill(causal_mask, mask_value)
|
418
416
|
|
417
|
+
if exists(self.sparse_topk):
|
418
|
+
top_values, _ = sim.topk(self.sparse_topk, dim = -1)
|
419
|
+
sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
|
420
|
+
sim = sim.masked_fill(~sparse_topk_mask, mask_value)
|
421
|
+
|
419
422
|
row_is_entirely_masked = None
|
420
423
|
|
421
424
|
if exists(mask):
|
x_transformers/x_transformers.py
CHANGED
@@ -925,6 +925,7 @@ class Attention(Module):
|
|
925
925
|
qk_norm_dim_scale = False,
|
926
926
|
l2_distance = False,
|
927
927
|
sigmoid = False,
|
928
|
+
custom_attn_fn: Callable | None = None,
|
928
929
|
one_kv_head = False,
|
929
930
|
kv_heads = None,
|
930
931
|
shared_kv = False,
|
@@ -1041,6 +1042,7 @@ class Attention(Module):
|
|
1041
1042
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1042
1043
|
l2_distance = l2_distance,
|
1043
1044
|
sigmoid = sigmoid,
|
1045
|
+
custom_attn_fn = custom_attn_fn,
|
1044
1046
|
add_zero_kv = add_zero_kv,
|
1045
1047
|
flash = flash,
|
1046
1048
|
softclamp_logits = softclamp_logits,
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=MJmMLIt0rHx-JNNsc2auUsCjsB-69NewufaRV32ADmA,14012
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=q6I6rvyYUWLgwtKOxPwF12UL1HzcIlauI8YrM8gvZac,83901
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.37.
|
12
|
-
x_transformers-1.37.
|
13
|
-
x_transformers-1.37.
|
14
|
-
x_transformers-1.37.
|
15
|
-
x_transformers-1.37.
|
11
|
+
x_transformers-1.37.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.10.dist-info/METADATA,sha256=qA7YAj5ZeaesnGamBR-cPSR_0HSwquwPBptaUmi7c3c,662
|
13
|
+
x_transformers-1.37.10.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|