x-transformers 2.1.37__py3-none-any.whl → 2.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,3 +37,5 @@ from x_transformers.dpo import (
37
37
  from x_transformers.neo_mlp import (
38
38
  NeoMLP
39
39
  )
40
+
41
+ from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer
@@ -0,0 +1,91 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.nn import Module
4
+ from torch.nn.utils.rnn import pad_sequence
5
+
6
+ from x_transformers.x_transformers import Decoder, TransformerWrapper
7
+
8
+ from einops import repeat, rearrange
9
+
10
+ # helper functions
11
+
12
+ def exists(v):
13
+ return v is not None
14
+
15
+ def default(v, d):
16
+ return v if exists(v) else d
17
+
18
+ # entropy based tokenizer applied in byte-latent transformer paper
19
+ # they use a simple entropy threshold for segmenting a string into variable sized tokens
20
+
21
+ # https://arxiv.org/abs/2412.09871
22
+
23
+ class EntropyBasedTokenizer(Module):
24
+ def __init__(
25
+ self,
26
+ decoder: TransformerWrapper,
27
+ entropy_threshold: float
28
+ ):
29
+ super().__init__()
30
+ assert isinstance(decoder.attn_layers, Decoder)
31
+
32
+ self.decoder = decoder
33
+ self.entropy_threshold = entropy_threshold
34
+
35
+ @torch.no_grad()
36
+ def forward(
37
+ self,
38
+ seq,
39
+ return_segmented_seq = False
40
+ ):
41
+ self.decoder.eval()
42
+
43
+ batch, seq_len, device = *seq.shape, seq.device
44
+
45
+ _, intermediates = self.decoder(seq, return_logit_entropies = True)
46
+
47
+ entropies = intermediates.logit_entropies
48
+
49
+ over_thres_mask = entropies >= self.entropy_threshold
50
+
51
+ arange = torch.arange(seq_len, device = device) + 1
52
+ arange = repeat(arange, 'n -> b n', b = batch)
53
+
54
+ # get a tensor of Int['b num_tokens'] with the token lengths, zero padded
55
+
56
+ boundaries = over_thres_mask.clone()
57
+ boundaries[..., -1] = True # last token is always a boundary
58
+
59
+ num_tokens = boundaries.sum(dim = -1) # number of tokens
60
+
61
+ boundaries = arange[boundaries].split(num_tokens.tolist())
62
+
63
+ # get the token lengths
64
+
65
+ token_lengths = []
66
+
67
+ for one_boundary in boundaries:
68
+ padded_boundary = F.pad(one_boundary, (1, 0), value = 0.)
69
+ one_token_lengths = padded_boundary[1:] - padded_boundary[:-1]
70
+
71
+ token_lengths.append(one_token_lengths)
72
+
73
+ token_lengths = pad_sequence(token_lengths, batch_first = True)
74
+
75
+ # early return
76
+
77
+ if not return_segmented_seq:
78
+ return token_lengths
79
+
80
+ # segment the sequence based on the token lengths
81
+
82
+ segmented_seq = []
83
+
84
+ for one_seq, one_token_length in zip(seq, token_lengths):
85
+
86
+ one_token_length = one_token_length[one_token_length > 0]
87
+
88
+ splitted_seq = one_seq.split(one_token_length.tolist())
89
+ segmented_seq.append(splitted_seq)
90
+
91
+ return segmented_seq
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.37
3
+ Version: 2.2.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2464,4 +2464,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2464
2464
  }
2465
2465
  ```
2466
2466
 
2467
+ ```bibtex
2468
+ @article{Pagnoni2024ByteLT,
2469
+ title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
2470
+ author = {Artidoro Pagnoni and Ram Pasunuru and Pedro Rodriguez and John Nguyen and Benjamin Muller and Margaret Li and Chunting Zhou and Lili Yu and Jason Weston and Luke S. Zettlemoyer and Gargi Ghosh and Mike Lewis and Ari Holtzman and Srinivasan Iyer},
2471
+ journal = {ArXiv},
2472
+ year = {2024},
2473
+ volume = {abs/2412.09871},
2474
+ url = {https://api.semanticscholar.org/CorpusID:274762821}
2475
+ }
2476
+ ```
2477
+
2467
2478
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,16 +1,17 @@
1
- x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
1
+ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
+ x_transformers/entropy_based_tokenizer.py,sha256=2Foh9tBUL55Lu0CcgA8kWzyfHYI7DQHBF3zK_WdQt0o,2519
7
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
11
  x_transformers/x_transformers.py,sha256=twoqq2kfVWxntitHKLs2sxFMhK1CPLxGGBDAmkiHXcM,111812
11
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
13
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.37.dist-info/METADATA,sha256=uaCIy-GAGH4OPrYa0mxjJJ-FDtMlMuiIbg1sQPb3BRw,88161
14
- x_transformers-2.1.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.37.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.37.dist-info/RECORD,,
14
+ x_transformers-2.2.1.dist-info/METADATA,sha256=SEoe84lxCNvvJ8HVRqaLoESrTYRgQ9dc6Y-Po8nhxbg,88686
15
+ x_transformers-2.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.2.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.2.1.dist-info/RECORD,,