x-transformers 1.39.4__py3-none-any.whl → 1.40.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -109,12 +109,35 @@ def qk_l2_dist_squared(q, k):
109
109
 
110
110
  # one-hot straight through softmax
111
111
 
112
- def one_hot_straight_through(t, temperature = 1.):
113
- one_hot_indices = t.argmax(dim = -1, keepdim = True)
114
- one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
112
+ def one_hot_straight_through(logits, temperature = 1.):
113
+ one_hot_indices = logits.argmax(dim = -1, keepdim = True)
114
+ one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.)
115
115
 
116
- t = (t / temperature).softmax(dim = -1)
117
- return one_hot + t - t.detach()
116
+ soft_attn = (logits / temperature).softmax(dim = -1)
117
+ return one_hot + soft_attn - soft_attn.detach()
118
+
119
+ # sparse topk attention - only keep topk attn logits for softmax
120
+ # optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`
121
+
122
+ def sparse_topk_attn(
123
+ logits,
124
+ sparse_topk,
125
+ temperature = 1.,
126
+ straight_through = False
127
+ ):
128
+ orig_logits = logits
129
+
130
+ mask_value = -torch.finfo(logits.dtype).max
131
+ top_values, _ = logits.topk(sparse_topk, dim = -1)
132
+ sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
133
+ logits = logits.masked_fill(~sparse_topk_mask, mask_value)
134
+ topk_attn = logits.softmax(dim = -1)
135
+
136
+ if not straight_through:
137
+ return topk_attn
138
+
139
+ soft_attn = (orig_logits / temperature).softmax(dim = -1)
140
+ return topk_attn.detach() + soft_attn - soft_attn.detach()
118
141
 
119
142
  # functions for creating causal mask
120
143
  # need a special one for onnx cpu (no support for .triu)
@@ -141,6 +164,7 @@ class Attend(Module):
141
164
  post_talking_heads = False,
142
165
  pre_scale_post_talking_heads = False,
143
166
  sparse_topk = None,
167
+ sparse_topk_straight_through = False,
144
168
  scale = None,
145
169
  qk_norm = False,
146
170
  l2_distance = False,
@@ -152,7 +176,6 @@ class Attend(Module):
152
176
  add_zero_kv = False,
153
177
  selective = False,
154
178
  hard = False,
155
- sigsoftmax = False,
156
179
  cope = None,
157
180
  onnxable = False,
158
181
  sdp_kwargs: dict = dict(
@@ -171,9 +194,13 @@ class Attend(Module):
171
194
 
172
195
  # attention type
173
196
 
197
+ is_sparse_topk_attn = exists(sparse_topk)
198
+
174
199
  assert not (flash and sigmoid), 'sigmoid attention not available for flash'
175
200
  assert not (flash and hard), 'hard attention not available for flash'
176
- assert at_most_one_of(sigmoid, hard, l2_distance)
201
+ assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'
202
+
203
+ assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn)
177
204
 
178
205
  if exists(custom_attn_fn):
179
206
  self.attn_fn = custom_attn_fn
@@ -181,6 +208,8 @@ class Attend(Module):
181
208
  self.attn_fn = F.sigmoid
182
209
  elif hard:
183
210
  self.attn_fn = one_hot_straight_through
211
+ elif is_sparse_topk_attn:
212
+ self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
184
213
  else:
185
214
  softmax_fn = partial(F.softmax, dim = -1)
186
215
  self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
@@ -214,16 +243,6 @@ class Attend(Module):
214
243
  assert not (selective and not causal), 'selective attention is designed for autoregressive'
215
244
  self.selective = selective
216
245
 
217
- # sparse topk
218
-
219
- assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
220
- self.sparse_topk = sparse_topk
221
-
222
- # sig softmax
223
-
224
- assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention'
225
- self.sigsoftmax = sigsoftmax
226
-
227
246
  # l2 distance attention
228
247
 
229
248
  self.l2_distance = l2_distance
@@ -476,11 +495,6 @@ class Attend(Module):
476
495
  causal_mask = self.create_causal_mask(i, j, device = device)
477
496
  sim = sim.masked_fill(causal_mask, mask_value)
478
497
 
479
- if exists(self.sparse_topk):
480
- top_values, _ = sim.topk(self.sparse_topk, dim = -1)
481
- sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
482
- sim = sim.masked_fill(~sparse_topk_mask, mask_value)
483
-
484
498
  row_is_entirely_masked = None
485
499
 
486
500
  if exists(mask):
@@ -494,9 +508,6 @@ class Attend(Module):
494
508
 
495
509
  pre_softmax_attn = sim
496
510
 
497
- if self.sigsoftmax:
498
- sim = sim + sim.sigmoid().log()
499
-
500
511
  attn = self.attn_fn(sim)
501
512
 
502
513
  attn = attn.type(dtype)
@@ -912,6 +912,7 @@ class Attention(Module):
912
912
  pre_scale_post_talking_heads = False,
913
913
  head_scale = False,
914
914
  sparse_topk = None,
915
+ sparse_topk_straight_through = False,
915
916
  num_mem_kv = 0,
916
917
  dropout = 0.,
917
918
  on_attn = False,
@@ -920,7 +921,6 @@ class Attention(Module):
920
921
  gate_values = False,
921
922
  zero_init_output = False,
922
923
  hard = False,
923
- sigsoftmax = False,
924
924
  max_attend_past = None,
925
925
  qk_norm = False,
926
926
  qk_norm_groups = 1,
@@ -1044,6 +1044,7 @@ class Attention(Module):
1044
1044
  pre_scale_post_talking_heads = pre_scale_post_talking_heads,
1045
1045
  dropout = dropout,
1046
1046
  sparse_topk = sparse_topk,
1047
+ sparse_topk_straight_through = sparse_topk_straight_through,
1047
1048
  hard = hard,
1048
1049
  qk_norm = qk_norm,
1049
1050
  scale = qk_norm_scale if qk_norm else self.scale,
@@ -1054,7 +1055,6 @@ class Attention(Module):
1054
1055
  add_zero_kv = add_zero_kv,
1055
1056
  flash = flash,
1056
1057
  softclamp_logits = softclamp_logits,
1057
- sigsoftmax = sigsoftmax,
1058
1058
  logit_softclamp_value = logit_softclamp_value,
1059
1059
  cope = cope,
1060
1060
  onnxable = onnxable
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.39.4
3
+ Version: 1.40.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=013qsFwoABVbyc-1L3RZTRCWo6BW9fAD8IVnC_qALGk,16708
2
+ x_transformers/attend.py,sha256=VbB0fi-ETgAF4dc2a_Meaqvt14LMaRVIjZ8NexUX8F0,17239
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=3JrVqwYVrd5UVf2esdunTcer7QL72H7VF4mL3UsCWOI,84508
8
+ x_transformers/x_transformers.py,sha256=TGXJZXCWR5BiMkS5Kx-JhFQ85AxkiJabLiHnrCTC874,84562
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.39.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.39.4.dist-info/METADATA,sha256=2KawHim0IOdlRjbRJCVsELM10T7nojxnMy6WrWtG0UE,661
13
- x_transformers-1.39.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.39.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.39.4.dist-info/RECORD,,
11
+ x_transformers-1.40.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.1.dist-info/METADATA,sha256=WouMl3Ld1llknOwj7BcKi-_YZ9Hx9RZ-ni-eGCP_uQY,661
13
+ x_transformers-1.40.1.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.1.dist-info/RECORD,,