x-transformers 2.6.6__tar.gz → 2.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.6.6 → x_transformers-2.7.0}/PKG-INFO +17 -6
- {x_transformers-2.6.6 → x_transformers-2.7.0}/README.md +16 -5
- {x_transformers-2.6.6 → x_transformers-2.7.0}/pyproject.toml +1 -1
- {x_transformers-2.6.6 → x_transformers-2.7.0}/tests/test_x_transformers.py +62 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/attend.py +1 -1
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/nonautoregressive_wrapper.py +49 -12
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/x_transformers.py +30 -2
- {x_transformers-2.6.6 → x_transformers-2.7.0}/.github/FUNDING.yml +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/.gitignore +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/LICENSE +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/data/README.md +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/data/enwik8.gz +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/all-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/attention-on-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/deepnorm.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/fcm.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/ffglu.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/flash-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/gate_values.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/gating.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/macaron-1.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/macaron-2.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/memory-transformer.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/normformer.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/pia.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/resi_dual.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/residual_attn.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/rezero.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/rotary.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/sandwich-2.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/sandwich.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/sandwich_norm.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/scalenorm.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/talking-heads.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/topk-attention.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/images/xval.png +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/train_belief_state.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/train_copy.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/train_enwik8.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/train_length_extrapolate.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/train_parity.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/__init__.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/continuous.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/dpo.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.6 → x_transformers-2.7.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.7.0
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2509,11 +2509,22 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2509
2509
|
|
2510
2510
|
```bibtex
|
2511
2511
|
@misc{openai_gpt_oss,
|
2512
|
-
|
2513
|
-
|
2514
|
-
|
2515
|
-
|
2516
|
-
|
2512
|
+
author = {OpenAI},
|
2513
|
+
title = {Introducing gpt-oss},
|
2514
|
+
howpublished = {https://openai.com/index/introducing-gpt-oss},
|
2515
|
+
month = {August},
|
2516
|
+
year = {2025}
|
2517
|
+
}
|
2518
|
+
```
|
2519
|
+
|
2520
|
+
```bibtex
|
2521
|
+
@article{Sahoo2024SimpleAE,
|
2522
|
+
title = {Simple and Effective Masked Diffusion Language Models},
|
2523
|
+
author = {Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
|
2524
|
+
journal = {ArXiv},
|
2525
|
+
year = {2024},
|
2526
|
+
volume = {abs/2406.07524},
|
2527
|
+
url = {https://api.semanticscholar.org/CorpusID:270380319}
|
2517
2528
|
}
|
2518
2529
|
```
|
2519
2530
|
|
@@ -2461,11 +2461,22 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2461
2461
|
|
2462
2462
|
```bibtex
|
2463
2463
|
@misc{openai_gpt_oss,
|
2464
|
-
|
2465
|
-
|
2466
|
-
|
2467
|
-
|
2468
|
-
|
2464
|
+
author = {OpenAI},
|
2465
|
+
title = {Introducing gpt-oss},
|
2466
|
+
howpublished = {https://openai.com/index/introducing-gpt-oss},
|
2467
|
+
month = {August},
|
2468
|
+
year = {2025}
|
2469
|
+
}
|
2470
|
+
```
|
2471
|
+
|
2472
|
+
```bibtex
|
2473
|
+
@article{Sahoo2024SimpleAE,
|
2474
|
+
title = {Simple and Effective Masked Diffusion Language Models},
|
2475
|
+
author = {Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
|
2476
|
+
journal = {ArXiv},
|
2477
|
+
year = {2024},
|
2478
|
+
volume = {abs/2406.07524},
|
2479
|
+
url = {https://api.semanticscholar.org/CorpusID:270380319}
|
2469
2480
|
}
|
2470
2481
|
```
|
2471
2482
|
|
@@ -1252,3 +1252,65 @@ def test_learned_head_attn_sink():
|
|
1252
1252
|
seq = torch.randint(0, 20000, (3, 1024))
|
1253
1253
|
|
1254
1254
|
logits = model(seq)
|
1255
|
+
|
1256
|
+
def test_accept_layer_intermediates():
|
1257
|
+
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper
|
1258
|
+
|
1259
|
+
vlm = TransformerWrapper(
|
1260
|
+
num_tokens = 20000,
|
1261
|
+
max_seq_len = 1024,
|
1262
|
+
attn_layers = Decoder(
|
1263
|
+
dim = 512,
|
1264
|
+
depth = 3,
|
1265
|
+
heads = 4,
|
1266
|
+
)
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
seq = torch.randint(0, 20000, (3, 1024))
|
1270
|
+
mask = torch.randint(0, 2, (3, 1024)).bool()
|
1271
|
+
|
1272
|
+
_, intermediates = vlm(seq, return_intermediates = True)
|
1273
|
+
|
1274
|
+
action_model = Decoder(
|
1275
|
+
dim = 512,
|
1276
|
+
depth = 6,
|
1277
|
+
heads = 8,
|
1278
|
+
)
|
1279
|
+
|
1280
|
+
seq = torch.randn(3, 32, 512)
|
1281
|
+
|
1282
|
+
embeds = action_model(
|
1283
|
+
seq,
|
1284
|
+
self_attn_additional_kv = intermediates,
|
1285
|
+
detach_additional_kv = True,
|
1286
|
+
additional_kv_mask = mask
|
1287
|
+
)
|
1288
|
+
|
1289
|
+
assert embeds.shape == (3, 32, 512)
|
1290
|
+
|
1291
|
+
@pytest.mark.parametrize('use_loss_weight', (False, True))
|
1292
|
+
def test_simple_mdlm(
|
1293
|
+
use_loss_weight
|
1294
|
+
):
|
1295
|
+
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
|
1296
|
+
|
1297
|
+
model = TransformerWrapper(
|
1298
|
+
num_tokens = 256 + 1,
|
1299
|
+
max_seq_len = 1024,
|
1300
|
+
attn_layers = Encoder(
|
1301
|
+
dim = 512,
|
1302
|
+
depth = 4,
|
1303
|
+
rotary_pos_emb = True
|
1304
|
+
)
|
1305
|
+
)
|
1306
|
+
|
1307
|
+
nar = NonAutoregressiveWrapper(
|
1308
|
+
model,
|
1309
|
+
mask_id = 256,
|
1310
|
+
use_simple_mdlm_loss_weight = use_loss_weight
|
1311
|
+
)
|
1312
|
+
|
1313
|
+
seq = torch.randint(0, 256, (1, 1024))
|
1314
|
+
|
1315
|
+
loss = nar(seq)
|
1316
|
+
loss.loss.backward()
|
@@ -23,7 +23,7 @@ class Intermediates:
|
|
23
23
|
pre_softmax_attn: Tensor | None = None
|
24
24
|
post_softmax_attn: Tensor | None = None
|
25
25
|
values: Tensor | None = None
|
26
|
-
cached_kv:
|
26
|
+
cached_kv: tuple[Tensor, Tensor] | None = None
|
27
27
|
layer_type: str | None = None
|
28
28
|
hybrid_hidden: Tensor | None = None
|
29
29
|
|
@@ -1,16 +1,20 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import math
|
2
4
|
from random import random
|
3
5
|
from contextlib import nullcontext
|
4
6
|
from collections import namedtuple
|
5
7
|
|
6
8
|
import torch
|
9
|
+
from torch import nn, pi
|
10
|
+
from torch.nn import Module
|
11
|
+
from torch.func import grad_and_value, vmap
|
7
12
|
import torch.nn.functional as F
|
8
|
-
from torch import nn
|
9
13
|
|
14
|
+
import einx
|
10
15
|
from einops import rearrange, repeat, pack, unpack
|
11
16
|
|
12
17
|
from x_transformers.x_transformers import TransformerWrapper
|
13
|
-
from typing import Optional
|
14
18
|
|
15
19
|
# constants
|
16
20
|
|
@@ -75,12 +79,12 @@ def linear_schedule(t):
|
|
75
79
|
|
76
80
|
def cosine_schedule(t):
|
77
81
|
""" https://arxiv.org/abs/2202.04200 """
|
78
|
-
return torch.cos(t *
|
82
|
+
return torch.cos(t * pi / 2)
|
79
83
|
|
80
84
|
# self token critic
|
81
85
|
# inspired by Nijkamp et al. - https://aclanthology.org/2021.naacl-main.409/
|
82
86
|
|
83
|
-
class SelfCritic(
|
87
|
+
class SelfCritic(Module):
|
84
88
|
def __init__(self, net):
|
85
89
|
super().__init__()
|
86
90
|
self.net = net
|
@@ -92,7 +96,7 @@ class SelfCritic(nn.Module):
|
|
92
96
|
embed = self.net(x, return_embeddings = True)
|
93
97
|
return self.to_logits(embed)
|
94
98
|
|
95
|
-
class NonAutoregressiveWrapper(
|
99
|
+
class NonAutoregressiveWrapper(Module):
|
96
100
|
"""
|
97
101
|
https://arxiv.org/abs/1904.09324
|
98
102
|
https://arxiv.org/abs/2202.04200
|
@@ -110,9 +114,10 @@ class NonAutoregressiveWrapper(nn.Module):
|
|
110
114
|
random_token_prob = 0.1, # which percentage of tokens to be replaced with random token, done in original MLM paper
|
111
115
|
schedule = 'linear',
|
112
116
|
can_mask_prev_unmasked = False, # when unmasking, whether it can remask previously unmasked
|
113
|
-
token_critic:
|
117
|
+
token_critic: TransformerWrapper | None = None,
|
114
118
|
self_token_critic = False,
|
115
|
-
critic_loss_weight = 1
|
119
|
+
critic_loss_weight = 1.,
|
120
|
+
use_simple_mdlm_loss_weight = True # Sahoo et al. https://arxiv.org/abs/2406.07524
|
116
121
|
):
|
117
122
|
super().__init__()
|
118
123
|
assert not (self_token_critic and exists(token_critic))
|
@@ -143,6 +148,23 @@ class NonAutoregressiveWrapper(nn.Module):
|
|
143
148
|
else:
|
144
149
|
raise ValueError(f'invalid schedule {schedule}')
|
145
150
|
|
151
|
+
# whether to use the loss weighting proposed in simple diffusion lm paper
|
152
|
+
|
153
|
+
self.loss_weight_fn = None
|
154
|
+
|
155
|
+
if use_simple_mdlm_loss_weight:
|
156
|
+
grad_and_value_schedule_fn = vmap(grad_and_value(self.schedule_fn))
|
157
|
+
|
158
|
+
# eq (10)
|
159
|
+
|
160
|
+
def loss_weight_fn(times):
|
161
|
+
grad, value = grad_and_value_schedule_fn(times)
|
162
|
+
return grad / (1. - value)
|
163
|
+
|
164
|
+
self.loss_weight_fn = loss_weight_fn
|
165
|
+
|
166
|
+
# whether to mask previous - in the simple mdlm paper, they chose not to
|
167
|
+
|
146
168
|
self.can_mask_prev_unmasked = can_mask_prev_unmasked
|
147
169
|
|
148
170
|
# self conditioning
|
@@ -311,12 +333,27 @@ class NonAutoregressiveWrapper(nn.Module):
|
|
311
333
|
|
312
334
|
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
|
313
335
|
|
314
|
-
#
|
336
|
+
# loss
|
315
337
|
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
338
|
+
if exists(self.loss_weight_fn):
|
339
|
+
# using simple mdlm loss weighting
|
340
|
+
|
341
|
+
loss = loss_fn(
|
342
|
+
rearrange(logits, 'b n l -> b l n'),
|
343
|
+
orig_seq,
|
344
|
+
reduction = 'none'
|
345
|
+
)
|
346
|
+
|
347
|
+
loss_weights = self.loss_weight_fn(rand_times) # calculate loss weight
|
348
|
+
loss = einx.multiply('b n, b', loss, loss_weights) # apply loss weights
|
349
|
+
|
350
|
+
loss = loss[mask].mean()
|
351
|
+
|
352
|
+
else:
|
353
|
+
loss = loss_fn(
|
354
|
+
logits[mask],
|
355
|
+
orig_seq[mask],
|
356
|
+
)
|
320
357
|
|
321
358
|
if not exists(self.token_critic) or only_train_generator:
|
322
359
|
return Losses(loss, loss, None)
|
@@ -10,7 +10,7 @@ import torch
|
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
|
13
|
-
from torch.utils._pytree import tree_flatten, tree_unflatten
|
13
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
16
16
|
from functools import partial, wraps
|
@@ -81,6 +81,9 @@ def cast_tuple(val, depth = 1):
|
|
81
81
|
def divisible_by(num, den):
|
82
82
|
return (num % den) == 0
|
83
83
|
|
84
|
+
def detach_all(obj):
|
85
|
+
return tree_map(lambda t: t.detach() if is_tensor(t) and t.requires_grad else t, obj)
|
86
|
+
|
84
87
|
def maybe(fn = None):
|
85
88
|
if not exists(fn):
|
86
89
|
fn = identity
|
@@ -157,6 +160,19 @@ def or_reduce(masks):
|
|
157
160
|
head = head | rest
|
158
161
|
return head
|
159
162
|
|
163
|
+
# cache helpers
|
164
|
+
|
165
|
+
def get_cached_kvs(
|
166
|
+
cache: LayerIntermediates
|
167
|
+
) -> list[tuple[Tensor, Tensor]]:
|
168
|
+
|
169
|
+
cached_kvs = []
|
170
|
+
|
171
|
+
for attn_intermediate in cache.attn_intermediates:
|
172
|
+
cached_kvs.append(attn_intermediate.cached_kv)
|
173
|
+
|
174
|
+
return cached_kvs
|
175
|
+
|
160
176
|
# entropy
|
161
177
|
|
162
178
|
def calc_entropy(
|
@@ -2441,8 +2457,13 @@ class AttentionLayers(Module):
|
|
2441
2457
|
context_pos = None,
|
2442
2458
|
attn_bias = None,
|
2443
2459
|
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
|
2444
|
-
self_attn_additional_kv:
|
2460
|
+
self_attn_additional_kv: (
|
2461
|
+
LayerIntermediates |
|
2462
|
+
list[tuple[Tensor, Tensor]]
|
2463
|
+
| None
|
2464
|
+
) = None,
|
2445
2465
|
additional_kv_mask = None,
|
2466
|
+
detach_additional_kv = False,
|
2446
2467
|
route_additional_kv_to_top = True,
|
2447
2468
|
condition = None,
|
2448
2469
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
@@ -2590,6 +2611,13 @@ class AttentionLayers(Module):
|
|
2590
2611
|
# additional self attn key / values - say coming from vlm
|
2591
2612
|
|
2592
2613
|
if exists(self_attn_additional_kv) and route_additional_kv_to_top:
|
2614
|
+
|
2615
|
+
if isinstance(self_attn_additional_kv, LayerIntermediates):
|
2616
|
+
self_attn_additional_kv = get_cached_kvs(self_attn_additional_kv)
|
2617
|
+
|
2618
|
+
if detach_additional_kv:
|
2619
|
+
self_attn_additional_kv = detach_all(self_attn_additional_kv)
|
2620
|
+
|
2593
2621
|
num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
|
2594
2622
|
|
2595
2623
|
self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|