x-transformers 2.4.6__py3-none-any.whl → 2.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/up_wrapper.py +11 -6
- {x_transformers-2.4.6.dist-info → x_transformers-2.4.7.dist-info}/METADATA +1 -1
- {x_transformers-2.4.6.dist-info → x_transformers-2.4.7.dist-info}/RECORD +5 -5
- {x_transformers-2.4.6.dist-info → x_transformers-2.4.7.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.6.dist-info → x_transformers-2.4.7.dist-info}/licenses/LICENSE +0 -0
x_transformers/up_wrapper.py
CHANGED
|
@@ -146,7 +146,8 @@ class UniversalPretrainWrapper(Module):
|
|
|
146
146
|
batch_size = 32,
|
|
147
147
|
seq_len = 512,
|
|
148
148
|
seed_length = 8,
|
|
149
|
-
reset_turing_machine_every = 0
|
|
149
|
+
reset_turing_machine_every = 0,
|
|
150
|
+
keep_buffer_on_cpu = False
|
|
150
151
|
):
|
|
151
152
|
super().__init__()
|
|
152
153
|
|
|
@@ -185,12 +186,16 @@ class UniversalPretrainWrapper(Module):
|
|
|
185
186
|
|
|
186
187
|
init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
|
|
187
188
|
|
|
188
|
-
|
|
189
|
+
if keep_buffer_on_cpu:
|
|
190
|
+
self.synth_data_buffer = init_data_buffer
|
|
191
|
+
else:
|
|
192
|
+
self.register_buffer('synth_data_buffer', init_data_buffer)
|
|
193
|
+
|
|
189
194
|
self.register_buffer('step', tensor(0))
|
|
190
195
|
|
|
191
196
|
@property
|
|
192
197
|
def device(self):
|
|
193
|
-
return self.
|
|
198
|
+
return self.step.device
|
|
194
199
|
|
|
195
200
|
def get_rand_sequences_from_buffer(self, size = None):
|
|
196
201
|
size = default(size, self.batch_size)
|
|
@@ -217,8 +222,8 @@ class UniversalPretrainWrapper(Module):
|
|
|
217
222
|
|
|
218
223
|
generated = self.data_generator.generate(
|
|
219
224
|
self.seq_len,
|
|
220
|
-
condition = conditions,
|
|
221
|
-
seed = seeds
|
|
225
|
+
condition = conditions.to(self.device),
|
|
226
|
+
seed = seeds.to(self.device)
|
|
222
227
|
)
|
|
223
228
|
|
|
224
229
|
self.step.add_(1)
|
|
@@ -244,6 +249,6 @@ class UniversalPretrainWrapper(Module):
|
|
|
244
249
|
|
|
245
250
|
# sample yet again according to pseudocode
|
|
246
251
|
|
|
247
|
-
data = self.get_rand_sequences_from_buffer()
|
|
252
|
+
data = self.get_rand_sequences_from_buffer().to(self.device)
|
|
248
253
|
|
|
249
254
|
return self.ar_wrapped(data)
|
|
@@ -8,11 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
|
11
|
-
x_transformers/up_wrapper.py,sha256=
|
|
11
|
+
x_transformers/up_wrapper.py,sha256=UDaLCQSeoFtC2b9ZPDbof4pzpbEBPO-FjQuLiA2kjP0,6839
|
|
12
12
|
x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
|
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
15
|
-
x_transformers-2.4.
|
|
16
|
-
x_transformers-2.4.
|
|
17
|
-
x_transformers-2.4.
|
|
18
|
-
x_transformers-2.4.
|
|
15
|
+
x_transformers-2.4.7.dist-info/METADATA,sha256=W-Er0p8u3zxuJmYnA8qA6u1GDUqB8p2yTlUS6Xhqsb0,90223
|
|
16
|
+
x_transformers-2.4.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
17
|
+
x_transformers-2.4.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
18
|
+
x_transformers-2.4.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|