x-transformers 2.4.5__py3-none-any.whl → 2.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ from functools import partial
6
6
  from random import randrange, uniform
7
7
 
8
8
  import torch
9
- from torch import nn, cat, randperm
9
+ from torch import nn, cat, tensor, randperm
10
10
  from torch.nn import LSTM, Module
11
11
 
12
12
  from x_transformers.x_transformers import (
@@ -76,6 +76,13 @@ class SyntheticDataGenerator(Module):
76
76
 
77
77
  self.apply(self.init_)
78
78
 
79
+ def reset_(self):
80
+ for m in self.modules():
81
+ if hasattr(m, 'reset_parameters'):
82
+ m.reset_parameters()
83
+
84
+ self.apply(self.init_)
85
+
79
86
  @torch.no_grad()
80
87
  def init_(self, m):
81
88
  if isinstance(m, nn.Linear):
@@ -138,7 +145,9 @@ class UniversalPretrainWrapper(Module):
138
145
  num_reset = 20,
139
146
  batch_size = 32,
140
147
  seq_len = 512,
141
- seed_length = 8
148
+ seed_length = 8,
149
+ reset_turing_machine_every = 0,
150
+ keep_buffer_on_cpu = False
142
151
  ):
143
152
  super().__init__()
144
153
 
@@ -157,6 +166,8 @@ class UniversalPretrainWrapper(Module):
157
166
  max_seq_len = seq_len
158
167
  )
159
168
 
169
+ self.reset_turing_machine_every = reset_turing_machine_every
170
+
160
171
  self.seq_len = seq_len
161
172
  self.data_generator = data_generator
162
173
 
@@ -175,11 +186,16 @@ class UniversalPretrainWrapper(Module):
175
186
 
176
187
  init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
177
188
 
178
- self.register_buffer('synth_data_buffer', init_data_buffer)
189
+ if keep_buffer_on_cpu:
190
+ self.synth_data_buffer = init_data_buffer
191
+ else:
192
+ self.register_buffer('synth_data_buffer', init_data_buffer)
193
+
194
+ self.register_buffer('step', tensor(0))
179
195
 
180
196
  @property
181
197
  def device(self):
182
- return self.synth_data_buffer.device
198
+ return self.step.device
183
199
 
184
200
  def get_rand_sequences_from_buffer(self, size = None):
185
201
  size = default(size, self.batch_size)
@@ -206,10 +222,17 @@ class UniversalPretrainWrapper(Module):
206
222
 
207
223
  generated = self.data_generator.generate(
208
224
  self.seq_len,
209
- condition = conditions,
210
- seed = seeds
225
+ condition = conditions.to(self.device),
226
+ seed = seeds.to(self.device)
211
227
  )
212
228
 
229
+ self.step.add_(1)
230
+
231
+ # maybe reset turing machine
232
+
233
+ if self.reset_turing_machine_every > 0 and divisible_by(self.step.item(), self.reset_turing_machine_every):
234
+ self.data_generator.reset_()
235
+
213
236
  # reset
214
237
 
215
238
  if self.num_reset > 0:
@@ -226,6 +249,6 @@ class UniversalPretrainWrapper(Module):
226
249
 
227
250
  # sample yet again according to pseudocode
228
251
 
229
- data = self.get_rand_sequences_from_buffer()
252
+ data = self.get_rand_sequences_from_buffer().to(self.device)
230
253
 
231
254
  return self.ar_wrapped(data)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.5
3
+ Version: 2.4.7
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -8,11 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/up_wrapper.py,sha256=RW-0EmheBqJz7g4j-M37XxWMttz-QuPhmnRWMrD7_-0,6103
11
+ x_transformers/up_wrapper.py,sha256=UDaLCQSeoFtC2b9ZPDbof4pzpbEBPO-FjQuLiA2kjP0,6839
12
12
  x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.4.5.dist-info/METADATA,sha256=J7sgvMp5pLgUWZ37qdDF4EBmECBqfgEZ0rSC7w12AAI,90223
16
- x_transformers-2.4.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.4.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.4.5.dist-info/RECORD,,
15
+ x_transformers-2.4.7.dist-info/METADATA,sha256=W-Er0p8u3zxuJmYnA8qA6u1GDUqB8p2yTlUS6Xhqsb0,90223
16
+ x_transformers-2.4.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.4.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.4.7.dist-info/RECORD,,