x-transformers 2.11.15__py3-none-any.whl → 2.11.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- x_transformers/free_transformer.py +23 -4
- {x_transformers-2.11.15.dist-info → x_transformers-2.11.16.dist-info}/METADATA +1 -1
- {x_transformers-2.11.15.dist-info → x_transformers-2.11.16.dist-info}/RECORD +5 -5
- {x_transformers-2.11.15.dist-info → x_transformers-2.11.16.dist-info}/WHEEL +0 -0
- {x_transformers-2.11.15.dist-info → x_transformers-2.11.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -225,8 +225,11 @@ class FreeTransformer(Module):
|
|
|
225
225
|
self,
|
|
226
226
|
decoder_head_embeds,
|
|
227
227
|
mask = None,
|
|
228
|
-
return_kl_loss = False
|
|
228
|
+
return_kl_loss = False,
|
|
229
|
+
per_token_latents = None
|
|
229
230
|
):
|
|
231
|
+
per_token_latents = default(per_token_latents, self.per_token_latents)
|
|
232
|
+
|
|
230
233
|
batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
|
|
231
234
|
|
|
232
235
|
query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
|
|
@@ -235,7 +238,7 @@ class FreeTransformer(Module):
|
|
|
235
238
|
|
|
236
239
|
# handle the interesting per query token latents, as in the paper
|
|
237
240
|
|
|
238
|
-
if
|
|
241
|
+
if per_token_latents:
|
|
239
242
|
query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
|
|
240
243
|
|
|
241
244
|
rotary_pos = torch.arange(seq_len, device = device)
|
|
@@ -342,13 +345,13 @@ class FreeTransformer(Module):
|
|
|
342
345
|
def forward(
|
|
343
346
|
self,
|
|
344
347
|
seq,
|
|
348
|
+
seq_for_latents = None,
|
|
345
349
|
return_all_losses = False
|
|
346
350
|
):
|
|
347
351
|
batch, device = seq.shape[0], seq.device
|
|
348
352
|
|
|
349
353
|
seq, labels = seq[:, :-1], seq[:, 1:]
|
|
350
354
|
|
|
351
|
-
encoder_mask = seq != self.pad_id
|
|
352
355
|
|
|
353
356
|
tokens = self.token_emb(seq)
|
|
354
357
|
|
|
@@ -357,9 +360,25 @@ class FreeTransformer(Module):
|
|
|
357
360
|
if exists(self.decoder_head):
|
|
358
361
|
tokens = self.decoder_head(tokens)
|
|
359
362
|
|
|
363
|
+
# determine whether to use a separate sequence for encoding latents
|
|
364
|
+
|
|
365
|
+
if exists(seq_for_latents):
|
|
366
|
+
tokens_for_latents = self.token_emb(seq_for_latents)
|
|
367
|
+
|
|
368
|
+
if exists(self.decoder_head):
|
|
369
|
+
tokens_for_latents = self.decoder_head(tokens_for_latents)
|
|
370
|
+
|
|
371
|
+
encoder_mask = seq_for_latents != self.pad_id
|
|
372
|
+
per_token_latents = False
|
|
373
|
+
else:
|
|
374
|
+
|
|
375
|
+
tokens_for_latents = tokens
|
|
376
|
+
encoder_mask = seq != self.pad_id
|
|
377
|
+
per_token_latents = None
|
|
378
|
+
|
|
360
379
|
# get latent Z
|
|
361
380
|
|
|
362
|
-
latents, kl_loss = self.encode_to_latents(
|
|
381
|
+
latents, kl_loss = self.encode_to_latents(tokens_for_latents, mask = encoder_mask, per_token_latents = per_token_latents, return_kl_loss = True)
|
|
363
382
|
|
|
364
383
|
condition = self.from_latent_to_condition(latents)
|
|
365
384
|
|
|
@@ -5,7 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
|
|
|
5
5
|
x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
|
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
|
8
|
-
x_transformers/free_transformer.py,sha256=
|
|
8
|
+
x_transformers/free_transformer.py,sha256=_hYYkaro3xei3MC3rwtuCWi9gSnciXyAT91_7SrA0nw,11396
|
|
9
9
|
x_transformers/gpt_vae.py,sha256=1zyjwgfZr6CRDsh5VMCPSdoCPg-sdX5mXmZ_mn4VyYQ,6082
|
|
10
10
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
|
11
11
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
@@ -14,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
|
|
|
14
14
|
x_transformers/x_transformers.py,sha256=5ctPu8tvlbUMrtW360e_LPnoGv6xcgQFsyWdbvLo6Tk,127002
|
|
15
15
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
16
16
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
17
|
-
x_transformers-2.11.
|
|
18
|
-
x_transformers-2.11.
|
|
19
|
-
x_transformers-2.11.
|
|
20
|
-
x_transformers-2.11.
|
|
17
|
+
x_transformers-2.11.16.dist-info/METADATA,sha256=cvhm5LnIRCdqLuv25iSU4vj0a6Np9j2lv2O9W-V48-k,96012
|
|
18
|
+
x_transformers-2.11.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
19
|
+
x_transformers-2.11.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
20
|
+
x_transformers-2.11.16.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|