x-transformers 2.6.6__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -23,7 +23,7 @@ class Intermediates:
23
23
  pre_softmax_attn: Tensor | None = None
24
24
  post_softmax_attn: Tensor | None = None
25
25
  values: Tensor | None = None
26
- cached_kv: Tuple[Tensor, Tensor] | None = None
26
+ cached_kv: tuple[Tensor, Tensor] | None = None
27
27
  layer_type: str | None = None
28
28
  hybrid_hidden: Tensor | None = None
29
29
 
@@ -1,16 +1,20 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  from random import random
3
5
  from contextlib import nullcontext
4
6
  from collections import namedtuple
5
7
 
6
8
  import torch
9
+ from torch import nn, pi
10
+ from torch.nn import Module
11
+ from torch.func import grad_and_value, vmap
7
12
  import torch.nn.functional as F
8
- from torch import nn
9
13
 
14
+ import einx
10
15
  from einops import rearrange, repeat, pack, unpack
11
16
 
12
17
  from x_transformers.x_transformers import TransformerWrapper
13
- from typing import Optional
14
18
 
15
19
  # constants
16
20
 
@@ -75,12 +79,12 @@ def linear_schedule(t):
75
79
 
76
80
  def cosine_schedule(t):
77
81
  """ https://arxiv.org/abs/2202.04200 """
78
- return torch.cos(t * math.pi / 2)
82
+ return torch.cos(t * pi / 2)
79
83
 
80
84
  # self token critic
81
85
  # inspired by Nijkamp et al. - https://aclanthology.org/2021.naacl-main.409/
82
86
 
83
- class SelfCritic(nn.Module):
87
+ class SelfCritic(Module):
84
88
  def __init__(self, net):
85
89
  super().__init__()
86
90
  self.net = net
@@ -92,7 +96,7 @@ class SelfCritic(nn.Module):
92
96
  embed = self.net(x, return_embeddings = True)
93
97
  return self.to_logits(embed)
94
98
 
