x-transformers 2.1.20__tar.gz → 2.1.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.20 → x_transformers-2.1.21}/PKG-INFO +1 -1
  2. {x_transformers-2.1.20 → x_transformers-2.1.21}/pyproject.toml +1 -1
  3. {x_transformers-2.1.20 → x_transformers-2.1.21}/train_belief_state.py +15 -0
  4. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/belief_state_wrapper.py +41 -25
  5. {x_transformers-2.1.20 → x_transformers-2.1.21}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.1.20 → x_transformers-2.1.21}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.1.20 → x_transformers-2.1.21}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.1.20 → x_transformers-2.1.21}/.gitignore +0 -0
  9. {x_transformers-2.1.20 → x_transformers-2.1.21}/LICENSE +0 -0
  10. {x_transformers-2.1.20 → x_transformers-2.1.21}/README.md +0 -0
  11. {x_transformers-2.1.20 → x_transformers-2.1.21}/data/README.md +0 -0
  12. {x_transformers-2.1.20 → x_transformers-2.1.21}/data/enwik8.gz +0 -0
  13. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/all-attention.png +0 -0
  14. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/deepnorm.png +0 -0
  17. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/fcm.png +0 -0
  23. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/ffglu.png +0 -0
  24. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/flash-attention.png +0 -0
  25. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/gate_values.png +0 -0
  26. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/gating.png +0 -0
  27. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/macaron-1.png +0 -0
  29. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/macaron-2.png +0 -0
  30. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/normformer.png +0 -0
  32. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/pia.png +0 -0
  33. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/resi_dual.png +0 -0
  35. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/residual_attn.png +0 -0
  36. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/rezero.png +0 -0
  37. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/rotary.png +0 -0
  38. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/sandwich.png +0 -0
  40. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/scalenorm.png +0 -0
  42. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/talking-heads.png +0 -0
  43. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/topk-attention.png +0 -0
  44. {x_transformers-2.1.20 → x_transformers-2.1.21}/images/xval.png +0 -0
  45. {x_transformers-2.1.20 → x_transformers-2.1.21}/tests/test_x_transformers.py +0 -0
  46. {x_transformers-2.1.20 → x_transformers-2.1.21}/train_copy.py +0 -0
  47. {x_transformers-2.1.20 → x_transformers-2.1.21}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.20 → x_transformers-2.1.21}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.20 → x_transformers-2.1.21}/train_parity.py +0 -0
  50. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/continuous.py +0 -0
  54. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/dpo.py +0 -0
  55. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/multi_input.py +0 -0
  56. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/neo_mlp.py +0 -0
  57. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/nonautoregressive_wrapper.py +0 -0
  58. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/x_transformers.py +0 -0
  59. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.20 → x_transformers-2.1.21}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.20
3
+ Version: 2.1.21
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.20"
3
+ version = "2.1.21"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -121,8 +121,11 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
121
121
  model.eval()
122
122
  inp = random.choice(val_dataset)[:-1]
123
123
  prime = decode_tokens(inp)
124
+
124
125
  print(f'%s \n\n %s', (prime, '*' * 100))
125
126
 
127
+ print('forwards:\n')
128
+
126
129
  sample = model.generate_with_suffix_cond(
127
130
  prompts = inp,
128
131
  seq_len = GENERATE_LENGTH,
@@ -131,3 +134,15 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
131
134
 
132
135
  output_str = decode_tokens(sample)
133
136
  print(output_str)
137
+
138
+ print('\nbackwards:\n')
139
+
140
+ sample = model.generate_with_suffix_cond(
141
+ prompts = inp,
142
+ seq_len = GENERATE_LENGTH,
143
+ cache_kv = True,
144
+ decode_backwards = True
145
+ )
146
+
147
+ output_str = decode_tokens(sample.flip(0))
148
+ print(output_str)
@@ -34,15 +34,6 @@ def exists(v):
34
34
  def default(v, d):
35
35
  return v if exists(v) else d
36
36
 
37
- def eval_decorator(fn):
38
- def inner(self, *args, **kwargs):
39
- was_training = self.training
40
- self.eval()
41
- out = fn(self, *args, **kwargs)
42
- self.train(was_training)
43
- return out
44
- return inner
45
-
46
37
  # wrappers
47
38
 
48
39
  class BeliefStateWrapper(Module):
@@ -69,6 +60,8 @@ class BeliefStateWrapper(Module):
69
60
  dim = forward_decoder.emb_dim
70
61
  num_tokens = forward_decoder.num_tokens
71
62
 
63
+ self.num_tokens = num_tokens
64
+
72
65
  # the suffix token
73
66
 
74
67
  self.suffix_token = nn.Parameter(torch.zeros(dim))
@@ -132,6 +125,7 @@ class BeliefStateWrapper(Module):
132
125
  filter_kwargs = dict(
133
126
  min_p = 0.1
134
127
  ),
128
+ decode_backwards = False,
135
129
  **kwargs
136
130
  ):
