x-transformers 1.37.8__py3-none-any.whl → 1.37.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from functools import partial
4
- from typing import Tuple
4
+ from typing import Tuple, Callable
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
@@ -36,6 +36,9 @@ def exists(val):
36
36
  def default(val, d):
37
37
  return val if exists(val) else d
38
38
 
39
+ def at_most_one_of(*bools):
40
+ return sum([*map(int, bools)]) <= 1
41
+
39
42
  def compact(arr):
40
43
  return [*filter(exists, arr)]
41
44
 
@@ -100,6 +103,8 @@ class Attend(Module):
100
103
  scale = None,
101
104
  qk_norm = False,
102
105
  l2_distance = False,
106
+ sigmoid = False,
107
+ custom_attn_fn: Callable | None = None,
103
108
  flash = False,
104
109
  softclamp_logits = False,
105
110
  logit_softclamp_value = 50.,
@@ -116,10 +121,27 @@ class Attend(Module):
116
121
  super().__init__()
117
122
  self.scale = scale
118
123
 
124
+ # causal related
125
+
119
126
  self.causal = causal
120
127
  self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
121
128
 
122
- self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
129
+ # attention type
130
+
131
+ assert not (flash and sigmoid), 'sigmoid attention not available for flash'
132
+ assert at_most_one_of(sigmoid, l2_distance)
133
+
134
+ self.sigmoid = sigmoid
135
+
136
+ if exists(custom_attn_fn):
137
+ self.attn_fn = custom_attn_fn
138
+ elif not sigmoid:
139
+ softmax_fn = partial(F.softmax, dim = -1)
140
+ self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
141
+ else:
142
+ self.attn_fn = F.sigmoid
143
+
144
+ # dropouts
123
145
 
124
146
  self.dropout = dropout
125
147
  self.attn_dropout = nn.Dropout(dropout)
@@ -385,11 +407,6 @@ class Attend(Module):
385
407
 
386
408
  mask_value = -torch.finfo(sim.dtype).max
387
409
 
388
- if exists(self.sparse_topk) and self.sparse_topk < j:
389
- top_values, _ = sim.topk(self.sparse_topk, dim = -1)
390
- sparse_topk_mask = sim < top_values[..., -1:]
391
- mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
392
-
393
410
  if exists(mask):
394
411
  sim = sim.masked_fill(~mask, mask_value)
395
412
 
@@ -397,6 +414,11 @@ class Attend(Module):
397
414
  causal_mask = self.create_causal_mask(i, j, device = device)
398
415
  sim = sim.masked_fill(causal_mask, mask_value)
399
416
 
417
+ if exists(self.sparse_topk):
418
+ top_values, _ = sim.topk(self.sparse_topk, dim = -1)
419
+ sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
420
+ sim = sim.masked_fill(~sparse_topk_mask, mask_value)
421
+
400
422
  row_is_entirely_masked = None
401
423
 
402
424
  if exists(mask):
@@ -410,7 +432,7 @@ class Attend(Module):
410
432
  if self.sigsoftmax:
411
433
  sim = sim + sim.sigmoid().log()
412
434
 
413
- attn = self.attn_fn(sim, dim = -1)
435
+ attn = self.attn_fn(sim)
414
436
  attn = attn.type(dtype)
415
437
 
416
438
  post_softmax_attn = attn
@@ -924,6 +924,8 @@ class Attention(Module):
924
924
  qk_norm_scale = 10,
925
925
  qk_norm_dim_scale = False,
926
926
  l2_distance = False,
927
+ sigmoid = False,
928
+ custom_attn_fn: Callable | None = None,
927
929
  one_kv_head = False,
928
930
  kv_heads = None,
929
931
  shared_kv = False,
@@ -1039,6 +1041,8 @@ class Attention(Module):
1039
1041
  qk_norm = qk_norm,
1040
1042
  scale = qk_norm_scale if qk_norm else self.scale,
1041
1043
  l2_distance = l2_distance,
1044
+ sigmoid = sigmoid,
1045
+ custom_attn_fn = custom_attn_fn,
1042
1046
  add_zero_kv = add_zero_kv,
1043
1047
  flash = flash,
1044
1048
  softclamp_logits = softclamp_logits,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.8
3
+ Version: 1.37.10
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=we7mkwVCD7_ColUD8_Fj0HM5jjOaa3wbstllp_XXK4k,13434
2
+ x_transformers/attend.py,sha256=MJmMLIt0rHx-JNNsc2auUsCjsB-69NewufaRV32ADmA,14012
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=-2fj6QcDSfMI5lJA_fzOW2mdzdS1C1LD6jMBtGQY48E,83752
8
+ x_transformers/x_transformers.py,sha256=q6I6rvyYUWLgwtKOxPwF12UL1HzcIlauI8YrM8gvZac,83901
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.37.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.37.8.dist-info/METADATA,sha256=fiT94VbrxWL-8jJBjxvFloWsH6n6reOGitRSlpAhvWs,661
13
- x_transformers-1.37.8.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.37.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.37.8.dist-info/RECORD,,
11
+ x_transformers-1.37.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.37.10.dist-info/METADATA,sha256=qA7YAj5ZeaesnGamBR-cPSR_0HSwquwPBptaUmi7c3c,662
13
+ x_transformers-1.37.10.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.37.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.37.10.dist-info/RECORD,,