x-transformers 2.2.1__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,14 @@
1
+ from itertools import zip_longest
2
+
1
3
  import torch
4
+ from torch import tensor
2
5
  import torch.nn.functional as F
3
6
  from torch.nn import Module
4
7
  from torch.nn.utils.rnn import pad_sequence
5
8
 
6
9
  from x_transformers.x_transformers import Decoder, TransformerWrapper
7
10
 
11
+ import einx
8
12
  from einops import repeat, rearrange
9
13
 
10
14
  # helper functions
@@ -36,29 +40,56 @@ class EntropyBasedTokenizer(Module):
36
40
  def forward(
37
41
  self,
38
42
  seq,
43
+ lens = None, # Int['b']
39
44
  return_segmented_seq = False
40
45
  ):
41
46
  self.decoder.eval()
42
47
 
48
+ is_var_length = exists(lens)
43
49
  batch, seq_len, device = *seq.shape, seq.device
44
50
 
51
+ arange = torch.arange(seq_len, device = device)
52
+
53
+ # forward through a small trained decoder and get the entropies of the logits
54
+
45
55
  _, intermediates = self.decoder(seq, return_logit_entropies = True)
46
56
 
47
57
  entropies = intermediates.logit_entropies
48
58
 
49
- over_thres_mask = entropies >= self.entropy_threshold
59
+ # get length mask for boundaries
60
+
61
+ mask = tensor(True, device = device)
62
+
63
+ if is_var_length:
64
+ mask = einx.less('n, b -> b n', arange, lens)
65
+
66
+ # the mask for tokens that were of a sufficient surprise level
50
67
 
51
- arange = torch.arange(seq_len, device = device) + 1
52
- arange = repeat(arange, 'n -> b n', b = batch)
68
+ over_thres_mask = (entropies >= self.entropy_threshold) & mask
69
+
70
+ # needed for selecting out indices at entropy threshold mask
71
+
72
+ arange_plus_one = arange + 1
73
+ arange_plus_one = repeat(arange_plus_one, 'n -> b n', b = batch)
53
74
 
54
75
  # get a tensor of Int['b num_tokens'] with the token lengths, zero padded
55
76
 
56
77
  boundaries = over_thres_mask.clone()
57
- boundaries[..., -1] = True # last token is always a boundary
78
+
79
+ # set the boundary of the last token
80
+
81
+ # if `lens` not given, assume always last token
82
+ # but if `lens` were given, then properly set the index
83
+
84
+ if not is_var_length:
85
+ boundaries[..., -1] = True
86
+ else:
87
+ scatter_indices = rearrange(lens - 1, 'b -> b 1')
88
+ boundaries.scatter_(-1, scatter_indices, True)
58
89
 
59
90
  num_tokens = boundaries.sum(dim = -1) # number of tokens
60
91
 
61
- boundaries = arange[boundaries].split(num_tokens.tolist())
92
+ boundaries = arange_plus_one[boundaries].split(num_tokens.tolist())
62
93
 
63
94
  # get the token lengths
64
95
 
@@ -79,12 +110,17 @@ class EntropyBasedTokenizer(Module):
79
110
 
80
111
  # segment the sequence based on the token lengths
81
112
 
113
+ lens = default(lens, (None,))
82
114
  segmented_seq = []
83
115
 
84
- for one_seq, one_token_length in zip(seq, token_lengths):
116
+ for one_seq, one_len, one_token_length in zip_longest(seq, lens, token_lengths):
117
+
118
+ if exists(one_len):
119
+ one_seq = one_seq[:one_len]
85
120
 
86
121
  one_token_length = one_token_length[one_token_length > 0]
87
122
 
123
+ print(one_seq.shape, one_token_length)
88
124
  splitted_seq = one_seq.split(one_token_length.tolist())
89
125
  segmented_seq.append(splitted_seq)
90
126
 
@@ -2006,7 +2006,7 @@ class AttentionLayers(Module):
2006
2006
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
2007
2007
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
2008
2008
 
2009
- assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), or T5 relative positional bias'
2009
+ assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
2010
2010
  assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
2011
2011
 
2012
2012
  # relative positional bias
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.2.1
3
+ Version: 2.2.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -4,14 +4,14 @@ x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67K
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
- x_transformers/entropy_based_tokenizer.py,sha256=2Foh9tBUL55Lu0CcgA8kWzyfHYI7DQHBF3zK_WdQt0o,2519
7
+ x_transformers/entropy_based_tokenizer.py,sha256=hdYfw8GqMj0YVWY_gpaCCzhkMALnQB9yAUaCg8RWMss,3624
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=twoqq2kfVWxntitHKLs2sxFMhK1CPLxGGBDAmkiHXcM,111812
11
+ x_transformers/x_transformers.py,sha256=Fl2CuAKTxJDOQvqwQo2FK8eO2s1iNLO-P1PP2Yw64rQ,111826
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
14
- x_transformers-2.2.1.dist-info/METADATA,sha256=SEoe84lxCNvvJ8HVRqaLoESrTYRgQ9dc6Y-Po8nhxbg,88686
15
- x_transformers-2.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.2.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.2.1.dist-info/RECORD,,
14
+ x_transformers-2.2.2.dist-info/METADATA,sha256=V0g-qeMS5RoayZAttx8bGdkIh0ZAsVmDDfTiPNI5qHM,88686
15
+ x_transformers-2.2.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.2.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.2.2.dist-info/RECORD,,