x-transformers 2.1.20__tar.gz → 2.1.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.20 → x_transformers-2.1.22}/PKG-INFO +12 -1
  2. {x_transformers-2.1.20 → x_transformers-2.1.22}/README.md +11 -0
  3. {x_transformers-2.1.20 → x_transformers-2.1.22}/pyproject.toml +1 -1
  4. {x_transformers-2.1.20 → x_transformers-2.1.22}/train_belief_state.py +29 -9
  5. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/belief_state_wrapper.py +42 -25
  6. {x_transformers-2.1.20 → x_transformers-2.1.22}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.20 → x_transformers-2.1.22}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.20 → x_transformers-2.1.22}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.20 → x_transformers-2.1.22}/.gitignore +0 -0
  10. {x_transformers-2.1.20 → x_transformers-2.1.22}/LICENSE +0 -0
  11. {x_transformers-2.1.20 → x_transformers-2.1.22}/data/README.md +0 -0
  12. {x_transformers-2.1.20 → x_transformers-2.1.22}/data/enwik8.gz +0 -0
  13. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/all-attention.png +0 -0
  14. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/deepnorm.png +0 -0
  17. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/fcm.png +0 -0
  23. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/ffglu.png +0 -0
  24. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/flash-attention.png +0 -0
  25. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/gate_values.png +0 -0
  26. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/gating.png +0 -0
  27. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/macaron-1.png +0 -0
  29. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/macaron-2.png +0 -0
  30. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/normformer.png +0 -0
  32. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/pia.png +0 -0
  33. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/resi_dual.png +0 -0
  35. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/residual_attn.png +0 -0
  36. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/rezero.png +0 -0
  37. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/rotary.png +0 -0
  38. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/sandwich.png +0 -0
  40. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/scalenorm.png +0 -0
  42. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/talking-heads.png +0 -0
  43. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/topk-attention.png +0 -0
  44. {x_transformers-2.1.20 → x_transformers-2.1.22}/images/xval.png +0 -0
  45. {x_transformers-2.1.20 → x_transformers-2.1.22}/tests/test_x_transformers.py +0 -0
  46. {x_transformers-2.1.20 → x_transformers-2.1.22}/train_copy.py +0 -0
  47. {x_transformers-2.1.20 → x_transformers-2.1.22}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.20 → x_transformers-2.1.22}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.20 → x_transformers-2.1.22}/train_parity.py +0 -0
  50. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/continuous.py +0 -0
  54. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/dpo.py +0 -0
  55. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/multi_input.py +0 -0
  56. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/neo_mlp.py +0 -0
  57. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/nonautoregressive_wrapper.py +0 -0
  58. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/x_transformers.py +0 -0
  59. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.20 → x_transformers-2.1.22}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.20
3
+ Version: 2.1.22
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2444,4 +2444,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2444
2444
  }
2445
2445
  ```
2446
2446
 
2447
+ ```bibtex
2448
+ @article{Charpentier2024GPTOB,
2449
+ title = {GPT or BERT: why not both?},
2450
+ author = {Lucas Georges Gabriel Charpentier and David Samuel},
2451
+ journal = {ArXiv},
2452
+ year = {2024},
2453
+ volume = {abs/2410.24159},
2454
+ url = {https://api.semanticscholar.org/CorpusID:273707069}
2455
+ }
2456
+ ```
2457
+
2447
2458
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2396,4 +2396,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2396
2396
  }
2397
2397
  ```
2398
2398
 
2399
+ ```bibtex
2400
+ @article{Charpentier2024GPTOB,
2401
+ title = {GPT or BERT: why not both?},
2402
+ author = {Lucas Georges Gabriel Charpentier and David Samuel},
2403
+ journal = {ArXiv},
2404
+ year = {2024},
2405
+ volume = {abs/2410.24159},
2406
+ url = {https://api.semanticscholar.org/CorpusID:273707069}
2407
+ }
2408
+ ```
2409
+
2399
2410
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.20"
3
+ version = "2.1.22"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -21,6 +21,8 @@ GENERATE_EVERY = 500
21
21
  GENERATE_LENGTH = 256
22
22
  SEQ_LEN = 256
23
23
 
24
+ FORWARD_BACKWARD_SAME_MODEL = True
25
+
24
26
  # helpers
25
27
 
26
28
  def cycle(loader):
@@ -47,16 +49,19 @@ forward_model = TransformerWrapper(
47
49
  )
48
50
  )
49
51
 
