x-transformers 2.4.0__py3-none-any.whl → 2.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,225 @@
1
+ # https://arxiv.org/abs/2506.20057
2
+ # Peter Bloem
3
+
4
+ from __future__ import annotations
5
+ from functools import partial
6
+ from random import randrange, uniform
7
+
8
+ import torch
9
+ from torch import nn, cat, randperm
10
+ from torch.nn import LSTM, Module
11
+
12
+ from x_transformers.x_transformers import (
13
+ TransformerWrapper,
14
+ AutoregressiveWrapper
15
+ )
16
+
17
+ # functions
18
+
19
+ def exists(v):
20
+ return v is not None
21
+
22
+ def default(v, d):
23
+ return v if exists(v) else d
24
+
25
+ def divisible_by(num, den):
26
+ return (num % den) == 0
27
+
28
+ # random sequences, mixture of random and constant (unsure why constant is needed)
29
+
30
+ def random_sequences(
31
+ num_tokens,
32
+ seq_len,
33
+ num_samples_random,
34
+ num_samples_constant,
35
+ shuffle = True,
36
+ device = None
37
+ ):
38
+ assert num_samples_random > 0 or num_samples_constant > 0
39
+
40
+ rand_seq = torch.randint(0, num_tokens, (num_samples_random, seq_len))
41
+ const_seq = torch.full((num_samples_constant, seq_len), randrange(num_tokens))
42
+
43
+ all_seq = cat((rand_seq, const_seq))
44
+
45
+ if exists(device):
46
+ all_seq = all_seq.to(device)
47
+
48
+ if not shuffle:
49
+ return all_seq
50
+
51
+ # shuffle with randperm
52
+
53
+ rand_indices = randperm(all_seq.shape[0])
54
+ return all_seq[rand_indices]
55
+
56
+ # synthetic data generator
57
+
58
+ class SyntheticDataGenerator(Module):
59
+ def __init__(
60
+ self,
61
+ dim,
62
+ num_tokens,
63
+ max_seq_len = 512,
64
+ hidden_size = None
65
+ ):
66
+ super().__init__()
67
+
68
+ self.max_seq_len = max_seq_len
69
+
70
+ self.embed = nn.Embedding(num_tokens, dim)
71
+
72
+ hidden_size = default(hidden_size, dim)
73
+ self.lstm = LSTM(dim, hidden_size, batch_first = True)
74
+
75
+ self.to_logits = nn.Linear(dim, num_tokens, bias = False)
76
+
77
+ self.apply(self.init_)
78
+
79
+ @torch.no_grad()
80
+ def init_(self, m):
81
+ if isinstance(m, nn.Linear):
82
+ m.weight *= uniform(0., 1.1) # he scales the lstm weights from 0 to 1.1
83
+
84
+ @torch.inference_mode()
85
+ @torch.compile
86
+ def generate(
87
+ self,
88
+ length,
89
+ seed = None,
90
+ condition = None,
91
+ temperature = 1e-4 # he uses a near greedy temperature
92
+ ):
93
+ assert exists(seed) or exists(condition)
94
+ prefix = [*filter(exists, (seed, condition))]
95
+ seq_len = self.max_seq_len
96
+
97
+ seq = torch.cat(prefix, dim = -1)
98
+
99
+ net_input = seq
100
+ hiddens = None
101
+
102
+ for _ in range(length):
103
+
104
+ logits, hiddens = self.forward(net_input, hiddens)
105
+
106
+ last_logit = logits[:, -1]
107
+ prob = (last_logit / temperature).softmax(dim = -1)
108
+
109
+ sampled = torch.multinomial(prob, 1)
110
+ net_input = sampled
111
+
112
+ seq = torch.cat((seq, sampled), dim = -1)
113
+
114
+ return seq[:, -seq_len:]
115
+
116
+ def forward(
117
+ self,
118
+ input,
119
+ hiddens = None
120
+ ):
121
+
122
+ tokens = self.embed(input)
123
+
124
+ embed, hidden = self.lstm(tokens, hiddens)
125
+
126
+ logits = self.to_logits(embed)
127
+
128
+ return logits, hidden
129
+
130
+ # classes
131
+
132
+ class UniversalPretrainWrapper(Module):
133
+ def __init__(
134
+ self,
135
+ model: TransformerWrapper,
136
+ data_generator: SyntheticDataGenerator | None = None,
137
+ buffer_size = None,
138
+ num_reset = 20,
139
+ batch_size = 32,
140
+ seq_len = 512,
141
+ seed_length = 8
142
+ ):
143
+ super().__init__()
144
+
145
+ self.model = model
146
+ self.ar_wrapped = AutoregressiveWrapper(model)
147
+
148
+ assert model.attn_layers.causal
149
+
150
+ num_tokens = model.num_tokens
151
+ dim = model.attn_layers.dim
152
+
153
+ if not exists(data_generator):
154
+ data_generator = SyntheticDataGenerator(
155
+ num_tokens = num_tokens,
156
+ dim = dim
157
+ )
158
+
159
+ self.seq_len = seq_len
160
+ self.data_generator = data_generator
161
+
162
+ self.seed_length = seed_length
163
+ self.batch_size = batch_size
164
+
165
+ buffer_size = default(buffer_size, batch_size * 20)
166
+ assert buffer_size > batch_size, f'data buffer size must be greater than batch size'
167
+
168
+ assert divisible_by(num_reset, 2)
169
+ self.num_reset = num_reset
170
+
171
+ self.buffer_size = buffer_size
172
+
173
+ self.random_sequences_fn = partial(random_sequences, num_tokens, seq_len)
174
+
175
+ init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
176
+
177
+ self.register_buffer('synth_data_buffer', init_data_buffer)
178
+
179
+ @property
180
+ def device(self):
181
+ return self.synth_data_buffer.device
182
+
183
+ def get_rand_sequences_from_buffer(self, size = None):
184
+ size = default(size, self.batch_size)
185
+ rand_indices = randperm(self.buffer_size, device = self.device)[:size]
186
+ return self.synth_data_buffer[rand_indices]
187
+
188
+ def forward(self):
189
+ # following algorithm 1.
190
+
191
+ conditions = self.get_rand_sequences_from_buffer()
192
+
193
+ # get seeds, which appears to be random sequences with random crops of seed length
194
+
195
+ seeds = self.get_rand_sequences_from_buffer()
196
+
197
+ seq_arange = torch.arange(self.seed_length)
198
+ rand_offset = torch.randint(0, self.seq_len - self.seed_length, (self.batch_size,))
199
+ seq_start_pos = rand_offset[:, None] + seq_arange
200
+
201
+ batch_arange = torch.arange(self.batch_size, device = self.device)[:, None]
202
+ seeds = seeds[batch_arange, seq_start_pos]
203
+
204
+ # seed, condition to turing machine
205
+
206
+ synthetic_data = self.data_generator.generate(
207
+ self.seq_len,
208
+ condition = conditions,
209
+ seed = seeds
210
+ )
211
+
212
+ # reset
213
+
214
+ if self.num_reset > 0:
215
+ buffer_to_reset = self.get_rand_sequences_from_buffer(self.num_reset)
216
+
217
+ with torch.no_grad():
218
+ reset_sequences = self.random_sequences_fn(self.num_reset // 2, self.num_reset // 2, device = self.device)
219
+ buffer_to_reset.copy_(reset_sequences)
220
+
221
+ # sample yet again according to pseudocode
222
+
223
+ data = self.get_rand_sequences_from_buffer()
224
+
225
+ return self.ar_wrapped(data)
@@ -3263,7 +3263,7 @@ class TransformerWrapper(Module):
3263
3263
 
3264
3264
  # attention pool
3265
3265
 
3266
- if exists(self.attn_pool):
3266
+ if exists(self.attn_pool) and return_intermediates:
3267
3267
  queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
3268
3268
 
3269
3269
  attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.0
3
+ Version: 2.4.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2495,4 +2495,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2495
2495
  }
