x-transformers 2.4.6__tar.gz → 2.4.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.6 → x_transformers-2.4.7}/PKG-INFO +1 -1
  2. {x_transformers-2.4.6 → x_transformers-2.4.7}/pyproject.toml +1 -1
  3. {x_transformers-2.4.6 → x_transformers-2.4.7}/tests/test_x_transformers.py +9 -2
  4. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/up_wrapper.py +11 -6
  5. {x_transformers-2.4.6 → x_transformers-2.4.7}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.4.6 → x_transformers-2.4.7}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.4.6 → x_transformers-2.4.7}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.4.6 → x_transformers-2.4.7}/.gitignore +0 -0
  9. {x_transformers-2.4.6 → x_transformers-2.4.7}/LICENSE +0 -0
  10. {x_transformers-2.4.6 → x_transformers-2.4.7}/README.md +0 -0
  11. {x_transformers-2.4.6 → x_transformers-2.4.7}/data/README.md +0 -0
  12. {x_transformers-2.4.6 → x_transformers-2.4.7}/data/enwik8.gz +0 -0
  13. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/all-attention.png +0 -0
  14. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/deepnorm.png +0 -0
  17. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/fcm.png +0 -0
  23. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/ffglu.png +0 -0
  24. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/flash-attention.png +0 -0
  25. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/gate_values.png +0 -0
  26. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/gating.png +0 -0
  27. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/macaron-1.png +0 -0
  29. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/macaron-2.png +0 -0
  30. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/normformer.png +0 -0
  32. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/pia.png +0 -0
  33. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/resi_dual.png +0 -0
  35. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/residual_attn.png +0 -0
  36. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/rezero.png +0 -0
  37. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/rotary.png +0 -0
  38. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/sandwich.png +0 -0
  40. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/scalenorm.png +0 -0
  42. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/talking-heads.png +0 -0
  43. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/topk-attention.png +0 -0
  44. {x_transformers-2.4.6 → x_transformers-2.4.7}/images/xval.png +0 -0
  45. {x_transformers-2.4.6 → x_transformers-2.4.7}/train_belief_state.py +0 -0
  46. {x_transformers-2.4.6 → x_transformers-2.4.7}/train_copy.py +0 -0
  47. {x_transformers-2.4.6 → x_transformers-2.4.7}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.4.6 → x_transformers-2.4.7}/train_enwik8.py +0 -0
  49. {x_transformers-2.4.6 → x_transformers-2.4.7}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.4.6 → x_transformers-2.4.7}/train_parity.py +0 -0
  51. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/x_transformers.py +0 -0
  62. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.6 → x_transformers-2.4.7}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.6
3
+ Version: 2.4.7
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.6"
3
+ version = "2.4.7"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1100,7 +1100,10 @@ def add_attn_pool():
1100
1100
 
1101
1101
  assert intermediates.attn_pooled_tokens.shape[1] == 3
1102
1102
 
1103
- def test_up():
1103
+ @pytest.mark.parametrize('keep_buffer_on_cpu', (False, True))
1104
+ def test_up(
1105
+ keep_buffer_on_cpu
1106
+ ):
1104
1107
  from x_transformers.up_wrapper import UniversalPretrainWrapper
1105
1108
 
1106
1109
  model = TransformerWrapper(
@@ -1115,7 +1118,11 @@ def test_up():
1115
1118
  ),
1116
1119
  )
1117
1120
 
1118
- up_wrapper = UniversalPretrainWrapper(model, seq_len = 16)
1121
+ up_wrapper = UniversalPretrainWrapper(
1122
+ model,
1123
+ seq_len = 16,
1124
+ keep_buffer_on_cpu = keep_buffer_on_cpu
1125
+ )
1119
1126
 
1120
1127
  loss = up_wrapper()
1121
1128
  loss.backward()
@@ -146,7 +146,8 @@ class UniversalPretrainWrapper(Module):
146
146
  batch_size = 32,
147
147
  seq_len = 512,
148
148
  seed_length = 8,
149
- reset_turing_machine_every = 0
149
+ reset_turing_machine_every = 0,
150
+ keep_buffer_on_cpu = False
150
151
  ):
151
152
  super().__init__()
152
153
 
@@ -185,12 +186,16 @@ class UniversalPretrainWrapper(Module):
185
186
 
186
187
  init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
187
188
 
188
- self.register_buffer('synth_data_buffer', init_data_buffer)
189
+ if keep_buffer_on_cpu:
190
+ self.synth_data_buffer = init_data_buffer
191
+ else:
192
+ self.register_buffer('synth_data_buffer', init_data_buffer)
193
+
189
194
  self.register_buffer('step', tensor(0))
190
195
 
191
196
  @property
192
197
  def device(self):
193
- return self.synth_data_buffer.device
198
+ return self.step.device
194
199
 
195
200
  def get_rand_sequences_from_buffer(self, size = None):
196
201
  size = default(size, self.batch_size)
@@ -217,8 +222,8 @@ class UniversalPretrainWrapper(Module):
217
222
 
218
223
  generated = self.data_generator.generate(
219
224
  self.seq_len,
220
- condition = conditions,
221
- seed = seeds
225
+ condition = conditions.to(self.device),
226
+ seed = seeds.to(self.device)
222
227
  )
223
228
 
224
229
  self.step.add_(1)
@@ -244,6 +249,6 @@ class UniversalPretrainWrapper(Module):
244
249
 
245
250
  # sample yet again according to pseudocode
246
251
 
247
- data = self.get_rand_sequences_from_buffer()
252
+ data = self.get_rand_sequences_from_buffer().to(self.device)
248
253
 
249
254
  return self.ar_wrapped(data)
File without changes
File without changes