x-transformers 2.4.4__tar.gz → 2.4.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.4 → x_transformers-2.4.6}/PKG-INFO +1 -1
  2. {x_transformers-2.4.4 → x_transformers-2.4.6}/pyproject.toml +1 -1
  3. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/up_wrapper.py +21 -3
  4. {x_transformers-2.4.4 → x_transformers-2.4.6}/.github/FUNDING.yml +0 -0
  5. {x_transformers-2.4.4 → x_transformers-2.4.6}/.github/workflows/python-publish.yml +0 -0
  6. {x_transformers-2.4.4 → x_transformers-2.4.6}/.github/workflows/python-test.yaml +0 -0
  7. {x_transformers-2.4.4 → x_transformers-2.4.6}/.gitignore +0 -0
  8. {x_transformers-2.4.4 → x_transformers-2.4.6}/LICENSE +0 -0
  9. {x_transformers-2.4.4 → x_transformers-2.4.6}/README.md +0 -0
  10. {x_transformers-2.4.4 → x_transformers-2.4.6}/data/README.md +0 -0
  11. {x_transformers-2.4.4 → x_transformers-2.4.6}/data/enwik8.gz +0 -0
  12. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/all-attention.png +0 -0
  13. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/attention-on-attention.png +0 -0
  14. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/cosine-sim-attention.png +0 -0
  15. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/deepnorm.png +0 -0
  16. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/dynamic-pos-bias-linear.png +0 -0
  17. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/dynamic-pos-bias-log.png +0 -0
  18. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  19. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/dynamic-pos-bias.png +0 -0
  20. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/enhanced-recurrence.png +0 -0
  21. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/fcm.png +0 -0
  22. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/ffglu.png +0 -0
  23. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/flash-attention.png +0 -0
  24. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/gate_values.png +0 -0
  25. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/gating.png +0 -0
  26. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/length-extrapolation-scale.png +0 -0
  27. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/macaron-1.png +0 -0
  28. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/macaron-2.png +0 -0
  29. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/memory-transformer.png +0 -0
  30. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/normformer.png +0 -0
  31. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/pia.png +0 -0
  32. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/qknorm-analysis.png +0 -0
  33. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/resi_dual.png +0 -0
  34. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/residual_attn.png +0 -0
  35. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/rezero.png +0 -0
  36. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/rotary.png +0 -0
  37. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/sandwich-2.png +0 -0
  38. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/sandwich.png +0 -0
  39. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/sandwich_norm.png +0 -0
  40. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/scalenorm.png +0 -0
  41. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/talking-heads.png +0 -0
  42. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/topk-attention.png +0 -0
  43. {x_transformers-2.4.4 → x_transformers-2.4.6}/images/xval.png +0 -0
  44. {x_transformers-2.4.4 → x_transformers-2.4.6}/tests/test_x_transformers.py +0 -0
  45. {x_transformers-2.4.4 → x_transformers-2.4.6}/train_belief_state.py +0 -0
  46. {x_transformers-2.4.4 → x_transformers-2.4.6}/train_copy.py +0 -0
  47. {x_transformers-2.4.4 → x_transformers-2.4.6}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.4.4 → x_transformers-2.4.6}/train_enwik8.py +0 -0
  49. {x_transformers-2.4.4 → x_transformers-2.4.6}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.4.4 → x_transformers-2.4.6}/train_parity.py +0 -0
  51. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/x_transformers.py +0 -0
  62. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.4 → x_transformers-2.4.6}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.4
3
+ Version: 2.4.6
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.4"
3
+ version = "2.4.6"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -6,7 +6,7 @@ from functools import partial
6
6
  from random import randrange, uniform
7
7
 
8
8
  import torch
9
- from torch import nn, cat, randperm
9
+ from torch import nn, cat, tensor, randperm
10
10
  from torch.nn import LSTM, Module
11
11
 
12
12
  from x_transformers.x_transformers import (
@@ -76,6 +76,13 @@ class SyntheticDataGenerator(Module):
76
76
 
77
77
  self.apply(self.init_)
78
78
 
79
+ def reset_(self):
80
+ for m in self.modules():
81
+ if hasattr(m, 'reset_parameters'):
82
+ m.reset_parameters()
83
+
84
+ self.apply(self.init_)
85
+
79
86
  @torch.no_grad()
80
87
  def init_(self, m):
81
88
  if isinstance(m, nn.Linear):
@@ -133,12 +140,13 @@ class UniversalPretrainWrapper(Module):
133
140
  def __init__(
134
141
  self,
135
142
  model: TransformerWrapper,
136
- data_generator: SyntheticDataGenerator | None = None,
143
+ data_generator: SyntheticDataGenerator | Module | None = None,
137
144
  buffer_size = None,
138
145
  num_reset = 20,
139
146
  batch_size = 32,
140
147
  seq_len = 512,
141
- seed_length = 8
148
+ seed_length = 8,
149
+ reset_turing_machine_every = 0
142
150
  ):
143
151
  super().__init__()
144
152
 
@@ -157,6 +165,8 @@ class UniversalPretrainWrapper(Module):
157
165
  max_seq_len = seq_len
158
166
  )
159
167
 
168
+ self.reset_turing_machine_every = reset_turing_machine_every
169
+
160
170
  self.seq_len = seq_len
161
171
  self.data_generator = data_generator
162
172
 
@@ -176,6 +186,7 @@ class UniversalPretrainWrapper(Module):
176
186
  init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
177
187
 
178
188
  self.register_buffer('synth_data_buffer', init_data_buffer)
189
+ self.register_buffer('step', tensor(0))
179
190
 
180
191
  @property
181
192
  def device(self):
@@ -210,6 +221,13 @@ class UniversalPretrainWrapper(Module):
210
221
  seed = seeds
211
222
  )
212
223
 
224
+ self.step.add_(1)
225
+
226
+ # maybe reset turing machine
227
+
228
+ if self.reset_turing_machine_every > 0 and divisible_by(self.step.item(), self.reset_turing_machine_every):
229
+ self.data_generator.reset_()
230
+
213
231
  # reset
214
232
 
215
233
  if self.num_reset > 0:
File without changes
File without changes