x-transformers 2.4.2__py3-none-any.whl → 2.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/up_wrapper.py +9 -2
- {x_transformers-2.4.2.dist-info → x_transformers-2.4.3.dist-info}/METADATA +1 -1
- {x_transformers-2.4.2.dist-info → x_transformers-2.4.3.dist-info}/RECORD +5 -5
- {x_transformers-2.4.2.dist-info → x_transformers-2.4.3.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.2.dist-info → x_transformers-2.4.3.dist-info}/licenses/LICENSE +0 -0
x_transformers/up_wrapper.py
CHANGED
@@ -153,7 +153,8 @@ class UniversalPretrainWrapper(Module):
|
|
153
153
|
if not exists(data_generator):
|
154
154
|
data_generator = SyntheticDataGenerator(
|
155
155
|
num_tokens = num_tokens,
|
156
|
-
dim = dim
|
156
|
+
dim = dim,
|
157
|
+
max_seq_len = seq_len
|
157
158
|
)
|
158
159
|
|
159
160
|
self.seq_len = seq_len
|
@@ -203,7 +204,7 @@ class UniversalPretrainWrapper(Module):
|
|
203
204
|
|
204
205
|
# seed, condition to turing machine
|
205
206
|
|
206
|
-
|
207
|
+
generated = self.data_generator.generate(
|
207
208
|
self.seq_len,
|
208
209
|
condition = conditions,
|
209
210
|
seed = seeds
|
@@ -218,6 +219,12 @@ class UniversalPretrainWrapper(Module):
|
|
218
219
|
reset_sequences = self.random_sequences_fn(self.num_reset // 2, self.num_reset // 2, device = self.device)
|
219
220
|
buffer_to_reset.copy_(reset_sequences)
|
220
221
|
|
222
|
+
# place "enriched" random generated sequences back
|
223
|
+
|
224
|
+
with torch.no_grad():
|
225
|
+
print(conditions.shape, generated.shape)
|
226
|
+
conditions.copy_(generated)
|
227
|
+
|
221
228
|
# sample yet again according to pseudocode
|
222
229
|
|
223
230
|
data = self.get_rand_sequences_from_buffer()
|
@@ -8,11 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/up_wrapper.py,sha256=
|
11
|
+
x_transformers/up_wrapper.py,sha256=qMUi7Ahoz9elf2SYLERMSYygI3rznJxVd9nR4jJM4ZA,6147
|
12
12
|
x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.4.
|
16
|
-
x_transformers-2.4.
|
17
|
-
x_transformers-2.4.
|
18
|
-
x_transformers-2.4.
|
15
|
+
x_transformers-2.4.3.dist-info/METADATA,sha256=XbY4s1Cu9E6lbWa5xyTc0drX4Vf5r3tOCYhnN9htNS4,90223
|
16
|
+
x_transformers-2.4.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.4.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.4.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|