x-transformers 2.5.3__tar.gz → 2.5.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.5.3 → x_transformers-2.5.4}/PKG-INFO +1 -1
  2. {x_transformers-2.5.3 → x_transformers-2.5.4}/pyproject.toml +1 -1
  3. {x_transformers-2.5.3 → x_transformers-2.5.4}/tests/test_x_transformers.py +29 -0
  4. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/autoregressive_wrapper.py +18 -4
  5. {x_transformers-2.5.3 → x_transformers-2.5.4}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.5.3 → x_transformers-2.5.4}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.5.3 → x_transformers-2.5.4}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.5.3 → x_transformers-2.5.4}/.gitignore +0 -0
  9. {x_transformers-2.5.3 → x_transformers-2.5.4}/LICENSE +0 -0
  10. {x_transformers-2.5.3 → x_transformers-2.5.4}/README.md +0 -0
  11. {x_transformers-2.5.3 → x_transformers-2.5.4}/data/README.md +0 -0
  12. {x_transformers-2.5.3 → x_transformers-2.5.4}/data/enwik8.gz +0 -0
  13. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/all-attention.png +0 -0
  14. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/deepnorm.png +0 -0
  17. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/fcm.png +0 -0
  23. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/ffglu.png +0 -0
  24. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/flash-attention.png +0 -0
  25. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/gate_values.png +0 -0
  26. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/gating.png +0 -0
  27. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/macaron-1.png +0 -0
  29. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/macaron-2.png +0 -0
  30. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/normformer.png +0 -0
  32. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/pia.png +0 -0
  33. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/resi_dual.png +0 -0
  35. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/residual_attn.png +0 -0
  36. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/rezero.png +0 -0
  37. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/rotary.png +0 -0
  38. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/sandwich.png +0 -0
  40. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/scalenorm.png +0 -0
  42. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/talking-heads.png +0 -0
  43. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/topk-attention.png +0 -0
  44. {x_transformers-2.5.3 → x_transformers-2.5.4}/images/xval.png +0 -0
  45. {x_transformers-2.5.3 → x_transformers-2.5.4}/train_belief_state.py +0 -0
  46. {x_transformers-2.5.3 → x_transformers-2.5.4}/train_copy.py +0 -0
  47. {x_transformers-2.5.3 → x_transformers-2.5.4}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.5.3 → x_transformers-2.5.4}/train_enwik8.py +0 -0
  49. {x_transformers-2.5.3 → x_transformers-2.5.4}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.5.3 → x_transformers-2.5.4}/train_parity.py +0 -0
  51. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/belief_state_wrapper.py +0 -0
  54. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/continuous.py +0 -0
  55. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/entropy_based_tokenizer.py +0 -0
  57. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/multi_input.py +0 -0
  58. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/neo_mlp.py +0 -0
  59. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/nonautoregressive_wrapper.py +0 -0
  60. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/up_wrapper.py +0 -0
  61. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/x_transformers.py +0 -0
  62. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.5.3
3
+ Version: 2.5.4
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.5.3"
3
+ version = "2.5.4"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1181,3 +1181,32 @@ def test_attn_pooler(
1181
1181
  out = model(x)
1182
1182
 
1183
1183
  assert out.shape == (2, num_pooled_tokens, 77)
1184
+
1185
+ def test_prompts_given_as_list_tensor():
1186
+ from x_transformers import AutoregressiveWrapper
1187
+
1188
+ model = TransformerWrapper(
1189
+ num_tokens = 20000,
1190
+ max_seq_len = 1024,
1191
+ attn_layers = Decoder(
1192
+ dim = 512,
1193
+ depth = 12,
1194
+ heads = 8
1195
+ )
1196
+ )
1197
+
1198
+ wrapped = AutoregressiveWrapper(model)
1199
+
1200
+ seq = torch.randint(0, 20000, (3, 1024))
1201
+
1202
+ loss = wrapped(seq)
1203
+ loss.backward()
1204
+
1205
+ sampled = wrapped.generate([
1206
+ torch.randint(0, 20000, (3,)),
1207
+ torch.randint(0, 20000, (5,)),
1208
+ torch.randint(0, 20000, (2,)),
1209
+ torch.randint(0, 20000, (7,)),
1210
+ ], 256)
1211
+
1212
+ assert sampled.shape == (4, 256)
@@ -4,9 +4,10 @@ from math import ceil, log
4
4
  from typing import Tuple, Callable
5
5
 
6
6
  import torch
7
- from torch import nn, Tensor
7
+ from torch import nn, tensor, Tensor
8
8
  from torch.nn import Module
9
9
  import torch.nn.functional as F
10
+ from torch.nn.utils.rnn import pad_sequence
10
11
 
11
12
  from einops import rearrange, repeat, pack, unpack
12
13
 
@@ -347,7 +348,7 @@ class AutoregressiveWrapper(Module):
347
348
  @eval_decorator
348
349
  def generate(
349
350
  self,
350
- prompts,
351
+ prompts: list[Tensor] | Tensor,
351
352
  seq_len,
352
353
  eos_token = None,
353
354
  temperature = 1.,
@@ -363,11 +364,23 @@ class AutoregressiveWrapper(Module):
363
364
  cache_kv = True,
364
365
  **kwargs
365
366
  ):
366
- max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
367
+ max_seq_len, greedy = self.max_seq_len, temperature == 0.
368
+
369
+ # handle prompts given as list of variable lengthed token ids
370
+
371
+ if isinstance(prompts, list):
372
+ assert len(prompts) > 0, 'prompts cannot be empty list'
373
+ assert not exists(prompt_lens), '`prompt_len` will be auto derived if prompts are passed in as list of Tensors'
374
+
375
+ prompt_lens = tensor([t.shape[0] for t in prompts], device = prompts[0].device)
376
+
377
+ prompts = pad_sequence(prompts, batch_first = True)
378
+
379
+ # pack maybe no batch
367
380
 
368
381
  prompts, ps = pack([prompts], '* n')
369
382
 
370
- b, t = prompts.shape
383
+ b, t, device = *prompts.shape, prompts.device
371
384
 
372
385
  # handle filter logits fn given as string
373
386
 
@@ -380,6 +393,7 @@ class AutoregressiveWrapper(Module):
380
393
 
381
394
  seq_start_pos = None
382
395
  if exists(prompt_lens):
396
+ print('prompt lens')
383
397
  prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
384
398
  seq_start_pos = t - prompt_lens
385
399
 
File without changes
File without changes