x-transformers 2.4.4__py3-none-any.whl → 2.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/up_wrapper.py +21 -3
- {x_transformers-2.4.4.dist-info → x_transformers-2.4.6.dist-info}/METADATA +1 -1
- {x_transformers-2.4.4.dist-info → x_transformers-2.4.6.dist-info}/RECORD +5 -5
- {x_transformers-2.4.4.dist-info → x_transformers-2.4.6.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.4.dist-info → x_transformers-2.4.6.dist-info}/licenses/LICENSE +0 -0
x_transformers/up_wrapper.py
CHANGED
|
@@ -6,7 +6,7 @@ from functools import partial
|
|
|
6
6
|
from random import randrange, uniform
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
-
from torch import nn, cat, randperm
|
|
9
|
+
from torch import nn, cat, tensor, randperm
|
|
10
10
|
from torch.nn import LSTM, Module
|
|
11
11
|
|
|
12
12
|
from x_transformers.x_transformers import (
|
|
@@ -76,6 +76,13 @@ class SyntheticDataGenerator(Module):
|
|
|
76
76
|
|
|
77
77
|
self.apply(self.init_)
|
|
78
78
|
|
|
79
|
+
def reset_(self):
|
|
80
|
+
for m in self.modules():
|
|
81
|
+
if hasattr(m, 'reset_parameters'):
|
|
82
|
+
m.reset_parameters()
|
|
83
|
+
|
|
84
|
+
self.apply(self.init_)
|
|
85
|
+
|
|
79
86
|
@torch.no_grad()
|
|
80
87
|
def init_(self, m):
|
|
81
88
|
if isinstance(m, nn.Linear):
|
|
@@ -133,12 +140,13 @@ class UniversalPretrainWrapper(Module):
|
|
|
133
140
|
def __init__(
|
|
134
141
|
self,
|
|
135
142
|
model: TransformerWrapper,
|
|
136
|
-
data_generator: SyntheticDataGenerator | None = None,
|
|
143
|
+
data_generator: SyntheticDataGenerator | Module | None = None,
|
|
137
144
|
buffer_size = None,
|
|
138
145
|
num_reset = 20,
|
|
139
146
|
batch_size = 32,
|
|
140
147
|
seq_len = 512,
|
|
141
|
-
seed_length = 8
|
|
148
|
+
seed_length = 8,
|
|
149
|
+
reset_turing_machine_every = 0
|
|
142
150
|
):
|
|
143
151
|
super().__init__()
|
|
144
152
|
|
|
@@ -157,6 +165,8 @@ class UniversalPretrainWrapper(Module):
|
|
|
157
165
|
max_seq_len = seq_len
|
|
158
166
|
)
|
|
159
167
|
|
|
168
|
+
self.reset_turing_machine_every = reset_turing_machine_every
|
|
169
|
+
|
|
160
170
|
self.seq_len = seq_len
|
|
161
171
|
self.data_generator = data_generator
|
|
162
172
|
|
|
@@ -176,6 +186,7 @@ class UniversalPretrainWrapper(Module):
|
|
|
176
186
|
init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
|
|
177
187
|
|
|
178
188
|
self.register_buffer('synth_data_buffer', init_data_buffer)
|
|
189
|
+
self.register_buffer('step', tensor(0))
|
|
179
190
|
|
|
180
191
|
@property
|
|
181
192
|
def device(self):
|
|
@@ -210,6 +221,13 @@ class UniversalPretrainWrapper(Module):
|
|
|
210
221
|
seed = seeds
|
|
211
222
|
)
|
|
212
223
|
|
|
224
|
+
self.step.add_(1)
|
|
225
|
+
|
|
226
|
+
# maybe reset turing machine
|
|
227
|
+
|
|
228
|
+
if self.reset_turing_machine_every > 0 and divisible_by(self.step.item(), self.reset_turing_machine_every):
|
|
229
|
+
self.data_generator.reset_()
|
|
230
|
+
|
|
213
231
|
# reset
|
|
214
232
|
|
|
215
233
|
if self.num_reset > 0:
|
|
@@ -8,11 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
|
11
|
-
x_transformers/up_wrapper.py,sha256=
|
|
11
|
+
x_transformers/up_wrapper.py,sha256=mUTbuTPrN_Z1j2qZ0NGX96djioANNGVdfCdIlg8nCsM,6664
|
|
12
12
|
x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
|
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
15
|
-
x_transformers-2.4.
|
|
16
|
-
x_transformers-2.4.
|
|
17
|
-
x_transformers-2.4.
|
|
18
|
-
x_transformers-2.4.
|
|
15
|
+
x_transformers-2.4.6.dist-info/METADATA,sha256=RlLYhhqc-kLKG4PRDHqbPeVmac4oVQjnqAANxq_nXrg,90223
|
|
16
|
+
x_transformers-2.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
17
|
+
x_transformers-2.4.6.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
18
|
+
x_transformers-2.4.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|