x-transformers 2.2.6__py3-none-any.whl → 2.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/entropy_based_tokenizer.py +8 -7
- x_transformers/x_transformers.py +4 -1
- {x_transformers-2.2.6.dist-info → x_transformers-2.2.8.dist-info}/METADATA +1 -1
- {x_transformers-2.2.6.dist-info → x_transformers-2.2.8.dist-info}/RECORD +6 -6
- {x_transformers-2.2.6.dist-info → x_transformers-2.2.8.dist-info}/WHEEL +0 -0
- {x_transformers-2.2.6.dist-info → x_transformers-2.2.8.dist-info}/licenses/LICENSE +0 -0
@@ -6,12 +6,6 @@ import torch.nn.functional as F
|
|
6
6
|
from torch.nn import Module
|
7
7
|
from torch.nn.utils.rnn import pad_sequence
|
8
8
|
|
9
|
-
from x_transformers.x_transformers import (
|
10
|
-
Decoder,
|
11
|
-
TransformerWrapper,
|
12
|
-
calc_entropy
|
13
|
-
)
|
14
|
-
|
15
9
|
import einx
|
16
10
|
from einops import repeat, rearrange, pack, unpack
|
17
11
|
|
@@ -23,6 +17,13 @@ def exists(v):
|
|
23
17
|
def default(v, d):
|
24
18
|
return v if exists(v) else d
|
25
19
|
|
20
|
+
def log(t, eps = 1e-20):
|
21
|
+
return t.clamp(min = eps).log()
|
22
|
+
|
23
|
+
def calc_entropy_from_logits(logits):
|
24
|
+
prob = logits.softmax(dim = -1)
|
25
|
+
return -(prob * log(prob)).sum(dim = -1)
|
26
|
+
|
26
27
|
# entropy based tokenizer applied in byte-latent transformer paper
|
27
28
|
# they use a simple entropy threshold for segmenting a string into variable sized tokens
|
28
29
|
|
@@ -60,7 +61,7 @@ class EntropyBasedTokenizer(Module):
|
|
60
61
|
|
61
62
|
logits = self.decoder(seq, **decoder_forward_kwargs)
|
62
63
|
|
63
|
-
entropies =
|
64
|
+
entropies = calc_entropy_from_logits(logits)
|
64
65
|
|
65
66
|
# get length mask for boundaries
|
66
67
|
|
x_transformers/x_transformers.py
CHANGED
@@ -1196,6 +1196,7 @@ class FeedForward(Module):
|
|
1196
1196
|
glu_mult_bias = False,
|
1197
1197
|
swish = False,
|
1198
1198
|
relu_squared = False,
|
1199
|
+
custom_activation = None,
|
1199
1200
|
post_act_ln = False,
|
1200
1201
|
dropout = 0.,
|
1201
1202
|
no_bias = False,
|
@@ -1205,7 +1206,9 @@ class FeedForward(Module):
|
|
1205
1206
|
inner_dim = int(dim * mult)
|
1206
1207
|
dim_out = default(dim_out, dim)
|
1207
1208
|
|
1208
|
-
if
|
1209
|
+
if exists(custom_activation):
|
1210
|
+
activation = deepcopy(custom_activation)
|
1211
|
+
elif relu_squared:
|
1209
1212
|
activation = ReluSquared()
|
1210
1213
|
elif swish:
|
1211
1214
|
activation = nn.SiLU()
|
@@ -4,14 +4,14 @@ x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67K
|
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
|
-
x_transformers/entropy_based_tokenizer.py,sha256=
|
7
|
+
x_transformers/entropy_based_tokenizer.py,sha256=c9dcERmoK454mD3JKlxeU3cNwYVPi8I8bI91xrvVvOA,3938
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=m2xiiTafFZiII-QZLCpPerdWbY8O41I6BAYCaaPdXig,111953
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.2.
|
15
|
-
x_transformers-2.2.
|
16
|
-
x_transformers-2.2.
|
17
|
-
x_transformers-2.2.
|
14
|
+
x_transformers-2.2.8.dist-info/METADATA,sha256=U54PnKoqbbDeOAU2moEtVMOAejOvXnF6BDpakJTHud8,88686
|
15
|
+
x_transformers-2.2.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.2.8.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.2.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|