x-transformers 2.1.32__tar.gz → 2.1.35__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.32 → x_transformers-2.1.35}/PKG-INFO +1 -1
  2. {x_transformers-2.1.32 → x_transformers-2.1.35}/pyproject.toml +1 -1
  3. {x_transformers-2.1.32 → x_transformers-2.1.35}/tests/test_x_transformers.py +1 -0
  4. {x_transformers-2.1.32 → x_transformers-2.1.35}/train_belief_state.py +3 -5
  5. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/belief_state_wrapper.py +51 -21
  6. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/x_transformers.py +1 -1
  7. {x_transformers-2.1.32 → x_transformers-2.1.35}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.1.32 → x_transformers-2.1.35}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.1.32 → x_transformers-2.1.35}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.1.32 → x_transformers-2.1.35}/.gitignore +0 -0
  11. {x_transformers-2.1.32 → x_transformers-2.1.35}/LICENSE +0 -0
  12. {x_transformers-2.1.32 → x_transformers-2.1.35}/README.md +0 -0
  13. {x_transformers-2.1.32 → x_transformers-2.1.35}/data/README.md +0 -0
  14. {x_transformers-2.1.32 → x_transformers-2.1.35}/data/enwik8.gz +0 -0
  15. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/all-attention.png +0 -0
  16. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/attention-on-attention.png +0 -0
  17. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/cosine-sim-attention.png +0 -0
  18. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/deepnorm.png +0 -0
  19. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/dynamic-pos-bias-linear.png +0 -0
  20. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/dynamic-pos-bias-log.png +0 -0
  21. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  22. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/dynamic-pos-bias.png +0 -0
  23. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/enhanced-recurrence.png +0 -0
  24. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/fcm.png +0 -0
  25. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/ffglu.png +0 -0
  26. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/flash-attention.png +0 -0
  27. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/gate_values.png +0 -0
  28. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/gating.png +0 -0
  29. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/length-extrapolation-scale.png +0 -0
  30. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/macaron-1.png +0 -0
  31. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/macaron-2.png +0 -0
  32. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/memory-transformer.png +0 -0
  33. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/normformer.png +0 -0
  34. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/pia.png +0 -0
  35. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/qknorm-analysis.png +0 -0
  36. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/resi_dual.png +0 -0
  37. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/residual_attn.png +0 -0
  38. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/rezero.png +0 -0
  39. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/rotary.png +0 -0
  40. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/sandwich-2.png +0 -0
  41. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/sandwich.png +0 -0
  42. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/sandwich_norm.png +0 -0
  43. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/scalenorm.png +0 -0
  44. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/talking-heads.png +0 -0
  45. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/topk-attention.png +0 -0
  46. {x_transformers-2.1.32 → x_transformers-2.1.35}/images/xval.png +0 -0
  47. {x_transformers-2.1.32 → x_transformers-2.1.35}/train_copy.py +0 -0
  48. {x_transformers-2.1.32 → x_transformers-2.1.35}/train_enwik8.py +0 -0
  49. {x_transformers-2.1.32 → x_transformers-2.1.35}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.1.32 → x_transformers-2.1.35}/train_parity.py +0 -0
  51. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/continuous.py +0 -0
  55. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/multi_input.py +0 -0
  57. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/neo_mlp.py +0 -0
  58. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/nonautoregressive_wrapper.py +0 -0
  59. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.32 → x_transformers-2.1.35}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.32
3
+ Version: 2.1.35
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.32"
3
+ version = "2.1.35"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -743,6 +743,7 @@ def test_belief_state_wrapper(
743
743
  lens = torch.randint(4, 16, (2,))
744
744
 
745
745
  loss = model(seq, lens = lens) # backwards happen automatically
746
+ loss.backward()
746
747
 
747
748
  suffix = None
748
749
  if goal_suffix:
@@ -106,10 +106,8 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
106
106
  model.train()
107
107
 
108
108
  for __ in range(GRADIENT_ACCUMULATE_EVERY):
109
- loss = model(
110
- next(train_loader),
111
- loss_scale = 1./ GRADIENT_ACCUMULATE_EVERY
112
- )
109
+ loss = model(next(train_loader))
110
+ (loss / GRADIENT_ACCUMULATE_EVERY).backward()
113
111
 
114
112
  print(f'training loss: {loss.item()}')
115
113
  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
@@ -119,7 +117,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
119
117
  if i % VALIDATE_EVERY == 0:
120
118
  model.eval()
121
119
  with torch.no_grad():
122
- loss = model(next(val_loader), return_loss_only = True)
120
+ loss = model(next(val_loader))
123
121
  print(f'validation loss: {loss.item()}')
124
122
 
125
123
  if i % GENERATE_EVERY == 0:
@@ -5,6 +5,7 @@
5
5
  # https://www.youtube.com/watch?v=aqhbRtB2Fyg
6
6
 
7
7
  from __future__ import annotations
8
+ from random import random
8
9
 
9
10
  import torch
10
11
  from torch.autograd import Function
@@ -53,6 +54,26 @@ def flip(x, dim = 1, lens = None):
53
54
 
54
55
  return x.gather(dim, flip_indices)
55
56
 
57
+ # detach multiple tensors and backward the gradients once
58
+
59
+ class DetachMultiple(Function):
60
+
61
+ @classmethod
62
+ def forward(self, ctx, *tensors):
63
+ detached_tensors = tuple(t.detach() for t in tensors)
64
+
65
+ for detached_tensor in detached_tensors:
66
+ detached_tensor.requires_grad_()
67
+
68
+ return detached_tensors
69
+
70
+ @classmethod
71
+ def backward(self, ctx, *grads):
72
+
73
+ return grads
74
+
75
+ detach_multiple = DetachMultiple.apply
76
+
56
77
  # wrappers
57
78
 
58
79
  class BeliefStateWrapper(Module):
@@ -69,6 +90,8 @@ class BeliefStateWrapper(Module):
69
90
  backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
70
91
  pred_distance = False,
71
92
  pred_distance_loss_weight: float = 1.,
93
+ cond_on_distance = False,
94
+ cond_on_distance_prob = 0.5,
72
95
  max_pred_distance = None
73
96
  ):
