x-transformers 2.4.0__py3-none-any.whl → 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/up_wrapper.py +225 -0
- x_transformers/x_transformers.py +1 -1
- {x_transformers-2.4.0.dist-info → x_transformers-2.4.2.dist-info}/METADATA +13 -1
- {x_transformers-2.4.0.dist-info → x_transformers-2.4.2.dist-info}/RECORD +6 -5
- {x_transformers-2.4.0.dist-info → x_transformers-2.4.2.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.0.dist-info → x_transformers-2.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,225 @@
|
|
1
|
+
# https://arxiv.org/abs/2506.20057
|
2
|
+
# Peter Bloem
|
3
|
+
|
4
|
+
from __future__ import annotations
|
5
|
+
from functools import partial
|
6
|
+
from random import randrange, uniform
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from torch import nn, cat, randperm
|
10
|
+
from torch.nn import LSTM, Module
|
11
|
+
|
12
|
+
from x_transformers.x_transformers import (
|
13
|
+
TransformerWrapper,
|
14
|
+
AutoregressiveWrapper
|
15
|
+
)
|
16
|
+
|
17
|
+
# functions
|
18
|
+
|
19
|
+
def exists(v):
|
20
|
+
return v is not None
|
21
|
+
|
22
|
+
def default(v, d):
|
23
|
+
return v if exists(v) else d
|
24
|
+
|
25
|
+
def divisible_by(num, den):
|
26
|
+
return (num % den) == 0
|
27
|
+
|
28
|
+
# random sequences, mixture of random and constant (unsure why constant is needed)
|
29
|
+
|
30
|
+
def random_sequences(
|
31
|
+
num_tokens,
|
32
|
+
seq_len,
|
33
|
+
num_samples_random,
|
34
|
+
num_samples_constant,
|
35
|
+
shuffle = True,
|
36
|
+
device = None
|
37
|
+
):
|
38
|
+
assert num_samples_random > 0 or num_samples_constant > 0
|
39
|
+
|
40
|
+
rand_seq = torch.randint(0, num_tokens, (num_samples_random, seq_len))
|
41
|
+
const_seq = torch.full((num_samples_constant, seq_len), randrange(num_tokens))
|
42
|
+
|
43
|
+
all_seq = cat((rand_seq, const_seq))
|
44
|
+
|
45
|
+
if exists(device):
|
46
|
+
all_seq = all_seq.to(device)
|
47
|
+
|
48
|
+
if not shuffle:
|
49
|
+
return all_seq
|
50
|
+
|
51
|
+
# shuffle with randperm
|
52
|
+
|
53
|
+
rand_indices = randperm(all_seq.shape[0])
|
54
|
+
return all_seq[rand_indices]
|
55
|
+
|
56
|
+
# synthetic data generator
|
57
|
+
|
58
|
+
class SyntheticDataGenerator(Module):
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
dim,
|
62
|
+
num_tokens,
|
63
|
+
max_seq_len = 512,
|
64
|
+
hidden_size = None
|
65
|
+
):
|
66
|
+
super().__init__()
|
67
|
+
|
68
|
+
self.max_seq_len = max_seq_len
|
69
|
+
|
70
|
+
self.embed = nn.Embedding(num_tokens, dim)
|
71
|
+
|
72
|
+
hidden_size = default(hidden_size, dim)
|
73
|
+
self.lstm = LSTM(dim, hidden_size, batch_first = True)
|
74
|
+
|
75
|
+
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
|
76
|
+
|
77
|
+
self.apply(self.init_)
|
78
|
+
|
79
|
+
@torch.no_grad()
|
80
|
+
def init_(self, m):
|
81
|
+
if isinstance(m, nn.Linear):
|
82
|
+
m.weight *= uniform(0., 1.1) # he scales the lstm weights from 0 to 1.1
|
83
|
+
|
84
|
+
@torch.inference_mode()
|
85
|
+
@torch.compile
|
86
|
+
def generate(
|
87
|
+
self,
|
88
|
+
length,
|
89
|
+
seed = None,
|
90
|
+
condition = None,
|
91
|
+
temperature = 1e-4 # he uses a near greedy temperature
|
92
|
+
):
|
93
|
+
assert exists(seed) or exists(condition)
|
94
|
+
prefix = [*filter(exists, (seed, condition))]
|
95
|
+
seq_len = self.max_seq_len
|
96
|
+
|
97
|
+
seq = torch.cat(prefix, dim = -1)
|
98
|
+
|
99
|
+
net_input = seq
|
100
|
+
hiddens = None
|
101
|
+
|
102
|
+
for _ in range(length):
|
103
|
+
|
104
|
+
logits, hiddens = self.forward(net_input, hiddens)
|
105
|
+
|
106
|
+
last_logit = logits[:, -1]
|
107
|
+
prob = (last_logit / temperature).softmax(dim = -1)
|
108
|
+
|
109
|
+
sampled = torch.multinomial(prob, 1)
|
110
|
+
net_input = sampled
|
111
|
+
|
112
|
+
seq = torch.cat((seq, sampled), dim = -1)
|
113
|
+
|
114
|
+
return seq[:, -seq_len:]
|
115
|
+
|
116
|
+
def forward(
|
117
|
+
self,
|
118
|
+
input,
|
119
|
+
hiddens = None
|
120
|
+
):
|
121
|
+
|
122
|
+
tokens = self.embed(input)
|
123
|
+
|
124
|
+
embed, hidden = self.lstm(tokens, hiddens)
|
125
|
+
|
126
|
+
logits = self.to_logits(embed)
|
127
|
+
|
128
|
+
return logits, hidden
|
129
|
+
|
130
|
+
# classes
|
131
|
+
|
132
|
+
class UniversalPretrainWrapper(Module):
|
133
|
+
def __init__(
|
134
|
+
self,
|
135
|
+
model: TransformerWrapper,
|
136
|
+
data_generator: SyntheticDataGenerator | None = None,
|
137
|
+
buffer_size = None,
|
138
|
+
num_reset = 20,
|
139
|
+
batch_size = 32,
|
140
|
+
seq_len = 512,
|
141
|
+
seed_length = 8
|
142
|
+
):
|
143
|
+
super().__init__()
|
144
|
+
|
145
|
+
self.model = model
|
146
|
+
self.ar_wrapped = AutoregressiveWrapper(model)
|
147
|
+
|
148
|
+
assert model.attn_layers.causal
|
149
|
+
|
150
|
+
num_tokens = model.num_tokens
|
151
|
+
dim = model.attn_layers.dim
|
152
|
+
|
153
|
+
if not exists(data_generator):
|
154
|
+
data_generator = SyntheticDataGenerator(
|
155
|
+
num_tokens = num_tokens,
|
156
|
+
dim = dim
|
157
|
+
)
|
158
|
+
|
159
|
+
self.seq_len = seq_len
|
160
|
+
self.data_generator = data_generator
|
161
|
+
|
162
|
+
self.seed_length = seed_length
|
163
|
+
self.batch_size = batch_size
|
164
|
+
|
165
|
+
buffer_size = default(buffer_size, batch_size * 20)
|
166
|
+
assert buffer_size > batch_size, f'data buffer size must be greater than batch size'
|
167
|
+
|
168
|
+
assert divisible_by(num_reset, 2)
|
169
|
+
self.num_reset = num_reset
|
170
|
+
|
171
|
+
self.buffer_size = buffer_size
|
172
|
+
|
173
|
+
self.random_sequences_fn = partial(random_sequences, num_tokens, seq_len)
|
174
|
+
|
175
|
+
init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
|
176
|
+
|
177
|
+
self.register_buffer('synth_data_buffer', init_data_buffer)
|
178
|
+
|
179
|
+
@property
|
180
|
+
def device(self):
|
181
|
+
return self.synth_data_buffer.device
|
182
|
+
|
183
|
+
def get_rand_sequences_from_buffer(self, size = None):
|
184
|
+
size = default(size, self.batch_size)
|
185
|
+
rand_indices = randperm(self.buffer_size, device = self.device)[:size]
|
186
|
+
return self.synth_data_buffer[rand_indices]
|
187
|
+
|
188
|
+
def forward(self):
|
189
|
+
# following algorithm 1.
|
190
|
+
|
191
|
+
conditions = self.get_rand_sequences_from_buffer()
|
192
|
+
|
193
|
+
# get seeds, which appears to be random sequences with random crops of seed length
|
194
|
+
|
195
|
+
seeds = self.get_rand_sequences_from_buffer()
|
196
|
+
|
197
|
+
seq_arange = torch.arange(self.seed_length)
|
198
|
+
rand_offset = torch.randint(0, self.seq_len - self.seed_length, (self.batch_size,))
|
199
|
+
seq_start_pos = rand_offset[:, None] + seq_arange
|
200
|
+
|
201
|
+
batch_arange = torch.arange(self.batch_size, device = self.device)[:, None]
|
202
|
+
seeds = seeds[batch_arange, seq_start_pos]
|
203
|
+
|
204
|
+
# seed, condition to turing machine
|
205
|
+
|
206
|
+
synthetic_data = self.data_generator.generate(
|
207
|
+
self.seq_len,
|
208
|
+
condition = conditions,
|
209
|
+
seed = seeds
|
210
|
+
)
|
211
|
+
|
212
|
+
# reset
|
213
|
+
|
214
|
+
if self.num_reset > 0:
|
215
|
+
buffer_to_reset = self.get_rand_sequences_from_buffer(self.num_reset)
|
216
|
+
|
217
|
+
with torch.no_grad():
|
218
|
+
reset_sequences = self.random_sequences_fn(self.num_reset // 2, self.num_reset // 2, device = self.device)
|
219
|
+
buffer_to_reset.copy_(reset_sequences)
|
220
|
+
|
221
|
+
# sample yet again according to pseudocode
|
222
|
+
|
223
|
+
data = self.get_rand_sequences_from_buffer()
|
224
|
+
|
225
|
+
return self.ar_wrapped(data)
|
x_transformers/x_transformers.py
CHANGED
@@ -3263,7 +3263,7 @@ class TransformerWrapper(Module):
|
|
3263
3263
|
|
3264
3264
|
# attention pool
|
3265
3265
|
|
3266
|
-
if exists(self.attn_pool):
|
3266
|
+
if exists(self.attn_pool) and return_intermediates:
|
3267
3267
|
queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
|
3268
3268
|
|
3269
3269
|
attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.4.
|
3
|
+
Version: 2.4.2
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2495,4 +2495,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2495
2495
|
}
|
2496
2496
|
```
|
2497
2497
|
|
2498
|
+
```bibtex
|
2499
|
+
@misc{bloem2025universalpretrainingiteratedrandom,
|
2500
|
+
title = {Universal pre-training by iterated random computation},
|
2501
|
+
author = {Peter Bloem},
|
2502
|
+
year = {2025},
|
2503
|
+
eprint = {2506.20057},
|
2504
|
+
archivePrefix = {arXiv},
|
2505
|
+
primaryClass = {cs.LG},
|
2506
|
+
url = {https://arxiv.org/abs/2506.20057},
|
2507
|
+
}
|
2508
|
+
```
|
2509
|
+
|
2498
2510
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -8,10 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/
|
11
|
+
x_transformers/up_wrapper.py,sha256=8mHA9_U5cTnGNp9Owtr__qQkN9kNsOKQlz6qHHztIdk,5929
|
12
|
+
x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
|
12
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.4.
|
15
|
-
x_transformers-2.4.
|
16
|
-
x_transformers-2.4.
|
17
|
-
x_transformers-2.4.
|
15
|
+
x_transformers-2.4.2.dist-info/METADATA,sha256=sr98RaCqCx78Ppt-XTsY2W-FMhNepFeFs8dgDGgbXs4,90223
|
16
|
+
x_transformers-2.4.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.4.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.4.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|