x-transformers 2.4.9__tar.gz → 2.4.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.9 → x_transformers-2.4.11}/PKG-INFO +1 -1
  2. {x_transformers-2.4.9 → x_transformers-2.4.11}/pyproject.toml +1 -1
  3. {x_transformers-2.4.9 → x_transformers-2.4.11}/tests/test_x_transformers.py +27 -0
  4. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/autoregressive_wrapper.py +183 -1
  5. {x_transformers-2.4.9 → x_transformers-2.4.11}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.4.9 → x_transformers-2.4.11}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.4.9 → x_transformers-2.4.11}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.4.9 → x_transformers-2.4.11}/.gitignore +0 -0
  9. {x_transformers-2.4.9 → x_transformers-2.4.11}/LICENSE +0 -0
  10. {x_transformers-2.4.9 → x_transformers-2.4.11}/README.md +0 -0
  11. {x_transformers-2.4.9 → x_transformers-2.4.11}/data/README.md +0 -0
  12. {x_transformers-2.4.9 → x_transformers-2.4.11}/data/enwik8.gz +0 -0
  13. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/all-attention.png +0 -0
  14. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/deepnorm.png +0 -0
  17. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/fcm.png +0 -0
  23. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/ffglu.png +0 -0
  24. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/flash-attention.png +0 -0
  25. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/gate_values.png +0 -0
  26. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/gating.png +0 -0
  27. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/macaron-1.png +0 -0
  29. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/macaron-2.png +0 -0
  30. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/normformer.png +0 -0
  32. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/pia.png +0 -0
  33. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/resi_dual.png +0 -0
  35. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/residual_attn.png +0 -0
  36. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/rezero.png +0 -0
  37. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/rotary.png +0 -0
  38. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/sandwich.png +0 -0
  40. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/scalenorm.png +0 -0
  42. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/talking-heads.png +0 -0
  43. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/topk-attention.png +0 -0
  44. {x_transformers-2.4.9 → x_transformers-2.4.11}/images/xval.png +0 -0
  45. {x_transformers-2.4.9 → x_transformers-2.4.11}/train_belief_state.py +0 -0
  46. {x_transformers-2.4.9 → x_transformers-2.4.11}/train_copy.py +0 -0
  47. {x_transformers-2.4.9 → x_transformers-2.4.11}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.4.9 → x_transformers-2.4.11}/train_enwik8.py +0 -0
  49. {x_transformers-2.4.9 → x_transformers-2.4.11}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.4.9 → x_transformers-2.4.11}/train_parity.py +0 -0
  51. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/belief_state_wrapper.py +0 -0
  54. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/continuous.py +0 -0
  55. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/entropy_based_tokenizer.py +0 -0
  57. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/multi_input.py +0 -0
  58. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/neo_mlp.py +0 -0
  59. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/nonautoregressive_wrapper.py +0 -0
  60. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/up_wrapper.py +0 -0
  61. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/x_transformers.py +0 -0
  62. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.9 → x_transformers-2.4.11}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.9
