x-transformers 2.2.0__tar.gz → 2.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {x_transformers-2.2.0 → x_transformers-2.2.2}/PKG-INFO +1 -1
  2. {x_transformers-2.2.0 → x_transformers-2.2.2}/pyproject.toml +1 -1
  3. {x_transformers-2.2.0 → x_transformers-2.2.2}/tests/test_x_transformers.py +9 -2
  4. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/entropy_based_tokenizer.py +43 -7
  5. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/x_transformers.py +1 -1
  6. {x_transformers-2.2.0 → x_transformers-2.2.2}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.2.0 → x_transformers-2.2.2}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.2.0 → x_transformers-2.2.2}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.2.0 → x_transformers-2.2.2}/.gitignore +0 -0
  10. {x_transformers-2.2.0 → x_transformers-2.2.2}/LICENSE +0 -0
  11. {x_transformers-2.2.0 → x_transformers-2.2.2}/README.md +0 -0
  12. {x_transformers-2.2.0 → x_transformers-2.2.2}/data/README.md +0 -0
  13. {x_transformers-2.2.0 → x_transformers-2.2.2}/data/enwik8.gz +0 -0
  14. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/all-attention.png +0 -0
  15. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/deepnorm.png +0 -0
  18. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/fcm.png +0 -0
  24. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/ffglu.png +0 -0
  25. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/flash-attention.png +0 -0
  26. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/gate_values.png +0 -0
  27. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/gating.png +0 -0
  28. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/macaron-1.png +0 -0
  30. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/macaron-2.png +0 -0
  31. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/normformer.png +0 -0
  33. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/pia.png +0 -0
  34. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/resi_dual.png +0 -0
  36. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/residual_attn.png +0 -0
  37. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/rezero.png +0 -0
  38. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/rotary.png +0 -0
  39. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/sandwich.png +0 -0
  41. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/scalenorm.png +0 -0
  43. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/talking-heads.png +0 -0
  44. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/topk-attention.png +0 -0
  45. {x_transformers-2.2.0 → x_transformers-2.2.2}/images/xval.png +0 -0
  46. {x_transformers-2.2.0 → x_transformers-2.2.2}/train_belief_state.py +0 -0
  47. {x_transformers-2.2.0 → x_transformers-2.2.2}/train_copy.py +0 -0
  48. {x_transformers-2.2.0 → x_transformers-2.2.2}/train_enwik8.py +0 -0
  49. {x_transformers-2.2.0 → x_transformers-2.2.2}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.2.0 → x_transformers-2.2.2}/train_parity.py +0 -0
  51. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/multi_input.py +0 -0
  58. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/neo_mlp.py +0 -0
  59. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  60. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  61. {x_transformers-2.2.0 → x_transformers-2.2.2}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.2.0
3
+ Version: 2.2.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.2.0"
3
+ version = "2.2.2"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -769,7 +769,10 @@ def test_dynamic_tanh():
769
769
 
770
770
  model(x)
771
771
 
772
- def test_entropy_based_tokenizer():
772
+ @pytest.mark.parametrize('var_length', (False, True))
773
+ def test_entropy_based_tokenizer(
774
+ var_length
775
+ ):
773
776
  from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer
774
777
 
