x-transformers 2.11.8__py3-none-any.whl → 2.11.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- x_transformers/free_transformer.py +31 -6
- {x_transformers-2.11.8.dist-info → x_transformers-2.11.10.dist-info}/METADATA +1 -1
- {x_transformers-2.11.8.dist-info → x_transformers-2.11.10.dist-info}/RECORD +5 -5
- {x_transformers-2.11.8.dist-info → x_transformers-2.11.10.dist-info}/WHEEL +0 -0
- {x_transformers-2.11.8.dist-info → x_transformers-2.11.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -197,7 +197,9 @@ class FreeTransformer(Module):
|
|
|
197
197
|
pre_norm_has_final_norm = False,
|
|
198
198
|
**kwargs,
|
|
199
199
|
**dec_kwargs
|
|
200
|
-
)
|
|
200
|
+
) if dec_head_depth > 0 else None
|
|
201
|
+
|
|
202
|
+
assert dec_tail_depth > 0
|
|
201
203
|
|
|
202
204
|
self.decoder_tail = Decoder(
|
|
203
205
|
dim = dim,
|
|
@@ -266,7 +268,8 @@ class FreeTransformer(Module):
|
|
|
266
268
|
seq_len,
|
|
267
269
|
latents = None,
|
|
268
270
|
filter_logits_fn = top_p,
|
|
269
|
-
logit_filter_kwargs: dict = dict(thres = 0.9)
|
|
271
|
+
logit_filter_kwargs: dict = dict(thres = 0.9),
|
|
272
|
+
use_kv_cache = True
|
|
270
273
|
):
|
|
271
274
|
prompts, inverse_pack = pack_with_inverse(prompts, '* n')
|
|
272
275
|
|
|
@@ -280,10 +283,16 @@ class FreeTransformer(Module):
|
|
|
280
283
|
latents = tensor(latents, device = self.device)
|
|
281
284
|
|
|
282
285
|
if latents.ndim == 1: # repeat latents
|
|
283
|
-
latents = repeat(latents, 'd -> b d', b = batch)
|
|
286
|
+
latents = repeat(latents, 'd -> b 1 d', b = batch)
|
|
287
|
+
elif latents.ndim == 2:
|
|
288
|
+
latents = rearrange(latents, 'b d -> b 1 d')
|
|
284
289
|
|
|
285
290
|
condition = self.from_latent_to_condition(latents)
|
|
286
291
|
|
|
292
|
+
# kv cache
|
|
293
|
+
|
|
294
|
+
head_cache = tail_cache = None
|
|
295
|
+
|
|
287
296
|
# generated
|
|
288
297
|
|
|
289
298
|
prompt_len = prompts.shape[-1]
|
|
@@ -294,9 +303,20 @@ class FreeTransformer(Module):
|
|
|
294
303
|
|
|
295
304
|
for _ in range(max(0, seq_len - prompt_len)):
|
|
296
305
|
|
|
297
|
-
|
|
306
|
+
# head, which may not exist
|
|
307
|
+
|
|
308
|
+
if exists(self.decoder_head):
|
|
309
|
+
head_embed, next_head_cache = self.decoder_head(tokens, cache = head_cache, return_hiddens = True)
|
|
310
|
+
else:
|
|
311
|
+
head_embed, next_head_cache = tokens, None
|
|
298
312
|
|
|
299
|
-
|
|
313
|
+
# handle one token being given to the decoder tail when doing kv caching - rotary embedding needs to know the seq position offset
|
|
314
|
+
|
|
315
|
+
seq_pos_offset = head_cache.cache_length if exists(head_cache) else 0
|
|
316
|
+
|
|
317
|
+
# tail
|
|
318
|
+
|
|
319
|
+
tail_embed, next_tail_cache = self.decoder_tail(head_embed, cache = tail_cache, seq_pos_offset = seq_pos_offset, self_attn_kv_residuals = condition, return_hiddens = True)
|
|
300
320
|
|
|
301
321
|
tail_embed = tail_embed[:, -1]
|
|
302
322
|
|
|
@@ -309,6 +329,10 @@ class FreeTransformer(Module):
|
|
|
309
329
|
generated, _ = pack((generated, sampled), 'b *')
|
|
310
330
|
tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
|
|
311
331
|
|
|
332
|
+
if use_kv_cache:
|
|
333
|
+
head_cache = next_head_cache
|
|
334
|
+
tail_cache = next_tail_cache
|
|
335
|
+
|
|
312
336
|
return inverse_pack(generated)
|
|
313
337
|
|
|
314
338
|
def forward(
|
|
@@ -326,7 +350,8 @@ class FreeTransformer(Module):
|
|
|
326
350
|
|
|
327
351
|
# decoder head
|
|
328
352
|
|
|
329
|
-
|
|
353
|
+
if exists(self.decoder_head):
|
|
354
|
+
tokens = self.decoder_head(tokens)
|
|
330
355
|
|
|
331
356
|
# get latent Z
|
|
332
357
|
|
|
@@ -5,7 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
|
|
|
5
5
|
x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
|
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
|
8
|
-
x_transformers/free_transformer.py,sha256=
|
|
8
|
+
x_transformers/free_transformer.py,sha256=9vpZLTy1hWrX_wZpxmCsoJjsoKisu7HO4MTzP48oZoQ,10535
|
|
9
9
|
x_transformers/gpt_vae.py,sha256=4QdznXZcU7pmMXUeEocAOKpcTkREYS-zDHktN5ADtNk,5981
|
|
10
10
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
|
11
11
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
@@ -14,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
|
|
|
14
14
|
x_transformers/x_transformers.py,sha256=bYnVtkcfr082ALprIGgYIUx53lLADGYpi9t6QEJp1Kc,126907
|
|
15
15
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
16
16
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
17
|
-
x_transformers-2.11.
|
|
18
|
-
x_transformers-2.11.
|
|
19
|
-
x_transformers-2.11.
|
|
20
|
-
x_transformers-2.11.
|
|
17
|
+
x_transformers-2.11.10.dist-info/METADATA,sha256=xcHidmoWV-DKOo65NAd84GKA4kRQwSGMvWh69Rwh_w8,96012
|
|
18
|
+
x_transformers-2.11.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
19
|
+
x_transformers-2.11.10.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
20
|
+
x_transformers-2.11.10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|