95
- class NonAutoregressiveWrapper(nn.Module):
99
+ class NonAutoregressiveWrapper(Module):
96
100
  """
97
101
  https://arxiv.org/abs/1904.09324
98
102
  https://arxiv.org/abs/2202.04200
@@ -110,9 +114,10 @@ class NonAutoregressiveWrapper(nn.Module):
110
114
  random_token_prob = 0.1, # which percentage of tokens to be replaced with random token, done in original MLM paper
111
115
  schedule = 'linear',
112
116
  can_mask_prev_unmasked = False, # when unmasking, whether it can remask previously unmasked
113
- token_critic: Optional[TransformerWrapper] = None,
117
+ token_critic: TransformerWrapper | None = None,
114
118
  self_token_critic = False,
115
- critic_loss_weight = 1.
119
+ critic_loss_weight = 1.,
120
+ use_simple_mdlm_loss_weight = True # Sahoo et al. https://arxiv.org/abs/2406.07524
116
121
  ):
117
122
  super().__init__()
118
123
  assert not (self_token_critic and exists(token_critic))
@@ -143,6 +148,23 @@ class NonAutoregressiveWrapper(nn.Module):
143
148
  else:
144
149
  raise ValueError(f'invalid schedule {schedule}')
145
150
 
151
+ # whether to use the loss weighting proposed in simple diffusion lm paper
152
+
153
+ self.loss_weight_fn = None
154
+
155
+ if use_simple_mdlm_loss_weight:
156
+ grad_and_value_schedule_fn = vmap(grad_and_value(self.schedule_fn))
157
+
158
+ # eq (10)
159
+
160
+ def loss_weight_fn(times):
161
+ grad, value = grad_and_value_schedule_fn(times)
162
+ return grad / (1. - value)
163
+
164
+ self.loss_weight_fn = loss_weight_fn
165
+
166
+ # whether to mask previous - in the simple mdlm paper, they chose not to
167
+
146
168
  self.can_mask_prev_unmasked = can_mask_prev_unmasked
147
169
 
148
170
  # self conditioning
@@ -311,12 +333,27 @@ class NonAutoregressiveWrapper(nn.Module):
311
333
 
312
334
  loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
313
335
 
314
- # cross entropy loss
336
+ # loss
315
337
 
316
- loss = loss_fn(
317
- logits[mask],
318
- orig_seq[mask]
319
- )
338
+ if exists(self.loss_weight_fn):
339
+ # using simple mdlm loss weighting
340
+
341
+ loss = loss_fn(
342
+ rearrange(logits, 'b n l -> b l n'),
343
+ orig_seq,
344
+ reduction = 'none'
345
+ )
346
+
347
+ loss_weights = self.loss_weight_fn(rand_times) # calculate loss weight
348
+ loss = einx.multiply('b n, b', loss, loss_weights) # apply loss weights
349
+
350
+ loss = loss[mask].mean()
351
+
352
+ else:
353
+ loss = loss_fn(
354
+ logits[mask],
355
+ orig_seq[mask],
356
+ )
320
357
 
321
358
  if not exists(self.token_critic) or only_train_generator:
322
359
  return Losses(loss, loss, None)
@@ -10,7 +10,7 @@ import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
12
  from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
13
- from torch.utils._pytree import tree_flatten, tree_unflatten
13
+ from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
16
16
  from functools import partial, wraps
@@ -81,6 +81,9 @@ def cast_tuple(val, depth = 1):
81
81
  def divisible_by(num, den):
82
82
  return (num % den) == 0
83
83
 
84
+ def detach_all(obj):
85
+ return tree_map(lambda t: t.detach() if is_tensor(t) and t.requires_grad else t, obj)
86
+
84
87
  def maybe(fn = None):
85
88
  if not exists(fn):
86
89
  fn = identity
@@ -157,6 +160,19 @@ def or_reduce(masks):
157
160
  head = head | rest
158
161
  return head
159
162
 
163
+ # cache helpers
164
+
165
+ def get_cached_kvs(
166
+ cache: LayerIntermediates
167
+ ) -> list[tuple[Tensor, Tensor]]:
168
+
169
+ cached_kvs = []
170
+
171
+ for attn_intermediate in cache.attn_intermediates:
172
+ cached_kvs.append(attn_intermediate.cached_kv)
173
+
174
+ return cached_kvs
175
+
160
176
  # entropy
161
177
 
162
178
  def calc_entropy(
@@ -2441,8 +2457,13 @@ class AttentionLayers(Module):
2441
2457
  context_pos = None,
2442
2458
  attn_bias = None,
2443
2459
  deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2444
- self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2460
+ self_attn_additional_kv: (
2461
+ LayerIntermediates |
2462
+ list[tuple[Tensor, Tensor]]
2463
+ | None
2464
+ ) = None,
2445
2465
  additional_kv_mask = None,
2466
+ detach_additional_kv = False,
2446
2467
  route_additional_kv_to_top = True,
2447
2468
  condition = None,
2448
2469
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
@@ -2590,6 +2611,13 @@ class AttentionLayers(Module):
2590
2611
  # additional self attn key / values - say coming from vlm
2591
2612
 
2592
2613
  if exists(self_attn_additional_kv) and route_additional_kv_to_top:
2614
+
2615
+ if isinstance(self_attn_additional_kv, LayerIntermediates):
2616
+ self_attn_additional_kv = get_cached_kvs(self_attn_additional_kv)
2617
+
2618
+ if detach_additional_kv:
2619
+ self_attn_additional_kv = detach_all(self_attn_additional_kv)
2620
+
2593
2621
  num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
2594
2622
 
2595
2623
  self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.6
3
+ Version: 2.7.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2509,11 +2509,22 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2509
2509
 
2510
2510
  ```bibtex
2511
2511
  @misc{openai_gpt_oss,
2512
- author = {OpenAI},
2513
- title = {Introducing gpt-oss},
2514
- howpublished = {https://openai.com/index/introducing-gpt-oss},
2515
- month = {August},
2516
- year = {2025}
2512
+ author = {OpenAI},
2513
+ title = {Introducing gpt-oss},
2514
+ howpublished = {https://openai.com/index/introducing-gpt-oss},
2515
+ month = {August},
2516
+ year = {2025}
2517
+ }
2518
+ ```
2519
+
2520
+ ```bibtex
2521
+ @article{Sahoo2024SimpleAE,
2522
+ title = {Simple and Effective Masked Diffusion Language Models},
2523
+ author = {Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
2524
+ journal = {ArXiv},
2525
+ year = {2024},
2526
+ volume = {abs/2406.07524},
2527
+ url = {https://api.semanticscholar.org/CorpusID:270380319}
2517
2528
  }
2518
2529
  ```
2519
2530
 
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
2
- x_transformers/attend.py,sha256=JJv6ypJbZIFmH1LQ49hFg6hD0Wf9Z7Im1AP2ekm9hVI,18091
2
+ x_transformers/attend.py,sha256=jzOwrtCIdAt1dRQBO68htDsgtjdTx6TAQQVB2xflS1w,18091
3
3
  x_transformers/autoregressive_wrapper.py,sha256=BsGO9xfVYkvynqbU1__tu_S_cxl7gss0YwnkhIa2baY,18401
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
@@ -7,12 +7,12 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
- x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
+ x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=vjRMEMA12Js94YwLVeZksYMEoRgK6CSKT6TJViMPp7U,122186
12
+ x_transformers/x_transformers.py,sha256=txdFN5266Tu-lQVMgyICMWt8azslAkxG5YL4n9tOUIo,122944
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.6.6.dist-info/METADATA,sha256=95CKrJ98X7R0hpb5D8GHSfi372UtxXDSeDaO2qB0Lrs,90445
16
- x_transformers-2.6.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.6.6.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.6.6.dist-info/RECORD,,
15
+ x_transformers-2.7.0.dist-info/METADATA,sha256=HFH-y2lnS8T-KZkv27z7hBGECYipDSkgtXj9LJbLMHo,90888
16
+ x_transformers-2.7.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.7.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.7.0.dist-info/RECORD,,