2496
2496
  ```
2497
2497
 
2498
+ ```bibtex
2499
+ @misc{bloem2025universalpretrainingiteratedrandom,
2500
+ title = {Universal pre-training by iterated random computation},
2501
+ author = {Peter Bloem},
2502
+ year = {2025},
2503
+ eprint = {2506.20057},
2504
+ archivePrefix = {arXiv},
2505
+ primaryClass = {cs.LG},
2506
+ url = {https://arxiv.org/abs/2506.20057},
2507
+ }
2508
+ ```
2509
+
2498
2510
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -8,10 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=IelVhLUuDmRnv6zXlQNvwUluW2RqVQQE2vYKCqctJyY,117583
11
+ x_transformers/up_wrapper.py,sha256=8mHA9_U5cTnGNp9Owtr__qQkN9kNsOKQlz6qHHztIdk,5929
12
+ x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
12
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.4.0.dist-info/METADATA,sha256=RyKkjmTnjbGUHA4EL-znJCPR17VF6i7ebvvgMKpTXVY,89896
15
- x_transformers-2.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.4.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.4.0.dist-info/RECORD,,
15
+ x_transformers-2.4.2.dist-info/METADATA,sha256=sr98RaCqCx78Ppt-XTsY2W-FMhNepFeFs8dgDGgbXs4,90223
16
+ x_transformers-2.4.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.4.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.4.2.dist-info/RECORD,,