775
778
  model = TransformerWrapper(
@@ -787,6 +790,10 @@ def test_entropy_based_tokenizer():
787
790
 
788
791
  seq = torch.randint(0, 20000, (2, 1024))
789
792
 
790
- segmented_seq = tokenizer(seq, return_segmented_seq = True)
793
+ lens = None
794
+ if var_length:
795
+ lens = torch.randint(512, 768, (2,))
796
+
797
+ segmented_seq = tokenizer(seq, lens, return_segmented_seq = True)
791
798
 
792
799
  assert len(segmented_seq) == seq.shape[0]
@@ -1,10 +1,14 @@
1
+ from itertools import zip_longest
2
+
1
3
  import torch
4
+ from torch import tensor
2
5
  import torch.nn.functional as F
3
6
  from torch.nn import Module
4
7
  from torch.nn.utils.rnn import pad_sequence
5
8
 
6
9
  from x_transformers.x_transformers import Decoder, TransformerWrapper
7
10
 
11
+ import einx
8
12
  from einops import repeat, rearrange
9
13
 
10
14
  # helper functions
@@ -24,7 +28,7 @@ class EntropyBasedTokenizer(Module):
24
28
  def __init__(
25
29
  self,
26
30
  decoder: TransformerWrapper,
27
- entropy_threshold = 1.5
31
+ entropy_threshold: float
28
32
  ):
29
33
  super().__init__()
30
34
  assert isinstance(decoder.attn_layers, Decoder)
@@ -36,29 +40,56 @@ class EntropyBasedTokenizer(Module):
36
40
  def forward(
37
41
  self,
38
42
  seq,
43
+ lens = None, # Int['b']
39
44
  return_segmented_seq = False
40
45
  ):
41
46
  self.decoder.eval()
42
47
 
48
+ is_var_length = exists(lens)
43
49
  batch, seq_len, device = *seq.shape, seq.device
44
50
 
51
+ arange = torch.arange(seq_len, device = device)
52
+
53
+ # forward through a small trained decoder and get the entropies of the logits
54
+
45
55
  _, intermediates = self.decoder(seq, return_logit_entropies = True)
46
56
 
47
57
  entropies = intermediates.logit_entropies
48
58
 
49
- over_thres_mask = entropies >= self.entropy_threshold
59
+ # get length mask for boundaries
60
+
61
+ mask = tensor(True, device = device)
62
+
63
+ if is_var_length:
64
+ mask = einx.less('n, b -> b n', arange, lens)
65
+
66
+ # the mask for tokens that were of a sufficient surprise level
50
67
 
51
- arange = torch.arange(seq_len, device = device) + 1
52
- arange = repeat(arange, 'n -> b n', b = batch)
68
+ over_thres_mask = (entropies >= self.entropy_threshold) & mask
69
+
70
+ # needed for selecting out indices at entropy threshold mask
71
+
72
+ arange_plus_one = arange + 1
73
+ arange_plus_one = repeat(arange_plus_one, 'n -> b n', b = batch)
53
74
 
54
75
  # get a tensor of Int['b num_tokens'] with the token lengths, zero padded
55
76
 
56
77
  boundaries = over_thres_mask.clone()
57
- boundaries[..., -1] = True # last token is always a boundary
78
+
79
+ # set the boundary of the last token
80
+
81
+ # if `lens` not given, assume always last token
82
+ # but if `lens` were given, then properly set the index
83
+
84
+ if not is_var_length:
85
+ boundaries[..., -1] = True
86
+ else:
87
+ scatter_indices = rearrange(lens - 1, 'b -> b 1')
88
+ boundaries.scatter_(-1, scatter_indices, True)
58
89
 
59
90
  num_tokens = boundaries.sum(dim = -1) # number of tokens
60
91
 
61
- boundaries = arange[boundaries].split(num_tokens.tolist())
92
+ boundaries = arange_plus_one[boundaries].split(num_tokens.tolist())
62
93
 
63
94
  # get the token lengths
64
95
 
@@ -79,12 +110,17 @@ class EntropyBasedTokenizer(Module):
79
110
 
80
111
  # segment the sequence based on the token lengths
81
112
 
113
+ lens = default(lens, (None,))
82
114
  segmented_seq = []
83
115
 
84
- for one_seq, one_token_length in zip(seq, token_lengths):
116
+ for one_seq, one_len, one_token_length in zip_longest(seq, lens, token_lengths):
117
+
118
+ if exists(one_len):
119
+ one_seq = one_seq[:one_len]
85
120
 
86
121
  one_token_length = one_token_length[one_token_length > 0]
87
122
 
123
+ print(one_seq.shape, one_token_length)
88
124
  splitted_seq = one_seq.split(one_token_length.tolist())
89
125
  segmented_seq.append(splitted_seq)
90
126
 
@@ -2006,7 +2006,7 @@ class AttentionLayers(Module):
2006
2006
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
2007
2007
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
2008
2008
 
2009
- assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), or T5 relative positional bias'
2009
+ assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
2010
2010
  assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
2011
2011
 
2012
2012
  # relative positional bias
File without changes
File without changes