x-transformers 1.37.8__py3-none-any.whl → 1.37.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +30 -8
- x_transformers/x_transformers.py +4 -0
- {x_transformers-1.37.8.dist-info → x_transformers-1.37.10.dist-info}/METADATA +1 -1
- {x_transformers-1.37.8.dist-info → x_transformers-1.37.10.dist-info}/RECORD +7 -7
- {x_transformers-1.37.8.dist-info → x_transformers-1.37.10.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.8.dist-info → x_transformers-1.37.10.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.8.dist-info → x_transformers-1.37.10.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from functools import partial
|
4
|
-
from typing import Tuple
|
4
|
+
from typing import Tuple, Callable
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
@@ -36,6 +36,9 @@ def exists(val):
|
|
36
36
|
def default(val, d):
|
37
37
|
return val if exists(val) else d
|
38
38
|
|
39
|
+
def at_most_one_of(*bools):
|
40
|
+
return sum([*map(int, bools)]) <= 1
|
41
|
+
|
39
42
|
def compact(arr):
|
40
43
|
return [*filter(exists, arr)]
|
41
44
|
|
@@ -100,6 +103,8 @@ class Attend(Module):
|
|
100
103
|
scale = None,
|
101
104
|
qk_norm = False,
|
102
105
|
l2_distance = False,
|
106
|
+
sigmoid = False,
|
107
|
+
custom_attn_fn: Callable | None = None,
|
103
108
|
flash = False,
|
104
109
|
softclamp_logits = False,
|
105
110
|
logit_softclamp_value = 50.,
|
@@ -116,10 +121,27 @@ class Attend(Module):
|
|
116
121
|
super().__init__()
|
117
122
|
self.scale = scale
|
118
123
|
|
124
|
+
# causal related
|
125
|
+
|
119
126
|
self.causal = causal
|
120
127
|
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
|
121
128
|
|
122
|
-
|
129
|
+
# attention type
|
130
|
+
|
131
|
+
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
|
132
|
+
assert at_most_one_of(sigmoid, l2_distance)
|
133
|
+
|
134
|
+
self.sigmoid = sigmoid
|
135
|
+
|
136
|
+
if exists(custom_attn_fn):
|
137
|
+
self.attn_fn = custom_attn_fn
|
138
|
+
elif not sigmoid:
|
139
|
+
softmax_fn = partial(F.softmax, dim = -1)
|
140
|
+
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
141
|
+
else:
|
142
|
+
self.attn_fn = F.sigmoid
|
143
|
+
|
144
|
+
# dropouts
|
123
145
|
|
124
146
|
self.dropout = dropout
|
125
147
|
self.attn_dropout = nn.Dropout(dropout)
|
@@ -385,11 +407,6 @@ class Attend(Module):
|
|
385
407
|
|
386
408
|
mask_value = -torch.finfo(sim.dtype).max
|
387
409
|
|
388
|
-
if exists(self.sparse_topk) and self.sparse_topk < j:
|
389
|
-
top_values, _ = sim.topk(self.sparse_topk, dim = -1)
|
390
|
-
sparse_topk_mask = sim < top_values[..., -1:]
|
391
|
-
mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
|
392
|
-
|
393
410
|
if exists(mask):
|
394
411
|
sim = sim.masked_fill(~mask, mask_value)
|
395
412
|
|
@@ -397,6 +414,11 @@ class Attend(Module):
|
|
397
414
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
398
415
|
sim = sim.masked_fill(causal_mask, mask_value)
|
399
416
|
|
417
|
+
if exists(self.sparse_topk):
|
418
|
+
top_values, _ = sim.topk(self.sparse_topk, dim = -1)
|
419
|
+
sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
|
420
|
+
sim = sim.masked_fill(~sparse_topk_mask, mask_value)
|
421
|
+
|
400
422
|
row_is_entirely_masked = None
|
401
423
|
|
402
424
|
if exists(mask):
|
@@ -410,7 +432,7 @@ class Attend(Module):
|
|
410
432
|
if self.sigsoftmax:
|
411
433
|
sim = sim + sim.sigmoid().log()
|
412
434
|
|
413
|
-
attn = self.attn_fn(sim
|
435
|
+
attn = self.attn_fn(sim)
|
414
436
|
attn = attn.type(dtype)
|
415
437
|
|
416
438
|
post_softmax_attn = attn
|
x_transformers/x_transformers.py
CHANGED
@@ -924,6 +924,8 @@ class Attention(Module):
|
|
924
924
|
qk_norm_scale = 10,
|
925
925
|
qk_norm_dim_scale = False,
|
926
926
|
l2_distance = False,
|
927
|
+
sigmoid = False,
|
928
|
+
custom_attn_fn: Callable | None = None,
|
927
929
|
one_kv_head = False,
|
928
930
|
kv_heads = None,
|
929
931
|
shared_kv = False,
|
@@ -1039,6 +1041,8 @@ class Attention(Module):
|
|
1039
1041
|
qk_norm = qk_norm,
|
1040
1042
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1041
1043
|
l2_distance = l2_distance,
|
1044
|
+
sigmoid = sigmoid,
|
1045
|
+
custom_attn_fn = custom_attn_fn,
|
1042
1046
|
add_zero_kv = add_zero_kv,
|
1043
1047
|
flash = flash,
|
1044
1048
|
softclamp_logits = softclamp_logits,
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=MJmMLIt0rHx-JNNsc2auUsCjsB-69NewufaRV32ADmA,14012
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256
|
8
|
+
x_transformers/x_transformers.py,sha256=q6I6rvyYUWLgwtKOxPwF12UL1HzcIlauI8YrM8gvZac,83901
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.37.
|
12
|
-
x_transformers-1.37.
|
13
|
-
x_transformers-1.37.
|
14
|
-
x_transformers-1.37.
|
15
|
-
x_transformers-1.37.
|
11
|
+
x_transformers-1.37.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.10.dist-info/METADATA,sha256=qA7YAj5ZeaesnGamBR-cPSR_0HSwquwPBptaUmi7c3c,662
|
13
|
+
x_transformers-1.37.10.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|