x-transformers 2.1.12__tar.gz → 2.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {x_transformers-2.1.12 → x_transformers-2.1.14}/PKG-INFO +1 -1
  2. {x_transformers-2.1.12 → x_transformers-2.1.14}/pyproject.toml +1 -1
  3. {x_transformers-2.1.12 → x_transformers-2.1.14}/tests/test_x_transformers.py +8 -2
  4. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/belief_state_wrapper.py +9 -2
  5. {x_transformers-2.1.12 → x_transformers-2.1.14}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.1.12 → x_transformers-2.1.14}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.1.12 → x_transformers-2.1.14}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.1.12 → x_transformers-2.1.14}/.gitignore +0 -0
  9. {x_transformers-2.1.12 → x_transformers-2.1.14}/LICENSE +0 -0
  10. {x_transformers-2.1.12 → x_transformers-2.1.14}/README.md +0 -0
  11. {x_transformers-2.1.12 → x_transformers-2.1.14}/data/README.md +0 -0
  12. {x_transformers-2.1.12 → x_transformers-2.1.14}/data/enwik8.gz +0 -0
  13. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/all-attention.png +0 -0
  14. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/deepnorm.png +0 -0
  17. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/fcm.png +0 -0
  23. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/ffglu.png +0 -0
  24. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/flash-attention.png +0 -0
  25. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/gate_values.png +0 -0
  26. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/gating.png +0 -0
  27. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/macaron-1.png +0 -0
  29. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/macaron-2.png +0 -0
  30. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/normformer.png +0 -0
  32. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/pia.png +0 -0
  33. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/resi_dual.png +0 -0
  35. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/residual_attn.png +0 -0
  36. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/rezero.png +0 -0
  37. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/rotary.png +0 -0
  38. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/sandwich.png +0 -0
  40. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/scalenorm.png +0 -0
  42. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/talking-heads.png +0 -0
  43. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/topk-attention.png +0 -0
  44. {x_transformers-2.1.12 → x_transformers-2.1.14}/images/xval.png +0 -0
  45. {x_transformers-2.1.12 → x_transformers-2.1.14}/train_copy.py +0 -0
  46. {x_transformers-2.1.12 → x_transformers-2.1.14}/train_enwik8.py +0 -0
  47. {x_transformers-2.1.12 → x_transformers-2.1.14}/train_length_extrapolate.py +0 -0
  48. {x_transformers-2.1.12 → x_transformers-2.1.14}/train_parity.py +0 -0
  49. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/__init__.py +0 -0
  50. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/attend.py +0 -0
  51. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/autoregressive_wrapper.py +0 -0
  52. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/continuous.py +0 -0
  53. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/dpo.py +0 -0
  54. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/multi_input.py +0 -0
  55. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/neo_mlp.py +0 -0
  56. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/nonautoregressive_wrapper.py +0 -0
  57. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/x_transformers.py +0 -0
  58. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  59. {x_transformers-2.1.12 → x_transformers-2.1.14}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.12
3
+ Version: 2.1.14
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.12"
3
+ version = "2.1.14"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -695,8 +695,10 @@ def test_lime(
695
695
  model(x)
696
696
 
697
697
  @pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5))
698
+ @pytest.mark.parametrize('goal_suffix', (False, True))
698
699
  def test_belief_state_wrapper(
699
- backward_ar_loss_weight
700
+ backward_ar_loss_weight,
701
+ goal_suffix
700
702
  ):
701
703
  from x_transformers.belief_state_wrapper import BeliefStateWrapper
702
704
 
@@ -733,5 +735,9 @@ def test_belief_state_wrapper(
733
735
  loss = model(seq, backward = False)
734
736
  loss.backward()
735
737
 
736
- sampled = model.generate_with_suffix_token_only(seq[:, :1], 16)
738
+ suffix = None
739
+ if goal_suffix:
740
+ suffix = torch.randint(0, 20000, (2, 2))
741
+
742
+ sampled = model.generate_with_suffix_token_only(seq[:, :1], 16, suffix = suffix)
737
743
  assert sampled.shape == (2, 16)
@@ -132,8 +132,11 @@ class BeliefStateWrapper(Module):
132
132
 
133
133
  # get the encoded suffix token once
134
134
 
135
- if exists(suffix) and suffix.ndim == 1:
136
- suffix = repeat(suffix, 'n -> b n', b = batch)
135
+ if exists(suffix):
136
+ if suffix.ndim == 1:
137
+ suffix = repeat(suffix, 'n -> b n', b = batch)
138
+
139
+ suffix = suffix.flip(1) # reverse autoregressive
137
140
 
138
141
  suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
139
142
 
@@ -145,6 +148,10 @@ class BeliefStateWrapper(Module):
145
148
  return_embeddings = True
146
149
  )
147
150
 
151
+ # pick out the last embedding for fill in the model
152
+
153
+ suffix_embed = suffix_embed[:, -1:]
154
+
148
155
  # sampling up to seq_len
149
156
 
150
157
  for _ in range(seq_len):
File without changes