x-transformers 2.1.36__tar.gz → 2.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {x_transformers-2.1.36 → x_transformers-2.2.0}/PKG-INFO +12 -1
  2. {x_transformers-2.1.36 → x_transformers-2.2.0}/README.md +11 -0
  3. {x_transformers-2.1.36 → x_transformers-2.2.0}/pyproject.toml +1 -1
  4. {x_transformers-2.1.36 → x_transformers-2.2.0}/tests/test_x_transformers.py +22 -0
  5. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/__init__.py +2 -0
  6. x_transformers-2.2.0/x_transformers/entropy_based_tokenizer.py +91 -0
  7. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/x_transformers.py +5 -2
  8. {x_transformers-2.1.36 → x_transformers-2.2.0}/.github/FUNDING.yml +0 -0
  9. {x_transformers-2.1.36 → x_transformers-2.2.0}/.github/workflows/python-publish.yml +0 -0
  10. {x_transformers-2.1.36 → x_transformers-2.2.0}/.github/workflows/python-test.yaml +0 -0
  11. {x_transformers-2.1.36 → x_transformers-2.2.0}/.gitignore +0 -0
  12. {x_transformers-2.1.36 → x_transformers-2.2.0}/LICENSE +0 -0
  13. {x_transformers-2.1.36 → x_transformers-2.2.0}/data/README.md +0 -0
  14. {x_transformers-2.1.36 → x_transformers-2.2.0}/data/enwik8.gz +0 -0
  15. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/all-attention.png +0 -0
  16. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/attention-on-attention.png +0 -0
  17. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/cosine-sim-attention.png +0 -0
  18. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/deepnorm.png +0 -0
  19. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/dynamic-pos-bias-linear.png +0 -0
  20. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/dynamic-pos-bias-log.png +0 -0
  21. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  22. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/dynamic-pos-bias.png +0 -0
  23. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/enhanced-recurrence.png +0 -0
  24. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/fcm.png +0 -0
  25. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/ffglu.png +0 -0
  26. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/flash-attention.png +0 -0
  27. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/gate_values.png +0 -0
  28. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/gating.png +0 -0
  29. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/length-extrapolation-scale.png +0 -0
  30. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/macaron-1.png +0 -0
  31. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/macaron-2.png +0 -0
  32. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/memory-transformer.png +0 -0
  33. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/normformer.png +0 -0
  34. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/pia.png +0 -0
  35. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/qknorm-analysis.png +0 -0
  36. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/resi_dual.png +0 -0
  37. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/residual_attn.png +0 -0
  38. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/rezero.png +0 -0
  39. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/rotary.png +0 -0
  40. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/sandwich-2.png +0 -0
  41. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/sandwich.png +0 -0
  42. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/sandwich_norm.png +0 -0
  43. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/scalenorm.png +0 -0
  44. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/talking-heads.png +0 -0
  45. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/topk-attention.png +0 -0
  46. {x_transformers-2.1.36 → x_transformers-2.2.0}/images/xval.png +0 -0
  47. {x_transformers-2.1.36 → x_transformers-2.2.0}/train_belief_state.py +0 -0
  48. {x_transformers-2.1.36 → x_transformers-2.2.0}/train_copy.py +0 -0
  49. {x_transformers-2.1.36 → x_transformers-2.2.0}/train_enwik8.py +0 -0
  50. {x_transformers-2.1.36 → x_transformers-2.2.0}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.1.36 → x_transformers-2.2.0}/train_parity.py +0 -0
  52. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/multi_input.py +0 -0
  58. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/neo_mlp.py +0 -0
  59. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  61. {x_transformers-2.1.36 → x_transformers-2.2.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.36
3
+ Version: 2.2.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2464,4 +2464,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2464
2464
  }
