x-transformers 2.4.6__py3-none-any.whl → 2.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -50,7 +50,7 @@ def random_sequences(
50
50
 
51
51
  # shuffle with randperm
52
52
 
53
- rand_indices = randperm(all_seq.shape[0])
53
+ rand_indices = randperm(all_seq.shape[0], device = all_seq.device)
54
54
  return all_seq[rand_indices]
55
55
 
56
56
  # synthetic data generator
@@ -146,7 +146,8 @@ class UniversalPretrainWrapper(Module):
146
146
  batch_size = 32,
147
147
  seq_len = 512,
148
148
  seed_length = 8,
149
- reset_turing_machine_every = 0
149
+ reset_turing_machine_every = 0,
150
+ keep_buffer_on_cpu = False
150
151
  ):
151
152
  super().__init__()
152
153
 
@@ -185,12 +186,16 @@ class UniversalPretrainWrapper(Module):
185
186
 
186
187
  init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
187
188
 
188
- self.register_buffer('synth_data_buffer', init_data_buffer)
189
+ if keep_buffer_on_cpu:
190
+ self.synth_data_buffer = init_data_buffer
191
+ else:
192
+ self.register_buffer('synth_data_buffer', init_data_buffer)
193
+
189
194
  self.register_buffer('step', tensor(0))
190
195
 
191
196
  @property
192
197
  def device(self):
193
- return self.synth_data_buffer.device
198
+ return self.step.device
194
199
 
195
200
  def get_rand_sequences_from_buffer(self, size = None):
196
201
  size = default(size, self.batch_size)
@@ -217,8 +222,8 @@ class UniversalPretrainWrapper(Module):
217
222
 
218
223
  generated = self.data_generator.generate(
219
224
  self.seq_len,
220
- condition = conditions,
221
- seed = seeds
225
+ condition = conditions.to(self.device),
226
+ seed = seeds.to(self.device)
222
227
  )
223
228
 
224
229
  self.step.add_(1)
@@ -244,6 +249,6 @@ class UniversalPretrainWrapper(Module):
244
249
 
245
250
  # sample yet again according to pseudocode
246
251
 
247
- data = self.get_rand_sequences_from_buffer()
252
+ data = self.get_rand_sequences_from_buffer().to(self.device)
248
253
 
249
254
  return self.ar_wrapped(data)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.6
3
+ Version: 2.4.8
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -8,11 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/up_wrapper.py,sha256=mUTbuTPrN_Z1j2qZ0NGX96djioANNGVdfCdIlg8nCsM,6664
11
+ x_transformers/up_wrapper.py,sha256=J4cIqqEDpm9pkG-QsdlsNzUHbovOxB1iR1zYGqnDxAM,6864
12
12
  x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.4.6.dist-info/METADATA,sha256=RlLYhhqc-kLKG4PRDHqbPeVmac4oVQjnqAANxq_nXrg,90223
16
- x_transformers-2.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.4.6.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.4.6.dist-info/RECORD,,
15
+ x_transformers-2.4.8.dist-info/METADATA,sha256=tqhX2F0TzwpBfPtla3n2_6SATVJGtQKfoBbfHk9Fvzw,90223
16
+ x_transformers-2.4.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.4.8.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.4.8.dist-info/RECORD,,