x-transformers 2.1.36__py3-none-any.whl → 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +2 -0
- x_transformers/entropy_based_tokenizer.py +91 -0
- x_transformers/x_transformers.py +5 -2
- {x_transformers-2.1.36.dist-info → x_transformers-2.2.0.dist-info}/METADATA +12 -1
- {x_transformers-2.1.36.dist-info → x_transformers-2.2.0.dist-info}/RECORD +7 -6
- {x_transformers-2.1.36.dist-info → x_transformers-2.2.0.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.36.dist-info → x_transformers-2.2.0.dist-info}/licenses/LICENSE +0 -0
x_transformers/__init__.py
CHANGED
@@ -0,0 +1,91 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
from torch.nn import Module
|
4
|
+
from torch.nn.utils.rnn import pad_sequence
|
5
|
+
|
6
|
+
from x_transformers.x_transformers import Decoder, TransformerWrapper
|
7
|
+
|
8
|
+
from einops import repeat, rearrange
|
9
|
+
|
10
|
+
# helper functions
|
11
|
+
|
12
|
+
def exists(v):
|
13
|
+
return v is not None
|
14
|
+
|
15
|
+
def default(v, d):
|
16
|
+
return v if exists(v) else d
|
17
|
+
|
18
|
+
# entropy based tokenizer applied in byte-latent transformer paper
|
19
|
+
# they use a simple entropy threshold for segmenting a string into variable sized tokens
|
20
|
+
|
21
|
+
# https://arxiv.org/abs/2412.09871
|
22
|
+
|
23
|
+
class EntropyBasedTokenizer(Module):
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
decoder: TransformerWrapper,
|
27
|
+
entropy_threshold = 1.5
|
28
|
+
):
|
29
|
+
super().__init__()
|
30
|
+
assert isinstance(decoder.attn_layers, Decoder)
|
31
|
+
|
32
|
+
self.decoder = decoder
|
33
|
+
self.entropy_threshold = entropy_threshold
|
34
|
+
|
35
|
+
@torch.no_grad()
|
36
|
+
def forward(
|
37
|
+
self,
|
38
|
+
seq,
|
39
|
+
return_segmented_seq = False
|
40
|
+
):
|
41
|
+
self.decoder.eval()
|
42
|
+
|
43
|
+
batch, seq_len, device = *seq.shape, seq.device
|
44
|
+
|
45
|
+
_, intermediates = self.decoder(seq, return_logit_entropies = True)
|
46
|
+
|
47
|
+
entropies = intermediates.logit_entropies
|
48
|
+
|
49
|
+
over_thres_mask = entropies >= self.entropy_threshold
|
50
|
+
|
51
|
+
arange = torch.arange(seq_len, device = device) + 1
|
52
|
+
arange = repeat(arange, 'n -> b n', b = batch)
|
53
|
+
|
54
|
+
# get a tensor of Int['b num_tokens'] with the token lengths, zero padded
|
55
|
+
|
56
|
+
boundaries = over_thres_mask.clone()
|
57
|
+
boundaries[..., -1] = True # last token is always a boundary
|
58
|
+
|
59
|
+
num_tokens = boundaries.sum(dim = -1) # number of tokens
|
60
|
+
|
61
|
+
boundaries = arange[boundaries].split(num_tokens.tolist())
|
62
|
+
|
63
|
+
# get the token lengths
|
64
|
+
|
65
|
+
token_lengths = []
|
66
|
+
|
67
|
+
for one_boundary in boundaries:
|
68
|
+
padded_boundary = F.pad(one_boundary, (1, 0), value = 0.)
|
69
|
+
one_token_lengths = padded_boundary[1:] - padded_boundary[:-1]
|
70
|
+
|
71
|
+
token_lengths.append(one_token_lengths)
|
72
|
+
|
73
|
+
token_lengths = pad_sequence(token_lengths, batch_first = True)
|
74
|
+
|
75
|
+
# early return
|
76
|
+
|
77
|
+
if not return_segmented_seq:
|
78
|
+
return token_lengths
|
79
|
+
|
80
|
+
# segment the sequence based on the token lengths
|
81
|
+
|
82
|
+
segmented_seq = []
|
83
|
+
|
84
|
+
for one_seq, one_token_length in zip(seq, token_lengths):
|
85
|
+
|
86
|
+
one_token_length = one_token_length[one_token_length > 0]
|
87
|
+
|
88
|
+
splitted_seq = one_seq.split(one_token_length.tolist())
|
89
|
+
segmented_seq.append(splitted_seq)
|
90
|
+
|
91
|
+
return segmented_seq
|
x_transformers/x_transformers.py
CHANGED
@@ -2909,6 +2909,7 @@ class TransformerWrapper(Module):
|
|
2909
2909
|
return_embeddings = False,
|
2910
2910
|
return_logits_and_embeddings = False,
|
2911
2911
|
return_intermediates = False,
|
2912
|
+
return_embeddings_and_intermediates = False,
|
2912
2913
|
return_logit_entropies = False,
|
2913
2914
|
mask = None,
|
2914
2915
|
return_mems = False,
|
@@ -2940,8 +2941,8 @@ class TransformerWrapper(Module):
|
|
2940
2941
|
|
2941
2942
|
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
2942
2943
|
|
2943
|
-
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
2944
|
-
return_embeddings = return_embeddings | (not exists(self.to_logits))
|
2944
|
+
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
|
2945
|
+
return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
|
2945
2946
|
|
2946
2947
|
# absolute positional embedding
|
2947
2948
|
|
@@ -3131,6 +3132,8 @@ class TransformerWrapper(Module):
|
|
3131
3132
|
|
3132
3133
|
if return_logits_and_embeddings:
|
3133
3134
|
out = (logits, x)
|
3135
|
+
elif return_embeddings_and_intermediates:
|
3136
|
+
out = (x, intermediates)
|
3134
3137
|
elif return_embeddings:
|
3135
3138
|
out = x
|
3136
3139
|
else:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.2.0
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2464,4 +2464,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2464
2464
|
}
|
2465
2465
|
```
|
2466
2466
|
|
2467
|
+
```bibtex
|
2468
|
+
@article{Pagnoni2024ByteLT,
|
2469
|
+
title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
|
2470
|
+
author = {Artidoro Pagnoni and Ram Pasunuru and Pedro Rodriguez and John Nguyen and Benjamin Muller and Margaret Li and Chunting Zhou and Lili Yu and Jason Weston and Luke S. Zettlemoyer and Gargi Ghosh and Mike Lewis and Ari Holtzman and Srinivasan Iyer},
|
2471
|
+
journal = {ArXiv},
|
2472
|
+
year = {2024},
|
2473
|
+
volume = {abs/2412.09871},
|
2474
|
+
url = {https://api.semanticscholar.org/CorpusID:274762821}
|
2475
|
+
}
|
2476
|
+
```
|
2477
|
+
|
2467
2478
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -1,16 +1,17 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=
|
1
|
+
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
|
+
x_transformers/entropy_based_tokenizer.py,sha256=s56Mfok-ulYTKiYmn06feB0QU91y4SQl1Pgj7W5EO3o,2518
|
7
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
8
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
9
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
10
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=twoqq2kfVWxntitHKLs2sxFMhK1CPLxGGBDAmkiHXcM,111812
|
11
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.
|
14
|
-
x_transformers-2.
|
15
|
-
x_transformers-2.
|
16
|
-
x_transformers-2.
|
14
|
+
x_transformers-2.2.0.dist-info/METADATA,sha256=oY5xaR9Xw3prT7RPBw3urS2PKWb7rJADn6SZzRW5Tnw,88686
|
15
|
+
x_transformers-2.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.2.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.2.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|