x-transformers 2.3.17__tar.gz → 2.3.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {x_transformers-2.3.17 → x_transformers-2.3.18}/PKG-INFO +1 -1
  2. {x_transformers-2.3.17 → x_transformers-2.3.18}/pyproject.toml +1 -1
  3. {x_transformers-2.3.17 → x_transformers-2.3.18}/tests/test_x_transformers.py +0 -3
  4. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/continuous.py +8 -3
  5. {x_transformers-2.3.17 → x_transformers-2.3.18}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.3.17 → x_transformers-2.3.18}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.3.17 → x_transformers-2.3.18}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.3.17 → x_transformers-2.3.18}/.gitignore +0 -0
  9. {x_transformers-2.3.17 → x_transformers-2.3.18}/LICENSE +0 -0
  10. {x_transformers-2.3.17 → x_transformers-2.3.18}/README.md +0 -0
  11. {x_transformers-2.3.17 → x_transformers-2.3.18}/data/README.md +0 -0
  12. {x_transformers-2.3.17 → x_transformers-2.3.18}/data/enwik8.gz +0 -0
  13. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/all-attention.png +0 -0
  14. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/deepnorm.png +0 -0
  17. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/fcm.png +0 -0
  23. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/ffglu.png +0 -0
  24. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/flash-attention.png +0 -0
  25. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/gate_values.png +0 -0
  26. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/gating.png +0 -0
  27. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/macaron-1.png +0 -0
  29. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/macaron-2.png +0 -0
  30. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/normformer.png +0 -0
  32. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/pia.png +0 -0
  33. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/resi_dual.png +0 -0
  35. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/residual_attn.png +0 -0
  36. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/rezero.png +0 -0
  37. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/rotary.png +0 -0
  38. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/sandwich.png +0 -0
  40. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/scalenorm.png +0 -0
  42. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/talking-heads.png +0 -0
  43. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/topk-attention.png +0 -0
  44. {x_transformers-2.3.17 → x_transformers-2.3.18}/images/xval.png +0 -0
  45. {x_transformers-2.3.17 → x_transformers-2.3.18}/train_belief_state.py +0 -0
  46. {x_transformers-2.3.17 → x_transformers-2.3.18}/train_copy.py +0 -0
  47. {x_transformers-2.3.17 → x_transformers-2.3.18}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.3.17 → x_transformers-2.3.18}/train_enwik8.py +0 -0
  49. {x_transformers-2.3.17 → x_transformers-2.3.18}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.3.17 → x_transformers-2.3.18}/train_parity.py +0 -0
  51. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/entropy_based_tokenizer.py +0 -0
  57. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/multi_input.py +0 -0
  58. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/neo_mlp.py +0 -0
  59. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/nonautoregressive_wrapper.py +0 -0
  60. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/x_transformers.py +0 -0
  61. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  62. {x_transformers-2.3.17 → x_transformers-2.3.18}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.17
3
+ Version: 2.3.18
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.3.17"
3
+ version = "2.3.18"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -876,9 +876,6 @@ def test_continuous(
876
876
  cache_kv,
877
877
  rollout_steps
878
878
  ):
879
- if probabilistic and rollout_steps > 1:
880
- pytest.skip()
881
-
882
879
  from x_transformers import (
883
880
  ContinuousTransformerWrapper,
884
881
  Decoder,
@@ -298,7 +298,6 @@ class ContinuousAutoregressiveWrapper(Module):
298
298
  **kwargs
299
299
  ):
300
300
  assert rollout_steps > 1
301
- assert not self.probabilistic, 'probabilistic not supported yet'
302
301
 
303
302
  steps = rollout_steps
304
303
 
@@ -369,8 +368,14 @@ class ContinuousAutoregressiveWrapper(Module):
369
368
  **kwargs
370
369
  )
371
370
 
372
- last_pred = out[:, -1:]
373
- inp = last_pred
371
+ last_pred = out[..., -1:, :]
372
+
373
+ if self.probabilistic:
374
+ mean, var = last_pred
375
+ std = var.clamp(min = 1e-5).sqrt()
376
+ inp = torch.normal(mean, std)
377
+ else:
378
+ inp = last_pred
374
379
 
375
380
  preds.append(last_pred)
376
381
 
File without changes