x-transformers 1.37.7__py3-none-any.whl → 1.37.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +25 -6
- x_transformers/x_transformers.py +2 -0
- {x_transformers-1.37.7.dist-info → x_transformers-1.37.9.dist-info}/METADATA +1 -1
- {x_transformers-1.37.7.dist-info → x_transformers-1.37.9.dist-info}/RECORD +7 -7
- {x_transformers-1.37.7.dist-info → x_transformers-1.37.9.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.7.dist-info → x_transformers-1.37.9.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.7.dist-info → x_transformers-1.37.9.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -36,6 +36,9 @@ def exists(val):
|
|
36
36
|
def default(val, d):
|
37
37
|
return val if exists(val) else d
|
38
38
|
|
39
|
+
def at_most_one_of(*bools):
|
40
|
+
return sum([*map(int, bools)]) <= 1
|
41
|
+
|
39
42
|
def compact(arr):
|
40
43
|
return [*filter(exists, arr)]
|
41
44
|
|
@@ -100,6 +103,7 @@ class Attend(Module):
|
|
100
103
|
scale = None,
|
101
104
|
qk_norm = False,
|
102
105
|
l2_distance = False,
|
106
|
+
sigmoid = False,
|
103
107
|
flash = False,
|
104
108
|
softclamp_logits = False,
|
105
109
|
logit_softclamp_value = 50.,
|
@@ -116,10 +120,25 @@ class Attend(Module):
|
|
116
120
|
super().__init__()
|
117
121
|
self.scale = scale
|
118
122
|
|
123
|
+
# causal related
|
124
|
+
|
119
125
|
self.causal = causal
|
120
126
|
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
|
121
127
|
|
122
|
-
|
128
|
+
# attention type
|
129
|
+
|
130
|
+
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
|
131
|
+
assert at_most_one_of(sigmoid, l2_distance)
|
132
|
+
|
133
|
+
self.sigmoid = sigmoid
|
134
|
+
|
135
|
+
if not sigmoid:
|
136
|
+
softmax_fn = partial(F.softmax, dim = -1)
|
137
|
+
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
138
|
+
else:
|
139
|
+
self.attn_fn = F.sigmoid
|
140
|
+
|
141
|
+
# dropouts
|
123
142
|
|
124
143
|
self.dropout = dropout
|
125
144
|
self.attn_dropout = nn.Dropout(dropout)
|
@@ -211,12 +230,12 @@ class Attend(Module):
|
|
211
230
|
|
212
231
|
if self.l2_distance:
|
213
232
|
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
|
214
|
-
k = F.pad(k, (0, 1), value = 1.)
|
215
|
-
k = torch.cat((k,
|
233
|
+
k = F.pad(k, (0, 1), value = -1.)
|
234
|
+
k = torch.cat((k, k_norm_sq), dim = -1)
|
216
235
|
|
217
236
|
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
|
218
|
-
q = torch.cat((2 * q,
|
219
|
-
q = F.pad(q, (0, 1), value = 1.)
|
237
|
+
q = torch.cat((2 * q, q_norm_sq), dim = -1)
|
238
|
+
q = F.pad(q, (0, 1), value = -1.)
|
220
239
|
|
221
240
|
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
|
222
241
|
|
@@ -410,7 +429,7 @@ class Attend(Module):
|
|
410
429
|
if self.sigsoftmax:
|
411
430
|
sim = sim + sim.sigmoid().log()
|
412
431
|
|
413
|
-
attn = self.attn_fn(sim
|
432
|
+
attn = self.attn_fn(sim)
|
414
433
|
attn = attn.type(dtype)
|
415
434
|
|
416
435
|
post_softmax_attn = attn
|
x_transformers/x_transformers.py
CHANGED
@@ -924,6 +924,7 @@ class Attention(Module):
|
|
924
924
|
qk_norm_scale = 10,
|
925
925
|
qk_norm_dim_scale = False,
|
926
926
|
l2_distance = False,
|
927
|
+
sigmoid = False,
|
927
928
|
one_kv_head = False,
|
928
929
|
kv_heads = None,
|
929
930
|
shared_kv = False,
|
@@ -1039,6 +1040,7 @@ class Attention(Module):
|
|
1039
1040
|
qk_norm = qk_norm,
|
1040
1041
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1041
1042
|
l2_distance = l2_distance,
|
1043
|
+
sigmoid = sigmoid,
|
1042
1044
|
add_zero_kv = add_zero_kv,
|
1043
1045
|
flash = flash,
|
1044
1046
|
softclamp_logits = softclamp_logits,
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=EtaTN6ahgRlFLkwfHA31RNL_bQyAHwhNBpGU1NIHJ-c,13894
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256
|
8
|
+
x_transformers/x_transformers.py,sha256=Y77_ZPWSKTJ-oYk4bHjhwMEkgoMaq_LyxcmCkvOPZ9g,83808
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.37.
|
12
|
-
x_transformers-1.37.
|
13
|
-
x_transformers-1.37.
|
14
|
-
x_transformers-1.37.
|
15
|
-
x_transformers-1.37.
|
11
|
+
x_transformers-1.37.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.9.dist-info/METADATA,sha256=9JF40JYlW1y_AOqdD1pwYZJpJWZn63SC0K0VX8IA2JU,661
|
13
|
+
x_transformers-1.37.9.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|