x-transformers 2.11.8__py3-none-any.whl → 2.11.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

@@ -197,7 +197,9 @@ class FreeTransformer(Module):
197
197
  pre_norm_has_final_norm = False,
198
198
  **kwargs,
199
199
  **dec_kwargs
200
- )
200
+ ) if dec_head_depth > 0 else None
201
+
202
+ assert dec_tail_depth > 0
201
203
 
202
204
  self.decoder_tail = Decoder(
203
205
  dim = dim,
@@ -266,7 +268,8 @@ class FreeTransformer(Module):
266
268
  seq_len,
267
269
  latents = None,
268
270
  filter_logits_fn = top_p,
269
- logit_filter_kwargs: dict = dict(thres = 0.9)
271
+ logit_filter_kwargs: dict = dict(thres = 0.9),
272
+ use_kv_cache = True
270
273
  ):
271
274
  prompts, inverse_pack = pack_with_inverse(prompts, '* n')
272
275
 
@@ -280,10 +283,16 @@ class FreeTransformer(Module):
280
283
  latents = tensor(latents, device = self.device)
281
284
 
282
285
  if latents.ndim == 1: # repeat latents
283
- latents = repeat(latents, 'd -> b d', b = batch)
286
+ latents = repeat(latents, 'd -> b 1 d', b = batch)
287
+ elif latents.ndim == 2:
288
+ latents = rearrange(latents, 'b d -> b 1 d')
284
289
 
285
290
  condition = self.from_latent_to_condition(latents)
286
291
 
292
+ # kv cache
293
+
294
+ head_cache = tail_cache = None
295
+
287
296
  # generated
288
297
 
289
298
  prompt_len = prompts.shape[-1]
@@ -294,9 +303,20 @@ class FreeTransformer(Module):
294
303
 
295
304
  for _ in range(max(0, seq_len - prompt_len)):
296
305
 
297
- head_embed = self.decoder_head(tokens)
306
+ # head, which may not exist
307
+
308
+ if exists(self.decoder_head):
309
+ head_embed, next_head_cache = self.decoder_head(tokens, cache = head_cache, return_hiddens = True)
310
+ else:
311
+ head_embed, next_head_cache = tokens, None
298
312
 
299
- tail_embed = self.decoder_tail(head_embed, self_attn_kv_residuals = condition)
313
+ # handle one token being given to the decoder tail when doing kv caching - rotary embedding needs to know the seq position offset
314
+
315
+ seq_pos_offset = head_cache.cache_length if exists(head_cache) else 0
316
+
317
+ # tail
318
+
319
+ tail_embed, next_tail_cache = self.decoder_tail(head_embed, cache = tail_cache, seq_pos_offset = seq_pos_offset, self_attn_kv_residuals = condition, return_hiddens = True)
300
320
 
301
321
  tail_embed = tail_embed[:, -1]
302
322
 
@@ -309,6 +329,10 @@ class FreeTransformer(Module):
309
329
  generated, _ = pack((generated, sampled), 'b *')
310
330
  tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
311
331
 
332
+ if use_kv_cache:
333
+ head_cache = next_head_cache
334
+ tail_cache = next_tail_cache
335
+
312
336
  return inverse_pack(generated)
313
337
 
314
338
  def forward(
@@ -326,7 +350,8 @@ class FreeTransformer(Module):
326
350
 
327
351
  # decoder head
328
352
 
329
- tokens = self.decoder_head(tokens)
353
+ if exists(self.decoder_head):
354
+ tokens = self.decoder_head(tokens)
330
355
 
331
356
  # get latent Z
332
357
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.8
3
+ Version: 2.11.10
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -5,7 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
- x_transformers/free_transformer.py,sha256=kfl_MIZxv4TARRQbq3NroGwZSBVHdYoJNu1hfWMloco,9555
8
+ x_transformers/free_transformer.py,sha256=9vpZLTy1hWrX_wZpxmCsoJjsoKisu7HO4MTzP48oZoQ,10535
9
9
  x_transformers/gpt_vae.py,sha256=4QdznXZcU7pmMXUeEocAOKpcTkREYS-zDHktN5ADtNk,5981
10
10
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
11
11
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
@@ -14,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
14
14
  x_transformers/x_transformers.py,sha256=bYnVtkcfr082ALprIGgYIUx53lLADGYpi9t6QEJp1Kc,126907
15
15
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
16
16
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
17
- x_transformers-2.11.8.dist-info/METADATA,sha256=NTTtQVh5bRCnk7RDpma7JHairCWIvaO2euEE2djXUFA,96011
18
- x_transformers-2.11.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- x_transformers-2.11.8.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
- x_transformers-2.11.8.dist-info/RECORD,,
17
+ x_transformers-2.11.10.dist-info/METADATA,sha256=xcHidmoWV-DKOo65NAd84GKA4kRQwSGMvWh69Rwh_w8,96012
18
+ x_transformers-2.11.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ x_transformers-2.11.10.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
+ x_transformers-2.11.10.dist-info/RECORD,,