x-transformers 2.4.0__tar.gz → 2.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.0 → x_transformers-2.4.2}/PKG-INFO +13 -1
  2. {x_transformers-2.4.0 → x_transformers-2.4.2}/README.md +12 -0
  3. {x_transformers-2.4.0 → x_transformers-2.4.2}/pyproject.toml +1 -1
  4. {x_transformers-2.4.0 → x_transformers-2.4.2}/tests/test_x_transformers.py +20 -0
  5. x_transformers-2.4.2/x_transformers/up_wrapper.py +225 -0
  6. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/x_transformers.py +1 -1
  7. {x_transformers-2.4.0 → x_transformers-2.4.2}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.4.0 → x_transformers-2.4.2}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.4.0 → x_transformers-2.4.2}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.4.0 → x_transformers-2.4.2}/.gitignore +0 -0
  11. {x_transformers-2.4.0 → x_transformers-2.4.2}/LICENSE +0 -0
  12. {x_transformers-2.4.0 → x_transformers-2.4.2}/data/README.md +0 -0
  13. {x_transformers-2.4.0 → x_transformers-2.4.2}/data/enwik8.gz +0 -0
  14. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/all-attention.png +0 -0
  15. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/deepnorm.png +0 -0
  18. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/fcm.png +0 -0
  24. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/ffglu.png +0 -0
  25. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/flash-attention.png +0 -0
  26. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/gate_values.png +0 -0
  27. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/gating.png +0 -0
  28. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/macaron-1.png +0 -0
  30. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/macaron-2.png +0 -0
  31. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/normformer.png +0 -0
  33. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/pia.png +0 -0
  34. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/resi_dual.png +0 -0
  36. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/residual_attn.png +0 -0
  37. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/rezero.png +0 -0
  38. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/rotary.png +0 -0
  39. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/sandwich.png +0 -0
  41. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/scalenorm.png +0 -0
  43. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/talking-heads.png +0 -0
  44. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/topk-attention.png +0 -0
  45. {x_transformers-2.4.0 → x_transformers-2.4.2}/images/xval.png +0 -0
  46. {x_transformers-2.4.0 → x_transformers-2.4.2}/train_belief_state.py +0 -0
  47. {x_transformers-2.4.0 → x_transformers-2.4.2}/train_copy.py +0 -0
  48. {x_transformers-2.4.0 → x_transformers-2.4.2}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.4.0 → x_transformers-2.4.2}/train_enwik8.py +0 -0
  50. {x_transformers-2.4.0 → x_transformers-2.4.2}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.4.0 → x_transformers-2.4.2}/train_parity.py +0 -0
  52. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/__init__.py +0 -0
  53. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/attend.py +0 -0
  54. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/autoregressive_wrapper.py +0 -0
  55. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/belief_state_wrapper.py +0 -0
  56. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/continuous.py +0 -0
  57. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/dpo.py +0 -0
  58. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/entropy_based_tokenizer.py +0 -0
  59. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/multi_input.py +0 -0
  60. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/neo_mlp.py +0 -0
  61. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  62. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.0 → x_transformers-2.4.2}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.0
3
+ Version: 2.4.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2495,4 +2495,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2495
2495
  }
