x-transformers 2.1.36__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,3 +37,5 @@ from x_transformers.dpo import (
37
37
  from x_transformers.neo_mlp import (
38
38
  NeoMLP
39
39
  )
40
+
41
+ from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer
@@ -0,0 +1,91 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.nn import Module
4
+ from torch.nn.utils.rnn import pad_sequence
5
+
6
+ from x_transformers.x_transformers import Decoder, TransformerWrapper
7
+
8
+ from einops import repeat, rearrange
9
+
10
+ # helper functions
11
+
12
+ def exists(v):
13
+ return v is not None
14
+
15
+ def default(v, d):
16
+ return v if exists(v) else d
17
+
18
+ # entropy based tokenizer applied in byte-latent transformer paper
19
+ # they use a simple entropy threshold for segmenting a string into variable sized tokens
20
+
21
+ # https://arxiv.org/abs/2412.09871
22
+
23
+ class EntropyBasedTokenizer(Module):
24
+ def __init__(
25
+ self,
26
+ decoder: TransformerWrapper,
27
+ entropy_threshold = 1.5
28
+ ):
29
+ super().__init__()
30
+ assert isinstance(decoder.attn_layers, Decoder)
31
+
32
+ self.decoder = decoder
33
+ self.entropy_threshold = entropy_threshold
34
+
35
+ @torch.no_grad()
36
+ def forward(
37
+ self,
38
+ seq,
39
+ return_segmented_seq = False
40
+ ):
41
+ self.decoder.eval()
42
+
43
+ batch, seq_len, device = *seq.shape, seq.device
44
+
45
+ _, intermediates = self.decoder(seq, return_logit_entropies = True)
46
+
47
+ entropies = intermediates.logit_entropies
48
+
49
+ over_thres_mask = entropies >= self.entropy_threshold
50
+
51
+ arange = torch.arange(seq_len, device = device) + 1
52
+ arange = repeat(arange, 'n -> b n', b = batch)
53
+
54
+ # get a tensor of Int['b num_tokens'] with the token lengths, zero padded
55
+
56
+ boundaries = over_thres_mask.clone()
57
+ boundaries[..., -1] = True # last token is always a boundary
58
+
59
+ num_tokens = boundaries.sum(dim = -1) # number of tokens
60
+
61
+ boundaries = arange[boundaries].split(num_tokens.tolist())
62
+
63
+ # get the token lengths
64
+
65
+ token_lengths = []
66
+
67
+ for one_boundary in boundaries:
68
+ padded_boundary = F.pad(one_boundary, (1, 0), value = 0.)
69
+ one_token_lengths = padded_boundary[1:] - padded_boundary[:-1]
70
+
71
+ token_lengths.append(one_token_lengths)
72
+
73
+ token_lengths = pad_sequence(token_lengths, batch_first = True)
74
+
75
+ # early return
76
+
77
+ if not return_segmented_seq:
78
+ return token_lengths
79
+
80
+ # segment the sequence based on the token lengths
81
+
82
+ segmented_seq = []
83
+
84
+ for one_seq, one_token_length in zip(seq, token_lengths):
85
+
86
+ one_token_length = one_token_length[one_token_length > 0]
87
+
88
+ splitted_seq = one_seq.split(one_token_length.tolist())
89
+ segmented_seq.append(splitted_seq)
90
+
91
+ return segmented_seq
@@ -2909,6 +2909,7 @@ class TransformerWrapper(Module):
2909
2909
  return_embeddings = False,
2910
2910
  return_logits_and_embeddings = False,
2911
2911
  return_intermediates = False,
2912
+ return_embeddings_and_intermediates = False,
2912
2913
  return_logit_entropies = False,
2913
2914
  mask = None,
2914
2915
  return_mems = False,
@@ -2940,8 +2941,8 @@ class TransformerWrapper(Module):
2940
2941
 
2941
2942
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2942
2943
 
2943
- return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2944
- return_embeddings = return_embeddings | (not exists(self.to_logits))
2944
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
2945
+ return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
2945
2946
 
2946
2947
  # absolute positional embedding
2947
2948
 
@@ -3131,6 +3132,8 @@ class TransformerWrapper(Module):
3131
3132
 
3132
3133
  if return_logits_and_embeddings:
3133
3134
  out = (logits, x)
3135
+ elif return_embeddings_and_intermediates:
3136
+ out = (x, intermediates)
3134
3137
  elif return_embeddings:
3135
3138
  out = x
3136
3139
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.36
3
+ Version: 2.2.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2464,4 +2464,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2464
2464
  }
2465
2465
  ```
2466
2466
 
2467
+ ```bibtex
2468
+ @article{Pagnoni2024ByteLT,
2469
+ title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
2470
+ author = {Artidoro Pagnoni and Ram Pasunuru and Pedro Rodriguez and John Nguyen and Benjamin Muller and Margaret Li and Chunting Zhou and Lili Yu and Jason Weston and Luke S. Zettlemoyer and Gargi Ghosh and Mike Lewis and Ari Holtzman and Srinivasan Iyer},
2471
+ journal = {ArXiv},
2472
+ year = {2024},
2473
+ volume = {abs/2412.09871},
2474
+ url = {https://api.semanticscholar.org/CorpusID:274762821}
2475
+ }
2476
+ ```
2477
+
2467
2478
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,16 +1,17 @@
1
- x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
1
+ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
+ x_transformers/entropy_based_tokenizer.py,sha256=s56Mfok-ulYTKiYmn06feB0QU91y4SQl1Pgj7W5EO3o,2518
7
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
- x_transformers/x_transformers.py,sha256=voN-uEBEKxpUu9K4MVcneSTrzdgJWnZGuQ1QRZQw4Q4,111596
11
+ x_transformers/x_transformers.py,sha256=twoqq2kfVWxntitHKLs2sxFMhK1CPLxGGBDAmkiHXcM,111812
11
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
13
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.36.dist-info/METADATA,sha256=D0qdMRucK3PWwEi8WwdiJdZ8X_hGTm1r3_7bJzYiWSM,88161
14
- x_transformers-2.1.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.36.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.36.dist-info/RECORD,,
14
+ x_transformers-2.2.0.dist-info/METADATA,sha256=oY5xaR9Xw3prT7RPBw3urS2PKWb7rJADn6SZzRW5Tnw,88686
15
+ x_transformers-2.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.2.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.2.0.dist-info/RECORD,,