x-transformers 2.4.0__tar.gz → 2.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.4.0 → x_transformers-2.4.2}/PKG-INFO +13 -1
- {x_transformers-2.4.0 → x_transformers-2.4.2}/README.md +12 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/pyproject.toml +1 -1
- {x_transformers-2.4.0 → x_transformers-2.4.2}/tests/test_x_transformers.py +20 -0
- x_transformers-2.4.2/x_transformers/up_wrapper.py +225 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/x_transformers.py +1 -1
- {x_transformers-2.4.0 → x_transformers-2.4.2}/.github/FUNDING.yml +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/.gitignore +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/LICENSE +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/data/README.md +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/data/enwik8.gz +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/all-attention.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/attention-on-attention.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/deepnorm.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/fcm.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/ffglu.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/flash-attention.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/gate_values.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/gating.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/macaron-1.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/macaron-2.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/memory-transformer.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/normformer.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/pia.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/resi_dual.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/residual_attn.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/rezero.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/rotary.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/sandwich-2.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/sandwich.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/sandwich_norm.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/scalenorm.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/talking-heads.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/topk-attention.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/images/xval.png +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/train_belief_state.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/train_copy.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/train_enwik8.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/train_length_extrapolate.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/train_parity.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/__init__.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/attend.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/continuous.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/dpo.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.4.
|
3
|
+
Version: 2.4.2
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2495,4 +2495,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2495
2495
|
}
|
2496
2496
|
```
|
2497
2497
|
|
2498
|
+
```bibtex
|
2499
|
+
@misc{bloem2025universalpretrainingiteratedrandom,
|
2500
|
+
title = {Universal pre-training by iterated random computation},
|
2501
|
+
author = {Peter Bloem},
|
2502
|
+
year = {2025},
|
2503
|
+
eprint = {2506.20057},
|
2504
|
+
archivePrefix = {arXiv},
|
2505
|
+
primaryClass = {cs.LG},
|
2506
|
+
url = {https://arxiv.org/abs/2506.20057},
|
2507
|
+
}
|
2508
|
+
```
|
2509
|
+
|
2498
2510
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -2447,4 +2447,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2447
2447
|
}
|
2448
2448
|
```
|
2449
2449
|
|
2450
|
+
```bibtex
|
2451
|
+
@misc{bloem2025universalpretrainingiteratedrandom,
|
2452
|
+
title = {Universal pre-training by iterated random computation},
|
2453
|
+
author = {Peter Bloem},
|
2454
|
+
year = {2025},
|
2455
|
+
eprint = {2506.20057},
|
2456
|
+
archivePrefix = {arXiv},
|
2457
|
+
primaryClass = {cs.LG},
|
2458
|
+
url = {https://arxiv.org/abs/2506.20057},
|
2459
|
+
}
|
2460
|
+
```
|
2461
|
+
|
2450
2462
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -1099,3 +1099,23 @@ def add_attn_pool():
|
|
1099
1099
|
logits, intermediates = model(x, return_intermediates = True)
|
1100
1100
|
|
1101
1101
|
assert intermediates.attn_pooled_tokens.shape[1] == 3
|
1102
|
+
|
1103
|
+
def test_up():
|
1104
|
+
from x_transformers.up_wrapper import UniversalPretrainWrapper
|
1105
|
+
|
1106
|
+
model = TransformerWrapper(
|
1107
|
+
num_tokens = 256,
|
1108
|
+
max_seq_len = 1024,
|
1109
|
+
attn_pool = True,
|
1110
|
+
num_attn_pool_queries = 3,
|
1111
|
+
attn_layers = Decoder(
|
1112
|
+
dim = 512,
|
1113
|
+
depth = 12,
|
1114
|
+
heads = 8
|
1115
|
+
),
|
1116
|
+
)
|
1117
|
+
|
1118
|
+
up_wrapper = UniversalPretrainWrapper(model, seq_len = 16)
|
1119
|
+
|
1120
|
+
loss = up_wrapper()
|
1121
|
+
loss.backward()
|
@@ -0,0 +1,225 @@
|
|
1
|
+
# https://arxiv.org/abs/2506.20057
|
2
|
+
# Peter Bloem
|
3
|
+
|
4
|
+
from __future__ import annotations
|
5
|
+
from functools import partial
|
6
|
+
from random import randrange, uniform
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from torch import nn, cat, randperm
|
10
|
+
from torch.nn import LSTM, Module
|
11
|
+
|
12
|
+
from x_transformers.x_transformers import (
|
13
|
+
TransformerWrapper,
|
14
|
+
AutoregressiveWrapper
|
15
|
+
)
|
16
|
+
|
17
|
+
# functions
|
18
|
+
|
19
|
+
def exists(v):
|
20
|
+
return v is not None
|
21
|
+
|
22
|
+
def default(v, d):
|
23
|
+
return v if exists(v) else d
|
24
|
+
|
25
|
+
def divisible_by(num, den):
|
26
|
+
return (num % den) == 0
|
27
|
+
|
28
|
+
# random sequences, mixture of random and constant (unsure why constant is needed)
|
29
|
+
|
30
|
+
def random_sequences(
|
31
|
+
num_tokens,
|
32
|
+
seq_len,
|
33
|
+
num_samples_random,
|
34
|
+
num_samples_constant,
|
35
|
+
shuffle = True,
|
36
|
+
device = None
|
37
|
+
):
|
38
|
+
assert num_samples_random > 0 or num_samples_constant > 0
|
39
|
+
|
40
|
+
rand_seq = torch.randint(0, num_tokens, (num_samples_random, seq_len))
|
41
|
+
const_seq = torch.full((num_samples_constant, seq_len), randrange(num_tokens))
|
42
|
+
|
43
|
+
all_seq = cat((rand_seq, const_seq))
|
44
|
+
|
45
|
+
if exists(device):
|
46
|
+
all_seq = all_seq.to(device)
|
47
|
+
|
48
|
+
if not shuffle:
|
49
|
+
return all_seq
|
50
|
+
|
51
|
+
# shuffle with randperm
|
52
|
+
|
53
|
+
rand_indices = randperm(all_seq.shape[0])
|
54
|
+
return all_seq[rand_indices]
|
55
|
+
|
56
|
+
# synthetic data generator
|
57
|
+
|
58
|
+
class SyntheticDataGenerator(Module):
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
dim,
|
62
|
+
num_tokens,
|
63
|
+
max_seq_len = 512,
|
64
|
+
hidden_size = None
|
65
|
+
):
|
66
|
+
super().__init__()
|
67
|
+
|
68
|
+
self.max_seq_len = max_seq_len
|
69
|
+
|
70
|
+
self.embed = nn.Embedding(num_tokens, dim)
|
71
|
+
|
72
|
+
hidden_size = default(hidden_size, dim)
|
73
|
+
self.lstm = LSTM(dim, hidden_size, batch_first = True)
|
74
|
+
|
75
|
+
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
|
76
|
+
|
77
|
+
self.apply(self.init_)
|
78
|
+
|
79
|
+
@torch.no_grad()
|
80
|
+
def init_(self, m):
|
81
|
+
if isinstance(m, nn.Linear):
|
82
|
+
m.weight *= uniform(0., 1.1) # he scales the lstm weights from 0 to 1.1
|
83
|
+
|
84
|
+
@torch.inference_mode()
|
85
|
+
@torch.compile
|
86
|
+
def generate(
|
87
|
+
self,
|
88
|
+
length,
|
89
|
+
seed = None,
|
90
|
+
condition = None,
|
91
|
+
temperature = 1e-4 # he uses a near greedy temperature
|
92
|
+
):
|
93
|
+
assert exists(seed) or exists(condition)
|
94
|
+
prefix = [*filter(exists, (seed, condition))]
|
95
|
+
seq_len = self.max_seq_len
|
96
|
+
|
97
|
+
seq = torch.cat(prefix, dim = -1)
|
98
|
+
|
99
|
+
net_input = seq
|
100
|
+
hiddens = None
|
101
|
+
|
102
|
+
for _ in range(length):
|
103
|
+
|
104
|
+
logits, hiddens = self.forward(net_input, hiddens)
|
105
|
+
|
106
|
+
last_logit = logits[:, -1]
|
107
|
+
prob = (last_logit / temperature).softmax(dim = -1)
|
108
|
+
|
109
|
+
sampled = torch.multinomial(prob, 1)
|
110
|
+
net_input = sampled
|
111
|
+
|
112
|
+
seq = torch.cat((seq, sampled), dim = -1)
|
113
|
+
|
114
|
+
return seq[:, -seq_len:]
|
115
|
+
|
116
|
+
def forward(
|
117
|
+
self,
|
118
|
+
input,
|
119
|
+
hiddens = None
|
120
|
+
):
|
121
|
+
|
122
|
+
tokens = self.embed(input)
|
123
|
+
|
124
|
+
embed, hidden = self.lstm(tokens, hiddens)
|
125
|
+
|
126
|
+
logits = self.to_logits(embed)
|
127
|
+
|
128
|
+
return logits, hidden
|
129
|
+
|
130
|
+
# classes
|
131
|
+
|
132
|
+
class UniversalPretrainWrapper(Module):
|
133
|
+
def __init__(
|
134
|
+
self,
|
135
|
+
model: TransformerWrapper,
|
136
|
+
data_generator: SyntheticDataGenerator | None = None,
|
137
|
+
buffer_size = None,
|
138
|
+
num_reset = 20,
|
139
|
+
batch_size = 32,
|
140
|
+
seq_len = 512,
|
141
|
+
seed_length = 8
|
142
|
+
):
|
143
|
+
super().__init__()
|
144
|
+
|
145
|
+
self.model = model
|
146
|
+
self.ar_wrapped = AutoregressiveWrapper(model)
|
147
|
+
|
148
|
+
assert model.attn_layers.causal
|
149
|
+
|
150
|
+
num_tokens = model.num_tokens
|
151
|
+
dim = model.attn_layers.dim
|
152
|
+
|
153
|
+
if not exists(data_generator):
|
154
|
+
data_generator = SyntheticDataGenerator(
|
155
|
+
num_tokens = num_tokens,
|
156
|
+
dim = dim
|
157
|
+
)
|
158
|
+
|
159
|
+
self.seq_len = seq_len
|
160
|
+
self.data_generator = data_generator
|
161
|
+
|
162
|
+
self.seed_length = seed_length
|
163
|
+
self.batch_size = batch_size
|
164
|
+
|
165
|
+
buffer_size = default(buffer_size, batch_size * 20)
|
166
|
+
assert buffer_size > batch_size, f'data buffer size must be greater than batch size'
|
167
|
+
|
168
|
+
assert divisible_by(num_reset, 2)
|
169
|
+
self.num_reset = num_reset
|
170
|
+
|
171
|
+
self.buffer_size = buffer_size
|
172
|
+
|
173
|
+
self.random_sequences_fn = partial(random_sequences, num_tokens, seq_len)
|
174
|
+
|
175
|
+
init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
|
176
|
+
|
177
|
+
self.register_buffer('synth_data_buffer', init_data_buffer)
|
178
|
+
|
179
|
+
@property
|
180
|
+
def device(self):
|
181
|
+
return self.synth_data_buffer.device
|
182
|
+
|
183
|
+
def get_rand_sequences_from_buffer(self, size = None):
|
184
|
+
size = default(size, self.batch_size)
|
185
|
+
rand_indices = randperm(self.buffer_size, device = self.device)[:size]
|
186
|
+
return self.synth_data_buffer[rand_indices]
|
187
|
+
|
188
|
+
def forward(self):
|
189
|
+
# following algorithm 1.
|
190
|
+
|
191
|
+
conditions = self.get_rand_sequences_from_buffer()
|
192
|
+
|
193
|
+
# get seeds, which appears to be random sequences with random crops of seed length
|
194
|
+
|
195
|
+
seeds = self.get_rand_sequences_from_buffer()
|
196
|
+
|
197
|
+
seq_arange = torch.arange(self.seed_length)
|
198
|
+
rand_offset = torch.randint(0, self.seq_len - self.seed_length, (self.batch_size,))
|
199
|
+
seq_start_pos = rand_offset[:, None] + seq_arange
|
200
|
+
|
201
|
+
batch_arange = torch.arange(self.batch_size, device = self.device)[:, None]
|
202
|
+
seeds = seeds[batch_arange, seq_start_pos]
|
203
|
+
|
204
|
+
# seed, condition to turing machine
|
205
|
+
|
206
|
+
synthetic_data = self.data_generator.generate(
|
207
|
+
self.seq_len,
|
208
|
+
condition = conditions,
|
209
|
+
seed = seeds
|
210
|
+
)
|
211
|
+
|
212
|
+
# reset
|
213
|
+
|
214
|
+
if self.num_reset > 0:
|
215
|
+
buffer_to_reset = self.get_rand_sequences_from_buffer(self.num_reset)
|
216
|
+
|
217
|
+
with torch.no_grad():
|
218
|
+
reset_sequences = self.random_sequences_fn(self.num_reset // 2, self.num_reset // 2, device = self.device)
|
219
|
+
buffer_to_reset.copy_(reset_sequences)
|
220
|
+
|
221
|
+
# sample yet again according to pseudocode
|
222
|
+
|
223
|
+
data = self.get_rand_sequences_from_buffer()
|
224
|
+
|
225
|
+
return self.ar_wrapped(data)
|
@@ -3263,7 +3263,7 @@ class TransformerWrapper(Module):
|
|
3263
3263
|
|
3264
3264
|
# attention pool
|
3265
3265
|
|
3266
|
-
if exists(self.attn_pool):
|
3266
|
+
if exists(self.attn_pool) and return_intermediates:
|
3267
3267
|
queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
|
3268
3268
|
|
3269
3269
|
attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|