50
- backward_model = TransformerWrapper(
51
- num_tokens = 256,
52
- max_seq_len = SEQ_LEN,
53
- attn_layers = Decoder(
54
- dim = 512,
55
- depth = 4, # do a smaller backwards
56
- heads = 8,
57
- rotary_pos_emb = True
52
+ backward_model = None
53
+
54
+ if not FORWARD_BACKWARD_SAME_MODEL:
55
+ backward_model = TransformerWrapper(
56
+ num_tokens = 256,
57
+ max_seq_len = SEQ_LEN,
58
+ attn_layers = Decoder(
59
+ dim = 512,
60
+ depth = 4, # do a smaller backwards
61
+ heads = 8,
62
+ rotary_pos_emb = True
63
+ )
58
64
  )
59
- )
60
65
 
61
66
  model = BeliefStateWrapper(
62
67
  forward_decoder = forward_model,
@@ -121,8 +126,11 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
121
126
  model.eval()
122
127
  inp = random.choice(val_dataset)[:-1]
123
128
  prime = decode_tokens(inp)
129
+
124
130
  print(f'%s \n\n %s', (prime, '*' * 100))
125
131
 
132
+ print('forwards:\n')
133
+
126
134
  sample = model.generate_with_suffix_cond(
127
135
  prompts = inp,
128
136
  seq_len = GENERATE_LENGTH,
@@ -131,3 +139,15 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
131
139
 
132
140
  output_str = decode_tokens(sample)
133
141
  print(output_str)
142
+
143
+ print('\nbackwards:\n')
144
+
145
+ sample = model.generate_with_suffix_cond(
146
+ prompts = inp,
147
+ seq_len = GENERATE_LENGTH,
148
+ cache_kv = True,
149
+ decode_backwards = True
150
+ )
151
+
152
+ output_str = decode_tokens(sample.flip(0))
153
+ print(output_str)
@@ -34,15 +34,6 @@ def exists(v):
34
34
  def default(v, d):
35
35
  return v if exists(v) else d
36
36
 
37
- def eval_decorator(fn):
38
- def inner(self, *args, **kwargs):
39
- was_training = self.training
40
- self.eval()
41
- out = fn(self, *args, **kwargs)
42
- self.train(was_training)
43
- return out
44
- return inner
45
-
46
37
  # wrappers
47
38
 
48
39
  class BeliefStateWrapper(Module):
@@ -69,6 +60,8 @@ class BeliefStateWrapper(Module):
69
60
  dim = forward_decoder.emb_dim
70
61
  num_tokens = forward_decoder.num_tokens
71
62
 
63
+ self.num_tokens = num_tokens
64
+
72
65
  # the suffix token
73
66
 
74
67
  self.suffix_token = nn.Parameter(torch.zeros(dim))
@@ -132,6 +125,7 @@ class BeliefStateWrapper(Module):
132
125
  filter_kwargs = dict(
133
126
  min_p = 0.1
134
127
  ),
128
+ decode_backwards = False,
135
129
  **kwargs
136
130
  ):
137
131
  max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
@@ -140,6 +134,14 @@ class BeliefStateWrapper(Module):
140
134
 
141
135
  batch, orig_seq_len = prompts.shape
142
136
 
137
+ # allow for decoding backwards, to make sure it is working
138
+
139
+ main_decoder = self.forward_decoder
140
+
141
+ if decode_backwards:
142
+ prompts = prompts.flip(1)
143
+ main_decoder = self.backward_decoder
144
+
143
145
  out = prompts
144
146
 
145
147
  # kv caches
@@ -148,32 +150,41 @@ class BeliefStateWrapper(Module):
148
150
 
149
151
  # get the encoded suffix token once
150
152
 
151
- if exists(suffix):
152
- if suffix.ndim == 1:
153
- suffix = repeat(suffix, 'n -> b n', b = batch)
154
-
155
- suffix = suffix.flip(1) # reverse autoregressive
156
-
157
153
  suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
158
154
 
159
155
  suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
160
156
 
161
- suffix_embed = self.backward_decoder(
162
- suffix,
163
- prepend_embeds = suffix_sos_tokens,
164
- return_embeddings = True
165
- )
157
+ if not decode_backwards:
158
+ if exists(suffix):
159
+ if suffix.ndim == 1:
160
+ suffix = repeat(suffix, 'n -> b n', b = batch)
166
161
 
167
- # pick out the last embedding for fill in the middle
162
+ suffix = suffix.flip(1) # reverse autoregressive
168
163
 
169
- suffix_embed = suffix_embed[:, -1:]
164
+ suffix_embed = self.backward_decoder(
165
+ suffix,
166
+ prepend_embeds = suffix_sos_tokens,
167
+ return_embeddings = True
168
+ )
169
+
170
+ # pick out the last embedding for fill in the middle
171
+
172
+ suffix_embed = suffix_embed[:, -1:]
173
+
174
+ else:
175
+ # just grab a random token for now for prefix
176
+
177
+ prefix_embed = torch.randint(0, self.num_tokens, (batch, 1), device = device)
178
+
179
+ prefix_embed = self.forward_decoder(prefix_embed, return_embeddings = True)
170
180
 
171
181
  # sampling up to seq_len
172
182
 
173
183
  for _ in range(seq_len):
174
184
 
175
- embeds, new_cache = self.forward_decoder(
185
+ embeds, new_cache = main_decoder(
176
186
  out,
187
+ prepend_embeds = suffix_sos_tokens if decode_backwards else None,
177
188
  return_intermediates = True,
178
189
  return_embeddings = True,
179
190
  cache = cache,
@@ -181,12 +192,18 @@ class BeliefStateWrapper(Module):
181
192
  )
182
193
 
183
194
  last_embeds = embeds[:, -1:]
184
- embeds = cat((last_embeds, suffix_embed), dim = -1)
195
+
196
+ if not decode_backwards:
197
+ embeds = cat((last_embeds, suffix_embed), dim = -1)
198
+ else:
199
+ embeds = cat((prefix_embed, last_embeds), dim = -1)
185
200
 
186
201
  if cache_kv and self.forward_decoder.can_cache_kv:
187
202
  cache = new_cache
188
203
 
189
- logits, _ = self.text_head(embeds).chunk(2, dim = -1)
204
+ forward_logits, backward_logits = self.text_head(embeds).chunk(2, dim = -1)
205
+
206
+ logits = forward_logits if not decode_backwards else backward_logits
190
207
 
191
208
  logits = logits[:, -1]
192
209
 
File without changes