x-transformers 2.1.34__tar.gz → 2.1.35__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.34 → x_transformers-2.1.35}/PKG-INFO +1 -1
  2. {x_transformers-2.1.34 → x_transformers-2.1.35}/pyproject.toml +1 -1
  3. {x_transformers-2.1.34 → x_transformers-2.1.35}/tests/test_x_transformers.py +1 -0
  4. {x_transformers-2.1.34 → x_transformers-2.1.35}/train_belief_state.py +3 -5
  5. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/belief_state_wrapper.py +21 -19
  6. {x_transformers-2.1.34 → x_transformers-2.1.35}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.34 → x_transformers-2.1.35}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.34 → x_transformers-2.1.35}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.34 → x_transformers-2.1.35}/.gitignore +0 -0
  10. {x_transformers-2.1.34 → x_transformers-2.1.35}/LICENSE +0 -0
  11. {x_transformers-2.1.34 → x_transformers-2.1.35}/README.md +0 -0
  12. {x_transformers-2.1.34 → x_transformers-2.1.35}/data/README.md +0 -0
  13. {x_transformers-2.1.34 → x_transformers-2.1.35}/data/enwik8.gz +0 -0
  14. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/all-attention.png +0 -0
  15. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/deepnorm.png +0 -0
  18. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/fcm.png +0 -0
  24. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/ffglu.png +0 -0
  25. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/flash-attention.png +0 -0
  26. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/gate_values.png +0 -0
  27. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/gating.png +0 -0
  28. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/macaron-1.png +0 -0
  30. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/macaron-2.png +0 -0
  31. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/normformer.png +0 -0
  33. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/pia.png +0 -0
  34. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/resi_dual.png +0 -0
  36. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/residual_attn.png +0 -0
  37. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/rezero.png +0 -0
  38. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/rotary.png +0 -0
  39. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/sandwich.png +0 -0
  41. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/scalenorm.png +0 -0
  43. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/talking-heads.png +0 -0
  44. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/topk-attention.png +0 -0
  45. {x_transformers-2.1.34 → x_transformers-2.1.35}/images/xval.png +0 -0
  46. {x_transformers-2.1.34 → x_transformers-2.1.35}/train_copy.py +0 -0
  47. {x_transformers-2.1.34 → x_transformers-2.1.35}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.34 → x_transformers-2.1.35}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.34 → x_transformers-2.1.35}/train_parity.py +0 -0
  50. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/continuous.py +0 -0
  54. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/dpo.py +0 -0
  55. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/multi_input.py +0 -0
  56. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/neo_mlp.py +0 -0
  57. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/nonautoregressive_wrapper.py +0 -0
  58. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/x_transformers.py +0 -0
  59. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.34 → x_transformers-2.1.35}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.34
3
+ Version: 2.1.35
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.34"
3
+ version = "2.1.35"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -743,6 +743,7 @@ def test_belief_state_wrapper(
743
743
  lens = torch.randint(4, 16, (2,))
744
744
 
745
745
  loss = model(seq, lens = lens) # backwards happen automatically
746
+ loss.backward()
746
747
 
747
748
  suffix = None
748
749
  if goal_suffix:
@@ -106,10 +106,8 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
106
106
  model.train()
107
107
 
108
108
  for __ in range(GRADIENT_ACCUMULATE_EVERY):
109
- loss = model(
110
- next(train_loader),
111
- loss_scale = 1./ GRADIENT_ACCUMULATE_EVERY
112
- )
109
+ loss = model(next(train_loader))
110
+ (loss / GRADIENT_ACCUMULATE_EVERY).backward()
113
111
 
114
112
  print(f'training loss: {loss.item()}')
115
113
  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
@@ -119,7 +117,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
119
117
  if i % VALIDATE_EVERY == 0:
120
118
  model.eval()
121
119
  with torch.no_grad():
122
- loss = model(next(val_loader), return_loss_only = True)
120
+ loss = model(next(val_loader))
123
121
  print(f'validation loss: {loss.item()}')
124
122
 
125
123
  if i % GENERATE_EVERY == 0:
@@ -54,6 +54,26 @@ def flip(x, dim = 1, lens = None):
54
54
 
55
55
  return x.gather(dim, flip_indices)
56
56
 
57
+ # detach multiple tensors and backward the gradients once
58
+
59
+ class DetachMultiple(Function):
60
+
61
+ @classmethod
62
+ def forward(self, ctx, *tensors):
63
+ detached_tensors = tuple(t.detach() for t in tensors)
64
+
65
+ for detached_tensor in detached_tensors:
66
+ detached_tensor.requires_grad_()
67
+
68
+ return detached_tensors
69
+
70
+ @classmethod
71
+ def backward(self, ctx, *grads):
72
+
73
+ return grads
74
+
75
+ detach_multiple = DetachMultiple.apply
76
+
57
77
  # wrappers
58
78
 
59
79
  class BeliefStateWrapper(Module):
@@ -268,8 +288,6 @@ class BeliefStateWrapper(Module):
268
288
  self,
269
289
  seq,
270
290
  lens: Tensor | None = None, # Int['b']
271
- return_loss_only = False,
272
- loss_scale = 1.,
273
291
  loss_weight_by_fb_indices: callable | None = None
274
292
  ):
275
293
  batch, seq_len, device = *seq.shape, seq.device
@@ -302,11 +320,7 @@ class BeliefStateWrapper(Module):
302
320
 
303
321
  # trick to reduce memory on backwards pass
304
322
 
305
- orig_forward_embeds, forward_embeds = forward_embeds, forward_embeds.detach()
306
- orig_backward_embeds, backward_embeds = backward_embeds, backward_embeds.detach()
307
-
308
- forward_embeds.requires_grad_()
309
- backward_embeds.requires_grad_()
323
+ forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds)
310
324
 
311
325
  # belief state objective
312
326
 
@@ -390,11 +404,6 @@ class BeliefStateWrapper(Module):
390
404
  pred_dist_loss * self.pred_distance_loss_weight
391
405
  )
392
406
 
393
- # maybe early return loss
394
-
395
- if return_loss_only:
396
- return loss
397
-
398
407
  # maybe loss weighting
399
408
 
400
409
  needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
@@ -420,11 +429,4 @@ class BeliefStateWrapper(Module):
420
429
 
421
430
  loss = loss.mean()
422
431
 
423
- # backwards
424
-
425
- (loss * loss_scale).backward()
426
-
427
- orig_forward_embeds.backward(forward_embeds.grad)
428
- orig_backward_embeds.backward(backward_embeds.grad)
429
-
430
432
  return loss
File without changes