x-transformers 2.1.4__tar.gz → 2.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {x_transformers-2.1.4 → x_transformers-2.1.6}/PKG-INFO +1 -1
  2. {x_transformers-2.1.4 → x_transformers-2.1.6}/pyproject.toml +1 -1
  3. {x_transformers-2.1.4 → x_transformers-2.1.6}/tests/test_x_transformers.py +5 -3
  4. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/__init__.py +1 -0
  5. x_transformers-2.1.4/x_transformers/belief_state.py → x_transformers-2.1.6/x_transformers/belief_state_wrapper.py +54 -8
  6. {x_transformers-2.1.4 → x_transformers-2.1.6}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.4 → x_transformers-2.1.6}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.4 → x_transformers-2.1.6}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.4 → x_transformers-2.1.6}/.gitignore +0 -0
  10. {x_transformers-2.1.4 → x_transformers-2.1.6}/LICENSE +0 -0
  11. {x_transformers-2.1.4 → x_transformers-2.1.6}/README.md +0 -0
  12. {x_transformers-2.1.4 → x_transformers-2.1.6}/data/README.md +0 -0
  13. {x_transformers-2.1.4 → x_transformers-2.1.6}/data/enwik8.gz +0 -0
  14. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/all-attention.png +0 -0
  15. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/deepnorm.png +0 -0
  18. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/fcm.png +0 -0
  24. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/ffglu.png +0 -0
  25. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/flash-attention.png +0 -0
  26. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/gate_values.png +0 -0
  27. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/gating.png +0 -0
  28. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/macaron-1.png +0 -0
  30. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/macaron-2.png +0 -0
  31. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/normformer.png +0 -0
  33. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/pia.png +0 -0
  34. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/resi_dual.png +0 -0
  36. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/residual_attn.png +0 -0
  37. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/rezero.png +0 -0
  38. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/rotary.png +0 -0
  39. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/sandwich.png +0 -0
  41. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/scalenorm.png +0 -0
  43. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/talking-heads.png +0 -0
  44. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/topk-attention.png +0 -0
  45. {x_transformers-2.1.4 → x_transformers-2.1.6}/images/xval.png +0 -0
  46. {x_transformers-2.1.4 → x_transformers-2.1.6}/train_copy.py +0 -0
  47. {x_transformers-2.1.4 → x_transformers-2.1.6}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.4 → x_transformers-2.1.6}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.4 → x_transformers-2.1.6}/train_parity.py +0 -0
  50. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/attend.py +0 -0
  51. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/autoregressive_wrapper.py +0 -0
  52. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/continuous.py +0 -0
  53. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/dpo.py +0 -0
  54. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/multi_input.py +0 -0
  55. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/neo_mlp.py +0 -0
  56. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/nonautoregressive_wrapper.py +0 -0
  57. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/x_transformers.py +0 -0
  58. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  59. {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.4
3
+ Version: 2.1.6
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.4"
3
+ version = "2.1.6"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -695,7 +695,7 @@ def test_lime(
695
695
  model(x)
696
696
 
697
697
  def test_belief_state_wrapper():
698
- from x_transformers.belief_state import BeliefStateWrapper
698
+ from x_transformers.belief_state_wrapper import BeliefStateWrapper
699
699
 
700
700
  forward_model = TransformerWrapper(
701
701
  num_tokens = 20000,
@@ -721,9 +721,11 @@ def test_belief_state_wrapper():
721
721
 
722
722
  model = BeliefStateWrapper(
723
723
  forward_decoder = forward_model,
724
- backward_decoder = backward_model
724
+ backward_decoder = backward_model,
725
+ backward_ar_loss_weight = 0.5
725
726
  )
726
727
 
727
728
  seq = torch.randint(0, 20000, (2, 16))
728
729
 
729
- loss = model(seq)
730
+ loss = model(seq, backward = False)
731
+ loss.backward()
@@ -14,6 +14,7 @@ from x_transformers.x_transformers import (
14
14
 
15
15
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
16
16
  from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
17
+ from x_transformers.belief_state_wrapper import BeliefStateWrapper
17
18
 
18
19
  from x_transformers.continuous import (
19
20
  ContinuousTransformerWrapper,
@@ -6,7 +6,7 @@
6
6
  import torch
7
7
  from torch.autograd import Function
8
8
  from torch.nn import Module, ModuleList
9
- from torch import nn, cat, stack, arange, cartesian_prod
9
+ from torch import nn, cat, stack, tensor, arange, cartesian_prod
10
10
  import torch.nn.functional as F
11
11
 
12
12
  from x_transformers.x_transformers import (
@@ -35,11 +35,13 @@ class BeliefStateWrapper(Module):
35
35
  def __init__(
36
36
  self,
37
37
  forward_decoder: TransformerWrapper,
38
- backward_decoder: TransformerWrapper
38
+ backward_decoder: TransformerWrapper,
39
+ train_frac_forward_backward_pairs: float = 1.,
40
+ backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
39
41
  ):
40
42
  super().__init__()
41
43
  assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
42
- assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same embedding dimension'
44
+ assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens'
43
45
 
44
46
  dim = forward_decoder.emb_dim
45
47
  num_tokens = forward_decoder.num_tokens
@@ -47,6 +49,7 @@ class BeliefStateWrapper(Module):
47
49
  # the suffix token
48
50
 
49
51
  self.suffix_token = nn.Parameter(torch.zeros(dim))
52
+ nn.init.normal_(self.suffix_token, std = 0.02)
50
53
 
51
54
  # the text prediction head, which predicts for the combinations of prefix and suffix the next and previous token for forwards and backward sequences
52
55
 
@@ -61,9 +64,24 @@ class BeliefStateWrapper(Module):
61
64
  self.forward_decoder = forward_decoder
62
65
  self.backward_decoder = backward_decoder
63
66
 
67
+ # what fraction of forward backward pairs to train on
68
+ # for further memory efficiency
69
+
70
+ assert 0 < train_frac_forward_backward_pairs <= 1.
71
+ self.train_frac_fb_pairs = train_frac_forward_backward_pairs
72
+ self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
73
+
74
+ # loss weighting
75
+
76
+ self.backward_ar_loss_weight = backward_ar_loss_weight
77
+ self.needs_loss_weight = backward_ar_loss_weight != 1.
78
+
79
+ self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
80
+
64
81
  def forward(
65
82
  self,
66
- seq
83
+ seq,
84
+ backward = True
67
85
  ):
68
86
  batch, seq_len, device = *seq.shape, seq.device
69
87
 
@@ -107,6 +125,17 @@ class BeliefStateWrapper(Module):
107
125
 
108
126
  fb_pairs = fb_pairs[valid_mask]
109
127
 
128
+ # maybe subsample fb pairs
129
+
130
+ if self.needs_subsample_fb_pairs:
131
+ num_pairs = fb_pairs.shape[0]
132
+
133
+ num_subsampled = max(int(num_pairs * self.train_frac_fb_pairs), 1)
134
+
135
+ rand_subsampled_indices = torch.randperm(num_pairs, device = device)[:num_subsampled]
136
+
137
+ fb_pairs = fb_pairs[rand_subsampled_indices]
138
+
110
139
  # get labels for both
111
140
 
112
141
  fi, bi = fb_pairs.unbind(dim = -1)
@@ -129,14 +158,31 @@ class BeliefStateWrapper(Module):
129
158
 
130
159
  fb_loss = F.cross_entropy(
131
160
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
132
- rearrange(labels, 'b n fb -> b (fb n)')
161
+ rearrange(labels, 'b n fb -> b (fb n)'),
162
+ reduction = 'none' if self.needs_loss_weight else 'mean'
133
163
  )
134
164
 
165
+ # maybe loss weighting
166
+
167
+ if self.needs_loss_weight:
168
+ fb_loss = rearrange(fb_loss, 'b (fb n) -> b fb n')
169
+ fb_loss = fb_loss * self.fwd_bwd_loss_weights
170
+ fb_loss = fb_loss.mean()
171
+
135
172
  # backwards
136
173
 
137
- fb_loss.backward()
174
+ orig_backward = getattr(fb_loss, 'backward')
175
+
176
+ def patched_backward_fn(*args, **kwargs):
177
+ orig_backward(*args, **kwargs)
178
+ orig_forward_embeds.backward(forward_embeds.grad)
179
+ orig_backward_embeds.backward(backward_embeds.grad)
180
+
181
+ # can allow the researcher to call .backward from the outside
138
182
 
139
- orig_forward_embeds.backward(forward_embeds.grad)
140
- orig_backward_embeds.backward(backward_embeds.grad)
183
+ if backward:
184
+ patched_backward_fn()
185
+ else:
186
+ setattr(fb_loss, 'backward', patched_backward_fn)
141
187
 
142
188
  return fb_loss
File without changes
File without changes