x-transformers 2.6.7__tar.gz → 2.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.6.7 → x_transformers-2.7.1}/PKG-INFO +17 -6
- {x_transformers-2.6.7 → x_transformers-2.7.1}/README.md +16 -5
- {x_transformers-2.6.7 → x_transformers-2.7.1}/pyproject.toml +1 -1
- {x_transformers-2.6.7 → x_transformers-2.7.1}/tests/test_x_transformers.py +27 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/continuous.py +10 -1
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/nonautoregressive_wrapper.py +49 -12
- {x_transformers-2.6.7 → x_transformers-2.7.1}/.github/FUNDING.yml +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/.gitignore +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/LICENSE +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/data/README.md +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/data/enwik8.gz +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/all-attention.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/attention-on-attention.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/deepnorm.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/fcm.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/ffglu.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/flash-attention.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/gate_values.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/gating.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/macaron-1.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/macaron-2.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/memory-transformer.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/normformer.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/pia.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/resi_dual.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/residual_attn.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/rezero.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/rotary.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/sandwich-2.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/sandwich.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/sandwich_norm.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/scalenorm.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/talking-heads.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/topk-attention.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/images/xval.png +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/train_belief_state.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/train_copy.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/train_enwik8.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/train_length_extrapolate.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/train_parity.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/__init__.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/attend.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/dpo.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.7 → x_transformers-2.7.1}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.7.1
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2509,11 +2509,22 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2509
2509
|
|
2510
2510
|
```bibtex
|
2511
2511
|
@misc{openai_gpt_oss,
|
2512
|
-
|
2513
|
-
|
2514
|
-
|
2515
|
-
|
2516
|
-
|
2512
|
+
author = {OpenAI},
|
2513
|
+
title = {Introducing gpt-oss},
|
2514
|
+
howpublished = {https://openai.com/index/introducing-gpt-oss},
|
2515
|
+
month = {August},
|
2516
|
+
year = {2025}
|
2517
|
+
}
|
2518
|
+
```
|
2519
|
+
|
2520
|
+
```bibtex
|
2521
|
+
@article{Sahoo2024SimpleAE,
|
2522
|
+
title = {Simple and Effective Masked Diffusion Language Models},
|
2523
|
+
author = {Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
|
2524
|
+
journal = {ArXiv},
|
2525
|
+
year = {2024},
|
2526
|
+
volume = {abs/2406.07524},
|
2527
|
+
url = {https://api.semanticscholar.org/CorpusID:270380319}
|
2517
2528
|
}
|
2518
2529
|
```
|
2519
2530
|
|
@@ -2461,11 +2461,22 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2461
2461
|
|
2462
2462
|
```bibtex
|
2463
2463
|
@misc{openai_gpt_oss,
|
2464
|
-
|
2465
|
-
|
2466
|
-
|
2467
|
-
|
2468
|
-
|
2464
|
+
author = {OpenAI},
|
2465
|
+
title = {Introducing gpt-oss},
|
2466
|
+
howpublished = {https://openai.com/index/introducing-gpt-oss},
|
2467
|
+
month = {August},
|
2468
|
+
year = {2025}
|
2469
|
+
}
|
2470
|
+
```
|
2471
|
+
|
2472
|
+
```bibtex
|
2473
|
+
@article{Sahoo2024SimpleAE,
|
2474
|
+
title = {Simple and Effective Masked Diffusion Language Models},
|
2475
|
+
author = {Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
|
2476
|
+
journal = {ArXiv},
|
2477
|
+
year = {2024},
|
2478
|
+
volume = {abs/2406.07524},
|
2479
|
+
url = {https://api.semanticscholar.org/CorpusID:270380319}
|
2469
2480
|
}
|
2470
2481
|
```
|
2471
2482
|
|
@@ -1287,3 +1287,30 @@ def test_accept_layer_intermediates():
|
|
1287
1287
|
)
|
1288
1288
|
|
1289
1289
|
assert embeds.shape == (3, 32, 512)
|
1290
|
+
|
1291
|
+
@pytest.mark.parametrize('use_loss_weight', (False, True))
|
1292
|
+
def test_simple_mdlm(
|
1293
|
+
use_loss_weight
|
1294
|
+
):
|
1295
|
+
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
|
1296
|
+
|
1297
|
+
model = TransformerWrapper(
|
1298
|
+
num_tokens = 256 + 1,
|
1299
|
+
max_seq_len = 1024,
|
1300
|
+
attn_layers = Encoder(
|
1301
|
+
dim = 512,
|
1302
|
+
depth = 4,
|
1303
|
+
rotary_pos_emb = True
|
1304
|
+
)
|
1305
|
+
)
|
1306
|
+
|
1307
|
+
nar = NonAutoregressiveWrapper(
|
1308
|
+
model,
|
1309
|
+
mask_id = 256,
|
1310
|
+
use_simple_mdlm_loss_weight = use_loss_weight
|
1311
|
+
)
|
1312
|
+
|
1313
|
+
seq = torch.randint(0, 256, (1, 1024))
|
1314
|
+
|
1315
|
+
loss = nar(seq)
|
1316
|
+
loss.loss.backward()
|
@@ -241,6 +241,7 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
241
241
|
self,
|
242
242
|
net: ContinuousTransformerWrapper,
|
243
243
|
loss_fn: Module | None = None,
|
244
|
+
use_l1_loss = False,
|
244
245
|
equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
245
246
|
):
|
246
247
|
super().__init__()
|
@@ -250,7 +251,15 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
250
251
|
probabilistic = net.probabilistic
|
251
252
|
self.probabilistic = probabilistic
|
252
253
|
|
253
|
-
|
254
|
+
# default loss function
|
255
|
+
|
256
|
+
if not exists(loss_fn):
|
257
|
+
if probabilistic:
|
258
|
+
loss_fn = GaussianNLL()
|
259
|
+
elif use_l1_loss:
|
260
|
+
loss_fn = nn.L1Loss(reduction = 'none')
|
261
|
+
else:
|
262
|
+
loss_fn = nn.MSELoss(reduction = 'none')
|
254
263
|
|
255
264
|
self.loss_fn = loss_fn
|
256
265
|
self.equal_loss_weight_batch = equal_loss_weight_batch
|
@@ -1,16 +1,20 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import math
|
2
4
|
from random import random
|
3
5
|
from contextlib import nullcontext
|
4
6
|
from collections import namedtuple
|
5
7
|
|
6
8
|
import torch
|
9
|
+
from torch import nn, pi
|
10
|
+
from torch.nn import Module
|
11
|
+
from torch.func import grad_and_value, vmap
|
7
12
|
import torch.nn.functional as F
|
8
|
-
from torch import nn
|
9
13
|
|
14
|
+
import einx
|
10
15
|
from einops import rearrange, repeat, pack, unpack
|
11
16
|
|
12
17
|
from x_transformers.x_transformers import TransformerWrapper
|
13
|
-
from typing import Optional
|
14
18
|
|
15
19
|
# constants
|
16
20
|
|
@@ -75,12 +79,12 @@ def linear_schedule(t):
|
|
75
79
|
|
76
80
|
def cosine_schedule(t):
|
77
81
|
""" https://arxiv.org/abs/2202.04200 """
|
78
|
-
return torch.cos(t *
|
82
|
+
return torch.cos(t * pi / 2)
|
79
83
|
|
80
84
|
# self token critic
|
81
85
|
# inspired by Nijkamp et al. - https://aclanthology.org/2021.naacl-main.409/
|
82
86
|
|
83
|
-
class SelfCritic(
|
87
|
+
class SelfCritic(Module):
|
84
88
|
def __init__(self, net):
|
85
89
|
super().__init__()
|
86
90
|
self.net = net
|
@@ -92,7 +96,7 @@ class SelfCritic(nn.Module):
|
|
92
96
|
embed = self.net(x, return_embeddings = True)
|
93
97
|
return self.to_logits(embed)
|
94
98
|
|
95
|
-
class NonAutoregressiveWrapper(
|
99
|
+
class NonAutoregressiveWrapper(Module):
|
96
100
|
"""
|
97
101
|
https://arxiv.org/abs/1904.09324
|
98
102
|
https://arxiv.org/abs/2202.04200
|
@@ -110,9 +114,10 @@ class NonAutoregressiveWrapper(nn.Module):
|
|
110
114
|
random_token_prob = 0.1, # which percentage of tokens to be replaced with random token, done in original MLM paper
|
111
115
|
schedule = 'linear',
|
112
116
|
can_mask_prev_unmasked = False, # when unmasking, whether it can remask previously unmasked
|
113
|
-
token_critic:
|
117
|
+
token_critic: TransformerWrapper | None = None,
|
114
118
|
self_token_critic = False,
|
115
|
-
critic_loss_weight = 1
|
119
|
+
critic_loss_weight = 1.,
|
120
|
+
use_simple_mdlm_loss_weight = True # Sahoo et al. https://arxiv.org/abs/2406.07524
|
116
121
|
):
|
117
122
|
super().__init__()
|
118
123
|
assert not (self_token_critic and exists(token_critic))
|
@@ -143,6 +148,23 @@ class NonAutoregressiveWrapper(nn.Module):
|
|
143
148
|
else:
|
144
149
|
raise ValueError(f'invalid schedule {schedule}')
|
145
150
|
|
151
|
+
# whether to use the loss weighting proposed in simple diffusion lm paper
|
152
|
+
|
153
|
+
self.loss_weight_fn = None
|
154
|
+
|
155
|
+
if use_simple_mdlm_loss_weight:
|
156
|
+
grad_and_value_schedule_fn = vmap(grad_and_value(self.schedule_fn))
|
157
|
+
|
158
|
+
# eq (10)
|
159
|
+
|
160
|
+
def loss_weight_fn(times):
|
161
|
+
grad, value = grad_and_value_schedule_fn(times)
|
162
|
+
return grad / (1. - value)
|
163
|
+
|
164
|
+
self.loss_weight_fn = loss_weight_fn
|
165
|
+
|
166
|
+
# whether to mask previous - in the simple mdlm paper, they chose not to
|
167
|
+
|
146
168
|
self.can_mask_prev_unmasked = can_mask_prev_unmasked
|
147
169
|
|
148
170
|
# self conditioning
|
@@ -311,12 +333,27 @@ class NonAutoregressiveWrapper(nn.Module):
|
|
311
333
|
|
312
334
|
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
|
313
335
|
|
314
|
-
#
|
336
|
+
# loss
|
315
337
|
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
338
|
+
if exists(self.loss_weight_fn):
|
339
|
+
# using simple mdlm loss weighting
|
340
|
+
|
341
|
+
loss = loss_fn(
|
342
|
+
rearrange(logits, 'b n l -> b l n'),
|
343
|
+
orig_seq,
|
344
|
+
reduction = 'none'
|
345
|
+
)
|
346
|
+
|
347
|
+
loss_weights = self.loss_weight_fn(rand_times) # calculate loss weight
|
348
|
+
loss = einx.multiply('b n, b', loss, loss_weights) # apply loss weights
|
349
|
+
|
350
|
+
loss = loss[mask].mean()
|
351
|
+
|
352
|
+
else:
|
353
|
+
loss = loss_fn(
|
354
|
+
logits[mask],
|
355
|
+
orig_seq[mask],
|
356
|
+
)
|
320
357
|
|
321
358
|
if not exists(self.token_critic) or only_train_generator:
|
322
359
|
return Losses(loss, loss, None)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|