x-transformers 2.11.19__tar.gz → 2.11.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.19 → x_transformers-2.11.20}/PKG-INFO +8 -7
- {x_transformers-2.11.19 → x_transformers-2.11.20}/README.md +7 -6
- {x_transformers-2.11.19 → x_transformers-2.11.20}/pyproject.toml +1 -1
- {x_transformers-2.11.19 → x_transformers-2.11.20}/tests/test_x_transformers.py +16 -8
- {x_transformers-2.11.19 → x_transformers-2.11.20}/train_enwik8.py +19 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/x_transformers.py +28 -1
- {x_transformers-2.11.19 → x_transformers-2.11.20}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/.gitignore +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/LICENSE +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/data/README.md +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/data/enwik8.gz +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/all-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/deepnorm.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/fcm.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/ffglu.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/flash-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/gate_values.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/gating.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/macaron-1.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/macaron-2.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/normformer.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/pia.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/resi_dual.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/residual_attn.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/rezero.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/rotary.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/sandwich.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/scalenorm.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/talking-heads.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/topk-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/images/xval.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/train_belief_state.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/train_copy.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/train_free.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/train_parity.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/train_with_muon.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/free_transformer.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/xval.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-transformers
|
|
3
|
-
Version: 2.11.
|
|
3
|
+
Version: 2.11.20
|
|
4
4
|
Summary: X-Transformers
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
|
@@ -2608,12 +2608,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2608
2608
|
```
|
|
2609
2609
|
|
|
2610
2610
|
```bibtex
|
|
2611
|
-
@
|
|
2612
|
-
title = {
|
|
2613
|
-
author = {
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2611
|
+
@inproceedings{anonymous2025beliefformer,
|
|
2612
|
+
title = {BeliefFormer: Belief Attention in Transformer},
|
|
2613
|
+
author = {Anonymous},
|
|
2614
|
+
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
|
|
2615
|
+
year = {2025},
|
|
2616
|
+
url = {https://openreview.net/forum?id=Ard2QzPAUK},
|
|
2617
|
+
note = {under review}
|
|
2617
2618
|
}
|
|
2618
2619
|
```
|
|
2619
2620
|
|
|
@@ -2559,12 +2559,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2559
2559
|
```
|
|
2560
2560
|
|
|
2561
2561
|
```bibtex
|
|
2562
|
-
@
|
|
2563
|
-
title = {
|
|
2564
|
-
author = {
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2562
|
+
@inproceedings{anonymous2025beliefformer,
|
|
2563
|
+
title = {BeliefFormer: Belief Attention in Transformer},
|
|
2564
|
+
author = {Anonymous},
|
|
2565
|
+
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
|
|
2566
|
+
year = {2025},
|
|
2567
|
+
url = {https://openreview.net/forum?id=Ard2QzPAUK},
|
|
2568
|
+
note = {under review}
|
|
2568
2569
|
}
|
|
2569
2570
|
```
|
|
2570
2571
|
|
|
@@ -1463,13 +1463,21 @@ def test_kv_input_residual():
|
|
|
1463
1463
|
|
|
1464
1464
|
assert tokens.shape == out.shape
|
|
1465
1465
|
|
|
1466
|
-
def
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1466
|
+
def test_belief_attn():
|
|
1467
|
+
from x_transformers import TransformerWrapper, Decoder
|
|
1468
|
+
|
|
1469
|
+
model = TransformerWrapper(
|
|
1470
|
+
num_tokens = 256,
|
|
1471
|
+
max_seq_len = 1024,
|
|
1472
|
+
attn_layers = Decoder(
|
|
1473
|
+
dim = 512,
|
|
1474
|
+
depth = 6,
|
|
1475
|
+
heads = 8,
|
|
1476
|
+
rotary_pos_emb = True,
|
|
1477
|
+
attn_orthog_projected_values = True
|
|
1478
|
+
)
|
|
1472
1479
|
)
|
|
1473
1480
|
|
|
1474
|
-
|
|
1475
|
-
|
|
1481
|
+
x = torch.randint(0, 256, (1, 10))
|
|
1482
|
+
|
|
1483
|
+
logits = model(x)
|
|
@@ -1,3 +1,11 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "tqdm",
|
|
4
|
+
# "x-transformers",
|
|
5
|
+
# "wandb"
|
|
6
|
+
# ]
|
|
7
|
+
# ///
|
|
8
|
+
|
|
1
9
|
from x_transformers import TransformerWrapper, Decoder
|
|
2
10
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
|
3
11
|
|
|
@@ -20,6 +28,7 @@ VALIDATE_EVERY = 100
|
|
|
20
28
|
GENERATE_EVERY = 500
|
|
21
29
|
GENERATE_LENGTH = 1024
|
|
22
30
|
SEQ_LEN = 1024
|
|
31
|
+
TRACK_EXPERIMENT_ONLINE = False
|
|
23
32
|
|
|
24
33
|
# helpers
|
|
25
34
|
|
|
@@ -80,6 +89,12 @@ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last
|
|
|
80
89
|
|
|
81
90
|
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
|
82
91
|
|
|
92
|
+
# experiment
|
|
93
|
+
|
|
94
|
+
import wandb
|
|
95
|
+
wandb.init(project = 'enwik8', mode = 'online' if TRACK_EXPERIMENT_ONLINE else 'disabled')
|
|
96
|
+
wandb.run.name = 'baseline'
|
|
97
|
+
|
|
83
98
|
# training
|
|
84
99
|
|
|
85
100
|
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
@@ -90,6 +105,8 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
90
105
|
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
|
|
91
106
|
|
|
92
107
|
print(f'training loss: {loss.item()}')
|
|
108
|
+
wandb.log(dict(loss = loss.item()))
|
|
109
|
+
|
|
93
110
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
|
94
111
|
optim.step()
|
|
95
112
|
optim.zero_grad()
|
|
@@ -98,7 +115,9 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
98
115
|
model.eval()
|
|
99
116
|
with torch.no_grad():
|
|
100
117
|
loss = model(next(val_loader))
|
|
118
|
+
|
|
101
119
|
print(f'validation loss: {loss.item()}')
|
|
120
|
+
wandb.log(dict(valid_loss = loss.item()))
|
|
102
121
|
|
|
103
122
|
if i % GENERATE_EVERY == 0:
|
|
104
123
|
model.eval()
|
|
@@ -161,6 +161,21 @@ def or_reduce(masks):
|
|
|
161
161
|
head = head | rest
|
|
162
162
|
return head
|
|
163
163
|
|
|
164
|
+
def orthog_project(x, y):
|
|
165
|
+
x, packed_shape = pack([x], 'b *')
|
|
166
|
+
y, _ = pack([y], 'b *')
|
|
167
|
+
|
|
168
|
+
dtype = x.dtype
|
|
169
|
+
x, y = x.double(), y.double()
|
|
170
|
+
unit = F.normalize(y, dim = -1)
|
|
171
|
+
|
|
172
|
+
parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
|
|
173
|
+
orthog = x - parallel
|
|
174
|
+
|
|
175
|
+
orthog, = unpack(orthog, packed_shape, 'b *')
|
|
176
|
+
|
|
177
|
+
return orthog.to(dtype)
|
|
178
|
+
|
|
164
179
|
# cache helpers
|
|
165
180
|
|
|
166
181
|
def get_cached_kvs(
|
|
@@ -1381,7 +1396,8 @@ class Attention(Module):
|
|
|
1381
1396
|
softclamp_logits = False,
|
|
1382
1397
|
logit_softclamp_value = 50.,
|
|
1383
1398
|
learned_value_residual_mix = False,
|
|
1384
|
-
|
|
1399
|
+
orthog_projected_values = False, # https://openreview.net/forum?id=Ard2QzPAUK
|
|
1400
|
+
laser = False, # https://arxiv.org/abs/2411.03493v1
|
|
1385
1401
|
laser_softclamp_value = 15.,
|
|
1386
1402
|
qkv_receive_diff_residuals = False,
|
|
1387
1403
|
use_latent_q = False,
|
|
@@ -1607,6 +1623,11 @@ class Attention(Module):
|
|
|
1607
1623
|
|
|
1608
1624
|
self.attn_on_attn = on_attn
|
|
1609
1625
|
|
|
1626
|
+
# return orthogonal projected weighted values on original values
|
|
1627
|
+
# "belief attention" - iclr 2026
|
|
1628
|
+
|
|
1629
|
+
self.orthog_projected_values = orthog_projected_values
|
|
1630
|
+
|
|
1610
1631
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
|
1611
1632
|
|
|
1612
1633
|
hybrid_mix = None
|
|
@@ -2048,6 +2069,12 @@ class Attention(Module):
|
|
|
2048
2069
|
gates = self.to_v_gate(x)
|
|
2049
2070
|
out = out * self.to_v_gate_activation(gates)
|
|
2050
2071
|
|
|
2072
|
+
# maybe return orthogonal projected - "belief" attention
|
|
2073
|
+
|
|
2074
|
+
if self.orthog_projected_values:
|
|
2075
|
+
merged_v = self.merge_heads(orig_values)
|
|
2076
|
+
out = orthog_project(out, merged_v)
|
|
2077
|
+
|
|
2051
2078
|
# combine the heads
|
|
2052
2079
|
|
|
2053
2080
|
out = self.to_out(out)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/nonautoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.19 → x_transformers-2.11.20}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|