137
131
  max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
@@ -140,6 +134,14 @@ class BeliefStateWrapper(Module):
140
134
 
141
135
  batch, orig_seq_len = prompts.shape
142
136
 
137
+ # allow for decoding backwards, to make sure it is working
138
+
139
+ main_decoder = self.forward_decoder
140
+
141
+ if decode_backwards:
142
+ prompts = prompts.flip(1)
143
+ main_decoder = self.backward_decoder
144
+
143
145
  out = prompts
144
146
 
145
147
  # kv caches
@@ -148,31 +150,39 @@ class BeliefStateWrapper(Module):
148
150
 
149
151
  # get the encoded suffix token once
150
152
 
151
- if exists(suffix):
152
- if suffix.ndim == 1:
153
- suffix = repeat(suffix, 'n -> b n', b = batch)
153
+ if not decode_backwards:
154
+ if exists(suffix):
155
+ if suffix.ndim == 1:
156
+ suffix = repeat(suffix, 'n -> b n', b = batch)
154
157
 
155
- suffix = suffix.flip(1) # reverse autoregressive
158
+ suffix = suffix.flip(1) # reverse autoregressive
156
159
 
157
- suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
160
+ suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
158
161
 
159
- suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
162
+ suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
160
163
 
161
- suffix_embed = self.backward_decoder(
162
- suffix,
163
- prepend_embeds = suffix_sos_tokens,
164
- return_embeddings = True
165
- )
164
+ suffix_embed = self.backward_decoder(
165
+ suffix,
166
+ prepend_embeds = suffix_sos_tokens,
167
+ return_embeddings = True
168
+ )
169
+
170
+ # pick out the last embedding for fill in the middle
171
+
172
+ suffix_embed = suffix_embed[:, -1:]
173
+
174
+ else:
175
+ # just grab a random token for now for prefix
166
176
 
167
- # pick out the last embedding for fill in the middle
177
+ prefix_embed = torch.randint(0, self.num_tokens, (batch, 1), device = device)
168
178
 
169
- suffix_embed = suffix_embed[:, -1:]
179
+ prefix_embed = self.forward_decoder(prefix_embed, return_embeddings = True)
170
180
 
171
181
  # sampling up to seq_len
172
182
 
173
183
  for _ in range(seq_len):
174
184
 
175
- embeds, new_cache = self.forward_decoder(
185
+ embeds, new_cache = main_decoder(
176
186
  out,
177
187
  return_intermediates = True,
178
188
  return_embeddings = True,
@@ -181,12 +191,18 @@ class BeliefStateWrapper(Module):
181
191
  )
182
192
 
183
193
  last_embeds = embeds[:, -1:]
184
- embeds = cat((last_embeds, suffix_embed), dim = -1)
194
+
195
+ if not decode_backwards:
196
+ embeds = cat((last_embeds, suffix_embed), dim = -1)
197
+ else:
198
+ embeds = cat((prefix_embed, last_embeds), dim = -1)
185
199
 
186
200
  if cache_kv and self.forward_decoder.can_cache_kv:
187
201
  cache = new_cache
188
202
 
189
- logits, _ = self.text_head(embeds).chunk(2, dim = -1)
203
+ forward_logits, backward_logits = self.text_head(embeds).chunk(2, dim = -1)
204
+
205
+ logits = forward_logits if not decode_backwards else backward_logits
190
206
 
191
207
  logits = logits[:, -1]
192
208
 
File without changes