x-transformers 2.2.1__py3-none-any.whl → 2.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/entropy_based_tokenizer.py +44 -9
- x_transformers/x_transformers.py +1 -1
- {x_transformers-2.2.1.dist-info → x_transformers-2.2.3.dist-info}/METADATA +1 -1
- {x_transformers-2.2.1.dist-info → x_transformers-2.2.3.dist-info}/RECORD +6 -6
- {x_transformers-2.2.1.dist-info → x_transformers-2.2.3.dist-info}/WHEEL +0 -0
- {x_transformers-2.2.1.dist-info → x_transformers-2.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,14 @@
|
|
1
|
+
from itertools import zip_longest
|
2
|
+
|
1
3
|
import torch
|
4
|
+
from torch import tensor
|
2
5
|
import torch.nn.functional as F
|
3
6
|
from torch.nn import Module
|
4
7
|
from torch.nn.utils.rnn import pad_sequence
|
5
8
|
|
6
9
|
from x_transformers.x_transformers import Decoder, TransformerWrapper
|
7
10
|
|
11
|
+
import einx
|
8
12
|
from einops import repeat, rearrange
|
9
13
|
|
10
14
|
# helper functions
|
@@ -36,37 +40,64 @@ class EntropyBasedTokenizer(Module):
|
|
36
40
|
def forward(
|
37
41
|
self,
|
38
42
|
seq,
|
43
|
+
lens = None, # Int['b']
|
39
44
|
return_segmented_seq = False
|
40
45
|
):
|
41
46
|
self.decoder.eval()
|
42
47
|
|
48
|
+
is_var_length = exists(lens)
|
43
49
|
batch, seq_len, device = *seq.shape, seq.device
|
44
50
|
|
51
|
+
arange = torch.arange(seq_len, device = device)
|
52
|
+
|
53
|
+
# forward through a small trained decoder and get the entropies of the logits
|
54
|
+
|
45
55
|
_, intermediates = self.decoder(seq, return_logit_entropies = True)
|
46
56
|
|
47
57
|
entropies = intermediates.logit_entropies
|
48
58
|
|
49
|
-
|
59
|
+
# get length mask for boundaries
|
60
|
+
|
61
|
+
mask = tensor(True, device = device)
|
62
|
+
|
63
|
+
if is_var_length:
|
64
|
+
mask = einx.less('n, b -> b n', arange, lens)
|
65
|
+
|
66
|
+
# the mask for tokens that were of a sufficient surprise level
|
50
67
|
|
51
|
-
|
52
|
-
|
68
|
+
over_thres_mask = (entropies >= self.entropy_threshold) & mask
|
69
|
+
|
70
|
+
# needed for selecting out indices at entropy threshold mask
|
71
|
+
|
72
|
+
arange_plus_one = arange + 1
|
73
|
+
arange_plus_one = repeat(arange_plus_one, 'n -> b n', b = batch)
|
53
74
|
|
54
75
|
# get a tensor of Int['b num_tokens'] with the token lengths, zero padded
|
55
76
|
|
56
77
|
boundaries = over_thres_mask.clone()
|
57
|
-
|
78
|
+
|
79
|
+
# set the boundary of the last token
|
80
|
+
|
81
|
+
# if `lens` not given, assume always last token
|
82
|
+
# but if `lens` were given, then properly set the index
|
83
|
+
|
84
|
+
if not is_var_length:
|
85
|
+
boundaries[..., -1] = True
|
86
|
+
else:
|
87
|
+
scatter_indices = rearrange(lens - 1, 'b -> b 1')
|
88
|
+
boundaries.scatter_(-1, scatter_indices, True)
|
58
89
|
|
59
90
|
num_tokens = boundaries.sum(dim = -1) # number of tokens
|
60
91
|
|
61
|
-
|
92
|
+
indices = arange_plus_one[boundaries].split(num_tokens.tolist())
|
62
93
|
|
63
94
|
# get the token lengths
|
64
95
|
|
65
96
|
token_lengths = []
|
66
97
|
|
67
|
-
for
|
68
|
-
|
69
|
-
one_token_lengths =
|
98
|
+
for one_indices in indices:
|
99
|
+
padded_indices = F.pad(one_indices, (1, 0), value = 0.)
|
100
|
+
one_token_lengths = padded_indices[1:] - padded_indices[:-1]
|
70
101
|
|
71
102
|
token_lengths.append(one_token_lengths)
|
72
103
|
|
@@ -79,9 +110,13 @@ class EntropyBasedTokenizer(Module):
|
|
79
110
|
|
80
111
|
# segment the sequence based on the token lengths
|
81
112
|
|
113
|
+
lens = default(lens, (None,))
|
82
114
|
segmented_seq = []
|
83
115
|
|
84
|
-
for one_seq, one_token_length in
|
116
|
+
for one_seq, one_len, one_token_length in zip_longest(seq, lens, token_lengths):
|
117
|
+
|
118
|
+
if exists(one_len):
|
119
|
+
one_seq = one_seq[:one_len]
|
85
120
|
|
86
121
|
one_token_length = one_token_length[one_token_length > 0]
|
87
122
|
|
x_transformers/x_transformers.py
CHANGED
@@ -2006,7 +2006,7 @@ class AttentionLayers(Module):
|
|
2006
2006
|
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
|
2007
2007
|
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
|
2008
2008
|
|
2009
|
-
assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), or T5 relative positional bias'
|
2009
|
+
assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
|
2010
2010
|
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
2011
2011
|
|
2012
2012
|
# relative positional bias
|
@@ -4,14 +4,14 @@ x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67K
|
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
|
-
x_transformers/entropy_based_tokenizer.py,sha256=
|
7
|
+
x_transformers/entropy_based_tokenizer.py,sha256=sEOf_J_9PGNFKPZ9Gks3MwNjiTa_JljjO1OU3ubIziQ,3562
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=Fl2CuAKTxJDOQvqwQo2FK8eO2s1iNLO-P1PP2Yw64rQ,111826
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.2.
|
15
|
-
x_transformers-2.2.
|
16
|
-
x_transformers-2.2.
|
17
|
-
x_transformers-2.2.
|
14
|
+
x_transformers-2.2.3.dist-info/METADATA,sha256=aHA3vvgyUsxooOXzLLPtw0FmRGYrqdzMXe2QVoRJk04,88686
|
15
|
+
x_transformers-2.2.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.2.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.2.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|