x-transformers 2.6.7__tar.gz → 2.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.6.7 → x_transformers-2.7.0}/PKG-INFO +17 -6
  2. {x_transformers-2.6.7 → x_transformers-2.7.0}/README.md +16 -5
  3. {x_transformers-2.6.7 → x_transformers-2.7.0}/pyproject.toml +1 -1
  4. {x_transformers-2.6.7 → x_transformers-2.7.0}/tests/test_x_transformers.py +27 -0
  5. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/nonautoregressive_wrapper.py +49 -12
  6. {x_transformers-2.6.7 → x_transformers-2.7.0}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.6.7 → x_transformers-2.7.0}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.6.7 → x_transformers-2.7.0}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.6.7 → x_transformers-2.7.0}/.gitignore +0 -0
  10. {x_transformers-2.6.7 → x_transformers-2.7.0}/LICENSE +0 -0
  11. {x_transformers-2.6.7 → x_transformers-2.7.0}/data/README.md +0 -0
  12. {x_transformers-2.6.7 → x_transformers-2.7.0}/data/enwik8.gz +0 -0
  13. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/all-attention.png +0 -0
  14. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/deepnorm.png +0 -0
  17. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/fcm.png +0 -0
  23. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/ffglu.png +0 -0
  24. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/flash-attention.png +0 -0
  25. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/gate_values.png +0 -0
  26. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/gating.png +0 -0
  27. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/macaron-1.png +0 -0
  29. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/macaron-2.png +0 -0
  30. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/normformer.png +0 -0
  32. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/pia.png +0 -0
  33. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/resi_dual.png +0 -0
  35. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/residual_attn.png +0 -0
  36. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/rezero.png +0 -0
  37. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/rotary.png +0 -0
  38. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/sandwich.png +0 -0
  40. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/scalenorm.png +0 -0
  42. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/talking-heads.png +0 -0
  43. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/topk-attention.png +0 -0
  44. {x_transformers-2.6.7 → x_transformers-2.7.0}/images/xval.png +0 -0
  45. {x_transformers-2.6.7 → x_transformers-2.7.0}/train_belief_state.py +0 -0
  46. {x_transformers-2.6.7 → x_transformers-2.7.0}/train_copy.py +0 -0
  47. {x_transformers-2.6.7 → x_transformers-2.7.0}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.6.7 → x_transformers-2.7.0}/train_enwik8.py +0 -0
  49. {x_transformers-2.6.7 → x_transformers-2.7.0}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.6.7 → x_transformers-2.7.0}/train_parity.py +0 -0
  51. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/up_wrapper.py +0 -0
  61. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/x_transformers.py +0 -0
  62. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.6.7 → x_transformers-2.7.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.7
3
+ Version: 2.7.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2509,11 +2509,22 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2509
2509
 
2510
2510
  ```bibtex
2511
2511
  @misc{openai_gpt_oss,
2512
- author = {OpenAI},
2513
- title = {Introducing gpt-oss},
2514
- howpublished = {https://openai.com/index/introducing-gpt-oss},
2515
- month = {August},
2516
- year = {2025}
2512
+ author = {OpenAI},
2513
+ title = {Introducing gpt-oss},
2514
+ howpublished = {https://openai.com/index/introducing-gpt-oss},
2515
+ month = {August},
2516
+ year = {2025}
2517
+ }
2518
+ ```
2519
+
2520
+ ```bibtex
2521
+ @article{Sahoo2024SimpleAE,
2522
+ title = {Simple and Effective Masked Diffusion Language Models},
2523
+ author = {Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
2524
+ journal = {ArXiv},
2525
+ year = {2024},
2526
+ volume = {abs/2406.07524},
2527
+ url = {https://api.semanticscholar.org/CorpusID:270380319}
2517
2528
  }
2518
2529
  ```
2519
2530
 
@@ -2461,11 +2461,22 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2461
2461
 
2462
2462
  ```bibtex
2463
2463
  @misc{openai_gpt_oss,
2464
- author = {OpenAI},
2465
- title = {Introducing gpt-oss},
2466
- howpublished = {https://openai.com/index/introducing-gpt-oss},
2467
- month = {August},
2468
- year = {2025}
2464
+ author = {OpenAI},
2465
+ title = {Introducing gpt-oss},
2466
+ howpublished = {https://openai.com/index/introducing-gpt-oss},
2467
+ month = {August},
2468
+ year = {2025}
2469
+ }
2470
+ ```
2471
+
2472
+ ```bibtex
2473
+ @article{Sahoo2024SimpleAE,
2474
+ title = {Simple and Effective Masked Diffusion Language Models},
2475
+ author = {Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
2476
+ journal = {ArXiv},
2477
+ year = {2024},
2478
+ volume = {abs/2406.07524},
2479
+ url = {https://api.semanticscholar.org/CorpusID:270380319}
2469
2480
  }
2470
2481
  ```
2471
2482
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.6.7"
3
+ version = "2.7.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1287,3 +1287,30 @@ def test_accept_layer_intermediates():
1287
1287
  )
1288
1288
 
1289
1289
  assert embeds.shape == (3, 32, 512)
