x-transformers 2.1.7__tar.gz → 2.1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {x_transformers-2.1.7 → x_transformers-2.1.10}/PKG-INFO +1 -1
  2. {x_transformers-2.1.7 → x_transformers-2.1.10}/pyproject.toml +1 -1
  3. {x_transformers-2.1.7 → x_transformers-2.1.10}/tests/test_x_transformers.py +5 -2
  4. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/belief_state_wrapper.py +15 -6
  5. {x_transformers-2.1.7 → x_transformers-2.1.10}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.1.7 → x_transformers-2.1.10}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.1.7 → x_transformers-2.1.10}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.1.7 → x_transformers-2.1.10}/.gitignore +0 -0
  9. {x_transformers-2.1.7 → x_transformers-2.1.10}/LICENSE +0 -0
  10. {x_transformers-2.1.7 → x_transformers-2.1.10}/README.md +0 -0
  11. {x_transformers-2.1.7 → x_transformers-2.1.10}/data/README.md +0 -0
  12. {x_transformers-2.1.7 → x_transformers-2.1.10}/data/enwik8.gz +0 -0
  13. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/all-attention.png +0 -0
  14. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/deepnorm.png +0 -0
  17. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/fcm.png +0 -0
  23. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/ffglu.png +0 -0
  24. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/flash-attention.png +0 -0
  25. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/gate_values.png +0 -0
  26. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/gating.png +0 -0
  27. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/macaron-1.png +0 -0
  29. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/macaron-2.png +0 -0
  30. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/normformer.png +0 -0
  32. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/pia.png +0 -0
  33. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/resi_dual.png +0 -0
  35. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/residual_attn.png +0 -0
  36. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/rezero.png +0 -0
  37. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/rotary.png +0 -0
  38. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/sandwich.png +0 -0
  40. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/scalenorm.png +0 -0
  42. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/talking-heads.png +0 -0
  43. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/topk-attention.png +0 -0
  44. {x_transformers-2.1.7 → x_transformers-2.1.10}/images/xval.png +0 -0
  45. {x_transformers-2.1.7 → x_transformers-2.1.10}/train_copy.py +0 -0
  46. {x_transformers-2.1.7 → x_transformers-2.1.10}/train_enwik8.py +0 -0
  47. {x_transformers-2.1.7 → x_transformers-2.1.10}/train_length_extrapolate.py +0 -0
  48. {x_transformers-2.1.7 → x_transformers-2.1.10}/train_parity.py +0 -0
  49. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/__init__.py +0 -0
  50. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/attend.py +0 -0
  51. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/autoregressive_wrapper.py +0 -0
  52. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/continuous.py +0 -0
  53. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/dpo.py +0 -0
  54. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/multi_input.py +0 -0
  55. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/neo_mlp.py +0 -0
  56. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/nonautoregressive_wrapper.py +0 -0
  57. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/x_transformers.py +0 -0
  58. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  59. {x_transformers-2.1.7 → x_transformers-2.1.10}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.7
3
+ Version: 2.1.10
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.7"
3
+ version = "2.1.10"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -694,7 +694,10 @@ def test_lime(
694
694
 
695
695
  model(x)
696
696
 
697
- def test_belief_state_wrapper():
697
+ @pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5))
698
+ def test_belief_state_wrapper(
699
+ backward_ar_loss_weight
700
+ ):
698
701
  from x_transformers.belief_state_wrapper import BeliefStateWrapper
699
702
 
700
703
  forward_model = TransformerWrapper(
@@ -722,7 +725,7 @@ def test_belief_state_wrapper():
722
725
  model = BeliefStateWrapper(
723
726
  forward_decoder = forward_model,
724
727
  backward_decoder = backward_model,
725
- backward_ar_loss_weight = 0.5
728
+ backward_ar_loss_weight = backward_ar_loss_weight
726
729
  )
727
730
 
728
731
  seq = torch.randint(0, 20000, (2, 16))
@@ -1,8 +1,11 @@
1
+
1
2
  # Belief State Transformer
2
3
 
3
4
  # https://arxiv.org/abs/2410.23506
4
5
  # https://www.youtube.com/watch?v=aqhbRtB2Fyg
5
6
 
7
+ from __future__ import annotations
8
+
6
9
  import torch
7
10
  from torch.autograd import Function
8
11
  from torch.nn import Module, ModuleList
@@ -35,11 +38,14 @@ class BeliefStateWrapper(Module):
35
38
  def __init__(
36
39
  self,
37
40
  forward_decoder: TransformerWrapper,
38
- backward_decoder: TransformerWrapper,
41
+ backward_decoder: TransformerWrapper | None = None,
39
42
  train_frac_forward_backward_pairs: float = 1.,
43
+ text_head: Module | None = None,
40
44
  backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
41
45
  ):
42
46
  super().__init__()
47
+ backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
48
+
43
49
  assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
44
50
  assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens'
45
51
 
@@ -53,11 +59,14 @@ class BeliefStateWrapper(Module):
53
59
 
54
60
  # the text prediction head, which predicts for the combinations of prefix and suffix the next and previous token for forwards and backward sequences
55
61
 
56
- self.text_head = nn.Sequential(
57
- nn.Linear(dim * 2, dim),
58
- nn.LeakyReLU(),
59
- nn.Linear(dim, num_tokens * 2),
60
- )
62
+ if not exists(text_head):
63
+ text_head = nn.Sequential(
64
+ nn.Linear(dim * 2, dim),
65
+ nn.LeakyReLU(),
66
+ nn.Linear(dim, num_tokens * 2),
67
+ )
68
+
69
+ self.text_head = text_head
61
70
 
62
71
  # the two decoders, one which is causal forward, the other causal backwards
63
72
 
File without changes