x-transformers 2.1.5__tar.gz → 2.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {x_transformers-2.1.5 → x_transformers-2.1.6}/PKG-INFO +1 -1
  2. {x_transformers-2.1.5 → x_transformers-2.1.6}/pyproject.toml +1 -1
  3. {x_transformers-2.1.5 → x_transformers-2.1.6}/tests/test_x_transformers.py +5 -3
  4. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/__init__.py +1 -0
  5. x_transformers-2.1.5/x_transformers/belief_state.py → x_transformers-2.1.6/x_transformers/belief_state_wrapper.py +33 -7
  6. {x_transformers-2.1.5 → x_transformers-2.1.6}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.5 → x_transformers-2.1.6}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.5 → x_transformers-2.1.6}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.5 → x_transformers-2.1.6}/.gitignore +0 -0
  10. {x_transformers-2.1.5 → x_transformers-2.1.6}/LICENSE +0 -0
  11. {x_transformers-2.1.5 → x_transformers-2.1.6}/README.md +0 -0
  12. {x_transformers-2.1.5 → x_transformers-2.1.6}/data/README.md +0 -0
  13. {x_transformers-2.1.5 → x_transformers-2.1.6}/data/enwik8.gz +0 -0
  14. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/all-attention.png +0 -0
  15. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/deepnorm.png +0 -0
  18. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/fcm.png +0 -0
  24. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/ffglu.png +0 -0
  25. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/flash-attention.png +0 -0
  26. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/gate_values.png +0 -0
  27. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/gating.png +0 -0
  28. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/macaron-1.png +0 -0
  30. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/macaron-2.png +0 -0
  31. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/normformer.png +0 -0
  33. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/pia.png +0 -0
  34. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/resi_dual.png +0 -0
  36. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/residual_attn.png +0 -0
  37. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/rezero.png +0 -0
  38. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/rotary.png +0 -0
  39. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/sandwich.png +0 -0
  41. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/scalenorm.png +0 -0
  43. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/talking-heads.png +0 -0
  44. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/topk-attention.png +0 -0
  45. {x_transformers-2.1.5 → x_transformers-2.1.6}/images/xval.png +0 -0
  46. {x_transformers-2.1.5 → x_transformers-2.1.6}/train_copy.py +0 -0
  47. {x_transformers-2.1.5 → x_transformers-2.1.6}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.5 → x_transformers-2.1.6}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.5 → x_transformers-2.1.6}/train_parity.py +0 -0
  50. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/attend.py +0 -0
  51. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/autoregressive_wrapper.py +0 -0
  52. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/continuous.py +0 -0
  53. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/dpo.py +0 -0
  54. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/multi_input.py +0 -0
  55. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/neo_mlp.py +0 -0
  56. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/nonautoregressive_wrapper.py +0 -0
  57. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/x_transformers.py +0 -0
  58. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  59. {x_transformers-2.1.5 → x_transformers-2.1.6}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.5
3
+ Version: 2.1.6
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.5"
3
+ version = "2.1.6"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -695,7 +695,7 @@ def test_lime(
695
695
  model(x)
696
696
 
697
697
  def test_belief_state_wrapper():
698
- from x_transformers.belief_state import BeliefStateWrapper
698
+ from x_transformers.belief_state_wrapper import BeliefStateWrapper
699
699
 
700
700
  forward_model = TransformerWrapper(
701
701
  num_tokens = 20000,
@@ -721,9 +721,11 @@ def test_belief_state_wrapper():
721
721
 
722
722
  model = BeliefStateWrapper(
723
723
  forward_decoder = forward_model,
724
- backward_decoder = backward_model
724
+ backward_decoder = backward_model,
725
+ backward_ar_loss_weight = 0.5
725
726
  )
726
727
 
727
728
  seq = torch.randint(0, 20000, (2, 16))
728
729
 
729
- loss = model(seq)
730
+ loss = model(seq, backward = False)
731
+ loss.backward()
@@ -14,6 +14,7 @@ from x_transformers.x_transformers import (
14
14
 
15
15
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
16
16
  from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
17
+ from x_transformers.belief_state_wrapper import BeliefStateWrapper
17
18
 
18
19
  from x_transformers.continuous import (
19
20
  ContinuousTransformerWrapper,
@@ -6,7 +6,7 @@
6
6
  import torch
7
7
  from torch.autograd import Function
8
8
  from torch.nn import Module, ModuleList
9
- from torch import nn, cat, stack, arange, cartesian_prod
9
+ from torch import nn, cat, stack, tensor, arange, cartesian_prod
10
10
  import torch.nn.functional as F
11
11
 
12
12
  from x_transformers.x_transformers import (
@@ -36,7 +36,8 @@ class BeliefStateWrapper(Module):
36
36
  self,
37
37
  forward_decoder: TransformerWrapper,
38
38
  backward_decoder: TransformerWrapper,
39
- train_frac_forward_backward_pairs: float = 1.
39
+ train_frac_forward_backward_pairs: float = 1.,
40
+ backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
40
41
  ):
41
42
  super().__init__()
42
43
  assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
@@ -70,9 +71,17 @@ class BeliefStateWrapper(Module):
70
71
  self.train_frac_fb_pairs = train_frac_forward_backward_pairs
71
72
  self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
72
73
 
74
+ # loss weighting
75
+
76
+ self.backward_ar_loss_weight = backward_ar_loss_weight
77
+ self.needs_loss_weight = backward_ar_loss_weight != 1.
78
+
79
+ self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
80
+
73
81
  def forward(
74
82
  self,
75
- seq
83
+ seq,
84
+ backward = True
76
85
  ):
77
86
  batch, seq_len, device = *seq.shape, seq.device
78
87
 
@@ -149,14 +158,31 @@ class BeliefStateWrapper(Module):
149
158
 
150
159
  fb_loss = F.cross_entropy(
151
160
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
152
- rearrange(labels, 'b n fb -> b (fb n)')
161
+ rearrange(labels, 'b n fb -> b (fb n)'),
162
+ reduction = 'none' if self.needs_loss_weight else 'mean'
153
163
  )
154
164
 
165
+ # maybe loss weighting
166
+
167
+ if self.needs_loss_weight:
168
+ fb_loss = rearrange(fb_loss, 'b (fb n) -> b fb n')
169
+ fb_loss = fb_loss * self.fwd_bwd_loss_weights
170
+ fb_loss = fb_loss.mean()
171
+
155
172
  # backwards
156
173
 
157
- fb_loss.backward()
174
+ orig_backward = getattr(fb_loss, 'backward')
175
+
176
+ def patched_backward_fn(*args, **kwargs):
177
+ orig_backward(*args, **kwargs)
178
+ orig_forward_embeds.backward(forward_embeds.grad)
179
+ orig_backward_embeds.backward(backward_embeds.grad)
180
+
181
+ # can allow the researcher to call .backward from the outside
158
182
 
159
- orig_forward_embeds.backward(forward_embeds.grad)
160
- orig_backward_embeds.backward(backward_embeds.grad)
183
+ if backward:
184
+ patched_backward_fn()
185
+ else:
186
+ setattr(fb_loss, 'backward', patched_backward_fn)
161
187
 
162
188
  return fb_loss
File without changes
File without changes