x-transformers 2.1.21__tar.gz → 2.1.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.21 → x_transformers-2.1.23}/PKG-INFO +12 -1
  2. {x_transformers-2.1.21 → x_transformers-2.1.23}/README.md +11 -0
  3. {x_transformers-2.1.21 → x_transformers-2.1.23}/pyproject.toml +1 -1
  4. {x_transformers-2.1.21 → x_transformers-2.1.23}/train_belief_state.py +14 -9
  5. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/belief_state_wrapper.py +21 -7
  6. {x_transformers-2.1.21 → x_transformers-2.1.23}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.21 → x_transformers-2.1.23}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.21 → x_transformers-2.1.23}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.21 → x_transformers-2.1.23}/.gitignore +0 -0
  10. {x_transformers-2.1.21 → x_transformers-2.1.23}/LICENSE +0 -0
  11. {x_transformers-2.1.21 → x_transformers-2.1.23}/data/README.md +0 -0
  12. {x_transformers-2.1.21 → x_transformers-2.1.23}/data/enwik8.gz +0 -0
  13. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/all-attention.png +0 -0
  14. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/deepnorm.png +0 -0
  17. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/fcm.png +0 -0
  23. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/ffglu.png +0 -0
  24. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/flash-attention.png +0 -0
  25. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/gate_values.png +0 -0
  26. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/gating.png +0 -0
  27. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/macaron-1.png +0 -0
  29. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/macaron-2.png +0 -0
  30. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/normformer.png +0 -0
  32. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/pia.png +0 -0
  33. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/resi_dual.png +0 -0
  35. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/residual_attn.png +0 -0
  36. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/rezero.png +0 -0
  37. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/rotary.png +0 -0
  38. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/sandwich.png +0 -0
  40. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/scalenorm.png +0 -0
  42. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/talking-heads.png +0 -0
  43. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/topk-attention.png +0 -0
  44. {x_transformers-2.1.21 → x_transformers-2.1.23}/images/xval.png +0 -0
  45. {x_transformers-2.1.21 → x_transformers-2.1.23}/tests/test_x_transformers.py +0 -0
  46. {x_transformers-2.1.21 → x_transformers-2.1.23}/train_copy.py +0 -0
  47. {x_transformers-2.1.21 → x_transformers-2.1.23}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.21 → x_transformers-2.1.23}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.21 → x_transformers-2.1.23}/train_parity.py +0 -0
  50. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/continuous.py +0 -0
  54. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/dpo.py +0 -0
  55. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/multi_input.py +0 -0
  56. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/neo_mlp.py +0 -0
  57. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/nonautoregressive_wrapper.py +0 -0
  58. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/x_transformers.py +0 -0
  59. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.21 → x_transformers-2.1.23}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.21
3
+ Version: 2.1.23
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2444,4 +2444,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2444
2444
  }
2445
2445
  ```
2446
2446
 
2447
+ ```bibtex
2448
+ @article{Charpentier2024GPTOB,
2449
+ title = {GPT or BERT: why not both?},
2450
+ author = {Lucas Georges Gabriel Charpentier and David Samuel},
2451
+ journal = {ArXiv},
2452
+ year = {2024},
2453
+ volume = {abs/2410.24159},
2454
+ url = {https://api.semanticscholar.org/CorpusID:273707069}
2455
+ }
2456
+ ```
2457
+
2447
2458
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2396,4 +2396,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2396
2396
  }
2397
2397
  ```
2398
2398
 
2399
+ ```bibtex
2400
+ @article{Charpentier2024GPTOB,
2401
+ title = {GPT or BERT: why not both?},
2402
+ author = {Lucas Georges Gabriel Charpentier and David Samuel},
2403
+ journal = {ArXiv},
2404
+ year = {2024},
2405
+ volume = {abs/2410.24159},
2406
+ url = {https://api.semanticscholar.org/CorpusID:273707069}
2407
+ }
2408
+ ```
2409
+
2399
2410
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.21"
3
+ version = "2.1.23"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -21,6 +21,8 @@ GENERATE_EVERY = 500
21
21
  GENERATE_LENGTH = 256
22
22
  SEQ_LEN = 256
23
23
 
24
+ FORWARD_BACKWARD_SAME_MODEL = True
25
+
24
26
  # helpers
25
27
 
26
28
  def cycle(loader):
@@ -47,16 +49,19 @@ forward_model = TransformerWrapper(
47
49
  )
48
50
  )
49
51
 
50
- backward_model = TransformerWrapper(
51
- num_tokens = 256,
52
- max_seq_len = SEQ_LEN,
53
- attn_layers = Decoder(
54
- dim = 512,
55
- depth = 4, # do a smaller backwards
56
- heads = 8,
57
- rotary_pos_emb = True
52
+ backward_model = None
53
+
54
+ if not FORWARD_BACKWARD_SAME_MODEL:
55
+ backward_model = TransformerWrapper(
56
+ num_tokens = 256,
57
+ max_seq_len = SEQ_LEN,
58
+ attn_layers = Decoder(
59
+ dim = 512,
60
+ depth = 4, # do a smaller backwards
61
+ heads = 8,
62
+ rotary_pos_emb = True
63
+ )
58
64
  )
59
- )
60
65
 
61
66
  model = BeliefStateWrapper(
62
67
  forward_decoder = forward_model,
@@ -150,6 +150,10 @@ class BeliefStateWrapper(Module):
150
150
 
151
151
  # get the encoded suffix token once
152
152
 
153
+ suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
154
+
155
+ suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
156
+
153
157
  if not decode_backwards:
154
158
  if exists(suffix):
155
159
  if suffix.ndim == 1:
@@ -157,10 +161,6 @@ class BeliefStateWrapper(Module):
157
161
 
158
162
  suffix = suffix.flip(1) # reverse autoregressive
159
163
 
160
- suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
161
-
162
- suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
163
-
164
164
  suffix_embed = self.backward_decoder(
165
165
  suffix,
166
166
  prepend_embeds = suffix_sos_tokens,
@@ -184,6 +184,7 @@ class BeliefStateWrapper(Module):
184
184
 
185
185
  embeds, new_cache = main_decoder(
186
186
  out,
187
+ prepend_embeds = suffix_sos_tokens if decode_backwards else None,
187
188
  return_intermediates = True,
188
189
  return_embeddings = True,
189
190
  cache = cache,
@@ -227,7 +228,8 @@ class BeliefStateWrapper(Module):
227
228
  self,
228
229
  seq,
229
230
  return_loss_only = False,
230
- loss_scale = 1.
231
+ loss_scale = 1.,
232
+ loss_weight_by_fb_indices: callable | None = None
231
233
  ):
232
234
  batch, seq_len, device = *seq.shape, seq.device
233
235
 
@@ -335,9 +337,21 @@ class BeliefStateWrapper(Module):
335
337
 
336
338
  # maybe loss weighting
337
339
 
338
- if self.needs_loss_weight:
340
+ needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
341
+
342
+ if needs_loss_weight:
339
343
  loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2)
340
- loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
344
+
345
+ if self.needs_loss_weight:
346
+ loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
347
+
348
+ # allow researcher to pass in a function that acts on the the forward backward indices Int['n fb']
349
+ # the reason this may be needed is because the earlier tokens will have more eligible pairs for training, and perhaps this could be normalized
350
+
351
+ if exists(loss_weight_by_fb_indices):
352
+ loss_weight = loss_weight_by_fb_indices(fb_pairs)
353
+ loss = einx.multiply('b fb n, n', loss, loss_weight)
354
+
341
355
  loss = loss.mean()
342
356
 
343
357
  # backwards
File without changes