x-transformers 2.11.18__tar.gz → 2.11.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.18 → x_transformers-2.11.20}/PKG-INFO +8 -7
- {x_transformers-2.11.18 → x_transformers-2.11.20}/README.md +7 -6
- {x_transformers-2.11.18 → x_transformers-2.11.20}/pyproject.toml +1 -1
- {x_transformers-2.11.18 → x_transformers-2.11.20}/tests/test_x_transformers.py +16 -8
- {x_transformers-2.11.18 → x_transformers-2.11.20}/train_enwik8.py +19 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/x_transformers.py +35 -3
- {x_transformers-2.11.18 → x_transformers-2.11.20}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/.gitignore +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/LICENSE +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/data/README.md +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/data/enwik8.gz +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/all-attention.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/deepnorm.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/fcm.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/ffglu.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/flash-attention.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/gate_values.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/gating.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/macaron-1.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/macaron-2.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/normformer.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/pia.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/resi_dual.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/residual_attn.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/rezero.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/rotary.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/sandwich.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/scalenorm.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/talking-heads.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/topk-attention.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/images/xval.png +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/train_belief_state.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/train_copy.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/train_free.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/train_parity.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/train_with_muon.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/free_transformer.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/xval.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-transformers
|
|
3
|
-
Version: 2.11.
|
|
3
|
+
Version: 2.11.20
|
|
4
4
|
Summary: X-Transformers
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
|
@@ -2608,12 +2608,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2608
2608
|
```
|
|
2609
2609
|
|
|
2610
2610
|
```bibtex
|
|
2611
|
-
@
|
|
2612
|
-
title = {
|
|
2613
|
-
author = {
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2611
|
+
@inproceedings{anonymous2025beliefformer,
|
|
2612
|
+
title = {BeliefFormer: Belief Attention in Transformer},
|
|
2613
|
+
author = {Anonymous},
|
|
2614
|
+
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
|
|
2615
|
+
year = {2025},
|
|
2616
|
+
url = {https://openreview.net/forum?id=Ard2QzPAUK},
|
|
2617
|
+
note = {under review}
|
|
2617
2618
|
}
|
|
2618
2619
|
```
|
|
2619
2620
|
|
|
@@ -2559,12 +2559,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2559
2559
|
```
|
|
2560
2560
|
|
|
2561
2561
|
```bibtex
|
|
2562
|
-
@
|
|
2563
|
-
title = {
|
|
2564
|
-
author = {
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2562
|
+
@inproceedings{anonymous2025beliefformer,
|
|
2563
|
+
title = {BeliefFormer: Belief Attention in Transformer},
|
|
2564
|
+
author = {Anonymous},
|
|
2565
|
+
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
|
|
2566
|
+
year = {2025},
|
|
2567
|
+
url = {https://openreview.net/forum?id=Ard2QzPAUK},
|
|
2568
|
+
note = {under review}
|
|
2568
2569
|
}
|
|
2569
2570
|
```
|
|
2570
2571
|
|
|
@@ -1463,13 +1463,21 @@ def test_kv_input_residual():
|
|
|
1463
1463
|
|
|
1464
1464
|
assert tokens.shape == out.shape
|
|
1465
1465
|
|
|
1466
|
-
def
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1466
|
+
def test_belief_attn():
|
|
1467
|
+
from x_transformers import TransformerWrapper, Decoder
|
|
1468
|
+
|
|
1469
|
+
model = TransformerWrapper(
|
|
1470
|
+
num_tokens = 256,
|
|
1471
|
+
max_seq_len = 1024,
|
|
1472
|
+
attn_layers = Decoder(
|
|
1473
|
+
dim = 512,
|
|
1474
|
+
depth = 6,
|
|
1475
|
+
heads = 8,
|
|
1476
|
+
rotary_pos_emb = True,
|
|
1477
|
+
attn_orthog_projected_values = True
|
|
1478
|
+
)
|
|
1472
1479
|
)
|
|
1473
1480
|
|
|
1474
|
-
|
|
1475
|
-
|
|
1481
|
+
x = torch.randint(0, 256, (1, 10))
|
|
1482
|
+
|
|
1483
|
+
logits = model(x)
|
|
@@ -1,3 +1,11 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "tqdm",
|
|
4
|
+
# "x-transformers",
|
|
5
|
+
# "wandb"
|
|
6
|
+
# ]
|
|
7
|
+
# ///
|
|
8
|
+
|
|
1
9
|
from x_transformers import TransformerWrapper, Decoder
|
|
2
10
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
|
3
11
|
|
|
@@ -20,6 +28,7 @@ VALIDATE_EVERY = 100
|
|
|
20
28
|
GENERATE_EVERY = 500
|
|
21
29
|
GENERATE_LENGTH = 1024
|
|
22
30
|
SEQ_LEN = 1024
|
|
31
|
+
TRACK_EXPERIMENT_ONLINE = False
|
|
23
32
|
|
|
24
33
|
# helpers
|
|
25
34
|
|
|
@@ -80,6 +89,12 @@ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last
|
|
|
80
89
|
|
|
81
90
|
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
|
82
91
|
|
|
92
|
+
# experiment
|
|
93
|
+
|
|
94
|
+
import wandb
|
|
95
|
+
wandb.init(project = 'enwik8', mode = 'online' if TRACK_EXPERIMENT_ONLINE else 'disabled')
|
|
96
|
+
wandb.run.name = 'baseline'
|
|
97
|
+
|
|
83
98
|
# training
|
|
84
99
|
|
|
85
100
|
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
@@ -90,6 +105,8 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
90
105
|
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
|
|
91
106
|
|
|
92
107
|
print(f'training loss: {loss.item()}')
|
|
108
|
+
wandb.log(dict(loss = loss.item()))
|
|
109
|
+
|
|
93
110
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
|
94
111
|
optim.step()
|
|
95
112
|
optim.zero_grad()
|
|
@@ -98,7 +115,9 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
98
115
|
model.eval()
|
|
99
116
|
with torch.no_grad():
|
|
100
117
|
loss = model(next(val_loader))
|
|
118
|
+
|
|
101
119
|
print(f'validation loss: {loss.item()}')
|
|
120
|
+
wandb.log(dict(valid_loss = loss.item()))
|
|
102
121
|
|
|
103
122
|
if i % GENERATE_EVERY == 0:
|
|
104
123
|
model.eval()
|
|
@@ -161,6 +161,21 @@ def or_reduce(masks):
|
|
|
161
161
|
head = head | rest
|
|
162
162
|
return head
|
|
163
163
|
|
|
164
|
+
def orthog_project(x, y):
|
|
165
|
+
x, packed_shape = pack([x], 'b *')
|
|
166
|
+
y, _ = pack([y], 'b *')
|
|
167
|
+
|
|
168
|
+
dtype = x.dtype
|
|
169
|
+
x, y = x.double(), y.double()
|
|
170
|
+
unit = F.normalize(y, dim = -1)
|
|
171
|
+
|
|
172
|
+
parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
|
|
173
|
+
orthog = x - parallel
|
|
174
|
+
|
|
175
|
+
orthog, = unpack(orthog, packed_shape, 'b *')
|
|
176
|
+
|
|
177
|
+
return orthog.to(dtype)
|
|
178
|
+
|
|
164
179
|
# cache helpers
|
|
165
180
|
|
|
166
181
|
def get_cached_kvs(
|
|
@@ -276,8 +291,13 @@ class ReluSquared(Module):
|
|
|
276
291
|
return F.relu(x) ** 2
|
|
277
292
|
|
|
278
293
|
class SoLU(Module):
|
|
294
|
+
def __init__(self, dim):
|
|
295
|
+
super().__init__()
|
|
296
|
+
self.norm = LayerNorm(dim)
|
|
297
|
+
|
|
279
298
|
def forward(self, x):
|
|
280
|
-
|
|
299
|
+
activated = x.softmax(dim = -1) * x
|
|
300
|
+
return self.norm(activated)
|
|
281
301
|
|
|
282
302
|
# embedding
|
|
283
303
|
|
|
@@ -1262,7 +1282,7 @@ class FeedForward(Module):
|
|
|
1262
1282
|
elif relu_squared:
|
|
1263
1283
|
activation = ReluSquared()
|
|
1264
1284
|
elif solu:
|
|
1265
|
-
activation = SoLU()
|
|
1285
|
+
activation = SoLU(inner_dim)
|
|
1266
1286
|
elif swish:
|
|
1267
1287
|
activation = nn.SiLU()
|
|
1268
1288
|
else:
|
|
@@ -1376,7 +1396,8 @@ class Attention(Module):
|
|
|
1376
1396
|
softclamp_logits = False,
|
|
1377
1397
|
logit_softclamp_value = 50.,
|
|
1378
1398
|
learned_value_residual_mix = False,
|
|
1379
|
-
|
|
1399
|
+
orthog_projected_values = False, # https://openreview.net/forum?id=Ard2QzPAUK
|
|
1400
|
+
laser = False, # https://arxiv.org/abs/2411.03493v1
|
|
1380
1401
|
laser_softclamp_value = 15.,
|
|
1381
1402
|
qkv_receive_diff_residuals = False,
|
|
1382
1403
|
use_latent_q = False,
|
|
@@ -1602,6 +1623,11 @@ class Attention(Module):
|
|
|
1602
1623
|
|
|
1603
1624
|
self.attn_on_attn = on_attn
|
|
1604
1625
|
|
|
1626
|
+
# return orthogonal projected weighted values on original values
|
|
1627
|
+
# "belief attention" - iclr 2026
|
|
1628
|
+
|
|
1629
|
+
self.orthog_projected_values = orthog_projected_values
|
|
1630
|
+
|
|
1605
1631
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
|
1606
1632
|
|
|
1607
1633
|
hybrid_mix = None
|
|
@@ -2043,6 +2069,12 @@ class Attention(Module):
|
|
|
2043
2069
|
gates = self.to_v_gate(x)
|
|
2044
2070
|
out = out * self.to_v_gate_activation(gates)
|
|
2045
2071
|
|
|
2072
|
+
# maybe return orthogonal projected - "belief" attention
|
|
2073
|
+
|
|
2074
|
+
if self.orthog_projected_values:
|
|
2075
|
+
merged_v = self.merge_heads(orig_values)
|
|
2076
|
+
out = orthog_project(out, merged_v)
|
|
2077
|
+
|
|
2046
2078
|
# combine the heads
|
|
2047
2079
|
|
|
2048
2080
|
out = self.to_out(out)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/nonautoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.18 → x_transformers-2.11.20}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|