x-transformers 2.4.5__tar.gz → 2.4.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.5 → x_transformers-2.4.7}/PKG-INFO +1 -1
  2. {x_transformers-2.4.5 → x_transformers-2.4.7}/pyproject.toml +1 -1
  3. {x_transformers-2.4.5 → x_transformers-2.4.7}/tests/test_x_transformers.py +9 -2
  4. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/up_wrapper.py +30 -7
  5. {x_transformers-2.4.5 → x_transformers-2.4.7}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.4.5 → x_transformers-2.4.7}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.4.5 → x_transformers-2.4.7}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.4.5 → x_transformers-2.4.7}/.gitignore +0 -0
  9. {x_transformers-2.4.5 → x_transformers-2.4.7}/LICENSE +0 -0
  10. {x_transformers-2.4.5 → x_transformers-2.4.7}/README.md +0 -0
  11. {x_transformers-2.4.5 → x_transformers-2.4.7}/data/README.md +0 -0
  12. {x_transformers-2.4.5 → x_transformers-2.4.7}/data/enwik8.gz +0 -0
  13. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/all-attention.png +0 -0
  14. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/deepnorm.png +0 -0
  17. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/fcm.png +0 -0
  23. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/ffglu.png +0 -0
  24. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/flash-attention.png +0 -0
  25. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/gate_values.png +0 -0
  26. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/gating.png +0 -0
  27. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/macaron-1.png +0 -0
  29. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/macaron-2.png +0 -0
  30. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/normformer.png +0 -0
  32. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/pia.png +0 -0
  33. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/resi_dual.png +0 -0
  35. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/residual_attn.png +0 -0
  36. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/rezero.png +0 -0
  37. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/rotary.png +0 -0
  38. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/sandwich.png +0 -0
  40. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/scalenorm.png +0 -0
  42. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/talking-heads.png +0 -0
  43. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/topk-attention.png +0 -0
  44. {x_transformers-2.4.5 → x_transformers-2.4.7}/images/xval.png +0 -0
  45. {x_transformers-2.4.5 → x_transformers-2.4.7}/train_belief_state.py +0 -0
  46. {x_transformers-2.4.5 → x_transformers-2.4.7}/train_copy.py +0 -0
  47. {x_transformers-2.4.5 → x_transformers-2.4.7}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.4.5 → x_transformers-2.4.7}/train_enwik8.py +0 -0
  49. {x_transformers-2.4.5 → x_transformers-2.4.7}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.4.5 → x_transformers-2.4.7}/train_parity.py +0 -0
  51. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/x_transformers.py +0 -0
  62. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.5
3
+ Version: 2.4.7
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.5"
3
+ version = "2.4.7"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1100,7 +1100,10 @@ def add_attn_pool():
1100
1100
 
1101
1101
  assert intermediates.attn_pooled_tokens.shape[1] == 3
1102
1102
 
1103
- def test_up():
1103
+ @pytest.mark.parametrize('keep_buffer_on_cpu', (False, True))
1104
+ def test_up(
1105
+ keep_buffer_on_cpu
1106
+ ):
1104
1107
  from x_transformers.up_wrapper import UniversalPretrainWrapper
1105
1108
 
1106
1109
  model = TransformerWrapper(
@@ -1115,7 +1118,11 @@ def test_up():
1115
1118
  ),
1116
1119
  )
1117
1120
 
1118
- up_wrapper = UniversalPretrainWrapper(model, seq_len = 16)
1121
+ up_wrapper = UniversalPretrainWrapper(
1122
+ model,
1123
+ seq_len = 16,
1124
+ keep_buffer_on_cpu = keep_buffer_on_cpu
1125
+ )
1119
1126
 
1120
1127
  loss = up_wrapper()
1121
1128
  loss.backward()
@@ -6,7 +6,7 @@ from functools import partial
6
6
  from random import randrange, uniform
7
7
 
8
8
  import torch
9
- from torch import nn, cat, randperm
9
+ from torch import nn, cat, tensor, randperm
10
10
  from torch.nn import LSTM, Module
11
11
 
12
12
  from x_transformers.x_transformers import (
@@ -76,6 +76,13 @@ class SyntheticDataGenerator(Module):
76
76
 
77
77
  self.apply(self.init_)
78
78
 
79
+ def reset_(self):
80
+ for m in self.modules():
81
+ if hasattr(m, 'reset_parameters'):
82
+ m.reset_parameters()
83
+
84
+ self.apply(self.init_)
85
+
79
86
  @torch.no_grad()
80
87
  def init_(self, m):
81
88
  if isinstance(m, nn.Linear):
@@ -138,7 +145,9 @@ class UniversalPretrainWrapper(Module):
138
145
  num_reset = 20,
139
146
  batch_size = 32,
140
147
  seq_len = 512,
141
- seed_length = 8
148
+ seed_length = 8,
149
+ reset_turing_machine_every = 0,
150
+ keep_buffer_on_cpu = False
142
151
  ):
143
152
  super().__init__()
144
153
 
@@ -157,6 +166,8 @@ class UniversalPretrainWrapper(Module):
157
166
  max_seq_len = seq_len
158
167
  )
159
168
 
169
+ self.reset_turing_machine_every = reset_turing_machine_every
170
+
160
171
  self.seq_len = seq_len
161
172
  self.data_generator = data_generator
162
173
 
@@ -175,11 +186,16 @@ class UniversalPretrainWrapper(Module):
175
186
 
176
187
  init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
177
188
 
178
- self.register_buffer('synth_data_buffer', init_data_buffer)
189
+ if keep_buffer_on_cpu:
190
+ self.synth_data_buffer = init_data_buffer
191
+ else:
192
+ self.register_buffer('synth_data_buffer', init_data_buffer)
193
+
194
+ self.register_buffer('step', tensor(0))
179
195
 
180
196
  @property
181
197
  def device(self):
182
- return self.synth_data_buffer.device
198
+ return self.step.device
183
199
 
184
200
  def get_rand_sequences_from_buffer(self, size = None):
185
201
  size = default(size, self.batch_size)
@@ -206,10 +222,17 @@ class UniversalPretrainWrapper(Module):
206
222
 
207
223
  generated = self.data_generator.generate(
208
224
  self.seq_len,
209
- condition = conditions,
210
- seed = seeds
225
+ condition = conditions.to(self.device),
226
+ seed = seeds.to(self.device)
211
227
  )
212
228
 
229
+ self.step.add_(1)
230
+
231
+ # maybe reset turing machine
232
+
233
+ if self.reset_turing_machine_every > 0 and divisible_by(self.step.item(), self.reset_turing_machine_every):
234
+ self.data_generator.reset_()
235
+
213
236
  # reset
214
237
 
215
238
  if self.num_reset > 0:
@@ -226,6 +249,6 @@ class UniversalPretrainWrapper(Module):
226
249
 
227
250
  # sample yet again according to pseudocode
228
251
 
229
- data = self.get_rand_sequences_from_buffer()
252
+ data = self.get_rand_sequences_from_buffer().to(self.device)
230
253
 
231
254
  return self.ar_wrapped(data)
File without changes
File without changes