x-transformers 1.37.1__py3-none-any.whl → 1.37.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +14 -5
- x_transformers/x_transformers.py +13 -1
- {x_transformers-1.37.1.dist-info → x_transformers-1.37.3.dist-info}/METADATA +1 -1
- {x_transformers-1.37.1.dist-info → x_transformers-1.37.3.dist-info}/RECORD +7 -7
- {x_transformers-1.37.1.dist-info → x_transformers-1.37.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.1.dist-info → x_transformers-1.37.3.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.1.dist-info → x_transformers-1.37.3.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -84,6 +84,7 @@ class Attend(Module):
|
|
84
84
|
softclamp_logits = False,
|
85
85
|
logit_softclamp_value = 50.,
|
86
86
|
add_zero_kv = False,
|
87
|
+
sigsoftmax = False,
|
87
88
|
cope = None,
|
88
89
|
onnxable = False,
|
89
90
|
sdp_kwargs: dict = dict(
|
@@ -117,6 +118,11 @@ class Attend(Module):
|
|
117
118
|
assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
|
118
119
|
self.sparse_topk = sparse_topk
|
119
120
|
|
121
|
+
# sig softmax
|
122
|
+
|
123
|
+
assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention'
|
124
|
+
self.sigsoftmax = sigsoftmax
|
125
|
+
|
120
126
|
# add a key / value token composed of zeros
|
121
127
|
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
|
122
128
|
|
@@ -298,14 +304,14 @@ class Attend(Module):
|
|
298
304
|
# handle grouped multi-query attention
|
299
305
|
|
300
306
|
if kv_heads == 1:
|
301
|
-
k, v =
|
307
|
+
k, v = tuple(rearrange(t, 'b 1 n d -> b n d') for t in (k, v))
|
302
308
|
elif kv_heads < heads:
|
303
|
-
k, v =
|
309
|
+
k, v = tuple(repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads) for t in (k, v))
|
304
310
|
|
305
311
|
# handle zero kv, as means for allowing network to attend to nothing
|
306
312
|
|
307
313
|
if self.add_zero_kv:
|
308
|
-
k, v =
|
314
|
+
k, v = tuple(F.pad(t, (0, 0, 1, 0), value = 0.) for t in (k, v))
|
309
315
|
|
310
316
|
if exists(mask):
|
311
317
|
mask = F.pad(mask, (1, 0), value = True)
|
@@ -359,12 +365,15 @@ class Attend(Module):
|
|
359
365
|
if exists(self.cope):
|
360
366
|
sim = sim + self.cope(q, sim)
|
361
367
|
|
362
|
-
pre_softmax_attn = sim
|
368
|
+
pre_softmax_attn = sim
|
369
|
+
|
370
|
+
if self.sigsoftmax:
|
371
|
+
sim = sim + sim.sigmoid().log()
|
363
372
|
|
364
373
|
attn = self.attn_fn(sim, dim = -1)
|
365
374
|
attn = attn.type(dtype)
|
366
375
|
|
367
|
-
post_softmax_attn = attn
|
376
|
+
post_softmax_attn = attn
|
368
377
|
|
369
378
|
attn = self.attn_dropout(attn)
|
370
379
|
|
x_transformers/x_transformers.py
CHANGED
@@ -917,6 +917,7 @@ class Attention(Module):
|
|
917
917
|
swiglu_values = False,
|
918
918
|
gate_values = False,
|
919
919
|
zero_init_output = False,
|
920
|
+
sigsoftmax = False,
|
920
921
|
max_attend_past = None,
|
921
922
|
qk_norm = False,
|
922
923
|
qk_norm_groups = 1,
|
@@ -1039,6 +1040,7 @@ class Attention(Module):
|
|
1039
1040
|
add_zero_kv = add_zero_kv,
|
1040
1041
|
flash = flash,
|
1041
1042
|
softclamp_logits = softclamp_logits,
|
1043
|
+
sigsoftmax = sigsoftmax,
|
1042
1044
|
logit_softclamp_value = logit_softclamp_value,
|
1043
1045
|
cope = cope,
|
1044
1046
|
onnxable = onnxable
|
@@ -2003,6 +2005,7 @@ class TransformerWrapper(Module):
|
|
2003
2005
|
token_emb: TokenEmbedding | None = None,
|
2004
2006
|
mixture_of_softmax = False,
|
2005
2007
|
mixture_of_softmax_k = 4,
|
2008
|
+
sigsoftmax_logits = False
|
2006
2009
|
):
|
2007
2010
|
super().__init__()
|
2008
2011
|
|
@@ -2090,6 +2093,10 @@ class TransformerWrapper(Module):
|
|
2090
2093
|
|
2091
2094
|
self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k)
|
2092
2095
|
|
2096
|
+
# sig softmax
|
2097
|
+
|
2098
|
+
self.sigsoftmax_logits = sigsoftmax_logits
|
2099
|
+
|
2093
2100
|
# output head, usually to logits of num_tokens
|
2094
2101
|
|
2095
2102
|
logits_dim = default(logits_dim, num_tokens)
|
@@ -2258,7 +2265,7 @@ class TransformerWrapper(Module):
|
|
2258
2265
|
# attention layers
|
2259
2266
|
|
2260
2267
|
if not self.recycling:
|
2261
|
-
assert recycle_steps == 1, 'you did not train with recycling'
|
2268
|
+
assert not exists(recycle_steps) or recycle_steps == 1, 'you did not train with recycling'
|
2262
2269
|
|
2263
2270
|
# regular
|
2264
2271
|
|
@@ -2322,6 +2329,11 @@ class TransformerWrapper(Module):
|
|
2322
2329
|
else:
|
2323
2330
|
logits = self.to_logits(x)
|
2324
2331
|
|
2332
|
+
# maybe sig softmax
|
2333
|
+
|
2334
|
+
if self.sigsoftmax_logits:
|
2335
|
+
logits = logits + logits.sigmoid().log()
|
2336
|
+
|
2325
2337
|
# handle maybe combine mixture
|
2326
2338
|
|
2327
2339
|
if exists(combine_mixture):
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=mV7duZ7ON2puS3-k4ctBifb2rq-jTJqrMbof7tI5jR4,12326
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=2FN4ZobFcdDGDGWEnUof_geb16dRGSJycZGwG899Pa4,10493
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=gOJBZzOJMu5RkIsxw9TZtde4Sx--D18yX8LjrYIsPbE,83677
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.37.
|
12
|
-
x_transformers-1.37.
|
13
|
-
x_transformers-1.37.
|
14
|
-
x_transformers-1.37.
|
15
|
-
x_transformers-1.37.
|
11
|
+
x_transformers-1.37.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.3.dist-info/METADATA,sha256=SIGTCQMrLkyq_aksJAst0iXw9VfFT6QWlGvtUElbTMg,661
|
13
|
+
x_transformers-1.37.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|