1290
+
1291
+ @pytest.mark.parametrize('use_loss_weight', (False, True))
1292
+ def test_simple_mdlm(
1293
+ use_loss_weight
1294
+ ):
1295
+ from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
1296
+
1297
+ model = TransformerWrapper(
1298
+ num_tokens = 256 + 1,
1299
+ max_seq_len = 1024,
1300
+ attn_layers = Encoder(
1301
+ dim = 512,
1302
+ depth = 4,
1303
+ rotary_pos_emb = True
1304
+ )
1305
+ )
1306
+
1307
+ nar = NonAutoregressiveWrapper(
1308
+ model,
1309
+ mask_id = 256,
1310
+ use_simple_mdlm_loss_weight = use_loss_weight
1311
+ )
1312
+
1313
+ seq = torch.randint(0, 256, (1, 1024))
1314
+
1315
+ loss = nar(seq)
1316
+ loss.loss.backward()
@@ -1,16 +1,20 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  from random import random
3
5
  from contextlib import nullcontext
4
6
  from collections import namedtuple
5
7
 
6
8
  import torch
9
+ from torch import nn, pi
10
+ from torch.nn import Module
11
+ from torch.func import grad_and_value, vmap
7
12
  import torch.nn.functional as F
8
- from torch import nn
9
13
 
14
+ import einx
10
15
  from einops import rearrange, repeat, pack, unpack
11
16
 
12
17
  from x_transformers.x_transformers import TransformerWrapper
13
- from typing import Optional
14
18
 
15
19
  # constants
16
20
 
@@ -75,12 +79,12 @@ def linear_schedule(t):
75
79
 
76
80
  def cosine_schedule(t):
77
81
  """ https://arxiv.org/abs/2202.04200 """
78
- return torch.cos(t * math.pi / 2)
82
+ return torch.cos(t * pi / 2)
79
83
 
80
84
  # self token critic
81
85
  # inspired by Nijkamp et al. - https://aclanthology.org/2021.naacl-main.409/
82
86
 
83
- class SelfCritic(nn.Module):
87
+ class SelfCritic(Module):
84
88
  def __init__(self, net):
85
89
  super().__init__()
86
90
  self.net = net
@@ -92,7 +96,7 @@ class SelfCritic(nn.Module):
92
96
  embed = self.net(x, return_embeddings = True)
93
97
  return self.to_logits(embed)
94
98
 
95
- class NonAutoregressiveWrapper(nn.Module):
99
+ class NonAutoregressiveWrapper(Module):
96
100
  """
97
101
  https://arxiv.org/abs/1904.09324
98
102
  https://arxiv.org/abs/2202.04200
@@ -110,9 +114,10 @@ class NonAutoregressiveWrapper(nn.Module):
110
114
  random_token_prob = 0.1, # which percentage of tokens to be replaced with random token, done in original MLM paper
111
115
  schedule = 'linear',
112
116
  can_mask_prev_unmasked = False, # when unmasking, whether it can remask previously unmasked
113
- token_critic: Optional[TransformerWrapper] = None,
117
+ token_critic: TransformerWrapper | None = None,
114
118
  self_token_critic = False,
115
- critic_loss_weight = 1.
119
+ critic_loss_weight = 1.,
120
+ use_simple_mdlm_loss_weight = True # Sahoo et al. https://arxiv.org/abs/2406.07524
116
121
  ):
117
122
  super().__init__()
118
123
  assert not (self_token_critic and exists(token_critic))
@@ -143,6 +148,23 @@ class NonAutoregressiveWrapper(nn.Module):
143
148
  else:
144
149
  raise ValueError(f'invalid schedule {schedule}')
145
150
 
151
+ # whether to use the loss weighting proposed in simple diffusion lm paper
152
+
153
+ self.loss_weight_fn = None
154
+
155
+ if use_simple_mdlm_loss_weight:
156
+ grad_and_value_schedule_fn = vmap(grad_and_value(self.schedule_fn))
157
+
158
+ # eq (10)
159
+
160
+ def loss_weight_fn(times):
161
+ grad, value = grad_and_value_schedule_fn(times)
162
+ return grad / (1. - value)
163
+
164
+ self.loss_weight_fn = loss_weight_fn
165
+
166
+ # whether to mask previous - in the simple mdlm paper, they chose not to
167
+
146
168
  self.can_mask_prev_unmasked = can_mask_prev_unmasked
147
169
 
148
170
  # self conditioning
@@ -311,12 +333,27 @@ class NonAutoregressiveWrapper(nn.Module):
311
333
 
312
334
  loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
313
335
 
314
- # cross entropy loss
336
+ # loss
315
337
 
316
- loss = loss_fn(
317
- logits[mask],
318
- orig_seq[mask]
319
- )
338
+ if exists(self.loss_weight_fn):
339
+ # using simple mdlm loss weighting
340
+
341
+ loss = loss_fn(
342
+ rearrange(logits, 'b n l -> b l n'),
343
+ orig_seq,
344
+ reduction = 'none'
345
+ )
346
+
347
+ loss_weights = self.loss_weight_fn(rand_times) # calculate loss weight
348
+ loss = einx.multiply('b n, b', loss, loss_weights) # apply loss weights
349
+
350
+ loss = loss[mask].mean()
351
+
352
+ else:
353
+ loss = loss_fn(
354
+ logits[mask],
355
+ orig_seq[mask],
356
+ )
320
357
 
321
358
  if not exists(self.token_critic) or only_train_generator:
322
359
  return Losses(loss, loss, None)
File without changes