x-transformers 2.4.8__tar.gz → 2.4.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.8 → x_transformers-2.4.9}/PKG-INFO +1 -1
  2. {x_transformers-2.4.8 → x_transformers-2.4.9}/pyproject.toml +1 -1
  3. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/up_wrapper.py +10 -4
  4. {x_transformers-2.4.8 → x_transformers-2.4.9}/.github/FUNDING.yml +0 -0
  5. {x_transformers-2.4.8 → x_transformers-2.4.9}/.github/workflows/python-publish.yml +0 -0
  6. {x_transformers-2.4.8 → x_transformers-2.4.9}/.github/workflows/python-test.yaml +0 -0
  7. {x_transformers-2.4.8 → x_transformers-2.4.9}/.gitignore +0 -0
  8. {x_transformers-2.4.8 → x_transformers-2.4.9}/LICENSE +0 -0
  9. {x_transformers-2.4.8 → x_transformers-2.4.9}/README.md +0 -0
  10. {x_transformers-2.4.8 → x_transformers-2.4.9}/data/README.md +0 -0
  11. {x_transformers-2.4.8 → x_transformers-2.4.9}/data/enwik8.gz +0 -0
  12. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/all-attention.png +0 -0
  13. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/attention-on-attention.png +0 -0
  14. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/cosine-sim-attention.png +0 -0
  15. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/deepnorm.png +0 -0
  16. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/dynamic-pos-bias-linear.png +0 -0
  17. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/dynamic-pos-bias-log.png +0 -0
  18. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  19. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/dynamic-pos-bias.png +0 -0
  20. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/enhanced-recurrence.png +0 -0
  21. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/fcm.png +0 -0
  22. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/ffglu.png +0 -0
  23. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/flash-attention.png +0 -0
  24. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/gate_values.png +0 -0
  25. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/gating.png +0 -0
  26. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/length-extrapolation-scale.png +0 -0
  27. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/macaron-1.png +0 -0
  28. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/macaron-2.png +0 -0
  29. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/memory-transformer.png +0 -0
  30. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/normformer.png +0 -0
  31. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/pia.png +0 -0
  32. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/qknorm-analysis.png +0 -0
  33. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/resi_dual.png +0 -0
  34. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/residual_attn.png +0 -0
  35. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/rezero.png +0 -0
  36. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/rotary.png +0 -0
  37. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/sandwich-2.png +0 -0
  38. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/sandwich.png +0 -0
  39. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/sandwich_norm.png +0 -0
  40. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/scalenorm.png +0 -0
  41. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/talking-heads.png +0 -0
  42. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/topk-attention.png +0 -0
  43. {x_transformers-2.4.8 → x_transformers-2.4.9}/images/xval.png +0 -0
  44. {x_transformers-2.4.8 → x_transformers-2.4.9}/tests/test_x_transformers.py +0 -0
  45. {x_transformers-2.4.8 → x_transformers-2.4.9}/train_belief_state.py +0 -0
  46. {x_transformers-2.4.8 → x_transformers-2.4.9}/train_copy.py +0 -0
  47. {x_transformers-2.4.8 → x_transformers-2.4.9}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.4.8 → x_transformers-2.4.9}/train_enwik8.py +0 -0
  49. {x_transformers-2.4.8 → x_transformers-2.4.9}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.4.8 → x_transformers-2.4.9}/train_parity.py +0 -0
  51. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/x_transformers.py +0 -0
  62. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.8 → x_transformers-2.4.9}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.8
3
+ Version: 2.4.9
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.8"
3
+ version = "2.4.9"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -7,7 +7,7 @@ from random import randrange, uniform
7
7
 
8
8
  import torch
9
9
  from torch import nn, cat, tensor, randperm
10
- from torch.nn import LSTM, Module
10
+ from torch.nn import LSTM, GRU, Module
11
11
 
12
12
  from x_transformers.x_transformers import (
13
13
  TransformerWrapper,
@@ -61,7 +61,9 @@ class SyntheticDataGenerator(Module):
61
61
  dim,
62
62
  num_tokens,
63
63
  max_seq_len = 512,
64
- hidden_size = None
64
+ hidden_size = None,
65
+ use_gru = False,
66
+ network_klass = None
65
67
  ):
66
68
  super().__init__()
67
69
 
@@ -70,7 +72,11 @@ class SyntheticDataGenerator(Module):
70
72
  self.embed = nn.Embedding(num_tokens, dim)
71
73
 
72
74
  hidden_size = default(hidden_size, dim)
73
- self.lstm = LSTM(dim, hidden_size, batch_first = True)
75
+
76
+ default_network_klass = partial(LSTM if not use_gru else GRU, batch_first = True)
77
+ network_klass = default(network_klass, default_network_klass)
78
+
79
+ self.net = network_klass(dim, hidden_size)
74
80
 
75
81
  self.to_logits = nn.Linear(dim, num_tokens, bias = False)
76
82
 
@@ -128,7 +134,7 @@ class SyntheticDataGenerator(Module):
128
134
 
129
135
  tokens = self.embed(input)
130
136
 
131
- embed, hidden = self.lstm(tokens, hiddens)
137
+ embed, hidden = self.net(tokens, hiddens)
132
138
 
133
139
  logits = self.to_logits(embed)
134
140
 
File without changes
File without changes