2465
2465
  ```
2466
2466
 
2467
+ ```bibtex
2468
+ @article{Pagnoni2024ByteLT,
2469
+ title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
2470
+ author = {Artidoro Pagnoni and Ram Pasunuru and Pedro Rodriguez and John Nguyen and Benjamin Muller and Margaret Li and Chunting Zhou and Lili Yu and Jason Weston and Luke S. Zettlemoyer and Gargi Ghosh and Mike Lewis and Ari Holtzman and Srinivasan Iyer},
2471
+ journal = {ArXiv},
2472
+ year = {2024},
2473
+ volume = {abs/2412.09871},
2474
+ url = {https://api.semanticscholar.org/CorpusID:274762821}
2475
+ }
2476
+ ```
2477
+
2467
2478
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2416,4 +2416,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2416
2416
  }
2417
2417
  ```
2418
2418
 
2419
+ ```bibtex
2420
+ @article{Pagnoni2024ByteLT,
2421
+ title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
2422
+ author = {Artidoro Pagnoni and Ram Pasunuru and Pedro Rodriguez and John Nguyen and Benjamin Muller and Margaret Li and Chunting Zhou and Lili Yu and Jason Weston and Luke S. Zettlemoyer and Gargi Ghosh and Mike Lewis and Ari Holtzman and Srinivasan Iyer},
2423
+ journal = {ArXiv},
2424
+ year = {2024},
2425
+ volume = {abs/2412.09871},
2426
+ url = {https://api.semanticscholar.org/CorpusID:274762821}
2427
+ }
2428
+ ```
2429
+
2419
2430
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.36"
3
+ version = "2.2.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -768,3 +768,25 @@ def test_dynamic_tanh():
768
768
  x = torch.randint(0, 20000, (2, 1024))
769
769
 
770
770
  model(x)
771
+
772
+ def test_entropy_based_tokenizer():
773
+ from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer
774
+
775
+ model = TransformerWrapper(
776
+ num_tokens = 20000,
777
+ max_seq_len = 1024,
778
+ attn_layers = Decoder(
779
+ dim = 128,
780
+ depth = 6,
781
+ heads = 8,
782
+ attn_dim_head = 64,
783
+ )
784
+ )
785
+
786
+ tokenizer = EntropyBasedTokenizer(model, entropy_threshold = 9.738)
787
+
788
+ seq = torch.randint(0, 20000, (2, 1024))
789
+
790
+ segmented_seq = tokenizer(seq, return_segmented_seq = True)
791
+
792
+ assert len(segmented_seq) == seq.shape[0]
@@ -37,3 +37,5 @@ from x_transformers.dpo import (
37
37
  from x_transformers.neo_mlp import (
38
38
  NeoMLP
39
39
  )
40
+
41
+ from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer
@@ -0,0 +1,91 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.nn import Module
4
+ from torch.nn.utils.rnn import pad_sequence
5
+
6
+ from x_transformers.x_transformers import Decoder, TransformerWrapper
7
+
8
+ from einops import repeat, rearrange
9
+
10
+ # helper functions
11
+
12
+ def exists(v):
13
+ return v is not None
14
+
15
+ def default(v, d):
16
+ return v if exists(v) else d
17
+
18
+ # entropy based tokenizer applied in byte-latent transformer paper
19
+ # they use a simple entropy threshold for segmenting a string into variable sized tokens
20
+
21
+ # https://arxiv.org/abs/2412.09871
22
+
23
+ class EntropyBasedTokenizer(Module):
24
+ def __init__(
25
+ self,
26
+ decoder: TransformerWrapper,
27
+ entropy_threshold = 1.5
28
+ ):
29
+ super().__init__()
30
+ assert isinstance(decoder.attn_layers, Decoder)
31
+
32
+ self.decoder = decoder
33
+ self.entropy_threshold = entropy_threshold
34
+
35
+ @torch.no_grad()
36
+ def forward(
37
+ self,
38
+ seq,
39
+ return_segmented_seq = False
40
+ ):
41
+ self.decoder.eval()
42
+
43
+ batch, seq_len, device = *seq.shape, seq.device
44
+
45
+ _, intermediates = self.decoder(seq, return_logit_entropies = True)
46
+
47
+ entropies = intermediates.logit_entropies
48
+
49
+ over_thres_mask = entropies >= self.entropy_threshold
50
+
51
+ arange = torch.arange(seq_len, device = device) + 1
52
+ arange = repeat(arange, 'n -> b n', b = batch)
53
+
54
+ # get a tensor of Int['b num_tokens'] with the token lengths, zero padded
55
+
56
+ boundaries = over_thres_mask.clone()
57
+ boundaries[..., -1] = True # last token is always a boundary
58
+
59
+ num_tokens = boundaries.sum(dim = -1) # number of tokens
60
+
61
+ boundaries = arange[boundaries].split(num_tokens.tolist())
62
+
63
+ # get the token lengths
64
+
65
+ token_lengths = []
66
+
67
+ for one_boundary in boundaries:
68
+ padded_boundary = F.pad(one_boundary, (1, 0), value = 0.)
69
+ one_token_lengths = padded_boundary[1:] - padded_boundary[:-1]
70
+
71
+ token_lengths.append(one_token_lengths)
72
+
73
+ token_lengths = pad_sequence(token_lengths, batch_first = True)
74
+
75
+ # early return
76
+
77
+ if not return_segmented_seq:
78
+ return token_lengths
79
+
80
+ # segment the sequence based on the token lengths
81
+
82
+ segmented_seq = []
83
+
84
+ for one_seq, one_token_length in zip(seq, token_lengths):
85
+
86
+ one_token_length = one_token_length[one_token_length > 0]
87
+
88
+ splitted_seq = one_seq.split(one_token_length.tolist())
89
+ segmented_seq.append(splitted_seq)
90
+
91
+ return segmented_seq
@@ -2909,6 +2909,7 @@ class TransformerWrapper(Module):
2909
2909
  return_embeddings = False,
2910
2910
  return_logits_and_embeddings = False,
2911
2911
  return_intermediates = False,
2912
+ return_embeddings_and_intermediates = False,
2912
2913
  return_logit_entropies = False,
2913
2914
  mask = None,
2914
2915
  return_mems = False,
@@ -2940,8 +2941,8 @@ class TransformerWrapper(Module):
2940
2941
 
2941
2942
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2942
2943
 
2943
- return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2944
- return_embeddings = return_embeddings | (not exists(self.to_logits))
2944
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
2945
+ return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
2945
2946
 
2946
2947
  # absolute positional embedding
2947
2948
 
@@ -3131,6 +3132,8 @@ class TransformerWrapper(Module):
3131
3132
 
3132
3133
  if return_logits_and_embeddings:
3133
3134
  out = (logits, x)
3135
+ elif return_embeddings_and_intermediates:
3136
+ out = (x, intermediates)
3134
3137
  elif return_embeddings:
3135
3138
  out = x
3136
3139
  else:
File without changes