x-transformers 2.11.15__tar.gz → 2.11.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.15 → x_transformers-2.11.16}/PKG-INFO +1 -1
- {x_transformers-2.11.15 → x_transformers-2.11.16}/pyproject.toml +1 -1
- {x_transformers-2.11.15 → x_transformers-2.11.16}/tests/test_x_transformers.py +6 -2
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/free_transformer.py +23 -4
- {x_transformers-2.11.15 → x_transformers-2.11.16}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/.gitignore +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/LICENSE +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/README.md +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/data/README.md +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/data/enwik8.gz +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/all-attention.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/deepnorm.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/fcm.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/ffglu.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/flash-attention.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/gate_values.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/gating.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/macaron-1.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/macaron-2.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/normformer.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/pia.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/resi_dual.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/residual_attn.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/rezero.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/rotary.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/sandwich.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/scalenorm.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/talking-heads.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/topk-attention.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/images/xval.png +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/train_belief_state.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/train_copy.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/train_enwik8.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/train_free.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/train_parity.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/train_with_muon.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/xval.py +0 -0
|
@@ -1411,9 +1411,11 @@ def test_attn_negative_weights(
|
|
|
1411
1411
|
|
|
1412
1412
|
@param('per_token_latents', (False, True))
|
|
1413
1413
|
@param('dec_head_depth', (0, 4))
|
|
1414
|
+
@param('separate_seq_for_latents', (False, True))
|
|
1414
1415
|
def test_free(
|
|
1415
1416
|
dec_head_depth,
|
|
1416
|
-
per_token_latents
|
|
1417
|
+
per_token_latents,
|
|
1418
|
+
separate_seq_for_latents
|
|
1417
1419
|
):
|
|
1418
1420
|
from x_transformers.free_transformer import FreeTransformer
|
|
1419
1421
|
|
|
@@ -1432,7 +1434,9 @@ def test_free(
|
|
|
1432
1434
|
|
|
1433
1435
|
seq = torch.randint(0, 256, (1, 1024))
|
|
1434
1436
|
|
|
1435
|
-
|
|
1437
|
+
separate_seq_for_latents = torch.randint(0, 256, (1, 32)) if separate_seq_for_latents else None
|
|
1438
|
+
|
|
1439
|
+
loss, (ar_loss, aux_loss) = model(seq, separate_seq_for_latents, return_all_losses = True)
|
|
1436
1440
|
loss.backward()
|
|
1437
1441
|
|
|
1438
1442
|
assert aux_loss.numel() == 1
|
|
@@ -225,8 +225,11 @@ class FreeTransformer(Module):
|
|
|
225
225
|
self,
|
|
226
226
|
decoder_head_embeds,
|
|
227
227
|
mask = None,
|
|
228
|
-
return_kl_loss = False
|
|
228
|
+
return_kl_loss = False,
|
|
229
|
+
per_token_latents = None
|
|
229
230
|
):
|
|
231
|
+
per_token_latents = default(per_token_latents, self.per_token_latents)
|
|
232
|
+
|
|
230
233
|
batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
|
|
231
234
|
|
|
232
235
|
query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
|
|
@@ -235,7 +238,7 @@ class FreeTransformer(Module):
|
|
|
235
238
|
|
|
236
239
|
# handle the interesting per query token latents, as in the paper
|
|
237
240
|
|
|
238
|
-
if
|
|
241
|
+
if per_token_latents:
|
|
239
242
|
query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
|
|
240
243
|
|
|
241
244
|
rotary_pos = torch.arange(seq_len, device = device)
|
|
@@ -342,13 +345,13 @@ class FreeTransformer(Module):
|
|
|
342
345
|
def forward(
|
|
343
346
|
self,
|
|
344
347
|
seq,
|
|
348
|
+
seq_for_latents = None,
|
|
345
349
|
return_all_losses = False
|
|
346
350
|
):
|
|
347
351
|
batch, device = seq.shape[0], seq.device
|
|
348
352
|
|
|
349
353
|
seq, labels = seq[:, :-1], seq[:, 1:]
|
|
350
354
|
|
|
351
|
-
encoder_mask = seq != self.pad_id
|
|
352
355
|
|
|
353
356
|
tokens = self.token_emb(seq)
|
|
354
357
|
|
|
@@ -357,9 +360,25 @@ class FreeTransformer(Module):
|
|
|
357
360
|
if exists(self.decoder_head):
|
|
358
361
|
tokens = self.decoder_head(tokens)
|
|
359
362
|
|
|
363
|
+
# determine whether to use a separate sequence for encoding latents
|
|
364
|
+
|
|
365
|
+
if exists(seq_for_latents):
|
|
366
|
+
tokens_for_latents = self.token_emb(seq_for_latents)
|
|
367
|
+
|
|
368
|
+
if exists(self.decoder_head):
|
|
369
|
+
tokens_for_latents = self.decoder_head(tokens_for_latents)
|
|
370
|
+
|
|
371
|
+
encoder_mask = seq_for_latents != self.pad_id
|
|
372
|
+
per_token_latents = False
|
|
373
|
+
else:
|
|
374
|
+
|
|
375
|
+
tokens_for_latents = tokens
|
|
376
|
+
encoder_mask = seq != self.pad_id
|
|
377
|
+
per_token_latents = None
|
|
378
|
+
|
|
360
379
|
# get latent Z
|
|
361
380
|
|
|
362
|
-
latents, kl_loss = self.encode_to_latents(
|
|
381
|
+
latents, kl_loss = self.encode_to_latents(tokens_for_latents, mask = encoder_mask, per_token_latents = per_token_latents, return_kl_loss = True)
|
|
363
382
|
|
|
364
383
|
condition = self.from_latent_to_condition(latents)
|
|
365
384
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/nonautoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{x_transformers-2.11.15 → x_transformers-2.11.16}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
|
File without changes
|
|
File without changes
|