x-transformers 2.4.10__tar.gz → 2.4.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.10 → x_transformers-2.4.11}/PKG-INFO +1 -1
  2. {x_transformers-2.4.10 → x_transformers-2.4.11}/pyproject.toml +1 -1
  3. {x_transformers-2.4.10 → x_transformers-2.4.11}/tests/test_x_transformers.py +5 -0
  4. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/autoregressive_wrapper.py +13 -7
  5. {x_transformers-2.4.10 → x_transformers-2.4.11}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.4.10 → x_transformers-2.4.11}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.4.10 → x_transformers-2.4.11}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.4.10 → x_transformers-2.4.11}/.gitignore +0 -0
  9. {x_transformers-2.4.10 → x_transformers-2.4.11}/LICENSE +0 -0
  10. {x_transformers-2.4.10 → x_transformers-2.4.11}/README.md +0 -0
  11. {x_transformers-2.4.10 → x_transformers-2.4.11}/data/README.md +0 -0
  12. {x_transformers-2.4.10 → x_transformers-2.4.11}/data/enwik8.gz +0 -0
  13. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/all-attention.png +0 -0
  14. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/deepnorm.png +0 -0
  17. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/fcm.png +0 -0
  23. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/ffglu.png +0 -0
  24. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/flash-attention.png +0 -0
  25. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/gate_values.png +0 -0
  26. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/gating.png +0 -0
  27. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/macaron-1.png +0 -0
  29. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/macaron-2.png +0 -0
  30. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/normformer.png +0 -0
  32. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/pia.png +0 -0
  33. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/resi_dual.png +0 -0
  35. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/residual_attn.png +0 -0
  36. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/rezero.png +0 -0
  37. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/rotary.png +0 -0
  38. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/sandwich.png +0 -0
  40. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/scalenorm.png +0 -0
  42. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/talking-heads.png +0 -0
  43. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/topk-attention.png +0 -0
  44. {x_transformers-2.4.10 → x_transformers-2.4.11}/images/xval.png +0 -0
  45. {x_transformers-2.4.10 → x_transformers-2.4.11}/train_belief_state.py +0 -0
  46. {x_transformers-2.4.10 → x_transformers-2.4.11}/train_copy.py +0 -0
  47. {x_transformers-2.4.10 → x_transformers-2.4.11}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.4.10 → x_transformers-2.4.11}/train_enwik8.py +0 -0
  49. {x_transformers-2.4.10 → x_transformers-2.4.11}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.4.10 → x_transformers-2.4.11}/train_parity.py +0 -0
  51. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/belief_state_wrapper.py +0 -0
  54. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/continuous.py +0 -0
  55. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/entropy_based_tokenizer.py +0 -0
  57. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/multi_input.py +0 -0
  58. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/neo_mlp.py +0 -0
  59. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/nonautoregressive_wrapper.py +0 -0
  60. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/up_wrapper.py +0 -0
  61. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/x_transformers.py +0 -0
  62. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.10 → x_transformers-2.4.11}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.10
3
+ Version: 2.4.11
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.10"
3
+ version = "2.4.11"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1148,3 +1148,8 @@ def test_beam_search(stochastic):
1148
1148
  generated = wrapper.beam_search(x[:, :1], 10, beams = 4, stochastic = stochastic)
1149
1149
 
1150
1150
  assert generated.shape == (2, 10)
1151
+
1152
+ beams, scores = wrapper.beam_search(x[:, :1], 10, beams = 4, return_beams_and_scores = True, stochastic = stochastic)
1153
+
1154
+ assert beams.shape == (4, 2, 10)
1155
+ assert scores.shape == (4, 2)
@@ -183,6 +183,7 @@ class AutoregressiveWrapper(Module):
183
183
  prompts,
184
184
  seq_len,
185
185
  beams = 4,
186
+ return_beams_and_scores = False,
186
187
  eos_token = None,
187
188
  temperature = 1.,
188
189
  stochastic = False,
@@ -285,7 +286,7 @@ class AutoregressiveWrapper(Module):
285
286
  out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
286
287
  samples = rearrange(samples, 'b beams -> (b beams) 1')
287
288
 
288
- if should_cache:
289
+ if should_cache and is_first:
289
290
  modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
290
291
 
291
292
  # concat sample
@@ -303,14 +304,13 @@ class AutoregressiveWrapper(Module):
303
304
 
304
305
  scores = scores[:, :beams]
305
306
  top_beams_indices = sort_indices[:, :beams]
306
- top_beams_indices = beams * batch_arange[:, None] + top_beams_indices
307
+
308
+ top_beams_indices = curr_num_beams * batch_arange[:, None] + top_beams_indices
307
309
 
308
310
  flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
309
311
 
310
312
  out = out[flattened_beam_indices]
311
313
 
312
- modify_cached_kv(cache, lambda t: t[flattened_beam_indices])
313
-
314
314
  scores = rearrange(scores, 'b beams -> (b beams)')
315
315
 
316
316
  if not exists(eos_token):
@@ -331,11 +331,17 @@ class AutoregressiveWrapper(Module):
331
331
 
332
332
  out = rearrange(out, '(b beams) seq -> b beams seq', b = batch)
333
333
 
334
- out = out[:, 0, orig_seq_len:]
334
+ out = out[..., orig_seq_len:]
335
335
 
336
- out, = unpack(out, packed_shape, '* n')
336
+ out, = unpack(out, packed_shape, '* beams n') # prompt may have no batch dimension
337
337
 
338
- return out
338
+ if not return_beams_and_scores:
339
+ return out[..., 0, :]
340
+
341
+ scores = rearrange(scores, '(b beams) -> beams b', b = batch)
342
+ out = rearrange(out, 'b beams n -> beams b n')
343
+
344
+ return out, scores
339
345
 
340
346
  @torch.no_grad()
341
347
  @eval_decorator
File without changes