x-transformers 2.11.9__tar.gz → 2.11.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.9 → x_transformers-2.11.10}/PKG-INFO +1 -1
- {x_transformers-2.11.9 → x_transformers-2.11.10}/pyproject.toml +1 -1
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/free_transformer.py +29 -6
- {x_transformers-2.11.9 → x_transformers-2.11.10}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/.gitignore +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/LICENSE +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/README.md +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/data/README.md +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/data/enwik8.gz +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/all-attention.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/deepnorm.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/fcm.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/ffglu.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/flash-attention.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/gate_values.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/gating.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/macaron-1.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/macaron-2.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/normformer.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/pia.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/resi_dual.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/residual_attn.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/rezero.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/rotary.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/sandwich.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/scalenorm.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/talking-heads.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/topk-attention.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/images/xval.png +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/tests/test_x_transformers.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/train_belief_state.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/train_copy.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/train_enwik8.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/train_free.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/train_parity.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/train_with_muon.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/xval.py +0 -0
|
@@ -197,7 +197,7 @@ class FreeTransformer(Module):
|
|
|
197
197
|
pre_norm_has_final_norm = False,
|
|
198
198
|
**kwargs,
|
|
199
199
|
**dec_kwargs
|
|
200
|
-
) if dec_head_depth > 0 else
|
|
200
|
+
) if dec_head_depth > 0 else None
|
|
201
201
|
|
|
202
202
|
assert dec_tail_depth > 0
|
|
203
203
|
|
|
@@ -268,7 +268,8 @@ class FreeTransformer(Module):
|
|
|
268
268
|
seq_len,
|
|
269
269
|
latents = None,
|
|
270
270
|
filter_logits_fn = top_p,
|
|
271
|
-
logit_filter_kwargs: dict = dict(thres = 0.9)
|
|
271
|
+
logit_filter_kwargs: dict = dict(thres = 0.9),
|
|
272
|
+
use_kv_cache = True
|
|
272
273
|
):
|
|
273
274
|
prompts, inverse_pack = pack_with_inverse(prompts, '* n')
|
|
274
275
|
|
|
@@ -282,10 +283,16 @@ class FreeTransformer(Module):
|
|
|
282
283
|
latents = tensor(latents, device = self.device)
|
|
283
284
|
|
|
284
285
|
if latents.ndim == 1: # repeat latents
|
|
285
|
-
latents = repeat(latents, 'd -> b d', b = batch)
|
|
286
|
+
latents = repeat(latents, 'd -> b 1 d', b = batch)
|
|
287
|
+
elif latents.ndim == 2:
|
|
288
|
+
latents = rearrange(latents, 'b d -> b 1 d')
|
|
286
289
|
|
|
287
290
|
condition = self.from_latent_to_condition(latents)
|
|
288
291
|
|
|
292
|
+
# kv cache
|
|
293
|
+
|
|
294
|
+
head_cache = tail_cache = None
|
|
295
|
+
|
|
289
296
|
# generated
|
|
290
297
|
|
|
291
298
|
prompt_len = prompts.shape[-1]
|
|
@@ -296,9 +303,20 @@ class FreeTransformer(Module):
|
|
|
296
303
|
|
|
297
304
|
for _ in range(max(0, seq_len - prompt_len)):
|
|
298
305
|
|
|
299
|
-
|
|
306
|
+
# head, which may not exist
|
|
307
|
+
|
|
308
|
+
if exists(self.decoder_head):
|
|
309
|
+
head_embed, next_head_cache = self.decoder_head(tokens, cache = head_cache, return_hiddens = True)
|
|
310
|
+
else:
|
|
311
|
+
head_embed, next_head_cache = tokens, None
|
|
312
|
+
|
|
313
|
+
# handle one token being given to the decoder tail when doing kv caching - rotary embedding needs to know the seq position offset
|
|
300
314
|
|
|
301
|
-
|
|
315
|
+
seq_pos_offset = head_cache.cache_length if exists(head_cache) else 0
|
|
316
|
+
|
|
317
|
+
# tail
|
|
318
|
+
|
|
319
|
+
tail_embed, next_tail_cache = self.decoder_tail(head_embed, cache = tail_cache, seq_pos_offset = seq_pos_offset, self_attn_kv_residuals = condition, return_hiddens = True)
|
|
302
320
|
|
|
303
321
|
tail_embed = tail_embed[:, -1]
|
|
304
322
|
|
|
@@ -311,6 +329,10 @@ class FreeTransformer(Module):
|
|
|
311
329
|
generated, _ = pack((generated, sampled), 'b *')
|
|
312
330
|
tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
|
|
313
331
|
|
|
332
|
+
if use_kv_cache:
|
|
333
|
+
head_cache = next_head_cache
|
|
334
|
+
tail_cache = next_tail_cache
|
|
335
|
+
|
|
314
336
|
return inverse_pack(generated)
|
|
315
337
|
|
|
316
338
|
def forward(
|
|
@@ -328,7 +350,8 @@ class FreeTransformer(Module):
|
|
|
328
350
|
|
|
329
351
|
# decoder head
|
|
330
352
|
|
|
331
|
-
|
|
353
|
+
if exists(self.decoder_head):
|
|
354
|
+
tokens = self.decoder_head(tokens)
|
|
332
355
|
|
|
333
356
|
# get latent Z
|
|
334
357
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/nonautoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.9 → x_transformers-2.11.10}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|