x-transformers 1.37.1__py3-none-any.whl → 1.37.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -84,6 +84,7 @@ class Attend(Module):
84
84
  softclamp_logits = False,
85
85
  logit_softclamp_value = 50.,
86
86
  add_zero_kv = False,
87
+ sigsoftmax = False,
87
88
  cope = None,
88
89
  onnxable = False,
89
90
  sdp_kwargs: dict = dict(
@@ -117,6 +118,11 @@ class Attend(Module):
117
118
  assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
118
119
  self.sparse_topk = sparse_topk
119
120
 
121
+ # sig softmax
122
+
123
+ assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention'
124
+ self.sigsoftmax = sigsoftmax
125
+
120
126
  # add a key / value token composed of zeros
121
127
  # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
122
128
 
@@ -298,14 +304,14 @@ class Attend(Module):
298
304
  # handle grouped multi-query attention
299
305
 
300
306
  if kv_heads == 1:
301
- k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
307
+ k, v = tuple(rearrange(t, 'b 1 n d -> b n d') for t in (k, v))
302
308
  elif kv_heads < heads:
303
- k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))
309
+ k, v = tuple(repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads) for t in (k, v))
304
310
 
305
311
  # handle zero kv, as means for allowing network to attend to nothing
306
312
 
307
313
  if self.add_zero_kv:
308
- k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))
314
+ k, v = tuple(F.pad(t, (0, 0, 1, 0), value = 0.) for t in (k, v))
309
315
 
310
316
  if exists(mask):
311
317
  mask = F.pad(mask, (1, 0), value = True)
@@ -359,12 +365,15 @@ class Attend(Module):
359
365
  if exists(self.cope):
360
366
  sim = sim + self.cope(q, sim)
361
367
 
362
- pre_softmax_attn = sim.clone()
368
+ pre_softmax_attn = sim
369
+
370
+ if self.sigsoftmax:
371
+ sim = sim + sim.sigmoid().log()
363
372
 
364
373
  attn = self.attn_fn(sim, dim = -1)
365
374
  attn = attn.type(dtype)
366
375
 
367
- post_softmax_attn = attn.clone()
376
+ post_softmax_attn = attn
368
377
 
369
378
  attn = self.attn_dropout(attn)
370
379
 
@@ -917,6 +917,7 @@ class Attention(Module):
917
917
  swiglu_values = False,
918
918
  gate_values = False,
919
919
  zero_init_output = False,
920
+ sigsoftmax = False,
920
921
  max_attend_past = None,
921
922
  qk_norm = False,
922
923
  qk_norm_groups = 1,
@@ -1039,6 +1040,7 @@ class Attention(Module):
1039
1040
  add_zero_kv = add_zero_kv,
1040
1041
  flash = flash,
1041
1042
  softclamp_logits = softclamp_logits,
1043
+ sigsoftmax = sigsoftmax,
1042
1044
  logit_softclamp_value = logit_softclamp_value,
1043
1045
  cope = cope,
1044
1046
  onnxable = onnxable
@@ -2003,6 +2005,7 @@ class TransformerWrapper(Module):
2003
2005
  token_emb: TokenEmbedding | None = None,
2004
2006
  mixture_of_softmax = False,
2005
2007
  mixture_of_softmax_k = 4,
2008
+ sigsoftmax_logits = False
2006
2009
  ):
2007
2010
  super().__init__()
2008
2011
 
@@ -2090,6 +2093,10 @@ class TransformerWrapper(Module):
2090
2093
 
2091
2094
  self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k)
2092
2095
 
2096
+ # sig softmax
2097
+
2098
+ self.sigsoftmax_logits = sigsoftmax_logits
2099
+
2093
2100
  # output head, usually to logits of num_tokens
2094
2101
 
2095
2102
  logits_dim = default(logits_dim, num_tokens)
@@ -2258,7 +2265,7 @@ class TransformerWrapper(Module):
2258
2265
  # attention layers
2259
2266
 
2260
2267
  if not self.recycling:
2261
- assert recycle_steps == 1, 'you did not train with recycling'
2268
+ assert not exists(recycle_steps) or recycle_steps == 1, 'you did not train with recycling'
2262
2269
 
2263
2270
  # regular
2264
2271
 
@@ -2322,6 +2329,11 @@ class TransformerWrapper(Module):
2322
2329
  else:
2323
2330
  logits = self.to_logits(x)
2324
2331
 
2332
+ # maybe sig softmax
2333
+
2334
+ if self.sigsoftmax_logits:
2335
+ logits = logits + logits.sigmoid().log()
2336
+
2325
2337
  # handle maybe combine mixture
2326
2338
 
2327
2339
  if exists(combine_mixture):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.1
3
+ Version: 1.37.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
2
+ x_transformers/attend.py,sha256=mV7duZ7ON2puS3-k4ctBifb2rq-jTJqrMbof7tI5jR4,12326
3
3
  x_transformers/autoregressive_wrapper.py,sha256=2FN4ZobFcdDGDGWEnUof_geb16dRGSJycZGwG899Pa4,10493
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=9lk6wtz0vNigyLoMWleo442Q0mhce-BCxEhazhSHuvI,83356
8
+ x_transformers/x_transformers.py,sha256=gOJBZzOJMu5RkIsxw9TZtde4Sx--D18yX8LjrYIsPbE,83677
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.37.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.37.1.dist-info/METADATA,sha256=ik8UKwzq_pW9zdxCl6pt7POrjRC7_GwIi6gAnY7Fck0,661
13
- x_transformers-1.37.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.37.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.37.1.dist-info/RECORD,,
11
+ x_transformers-1.37.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.37.3.dist-info/METADATA,sha256=SIGTCQMrLkyq_aksJAst0iXw9VfFT6QWlGvtUElbTMg,661
13
+ x_transformers-1.37.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.37.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.37.3.dist-info/RECORD,,