x-transformers 2.4.5__py3-none-any.whl → 2.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/up_wrapper.py +30 -7
- {x_transformers-2.4.5.dist-info → x_transformers-2.4.7.dist-info}/METADATA +1 -1
- {x_transformers-2.4.5.dist-info → x_transformers-2.4.7.dist-info}/RECORD +5 -5
- {x_transformers-2.4.5.dist-info → x_transformers-2.4.7.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.5.dist-info → x_transformers-2.4.7.dist-info}/licenses/LICENSE +0 -0
x_transformers/up_wrapper.py
CHANGED
|
@@ -6,7 +6,7 @@ from functools import partial
|
|
|
6
6
|
from random import randrange, uniform
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
-
from torch import nn, cat, randperm
|
|
9
|
+
from torch import nn, cat, tensor, randperm
|
|
10
10
|
from torch.nn import LSTM, Module
|
|
11
11
|
|
|
12
12
|
from x_transformers.x_transformers import (
|
|
@@ -76,6 +76,13 @@ class SyntheticDataGenerator(Module):
|
|
|
76
76
|
|
|
77
77
|
self.apply(self.init_)
|
|
78
78
|
|
|
79
|
+
def reset_(self):
|
|
80
|
+
for m in self.modules():
|
|
81
|
+
if hasattr(m, 'reset_parameters'):
|
|
82
|
+
m.reset_parameters()
|
|
83
|
+
|
|
84
|
+
self.apply(self.init_)
|
|
85
|
+
|
|
79
86
|
@torch.no_grad()
|
|
80
87
|
def init_(self, m):
|
|
81
88
|
if isinstance(m, nn.Linear):
|
|
@@ -138,7 +145,9 @@ class UniversalPretrainWrapper(Module):
|
|
|
138
145
|
num_reset = 20,
|
|
139
146
|
batch_size = 32,
|
|
140
147
|
seq_len = 512,
|
|
141
|
-
seed_length = 8
|
|
148
|
+
seed_length = 8,
|
|
149
|
+
reset_turing_machine_every = 0,
|
|
150
|
+
keep_buffer_on_cpu = False
|
|
142
151
|
):
|
|
143
152
|
super().__init__()
|
|
144
153
|
|
|
@@ -157,6 +166,8 @@ class UniversalPretrainWrapper(Module):
|
|
|
157
166
|
max_seq_len = seq_len
|
|
158
167
|
)
|
|
159
168
|
|
|
169
|
+
self.reset_turing_machine_every = reset_turing_machine_every
|
|
170
|
+
|
|
160
171
|
self.seq_len = seq_len
|
|
161
172
|
self.data_generator = data_generator
|
|
162
173
|
|
|
@@ -175,11 +186,16 @@ class UniversalPretrainWrapper(Module):
|
|
|
175
186
|
|
|
176
187
|
init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
|
|
177
188
|
|
|
178
|
-
|
|
189
|
+
if keep_buffer_on_cpu:
|
|
190
|
+
self.synth_data_buffer = init_data_buffer
|
|
191
|
+
else:
|
|
192
|
+
self.register_buffer('synth_data_buffer', init_data_buffer)
|
|
193
|
+
|
|
194
|
+
self.register_buffer('step', tensor(0))
|
|
179
195
|
|
|
180
196
|
@property
|
|
181
197
|
def device(self):
|
|
182
|
-
return self.
|
|
198
|
+
return self.step.device
|
|
183
199
|
|
|
184
200
|
def get_rand_sequences_from_buffer(self, size = None):
|
|
185
201
|
size = default(size, self.batch_size)
|
|
@@ -206,10 +222,17 @@ class UniversalPretrainWrapper(Module):
|
|
|
206
222
|
|
|
207
223
|
generated = self.data_generator.generate(
|
|
208
224
|
self.seq_len,
|
|
209
|
-
condition = conditions,
|
|
210
|
-
seed = seeds
|
|
225
|
+
condition = conditions.to(self.device),
|
|
226
|
+
seed = seeds.to(self.device)
|
|
211
227
|
)
|
|
212
228
|
|
|
229
|
+
self.step.add_(1)
|
|
230
|
+
|
|
231
|
+
# maybe reset turing machine
|
|
232
|
+
|
|
233
|
+
if self.reset_turing_machine_every > 0 and divisible_by(self.step.item(), self.reset_turing_machine_every):
|
|
234
|
+
self.data_generator.reset_()
|
|
235
|
+
|
|
213
236
|
# reset
|
|
214
237
|
|
|
215
238
|
if self.num_reset > 0:
|
|
@@ -226,6 +249,6 @@ class UniversalPretrainWrapper(Module):
|
|
|
226
249
|
|
|
227
250
|
# sample yet again according to pseudocode
|
|
228
251
|
|
|
229
|
-
data = self.get_rand_sequences_from_buffer()
|
|
252
|
+
data = self.get_rand_sequences_from_buffer().to(self.device)
|
|
230
253
|
|
|
231
254
|
return self.ar_wrapped(data)
|
|
@@ -8,11 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
|
11
|
-
x_transformers/up_wrapper.py,sha256=
|
|
11
|
+
x_transformers/up_wrapper.py,sha256=UDaLCQSeoFtC2b9ZPDbof4pzpbEBPO-FjQuLiA2kjP0,6839
|
|
12
12
|
x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
|
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
15
|
-
x_transformers-2.4.
|
|
16
|
-
x_transformers-2.4.
|
|
17
|
-
x_transformers-2.4.
|
|
18
|
-
x_transformers-2.4.
|
|
15
|
+
x_transformers-2.4.7.dist-info/METADATA,sha256=W-Er0p8u3zxuJmYnA8qA6u1GDUqB8p2yTlUS6Xhqsb0,90223
|
|
16
|
+
x_transformers-2.4.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
17
|
+
x_transformers-2.4.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
18
|
+
x_transformers-2.4.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|