2496
2496
  ```
2497
2497
 
2498
+ ```bibtex
2499
+ @misc{bloem2025universalpretrainingiteratedrandom,
2500
+ title = {Universal pre-training by iterated random computation},
2501
+ author = {Peter Bloem},
2502
+ year = {2025},
2503
+ eprint = {2506.20057},
2504
+ archivePrefix = {arXiv},
2505
+ primaryClass = {cs.LG},
2506
+ url = {https://arxiv.org/abs/2506.20057},
2507
+ }
2508
+ ```
2509
+
2498
2510
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2447,4 +2447,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2447
2447
  }
2448
2448
  ```
2449
2449
 
2450
+ ```bibtex
2451
+ @misc{bloem2025universalpretrainingiteratedrandom,
2452
+ title = {Universal pre-training by iterated random computation},
2453
+ author = {Peter Bloem},
2454
+ year = {2025},
2455
+ eprint = {2506.20057},
2456
+ archivePrefix = {arXiv},
2457
+ primaryClass = {cs.LG},
2458
+ url = {https://arxiv.org/abs/2506.20057},
2459
+ }
2460
+ ```
2461
+
2450
2462
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.0"
3
+ version = "2.4.2"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1099,3 +1099,23 @@ def add_attn_pool():
1099
1099
  logits, intermediates = model(x, return_intermediates = True)
1100
1100
 
1101
1101
  assert intermediates.attn_pooled_tokens.shape[1] == 3
1102
+
1103
+ def test_up():
1104
+ from x_transformers.up_wrapper import UniversalPretrainWrapper
1105
+
1106
+ model = TransformerWrapper(
1107
+ num_tokens = 256,
1108
+ max_seq_len = 1024,
1109
+ attn_pool = True,
1110
+ num_attn_pool_queries = 3,
1111
+ attn_layers = Decoder(
1112
+ dim = 512,
1113
+ depth = 12,
1114
+ heads = 8
1115
+ ),
1116
+ )
1117
+
1118
+ up_wrapper = UniversalPretrainWrapper(model, seq_len = 16)
1119
+
1120
+ loss = up_wrapper()
1121
+ loss.backward()
@@ -0,0 +1,225 @@
1
+ # https://arxiv.org/abs/2506.20057
2
+ # Peter Bloem
3
+
4
+ from __future__ import annotations
5
+ from functools import partial
6
+ from random import randrange, uniform
7
+
8
+ import torch
9
+ from torch import nn, cat, randperm
10
+ from torch.nn import LSTM, Module
11
+
12
+ from x_transformers.x_transformers import (
13
+ TransformerWrapper,
14
+ AutoregressiveWrapper
15
+ )
16
+
17
+ # functions
18
+
19
+ def exists(v):
20
+ return v is not None
21
+
22
+ def default(v, d):
23
+ return v if exists(v) else d
24
+
25
+ def divisible_by(num, den):
26
+ return (num % den) == 0
27
+
28
+ # random sequences, mixture of random and constant (unsure why constant is needed)
29
+
30
+ def random_sequences(
31
+ num_tokens,
32
+ seq_len,
33
+ num_samples_random,
34
+ num_samples_constant,
35
+ shuffle = True,
36
+ device = None
37
+ ):
38
+ assert num_samples_random > 0 or num_samples_constant > 0
39
+
40
+ rand_seq = torch.randint(0, num_tokens, (num_samples_random, seq_len))
41
+ const_seq = torch.full((num_samples_constant, seq_len), randrange(num_tokens))
42
+
43
+ all_seq = cat((rand_seq, const_seq))
44
+
45
+ if exists(device):
46
+ all_seq = all_seq.to(device)
47
+
48
+ if not shuffle:
49
+ return all_seq
50
+
51
+ # shuffle with randperm
52
+
53
+ rand_indices = randperm(all_seq.shape[0])
54
+ return all_seq[rand_indices]
55
+
56
+ # synthetic data generator
57
+
58
+ class SyntheticDataGenerator(Module):
59
+ def __init__(
60
+ self,
61
+ dim,
62
+ num_tokens,
63
+ max_seq_len = 512,
64
+ hidden_size = None
65
+ ):
66
+ super().__init__()
67
+
68
+ self.max_seq_len = max_seq_len
69
+
70
+ self.embed = nn.Embedding(num_tokens, dim)
71
+
72
+ hidden_size = default(hidden_size, dim)
73
+ self.lstm = LSTM(dim, hidden_size, batch_first = True)
74
+
75
+ self.to_logits = nn.Linear(dim, num_tokens, bias = False)
76
+
77
+ self.apply(self.init_)
78
+
79
+ @torch.no_grad()
80
+ def init_(self, m):
81
+ if isinstance(m, nn.Linear):
82
+ m.weight *= uniform(0., 1.1) # he scales the lstm weights from 0 to 1.1
83
+
84
+ @torch.inference_mode()
85
+ @torch.compile
86
+ def generate(
87
+ self,
88
+ length,
89
+ seed = None,
90
+ condition = None,
91
+ temperature = 1e-4 # he uses a near greedy temperature
92
+ ):
93
+ assert exists(seed) or exists(condition)
94
+ prefix = [*filter(exists, (seed, condition))]
95
+ seq_len = self.max_seq_len
96
+
97
+ seq = torch.cat(prefix, dim = -1)
98
+
99
+ net_input = seq
100
+ hiddens = None
101
+
102
+ for _ in range(length):
103
+
104
+ logits, hiddens = self.forward(net_input, hiddens)
105
+
106
+ last_logit = logits[:, -1]
107
+ prob = (last_logit / temperature).softmax(dim = -1)
108
+
109
+ sampled = torch.multinomial(prob, 1)
110
+ net_input = sampled
111
+
112
+ seq = torch.cat((seq, sampled), dim = -1)
113
+
114
+ return seq[:, -seq_len:]
115
+
116
+ def forward(
117
+ self,
118
+ input,
119
+ hiddens = None
120
+ ):
121
+
122
+ tokens = self.embed(input)
123
+
124
+ embed, hidden = self.lstm(tokens, hiddens)
125
+
126
+ logits = self.to_logits(embed)
127
+
128
+ return logits, hidden
129
+
130
+ # classes
131
+
132
+ class UniversalPretrainWrapper(Module):
133
+ def __init__(
134
+ self,
135
+ model: TransformerWrapper,
136
+ data_generator: SyntheticDataGenerator | None = None,
137
+ buffer_size = None,
138
+ num_reset = 20,
139
+ batch_size = 32,
140
+ seq_len = 512,
141
+ seed_length = 8
142
+ ):
143
+ super().__init__()
144
+
145
+ self.model = model
146
+ self.ar_wrapped = AutoregressiveWrapper(model)
147
+
148
+ assert model.attn_layers.causal
149
+
150
+ num_tokens = model.num_tokens
151
+ dim = model.attn_layers.dim
152
+
153
+ if not exists(data_generator):
154
+ data_generator = SyntheticDataGenerator(
155
+ num_tokens = num_tokens,
156
+ dim = dim
157
+ )
158
+
159
+ self.seq_len = seq_len
160
+ self.data_generator = data_generator
161
+
162
+ self.seed_length = seed_length
163
+ self.batch_size = batch_size
164
+
165
+ buffer_size = default(buffer_size, batch_size * 20)
166
+ assert buffer_size > batch_size, f'data buffer size must be greater than batch size'
167
+
168
+ assert divisible_by(num_reset, 2)
169
+ self.num_reset = num_reset
170
+
171
+ self.buffer_size = buffer_size
172
+
173
+ self.random_sequences_fn = partial(random_sequences, num_tokens, seq_len)
174
+
175
+ init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
176
+
177
+ self.register_buffer('synth_data_buffer', init_data_buffer)
178
+
179
+ @property
180
+ def device(self):
181
+ return self.synth_data_buffer.device
182
+
183
+ def get_rand_sequences_from_buffer(self, size = None):
184
+ size = default(size, self.batch_size)
185
+ rand_indices = randperm(self.buffer_size, device = self.device)[:size]
186
+ return self.synth_data_buffer[rand_indices]
187
+
188
+ def forward(self):
189
+ # following algorithm 1.
190
+
191
+ conditions = self.get_rand_sequences_from_buffer()
192
+
193
+ # get seeds, which appears to be random sequences with random crops of seed length
194
+
195
+ seeds = self.get_rand_sequences_from_buffer()
196
+
197
+ seq_arange = torch.arange(self.seed_length)
198
+ rand_offset = torch.randint(0, self.seq_len - self.seed_length, (self.batch_size,))
199
+ seq_start_pos = rand_offset[:, None] + seq_arange
200
+
201
+ batch_arange = torch.arange(self.batch_size, device = self.device)[:, None]
202
+ seeds = seeds[batch_arange, seq_start_pos]
203
+
204
+ # seed, condition to turing machine
205
+
206
+ synthetic_data = self.data_generator.generate(
207
+ self.seq_len,
208
+ condition = conditions,
209
+ seed = seeds
210
+ )
211
+
212
+ # reset
213
+
214
+ if self.num_reset > 0:
215
+ buffer_to_reset = self.get_rand_sequences_from_buffer(self.num_reset)
216
+
217
+ with torch.no_grad():
218
+ reset_sequences = self.random_sequences_fn(self.num_reset // 2, self.num_reset // 2, device = self.device)
219
+ buffer_to_reset.copy_(reset_sequences)
220
+
221
+ # sample yet again according to pseudocode
222
+
223
+ data = self.get_rand_sequences_from_buffer()
224
+
225
+ return self.ar_wrapped(data)
@@ -3263,7 +3263,7 @@ class TransformerWrapper(Module):
3263
3263
 
3264
3264
  # attention pool
3265
3265
 
3266
- if exists(self.attn_pool):
3266
+ if exists(self.attn_pool) and return_intermediates:
3267
3267
  queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
3268
3268
 
3269
3269
  attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
File without changes