x-transformers 2.11.19__tar.gz → 2.11.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.19 → x_transformers-2.11.22}/PKG-INFO +8 -7
- {x_transformers-2.11.19 → x_transformers-2.11.22}/README.md +7 -6
- {x_transformers-2.11.19 → x_transformers-2.11.22}/pyproject.toml +1 -1
- {x_transformers-2.11.19 → x_transformers-2.11.22}/tests/test_x_transformers.py +22 -8
- {x_transformers-2.11.19 → x_transformers-2.11.22}/train_enwik8.py +22 -1
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/x_transformers.py +45 -1
- {x_transformers-2.11.19 → x_transformers-2.11.22}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/.gitignore +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/LICENSE +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/data/README.md +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/data/enwik8.gz +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/all-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/deepnorm.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/fcm.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/ffglu.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/flash-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/gate_values.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/gating.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/macaron-1.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/macaron-2.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/normformer.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/pia.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/resi_dual.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/residual_attn.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/rezero.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/rotary.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/sandwich.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/scalenorm.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/talking-heads.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/topk-attention.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/images/xval.png +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/train_belief_state.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/train_copy.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/train_free.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/train_parity.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/train_with_muon.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/free_transformer.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/xval.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-transformers
|
|
3
|
-
Version: 2.11.
|
|
3
|
+
Version: 2.11.22
|
|
4
4
|
Summary: X-Transformers
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
|
@@ -2608,12 +2608,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2608
2608
|
```
|
|
2609
2609
|
|
|
2610
2610
|
```bibtex
|
|
2611
|
-
@
|
|
2612
|
-
title = {
|
|
2613
|
-
author = {
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2611
|
+
@inproceedings{anonymous2025beliefformer,
|
|
2612
|
+
title = {BeliefFormer: Belief Attention in Transformer},
|
|
2613
|
+
author = {Anonymous},
|
|
2614
|
+
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
|
|
2615
|
+
year = {2025},
|
|
2616
|
+
url = {https://openreview.net/forum?id=Ard2QzPAUK},
|
|
2617
|
+
note = {under review}
|
|
2617
2618
|
}
|
|
2618
2619
|
```
|
|
2619
2620
|
|
|
@@ -2559,12 +2559,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2559
2559
|
```
|
|
2560
2560
|
|
|
2561
2561
|
```bibtex
|
|
2562
|
-
@
|
|
2563
|
-
title = {
|
|
2564
|
-
author = {
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2562
|
+
@inproceedings{anonymous2025beliefformer,
|
|
2563
|
+
title = {BeliefFormer: Belief Attention in Transformer},
|
|
2564
|
+
author = {Anonymous},
|
|
2565
|
+
booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
|
|
2566
|
+
year = {2025},
|
|
2567
|
+
url = {https://openreview.net/forum?id=Ard2QzPAUK},
|
|
2568
|
+
note = {under review}
|
|
2568
2569
|
}
|
|
2569
2570
|
```
|
|
2570
2571
|
|
|
@@ -1463,13 +1463,27 @@ def test_kv_input_residual():
|
|
|
1463
1463
|
|
|
1464
1464
|
assert tokens.shape == out.shape
|
|
1465
1465
|
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1466
|
+
@param('orthog_project', (False, True))
|
|
1467
|
+
@param('orthog_project_per_head', (False, True))
|
|
1468
|
+
def test_belief_attn(
|
|
1469
|
+
orthog_project,
|
|
1470
|
+
orthog_project_per_head
|
|
1471
|
+
):
|
|
1472
|
+
from x_transformers import TransformerWrapper, Decoder
|
|
1473
|
+
|
|
1474
|
+
model = TransformerWrapper(
|
|
1475
|
+
num_tokens = 256,
|
|
1476
|
+
max_seq_len = 1024,
|
|
1477
|
+
attn_layers = Decoder(
|
|
1478
|
+
dim = 512,
|
|
1479
|
+
depth = 6,
|
|
1480
|
+
heads = 8,
|
|
1481
|
+
rotary_pos_emb = True,
|
|
1482
|
+
attn_orthog_projected_values = orthog_project,
|
|
1483
|
+
attn_orthog_projected_values_per_head = orthog_project_per_head
|
|
1484
|
+
)
|
|
1472
1485
|
)
|
|
1473
1486
|
|
|
1474
|
-
|
|
1475
|
-
|
|
1487
|
+
x = torch.randint(0, 256, (1, 10))
|
|
1488
|
+
|
|
1489
|
+
logits = model(x)
|
|
@@ -1,3 +1,11 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "tqdm",
|
|
4
|
+
# "x-transformers",
|
|
5
|
+
# "wandb"
|
|
6
|
+
# ]
|
|
7
|
+
# ///
|
|
8
|
+
|
|
1
9
|
from x_transformers import TransformerWrapper, Decoder
|
|
2
10
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
|
3
11
|
|
|
@@ -20,6 +28,7 @@ VALIDATE_EVERY = 100
|
|
|
20
28
|
GENERATE_EVERY = 500
|
|
21
29
|
GENERATE_LENGTH = 1024
|
|
22
30
|
SEQ_LEN = 1024
|
|
31
|
+
TRACK_EXPERIMENT_ONLINE = False
|
|
23
32
|
|
|
24
33
|
# helpers
|
|
25
34
|
|
|
@@ -43,7 +52,9 @@ model = TransformerWrapper(
|
|
|
43
52
|
dim = 512,
|
|
44
53
|
depth = 6,
|
|
45
54
|
heads = 8,
|
|
46
|
-
rotary_pos_emb = True
|
|
55
|
+
rotary_pos_emb = True,
|
|
56
|
+
attn_orthog_projected_values = True,
|
|
57
|
+
attn_orthog_projected_values_per_head = True
|
|
47
58
|
)
|
|
48
59
|
)
|
|
49
60
|
|
|
@@ -80,6 +91,12 @@ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last
|
|
|
80
91
|
|
|
81
92
|
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
|
82
93
|
|
|
94
|
+
# experiment
|
|
95
|
+
|
|
96
|
+
import wandb
|
|
97
|
+
wandb.init(project = 'enwik8', mode = 'online' if TRACK_EXPERIMENT_ONLINE else 'disabled')
|
|
98
|
+
wandb.run.name = 'baseline'
|
|
99
|
+
|
|
83
100
|
# training
|
|
84
101
|
|
|
85
102
|
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
@@ -90,6 +107,8 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
90
107
|
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
|
|
91
108
|
|
|
92
109
|
print(f'training loss: {loss.item()}')
|
|
110
|
+
wandb.log(dict(loss = loss.item()))
|
|
111
|
+
|
|
93
112
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
|
94
113
|
optim.step()
|
|
95
114
|
optim.zero_grad()
|
|
@@ -98,7 +117,9 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
98
117
|
model.eval()
|
|
99
118
|
with torch.no_grad():
|
|
100
119
|
loss = model(next(val_loader))
|
|
120
|
+
|
|
101
121
|
print(f'validation loss: {loss.item()}')
|
|
122
|
+
wandb.log(dict(valid_loss = loss.item()))
|
|
102
123
|
|
|
103
124
|
if i % GENERATE_EVERY == 0:
|
|
104
125
|
model.eval()
|
|
@@ -161,6 +161,21 @@ def or_reduce(masks):
|
|
|
161
161
|
head = head | rest
|
|
162
162
|
return head
|
|
163
163
|
|
|
164
|
+
def orthog_project(x, y):
|
|
165
|
+
x, packed_shape = pack([x], 'b *')
|
|
166
|
+
y, _ = pack([y], 'b *')
|
|
167
|
+
|
|
168
|
+
dtype = x.dtype
|
|
169
|
+
x, y = x.double(), y.double()
|
|
170
|
+
unit = F.normalize(y, dim = -1)
|
|
171
|
+
|
|
172
|
+
parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
|
|
173
|
+
orthog = x - parallel
|
|
174
|
+
|
|
175
|
+
orthog, = unpack(orthog, packed_shape, 'b *')
|
|
176
|
+
|
|
177
|
+
return orthog.to(dtype)
|
|
178
|
+
|
|
164
179
|
# cache helpers
|
|
165
180
|
|
|
166
181
|
def get_cached_kvs(
|
|
@@ -1381,7 +1396,9 @@ class Attention(Module):
|
|
|
1381
1396
|
softclamp_logits = False,
|
|
1382
1397
|
logit_softclamp_value = 50.,
|
|
1383
1398
|
learned_value_residual_mix = False,
|
|
1384
|
-
|
|
1399
|
+
orthog_projected_values = False, # https://openreview.net/forum?id=Ard2QzPAUK
|
|
1400
|
+
orthog_projected_values_per_head = False,
|
|
1401
|
+
laser = False, # https://arxiv.org/abs/2411.03493v1
|
|
1385
1402
|
laser_softclamp_value = 15.,
|
|
1386
1403
|
qkv_receive_diff_residuals = False,
|
|
1387
1404
|
use_latent_q = False,
|
|
@@ -1607,6 +1624,14 @@ class Attention(Module):
|
|
|
1607
1624
|
|
|
1608
1625
|
self.attn_on_attn = on_attn
|
|
1609
1626
|
|
|
1627
|
+
# return orthogonal projected weighted values on original values
|
|
1628
|
+
# "belief attention" - iclr 2026
|
|
1629
|
+
|
|
1630
|
+
self.orthog_projected_values = orthog_projected_values
|
|
1631
|
+
self.orthog_projected_values_per_head = orthog_projected_values_per_head
|
|
1632
|
+
|
|
1633
|
+
out_dim *= max(1, int(orthog_projected_values) + int(orthog_projected_values_per_head))
|
|
1634
|
+
|
|
1610
1635
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
|
1611
1636
|
|
|
1612
1637
|
hybrid_mix = None
|
|
@@ -2048,6 +2073,25 @@ class Attention(Module):
|
|
|
2048
2073
|
gates = self.to_v_gate(x)
|
|
2049
2074
|
out = out * self.to_v_gate_activation(gates)
|
|
2050
2075
|
|
|
2076
|
+
# maybe orthogonal projected weighted values - "belief" attention
|
|
2077
|
+
|
|
2078
|
+
if self.orthog_projected_values or self.orthog_projected_values_per_head:
|
|
2079
|
+
orthog_projected = []
|
|
2080
|
+
v_for_proj = self.merge_heads(orig_values)
|
|
2081
|
+
|
|
2082
|
+
if self.orthog_projected_values:
|
|
2083
|
+
projected = orthog_project(out, v_for_proj)
|
|
2084
|
+
orthog_projected.append(projected)
|
|
2085
|
+
|
|
2086
|
+
if self.orthog_projected_values_per_head:
|
|
2087
|
+
v_for_proj = rearrange(v_for_proj, 'b n (h d) -> b n h d', h = h)
|
|
2088
|
+
out = rearrange(out, 'b n (h d) -> b n h d', h = h)
|
|
2089
|
+
projected = orthog_project(out, v_for_proj)
|
|
2090
|
+
projected = rearrange(projected, 'b n h d -> b n (h d)')
|
|
2091
|
+
orthog_projected.append(projected)
|
|
2092
|
+
|
|
2093
|
+
out = cat(orthog_projected, dim = -1)
|
|
2094
|
+
|
|
2051
2095
|
# combine the heads
|
|
2052
2096
|
|
|
2053
2097
|
out = self.to_out(out)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/nonautoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|