x-transformers 2.1.37__py3-none-any.whl → 2.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +2 -0
- x_transformers/entropy_based_tokenizer.py +91 -0
- {x_transformers-2.1.37.dist-info → x_transformers-2.2.1.dist-info}/METADATA +12 -1
- {x_transformers-2.1.37.dist-info → x_transformers-2.2.1.dist-info}/RECORD +6 -5
- {x_transformers-2.1.37.dist-info → x_transformers-2.2.1.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.37.dist-info → x_transformers-2.2.1.dist-info}/licenses/LICENSE +0 -0
x_transformers/__init__.py
CHANGED
@@ -0,0 +1,91 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
from torch.nn import Module
|
4
|
+
from torch.nn.utils.rnn import pad_sequence
|
5
|
+
|
6
|
+
from x_transformers.x_transformers import Decoder, TransformerWrapper
|
7
|
+
|
8
|
+
from einops import repeat, rearrange
|
9
|
+
|
10
|
+
# helper functions
|
11
|
+
|
12
|
+
def exists(v):
|
13
|
+
return v is not None
|
14
|
+
|
15
|
+
def default(v, d):
|
16
|
+
return v if exists(v) else d
|
17
|
+
|
18
|
+
# entropy based tokenizer applied in byte-latent transformer paper
|
19
|
+
# they use a simple entropy threshold for segmenting a string into variable sized tokens
|
20
|
+
|
21
|
+
# https://arxiv.org/abs/2412.09871
|
22
|
+
|
23
|
+
class EntropyBasedTokenizer(Module):
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
decoder: TransformerWrapper,
|
27
|
+
entropy_threshold: float
|
28
|
+
):
|
29
|
+
super().__init__()
|
30
|
+
assert isinstance(decoder.attn_layers, Decoder)
|
31
|
+
|
32
|
+
self.decoder = decoder
|
33
|
+
self.entropy_threshold = entropy_threshold
|
34
|
+
|
35
|
+
@torch.no_grad()
|
36
|
+
def forward(
|
37
|
+
self,
|
38
|
+
seq,
|
39
|
+
return_segmented_seq = False
|
40
|
+
):
|
41
|
+
self.decoder.eval()
|
42
|
+
|
43
|
+
batch, seq_len, device = *seq.shape, seq.device
|
44
|
+
|
45
|
+
_, intermediates = self.decoder(seq, return_logit_entropies = True)
|
46
|
+
|
47
|
+
entropies = intermediates.logit_entropies
|
48
|
+
|
49
|
+
over_thres_mask = entropies >= self.entropy_threshold
|
50
|
+
|
51
|
+
arange = torch.arange(seq_len, device = device) + 1
|
52
|
+
arange = repeat(arange, 'n -> b n', b = batch)
|
53
|
+
|
54
|
+
# get a tensor of Int['b num_tokens'] with the token lengths, zero padded
|
55
|
+
|
56
|
+
boundaries = over_thres_mask.clone()
|
57
|
+
boundaries[..., -1] = True # last token is always a boundary
|
58
|
+
|
59
|
+
num_tokens = boundaries.sum(dim = -1) # number of tokens
|
60
|
+
|
61
|
+
boundaries = arange[boundaries].split(num_tokens.tolist())
|
62
|
+
|
63
|
+
# get the token lengths
|
64
|
+
|
65
|
+
token_lengths = []
|
66
|
+
|
67
|
+
for one_boundary in boundaries:
|
68
|
+
padded_boundary = F.pad(one_boundary, (1, 0), value = 0.)
|
69
|
+
one_token_lengths = padded_boundary[1:] - padded_boundary[:-1]
|
70
|
+
|
71
|
+
token_lengths.append(one_token_lengths)
|
72
|
+
|
73
|
+
token_lengths = pad_sequence(token_lengths, batch_first = True)
|
74
|
+
|
75
|
+
# early return
|
76
|
+
|
77
|
+
if not return_segmented_seq:
|
78
|
+
return token_lengths
|
79
|
+
|
80
|
+
# segment the sequence based on the token lengths
|
81
|
+
|
82
|
+
segmented_seq = []
|
83
|
+
|
84
|
+
for one_seq, one_token_length in zip(seq, token_lengths):
|
85
|
+
|
86
|
+
one_token_length = one_token_length[one_token_length > 0]
|
87
|
+
|
88
|
+
splitted_seq = one_seq.split(one_token_length.tolist())
|
89
|
+
segmented_seq.append(splitted_seq)
|
90
|
+
|
91
|
+
return segmented_seq
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.1
|
3
|
+
Version: 2.2.1
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2464,4 +2464,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2464
2464
|
}
|
2465
2465
|
```
|
2466
2466
|
|
2467
|
+
```bibtex
|
2468
|
+
@article{Pagnoni2024ByteLT,
|
2469
|
+
title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
|
2470
|
+
author = {Artidoro Pagnoni and Ram Pasunuru and Pedro Rodriguez and John Nguyen and Benjamin Muller and Margaret Li and Chunting Zhou and Lili Yu and Jason Weston and Luke S. Zettlemoyer and Gargi Ghosh and Mike Lewis and Ari Holtzman and Srinivasan Iyer},
|
2471
|
+
journal = {ArXiv},
|
2472
|
+
year = {2024},
|
2473
|
+
volume = {abs/2412.09871},
|
2474
|
+
url = {https://api.semanticscholar.org/CorpusID:274762821}
|
2475
|
+
}
|
2476
|
+
```
|
2477
|
+
|
2467
2478
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -1,16 +1,17 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=
|
1
|
+
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
|
+
x_transformers/entropy_based_tokenizer.py,sha256=2Foh9tBUL55Lu0CcgA8kWzyfHYI7DQHBF3zK_WdQt0o,2519
|
7
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
8
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
9
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
10
11
|
x_transformers/x_transformers.py,sha256=twoqq2kfVWxntitHKLs2sxFMhK1CPLxGGBDAmkiHXcM,111812
|
11
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
14
|
+
x_transformers-2.2.1.dist-info/METADATA,sha256=SEoe84lxCNvvJ8HVRqaLoESrTYRgQ9dc6Y-Po8nhxbg,88686
|
15
|
+
x_transformers-2.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.2.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|