74
97
  super().__init__()
@@ -111,6 +134,21 @@ class BeliefStateWrapper(Module):
111
134
 
112
135
  self.pred_distance_loss_weight = pred_distance_loss_weight
113
136
 
137
+ # conditioning on distance
138
+
139
+ assert 0. < cond_on_distance_prob < 1.
140
+
141
+ self.cond_on_distance = cond_on_distance
142
+ self.cond_on_distance_prob = cond_on_distance_prob
143
+
144
+ if cond_on_distance:
145
+ self.to_distance_cond = nn.Sequential(
146
+ Rearrange('... -> ... 1'),
147
+ nn.Linear(1, dim),
148
+ nn.LeakyReLU(),
149
+ nn.Linear(dim, dim * 2),
150
+ )
151
+
114
152
  # the two decoders, one which is causal forward, the other causal backwards
115
153
 
116
154
  self.forward_decoder = forward_decoder
@@ -250,8 +288,6 @@ class BeliefStateWrapper(Module):
250
288
  self,
251
289
  seq,
252
290
  lens: Tensor | None = None, # Int['b']
253
- return_loss_only = False,
254
- loss_scale = 1.,
255
291
  loss_weight_by_fb_indices: callable | None = None
256
292
  ):
257
293
  batch, seq_len, device = *seq.shape, seq.device
@@ -284,11 +320,7 @@ class BeliefStateWrapper(Module):
284
320
 
285
321
  # trick to reduce memory on backwards pass
286
322
 
287
- orig_forward_embeds, forward_embeds = forward_embeds, forward_embeds.detach()
288
- orig_backward_embeds, backward_embeds = backward_embeds, backward_embeds.detach()
289
-
290
- forward_embeds.requires_grad_()
291
- backward_embeds.requires_grad_()
323
+ forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds)
292
324
 
293
325
  # belief state objective
294
326
 
@@ -344,9 +376,19 @@ class BeliefStateWrapper(Module):
344
376
  ignore_index = -1
345
377
  )
346
378
 
347
- # maybe predict terminal
379
+ # maybe condition on distance
380
+
381
+ cond_on_distance = self.cond_on_distance and (random() < self.cond_on_distance_prob)
382
+
383
+ if cond_on_distance:
384
+ distance = (bi - fi).float()
385
+ distance_cond = self.to_distance_cond(distance)
386
+
387
+ fb_embeds = fb_embeds * distance_cond
348
388
 
349
- if exists(self.to_distance_logits):
389
+ # maybe predict distance
390
+
391
+ if exists(self.to_distance_logits) and not cond_on_distance:
350
392
  distance_logits = self.to_distance_logits(fb_embeds)
351
393
 
352
394
  distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1)
@@ -362,11 +404,6 @@ class BeliefStateWrapper(Module):
362
404
  pred_dist_loss * self.pred_distance_loss_weight
363
405
  )
364
406
 
365
- # maybe early return loss
366
-
367
- if return_loss_only:
368
- return loss
369
-
370
407
  # maybe loss weighting
371
408
 
372
409
  needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
@@ -392,11 +429,4 @@ class BeliefStateWrapper(Module):
392
429
 
393
430
  loss = loss.mean()
394
431
 
395
- # backwards
396
-
397
- (loss * loss_scale).backward()
398
-
399
- orig_forward_embeds.backward(forward_embeds.grad)
400
- orig_backward_embeds.backward(backward_embeds.grad)
401
-
402
432
  return loss
@@ -2055,7 +2055,7 @@ class AttentionLayers(Module):
2055
2055
  elif use_simple_rmsnorm:
2056
2056
  norm_class = SimpleRMSNorm
2057
2057
  elif use_dynamic_tanh:
2058
- assert pre_norm, 'only tested for pre-norm'
2058
+ assert pre_norm, 'dynamic tanh norm only tested for pre-norm'
2059
2059
  norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2060
2060
  elif use_adaptive_layernorm:
2061
2061
  norm_need_condition = True
File without changes