x-transformers 1.39.3__py3-none-any.whl → 1.40.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -107,14 +107,32 @@ def qk_l2_dist_squared(q, k):
107
107
  l2_dist_squared = torch.cdist(q, k) ** 2
108
108
  return unpack_one(l2_dist_squared, packed_shape, '* i j')
109
109
 
110
- # gumbel softmax
110
+ # one-hot straight through softmax
111
111
 
112
- def gumbel_softmax(t, temperature = 1.):
113
- one_hot_indices = t.argmax(dim = -1, keepdim = True)
114
- one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
112
+ def one_hot_straight_through(logits, temperature = 1.):
113
+ one_hot_indices = logits.argmax(dim = -1, keepdim = True)
114
+ one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.)
115
115
 
116
- t = (t / temperature).softmax(dim = -1)
117
- return one_hot + t - t.detach()
116
+ soft_attn = (logits / temperature).softmax(dim = -1)
117
+ return one_hot + soft_attn - soft_attn.detach()
118
+
119
+ # sparse topk attention - only keep topk attn logits for softmax
120
+ # optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`
121
+
122
+ def sparse_topk_attn(logits, sparse_topk, temperature = 1., straight_through = False):
123
+ orig_logits = logits
124
+
125
+ mask_value = -torch.finfo(logits.dtype).max
126
+ top_values, _ = logits.topk(sparse_topk, dim = -1)
127
+ sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
128
+ logits = logits.masked_fill(~sparse_topk_mask, mask_value)
129
+ topk_attn = logits.softmax(dim = -1)
130
+
131
+ if not straight_through:
132
+ return topk_attn
133
+
134
+ soft_attn = (orig_logits / temperature).softmax(dim = -1)
135
+ return topk_attn + soft_attn - soft_attn.detach()
118
136
 
119
137
  # functions for creating causal mask
120
138
  # need a special one for onnx cpu (no support for .triu)
@@ -141,6 +159,7 @@ class Attend(Module):
141
159
  post_talking_heads = False,
142
160
  pre_scale_post_talking_heads = False,
143
161
  sparse_topk = None,
162
+ sparse_topk_straight_through = False,
144
163
  scale = None,
145
164
  qk_norm = False,
146
165
  l2_distance = False,
@@ -152,7 +171,6 @@ class Attend(Module):
152
171
  add_zero_kv = False,
153
172
  selective = False,
154
173
  hard = False,
155
- sigsoftmax = False,
156
174
  cope = None,
157
175
  onnxable = False,
158
176
  sdp_kwargs: dict = dict(
@@ -171,16 +189,22 @@ class Attend(Module):
171
189
 
172
190
  # attention type
173
191
 
192
+ is_sparse_topk_attn = exists(sparse_topk)
193
+
174
194
  assert not (flash and sigmoid), 'sigmoid attention not available for flash'
175
195
  assert not (flash and hard), 'hard attention not available for flash'
176
- assert at_most_one_of(sigmoid, hard, l2_distance)
196
+ assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'
197
+
198
+ assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn)
177
199
 
178
200
  if exists(custom_attn_fn):
179
201
  self.attn_fn = custom_attn_fn
180
202
  elif sigmoid:
181
203
  self.attn_fn = F.sigmoid
182
204
  elif hard:
183
- self.attn_fn = gumbel_softmax
205
+ self.attn_fn = one_hot_straight_through
206
+ elif is_sparse_topk_attn:
207
+ self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
184
208
  else:
185
209
  softmax_fn = partial(F.softmax, dim = -1)
186
210
  self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
@@ -214,16 +238,6 @@ class Attend(Module):
214
238
  assert not (selective and not causal), 'selective attention is designed for autoregressive'
215
239
  self.selective = selective
216
240
 
217
- # sparse topk
218
-
219
- assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
220
- self.sparse_topk = sparse_topk
221
-
222
- # sig softmax
223
-
224
- assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention'
225
- self.sigsoftmax = sigsoftmax
226
-
227
241
  # l2 distance attention
228
242
 
229
243
  self.l2_distance = l2_distance
@@ -476,11 +490,6 @@ class Attend(Module):
476
490
  causal_mask = self.create_causal_mask(i, j, device = device)
477
491
  sim = sim.masked_fill(causal_mask, mask_value)
478
492
 
479
- if exists(self.sparse_topk):
480
- top_values, _ = sim.topk(self.sparse_topk, dim = -1)
481
- sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
482
- sim = sim.masked_fill(~sparse_topk_mask, mask_value)
483
-
484
493
  row_is_entirely_masked = None
485
494
 
486
495
  if exists(mask):
@@ -494,9 +503,6 @@ class Attend(Module):
494
503
 
495
504
  pre_softmax_attn = sim
496
505
 
497
- if self.sigsoftmax:
498
- sim = sim + sim.sigmoid().log()
499
-
500
506
  attn = self.attn_fn(sim)
501
507
 
502
508
  attn = attn.type(dtype)
@@ -912,6 +912,7 @@ class Attention(Module):
912
912
  pre_scale_post_talking_heads = False,
913
913
  head_scale = False,
914
914
  sparse_topk = None,
915
+ sparse_topk_straight_through = False,
915
916
  num_mem_kv = 0,
916
917
  dropout = 0.,
917
918
  on_attn = False,
@@ -920,7 +921,6 @@ class Attention(Module):
920
921
  gate_values = False,
921
922
  zero_init_output = False,
922
923
  hard = False,
923
- sigsoftmax = False,
924
924
  max_attend_past = None,
925
925
  qk_norm = False,
926
926
  qk_norm_groups = 1,
@@ -1044,6 +1044,7 @@ class Attention(Module):
1044
1044
  pre_scale_post_talking_heads = pre_scale_post_talking_heads,
1045
1045
  dropout = dropout,
1046
1046
  sparse_topk = sparse_topk,
1047
+ sparse_topk_straight_through = sparse_topk_straight_through,
1047
1048
  hard = hard,
1048
1049
  qk_norm = qk_norm,
1049
1050
  scale = qk_norm_scale if qk_norm else self.scale,
@@ -1054,7 +1055,6 @@ class Attention(Module):
1054
1055
  add_zero_kv = add_zero_kv,
1055
1056
  flash = flash,
1056
1057
  softclamp_logits = softclamp_logits,
1057
- sigsoftmax = sigsoftmax,
1058
1058
  logit_softclamp_value = logit_softclamp_value,
1059
1059
  cope = cope,
1060
1060
  onnxable = onnxable
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.39.3
3
+ Version: 1.40.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=K4DNcWAgOHzNtOTIJA15VA3VQ2KMyv-PX8oO1R0Z5Rw,16670
2
+ x_transformers/attend.py,sha256=eoBEK0HdDCWaJgxwGZPeO36ydBt1NbB-gpij_Jkj4Mw,17212
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=3JrVqwYVrd5UVf2esdunTcer7QL72H7VF4mL3UsCWOI,84508
8
+ x_transformers/x_transformers.py,sha256=TGXJZXCWR5BiMkS5Kx-JhFQ85AxkiJabLiHnrCTC874,84562
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.39.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.39.3.dist-info/METADATA,sha256=aqCockVfuZfLM3D9ZdlgN_HfUPdhZm4UWyKg2HLkuUo,661
13
- x_transformers-1.39.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.39.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.39.3.dist-info/RECORD,,
11
+ x_transformers-1.40.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.0.dist-info/METADATA,sha256=WxBpjG7F8utkdJN1AF9vZMAfbltT0lmpdgtUErBbYQY,661
13
+ x_transformers-1.40.0.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.0.dist-info/RECORD,,