x-transformers 2.1.11__tar.gz → 2.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {x_transformers-2.1.11 → x_transformers-2.1.14}/PKG-INFO +1 -1
  2. {x_transformers-2.1.11 → x_transformers-2.1.14}/pyproject.toml +1 -1
  3. {x_transformers-2.1.11 → x_transformers-2.1.14}/tests/test_x_transformers.py +8 -2
  4. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/belief_state_wrapper.py +8 -4
  5. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/x_transformers.py +9 -0
  6. {x_transformers-2.1.11 → x_transformers-2.1.14}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.11 → x_transformers-2.1.14}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.11 → x_transformers-2.1.14}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.11 → x_transformers-2.1.14}/.gitignore +0 -0
  10. {x_transformers-2.1.11 → x_transformers-2.1.14}/LICENSE +0 -0
  11. {x_transformers-2.1.11 → x_transformers-2.1.14}/README.md +0 -0
  12. {x_transformers-2.1.11 → x_transformers-2.1.14}/data/README.md +0 -0
  13. {x_transformers-2.1.11 → x_transformers-2.1.14}/data/enwik8.gz +0 -0
  14. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/all-attention.png +0 -0
  15. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/deepnorm.png +0 -0
  18. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/fcm.png +0 -0
  24. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/ffglu.png +0 -0
  25. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/flash-attention.png +0 -0
  26. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/gate_values.png +0 -0
  27. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/gating.png +0 -0
  28. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/macaron-1.png +0 -0
  30. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/macaron-2.png +0 -0
  31. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/normformer.png +0 -0
  33. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/pia.png +0 -0
  34. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/resi_dual.png +0 -0
  36. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/residual_attn.png +0 -0
  37. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/rezero.png +0 -0
  38. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/rotary.png +0 -0
  39. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/sandwich.png +0 -0
  41. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/scalenorm.png +0 -0
  43. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/talking-heads.png +0 -0
  44. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/topk-attention.png +0 -0
  45. {x_transformers-2.1.11 → x_transformers-2.1.14}/images/xval.png +0 -0
  46. {x_transformers-2.1.11 → x_transformers-2.1.14}/train_copy.py +0 -0
  47. {x_transformers-2.1.11 → x_transformers-2.1.14}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.11 → x_transformers-2.1.14}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.11 → x_transformers-2.1.14}/train_parity.py +0 -0
  50. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/continuous.py +0 -0
  54. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/dpo.py +0 -0
  55. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/multi_input.py +0 -0
  56. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/neo_mlp.py +0 -0
  57. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/nonautoregressive_wrapper.py +0 -0
  58. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  59. {x_transformers-2.1.11 → x_transformers-2.1.14}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.11
3
+ Version: 2.1.14
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.11"
3
+ version = "2.1.14"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -695,8 +695,10 @@ def test_lime(
695
695
  model(x)
696
696
 
697
697
  @pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5))
698
+ @pytest.mark.parametrize('goal_suffix', (False, True))
698
699
  def test_belief_state_wrapper(
699
- backward_ar_loss_weight
700
+ backward_ar_loss_weight,
701
+ goal_suffix
700
702
  ):
701
703
  from x_transformers.belief_state_wrapper import BeliefStateWrapper
702
704
 
@@ -733,5 +735,9 @@ def test_belief_state_wrapper(
733
735
  loss = model(seq, backward = False)
734
736
  loss.backward()
735
737
 
736
- sampled = model.generate_with_suffix_token_only(seq[:, :1], 16)
738
+ suffix = None
739
+ if goal_suffix:
740
+ suffix = torch.randint(0, 20000, (2, 2))
741
+
742
+ sampled = model.generate_with_suffix_token_only(seq[:, :1], 16, suffix = suffix)
737
743
  assert sampled.shape == (2, 16)
@@ -132,11 +132,11 @@ class BeliefStateWrapper(Module):
132
132
 
133
133
  # get the encoded suffix token once
134
134
 
135
- if not exists(suffix):
136
- suffix = out[:, 0:0]
135
+ if exists(suffix):
136
+ if suffix.ndim == 1:
137
+ suffix = repeat(suffix, 'n -> b n', b = batch)
137
138
 
138
- if suffix.ndim == 1:
139
- suffix = repeat(suffix, 'n -> b n', b = batch)
139
+ suffix = suffix.flip(1) # reverse autoregressive
140
140
 
141
141
  suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
142
142
 
@@ -148,6 +148,10 @@ class BeliefStateWrapper(Module):
148
148
  return_embeddings = True
149
149
  )
150
150
 
151
+ # pick out the last embedding for fill in the model
152
+
153
+ suffix_embed = suffix_embed[:, -1:]
154
+
151
155
  # sampling up to seq_len
152
156
 
153
157
  for _ in range(seq_len):
@@ -2898,6 +2898,15 @@ class TransformerWrapper(Module):
2898
2898
  to_logits_kwargs = dict(),
2899
2899
  **kwargs,
2900
2900
  ):
2901
+
2902
+ # if sequence is None, auto create an empty one if `prepend_embeds` was supplied
2903
+
2904
+ if not exists(x):
2905
+ assert exists(prepend_embeds)
2906
+ x = prepend_embeds.new_empty((prepend_embeds.shape[0], 0), dtype = torch.long)
2907
+
2908
+ # shapes and variables
2909
+
2901
2910
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2902
2911
 
2903
2912
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
File without changes