x-transformers 2.2.7__py3-none-any.whl → 2.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/entropy_based_tokenizer.py +28 -3
- x_transformers/x_transformers.py +4 -1
- {x_transformers-2.2.7.dist-info → x_transformers-2.2.9.dist-info}/METADATA +1 -1
- {x_transformers-2.2.7.dist-info → x_transformers-2.2.9.dist-info}/RECORD +6 -6
- {x_transformers-2.2.7.dist-info → x_transformers-2.2.9.dist-info}/WHEEL +0 -0
- {x_transformers-2.2.7.dist-info → x_transformers-2.2.9.dist-info}/licenses/LICENSE +0 -0
@@ -33,12 +33,15 @@ class EntropyBasedTokenizer(Module):
|
|
33
33
|
def __init__(
|
34
34
|
self,
|
35
35
|
decoder: Module,
|
36
|
-
entropy_threshold: float
|
36
|
+
entropy_threshold: float,
|
37
|
+
max_token_size: int | None = None
|
37
38
|
):
|
38
39
|
super().__init__()
|
39
40
|
self.decoder = decoder
|
40
41
|
self.entropy_threshold = entropy_threshold
|
41
42
|
|
43
|
+
self.max_token_size = max_token_size
|
44
|
+
|
42
45
|
@torch.no_grad()
|
43
46
|
def forward(
|
44
47
|
self,
|
@@ -53,7 +56,7 @@ class EntropyBasedTokenizer(Module):
|
|
53
56
|
self.decoder.eval()
|
54
57
|
|
55
58
|
is_var_length = exists(lens)
|
56
|
-
batch, seq_len, device = *seq.shape, seq.device
|
59
|
+
batch, seq_len, device, max_token_size = *seq.shape, seq.device, self.max_token_size
|
57
60
|
|
58
61
|
arange = torch.arange(seq_len, device = device)
|
59
62
|
|
@@ -94,7 +97,29 @@ class EntropyBasedTokenizer(Module):
|
|
94
97
|
scatter_indices = rearrange(lens - 1, 'b -> b 1')
|
95
98
|
boundaries.scatter_(-1, scatter_indices, True)
|
96
99
|
|
97
|
-
|
100
|
+
# handle max token size - technique has the flaw that repeating subsequences are grouped into one large token
|
101
|
+
|
102
|
+
if exists(max_token_size):
|
103
|
+
token_ids = boundaries.cumsum(dim = -1)
|
104
|
+
token_ids = F.pad(token_ids, (1, -1), value = 0)
|
105
|
+
|
106
|
+
max_num_tokens = boundaries.sum(dim = -1).amax().item()
|
107
|
+
token_ids_seq = torch.arange(max_num_tokens, device = device)
|
108
|
+
|
109
|
+
token_mask = einx.equal('j, b i -> b j i', token_ids_seq, token_ids)
|
110
|
+
|
111
|
+
token_sub_seq_arange = token_mask.cumsum(dim = -1)
|
112
|
+
|
113
|
+
sub_seq_boundaries = (token_sub_seq_arange % max_token_size == 0)
|
114
|
+
sub_seq_boundaries = (sub_seq_boundaries & token_mask).any(dim = 1)
|
115
|
+
|
116
|
+
boundaries = boundaries | sub_seq_boundaries
|
117
|
+
|
118
|
+
# number of tokens
|
119
|
+
|
120
|
+
num_tokens = boundaries.sum(dim = -1)
|
121
|
+
|
122
|
+
# get number of tokens as well as derived indices
|
98
123
|
|
99
124
|
indices = arange_plus_one[boundaries].split(num_tokens.tolist())
|
100
125
|
|
x_transformers/x_transformers.py
CHANGED
@@ -1196,6 +1196,7 @@ class FeedForward(Module):
|
|
1196
1196
|
glu_mult_bias = False,
|
1197
1197
|
swish = False,
|
1198
1198
|
relu_squared = False,
|
1199
|
+
custom_activation = None,
|
1199
1200
|
post_act_ln = False,
|
1200
1201
|
dropout = 0.,
|
1201
1202
|
no_bias = False,
|
@@ -1205,7 +1206,9 @@ class FeedForward(Module):
|
|
1205
1206
|
inner_dim = int(dim * mult)
|
1206
1207
|
dim_out = default(dim_out, dim)
|
1207
1208
|
|
1208
|
-
if
|
1209
|
+
if exists(custom_activation):
|
1210
|
+
activation = deepcopy(custom_activation)
|
1211
|
+
elif relu_squared:
|
1209
1212
|
activation = ReluSquared()
|
1210
1213
|
elif swish:
|
1211
1214
|
activation = nn.SiLU()
|
@@ -4,14 +4,14 @@ x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67K
|
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
|
-
x_transformers/entropy_based_tokenizer.py,sha256=
|
7
|
+
x_transformers/entropy_based_tokenizer.py,sha256=xvMJvQFT_W_E5zObBF8dn5gtGNLgF5540MJaobZ4lnY,4906
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=m2xiiTafFZiII-QZLCpPerdWbY8O41I6BAYCaaPdXig,111953
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.2.
|
15
|
-
x_transformers-2.2.
|
16
|
-
x_transformers-2.2.
|
17
|
-
x_transformers-2.2.
|
14
|
+
x_transformers-2.2.9.dist-info/METADATA,sha256=4r5W3sk0VVAVY7o5jZrrojbDVVh9oTVm37GNbvoLn6Y,88686
|
15
|
+
x_transformers-2.2.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.2.9.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.2.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|