x-transformers 2.11.8__tar.gz → 2.11.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.8 → x_transformers-2.11.10}/PKG-INFO +1 -1
- {x_transformers-2.11.8 → x_transformers-2.11.10}/pyproject.toml +1 -1
- {x_transformers-2.11.8 → x_transformers-2.11.10}/tests/test_x_transformers.py +4 -2
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/free_transformer.py +31 -6
- {x_transformers-2.11.8 → x_transformers-2.11.10}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/.gitignore +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/LICENSE +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/README.md +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/data/README.md +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/data/enwik8.gz +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/all-attention.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/deepnorm.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/fcm.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/ffglu.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/flash-attention.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/gate_values.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/gating.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/macaron-1.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/macaron-2.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/normformer.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/pia.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/resi_dual.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/residual_attn.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/rezero.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/rotary.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/sandwich.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/scalenorm.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/talking-heads.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/topk-attention.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/images/xval.png +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/train_belief_state.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/train_copy.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/train_enwik8.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/train_free.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/train_parity.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/train_with_muon.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/xval.py +0 -0
|
@@ -1410,7 +1410,9 @@ def test_attn_negative_weights(
|
|
|
1410
1410
|
logits = model(x)
|
|
1411
1411
|
|
|
1412
1412
|
@param('per_token_latents', (False, True))
|
|
1413
|
+
@param('dec_head_depth', (0, 4))
|
|
1413
1414
|
def test_free(
|
|
1415
|
+
dec_head_depth,
|
|
1414
1416
|
per_token_latents
|
|
1415
1417
|
):
|
|
1416
1418
|
from x_transformers.free_transformer import FreeTransformer
|
|
@@ -1420,9 +1422,9 @@ def test_free(
|
|
|
1420
1422
|
max_seq_len = 1024,
|
|
1421
1423
|
dim = 512,
|
|
1422
1424
|
heads = 8,
|
|
1423
|
-
dec_head_depth =
|
|
1425
|
+
dec_head_depth = dec_head_depth,
|
|
1424
1426
|
dec_tail_depth = 4,
|
|
1425
|
-
enc_depth =
|
|
1427
|
+
enc_depth = 2,
|
|
1426
1428
|
kl_loss_weight = 1.,
|
|
1427
1429
|
per_token_latents = per_token_latents,
|
|
1428
1430
|
latent_bits = 8
|
|
@@ -197,7 +197,9 @@ class FreeTransformer(Module):
|
|
|
197
197
|
pre_norm_has_final_norm = False,
|
|
198
198
|
**kwargs,
|
|
199
199
|
**dec_kwargs
|
|
200
|
-
)
|
|
200
|
+
) if dec_head_depth > 0 else None
|
|
201
|
+
|
|
202
|
+
assert dec_tail_depth > 0
|
|
201
203
|
|
|
202
204
|
self.decoder_tail = Decoder(
|
|
203
205
|
dim = dim,
|
|
@@ -266,7 +268,8 @@ class FreeTransformer(Module):
|
|
|
266
268
|
seq_len,
|
|
267
269
|
latents = None,
|
|
268
270
|
filter_logits_fn = top_p,
|
|
269
|
-
logit_filter_kwargs: dict = dict(thres = 0.9)
|
|
271
|
+
logit_filter_kwargs: dict = dict(thres = 0.9),
|
|
272
|
+
use_kv_cache = True
|
|
270
273
|
):
|
|
271
274
|
prompts, inverse_pack = pack_with_inverse(prompts, '* n')
|
|
272
275
|
|
|
@@ -280,10 +283,16 @@ class FreeTransformer(Module):
|
|
|
280
283
|
latents = tensor(latents, device = self.device)
|
|
281
284
|
|
|
282
285
|
if latents.ndim == 1: # repeat latents
|
|
283
|
-
latents = repeat(latents, 'd -> b d', b = batch)
|
|
286
|
+
latents = repeat(latents, 'd -> b 1 d', b = batch)
|
|
287
|
+
elif latents.ndim == 2:
|
|
288
|
+
latents = rearrange(latents, 'b d -> b 1 d')
|
|
284
289
|
|
|
285
290
|
condition = self.from_latent_to_condition(latents)
|
|
286
291
|
|
|
292
|
+
# kv cache
|
|
293
|
+
|
|
294
|
+
head_cache = tail_cache = None
|
|
295
|
+
|
|
287
296
|
# generated
|
|
288
297
|
|
|
289
298
|
prompt_len = prompts.shape[-1]
|
|
@@ -294,9 +303,20 @@ class FreeTransformer(Module):
|
|
|
294
303
|
|
|
295
304
|
for _ in range(max(0, seq_len - prompt_len)):
|
|
296
305
|
|
|
297
|
-
|
|
306
|
+
# head, which may not exist
|
|
307
|
+
|
|
308
|
+
if exists(self.decoder_head):
|
|
309
|
+
head_embed, next_head_cache = self.decoder_head(tokens, cache = head_cache, return_hiddens = True)
|
|
310
|
+
else:
|
|
311
|
+
head_embed, next_head_cache = tokens, None
|
|
298
312
|
|
|
299
|
-
|
|
313
|
+
# handle one token being given to the decoder tail when doing kv caching - rotary embedding needs to know the seq position offset
|
|
314
|
+
|
|
315
|
+
seq_pos_offset = head_cache.cache_length if exists(head_cache) else 0
|
|
316
|
+
|
|
317
|
+
# tail
|
|
318
|
+
|
|
319
|
+
tail_embed, next_tail_cache = self.decoder_tail(head_embed, cache = tail_cache, seq_pos_offset = seq_pos_offset, self_attn_kv_residuals = condition, return_hiddens = True)
|
|
300
320
|
|
|
301
321
|
tail_embed = tail_embed[:, -1]
|
|
302
322
|
|
|
@@ -309,6 +329,10 @@ class FreeTransformer(Module):
|
|
|
309
329
|
generated, _ = pack((generated, sampled), 'b *')
|
|
310
330
|
tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
|
|
311
331
|
|
|
332
|
+
if use_kv_cache:
|
|
333
|
+
head_cache = next_head_cache
|
|
334
|
+
tail_cache = next_tail_cache
|
|
335
|
+
|
|
312
336
|
return inverse_pack(generated)
|
|
313
337
|
|
|
314
338
|
def forward(
|
|
@@ -326,7 +350,8 @@ class FreeTransformer(Module):
|
|
|
326
350
|
|
|
327
351
|
# decoder head
|
|
328
352
|
|
|
329
|
-
|
|
353
|
+
if exists(self.decoder_head):
|
|
354
|
+
tokens = self.decoder_head(tokens)
|
|
330
355
|
|
|
331
356
|
# get latent Z
|
|
332
357
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/nonautoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.8 → x_transformers-2.11.10}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|