3
+ Version: 2.4.11
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.9"
3
+ version = "2.4.11"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1126,3 +1126,30 @@ def test_up(
1126
1126
 
1127
1127
  loss = up_wrapper()
1128
1128
  loss.backward()
1129
+
1130
+ @pytest.mark.parametrize('stochastic', (False, True))
1131
+ def test_beam_search(stochastic):
1132
+ from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper
1133
+
1134
+ model = TransformerWrapper(
1135
+ num_tokens = 256,
1136
+ max_seq_len = 1024,
1137
+ attn_layers = Decoder(
1138
+ dim = 512,
1139
+ depth = 12,
1140
+ heads = 8
1141
+ ),
1142
+ )
1143
+
1144
+ x = torch.randint(0, 256, (2, 10))
1145
+
1146
+ wrapper = AutoregressiveWrapper(model)
1147
+
1148
+ generated = wrapper.beam_search(x[:, :1], 10, beams = 4, stochastic = stochastic)
1149
+
1150
+ assert generated.shape == (2, 10)
1151
+
1152
+ beams, scores = wrapper.beam_search(x[:, :1], 10, beams = 4, return_beams_and_scores = True, stochastic = stochastic)
1153
+
1154
+ assert beams.shape == (4, 2, 10)
1155
+ assert scores.shape == (4, 2)
@@ -8,7 +8,7 @@ from torch import nn, Tensor
8
8
  from torch.nn import Module
9
9
  import torch.nn.functional as F
10
10
 
11
- from einops import rearrange, pack, unpack
11
+ from einops import rearrange, repeat, pack, unpack
12
12
 
13
13
  def exists(val):
14
14
  return val is not None
@@ -34,6 +34,21 @@ def eval_decorator(fn):
34
34
  return out
35
35
  return inner
36
36
 
37
+ # gumbel topk
38
+
39
+ def log(t, eps = 1e-20):
40
+ return t.clamp(min = eps).log()
41
+
42
+ def gumbel_noise(t):
43
+ return -log(-log(torch.rand_like(t)))
44
+
45
+ # function for modifying all the cached key / values
46
+
47
+ def modify_cached_kv(cache, fn):
48
+ for inter in cache.attn_intermediates:
49
+ if inter.layer_type == 'a':
50
+ inter.cached_kv = [fn(t) for t in inter.cached_kv]
51
+
37
52
  # for variable lengthed prefixes
38
53
 
39
54
  def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
@@ -161,6 +176,173 @@ class AutoregressiveWrapper(Module):
161
176
  self.add_continuous_pred_head = net.add_continuous_pred_head
162
177
  self.next_embed_loss_weight = next_embed_loss_weight
163
178
 
179
+ @torch.no_grad()
180
+ @eval_decorator
181
+ def beam_search(
182
+ self,
183
+ prompts,
184
+ seq_len,
185
+ beams = 4,
186
+ return_beams_and_scores = False,
187
+ eos_token = None,
188
+ temperature = 1.,
189
+ stochastic = False,
190
+ prompt_lens: Tensor | None = None,
191
+ filter_logits_fn: str | Callable = top_k,
192
+ restrict_to_max_seq_len = True,
193
+ filter_kwargs: dict = dict(),
194
+ cache_kv = True,
195
+ **kwargs
196
+ ):
197
+ assert not exists(eos_token), 'eos token not supported yet'
198
+
199
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
200
+
201
+ prompts, packed_shape = pack([prompts], '* n')
202
+
203
+ batch, orig_seq_len = prompts.shape
204
+
205
+ # handle filter logits fn given as string
206
+
207
+ if isinstance(filter_logits_fn, str):
208
+ assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
209
+
210
+ filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
211
+
212
+ # handle variable lengthed prompts (prefixes)
213
+
214
+ seq_start_pos = None
215
+ if exists(prompt_lens):
216
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
217
+ seq_start_pos = orig_seq_len - prompt_lens
218
+
219
+ # output from which sampled tokens appended to
220
+
221
+ out = prompts
222
+
223
+ # kv caches
224
+
225
+ cache = None
226
+
227
+ should_cache = cache_kv and self.net.can_cache_kv
228
+
229
+ # scores for the beams
230
+
231
+ scores = torch.zeros((batch,), device = device)
232
+
233
+ batch_arange = torch.arange(batch, device = device)
234
+
235
+ # sampling up to seq_len
236
+
237
+ for i in range(seq_len):
238
+ is_first = i == 0
239
+
240
+ if restrict_to_max_seq_len:
241
+ max_len_exceeded = out.shape[-1] > max_seq_len
242
+
243
+ assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
244
+
245
+ x = out[:, -max_seq_len:]
246
+
247
+ if exists(cache):
248
+ modify_cached_kv(cache, lambda t: t[..., -(max_seq_len - 1):, :])
249
+
250
+ logits, new_cache = self.net(
251
+ x,
252
+ return_intermediates = True,
253
+ cache = cache,
254
+ seq_start_pos = seq_start_pos,
255
+ **kwargs
256
+ )
257
+
258
+ if should_cache:
259
+ cache = new_cache
260
+
261
+ logits = logits[:, -1]
262
+
263
+ # to add to the scores
264
+
265
+ log_probs = logits.log_softmax(dim = -1)
266
+
267
+ # maybe filter by top_k, top_p (nucleus) for stochastic beam search
268
+
269
+ if stochastic and not greedy:
270
+ logits = filter_logits_fn(logits, **filter_kwargs)
271
+ logits = (logits / temperature) + gumbel_noise(logits)
272
+
273
+ # (gumbel) topk
274
+
275
+ samples = logits.topk(beams, dim = -1).indices
276
+
277
+ # get the scores for keeping track of beams
278
+
279
+ next_scores = log_probs.gather(-1, samples)
280
+
281
+ # expand beam times
282
+
283
+ scores = repeat(scores, 'b -> b beams', beams = beams)
284
+ scores = scores + next_scores
285
+
286
+ out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
287
+ samples = rearrange(samples, 'b beams -> (b beams) 1')
288
+
289
+ if should_cache and is_first:
290
+ modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
291
+
292
+ # concat sample
293
+
294
+ out = torch.cat((out, samples), dim=-1)
295
+
296
+ # sort by score and excise
297
+ # excise out the beams
298
+
299
+ scores = rearrange(scores, '(b prev_beams) next_beams -> b (prev_beams next_beams)', b = batch)
300
+ curr_num_beams = scores.shape[-1]
301
+
302
+ if curr_num_beams > beams:
303
+ scores, sort_indices = scores.sort(dim = -1, descending = True)
304
+
305
+ scores = scores[:, :beams]
306
+ top_beams_indices = sort_indices[:, :beams]
307
+
308
+ top_beams_indices = curr_num_beams * batch_arange[:, None] + top_beams_indices
309
+
310
+ flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
311
+
312
+ out = out[flattened_beam_indices]
313
+
314
+ scores = rearrange(scores, 'b beams -> (b beams)')
315
+
316
+ if not exists(eos_token):
317
+ continue
318
+
319
+ is_eos_tokens = (out == eos_token)
320
+
321
+ if is_eos_tokens.any(dim = -1).all():
322
+ break
323
+
324
+ if exists(eos_token):
325
+ # mask out everything after the eos tokens
326
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
327
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
328
+ out = out.masked_fill(mask, self.pad_value)
329
+
330
+ # select out the top beam
331
+
332
+ out = rearrange(out, '(b beams) seq -> b beams seq', b = batch)
333
+
334
+ out = out[..., orig_seq_len:]
335
+
336
+ out, = unpack(out, packed_shape, '* beams n') # prompt may have no batch dimension
337
+
338
+ if not return_beams_and_scores:
339
+ return out[..., 0, :]
340
+
341
+ scores = rearrange(scores, '(b beams) -> beams b', b = batch)
342
+ out = rearrange(out, 'b beams n -> beams b n')
343
+
344
+ return out, scores
345
+
164
346
  @torch.no_grad()
165
347
  @eval_decorator
166
348
  def generate(
File without changes