x-transformers 1.42.7__tar.gz → 1.42.9__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.7/x_transformers.egg-info → x_transformers-1.42.9}/PKG-INFO +1 -1
- {x_transformers-1.42.7 → x_transformers-1.42.9}/setup.py +1 -1
- {x_transformers-1.42.7 → x_transformers-1.42.9}/tests/test_x_transformers.py +115 -1
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/x_transformers.py +39 -18
- {x_transformers-1.42.7 → x_transformers-1.42.9/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.7 → x_transformers-1.42.9}/LICENSE +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/README.md +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/setup.cfg +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/attend.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/continuous.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,12 +1,15 @@
|
|
1
1
|
import pytest
|
2
|
+
|
2
3
|
import torch
|
4
|
+
from torch import nn
|
5
|
+
from torch.nn import Module
|
3
6
|
|
4
7
|
from x_transformers.x_transformers import (
|
5
8
|
XTransformer,
|
6
9
|
TransformerWrapper,
|
7
10
|
Encoder,
|
8
11
|
Decoder,
|
9
|
-
|
12
|
+
LinearNoBias,
|
10
13
|
)
|
11
14
|
|
12
15
|
from x_transformers.neo_mlp import (
|
@@ -378,6 +381,7 @@ def test_neo_mlp():
|
|
378
381
|
assert out.shape == (3, 7)
|
379
382
|
|
380
383
|
def test_custom_alibi():
|
384
|
+
|
381
385
|
model = TransformerWrapper(
|
382
386
|
num_tokens = 20_000,
|
383
387
|
max_seq_len = 1024,
|
@@ -394,3 +398,113 @@ def test_custom_alibi():
|
|
394
398
|
pos = torch.tensor([[0, 1, 2, 4], [1, 3, 5, 7]])
|
395
399
|
|
396
400
|
logits = model(x, pos = pos)
|
401
|
+
|
402
|
+
def test_custom_alibi_across_heads():
|
403
|
+
|
404
|
+
model = Decoder(
|
405
|
+
dim = 512,
|
406
|
+
depth = 2,
|
407
|
+
heads = 2,
|
408
|
+
alibi_pos_bias = True,
|
409
|
+
rel_pos_kwargs = dict(
|
410
|
+
slopes = [1, 1]
|
411
|
+
),
|
412
|
+
)
|
413
|
+
|
414
|
+
x = torch.randn(2, 4, 512)
|
415
|
+
|
416
|
+
pos = torch.tensor([
|
417
|
+
[[0, 1, 2, 4], [1, 3, 5, 7]],
|
418
|
+
[[2, 3, 4, 5], [6, 8, 9, 10]]
|
419
|
+
])
|
420
|
+
|
421
|
+
embed = model(x, pos = pos)
|
422
|
+
|
423
|
+
@pytest.mark.parametrize('embedder_type', ('embedding', 'none', 'custom'))
|
424
|
+
def test_embedder(embedder_type):
|
425
|
+
num_tokens = 20000
|
426
|
+
dim = 128
|
427
|
+
token_emb_kwargs = {}
|
428
|
+
|
429
|
+
if embedder_type == 'embedding':
|
430
|
+
embedder = nn.Embedding(num_tokens, dim)
|
431
|
+
elif embedder_type == 'none':
|
432
|
+
embedder = None
|
433
|
+
else:
|
434
|
+
class CustomEmbedder(Module):
|
435
|
+
"""
|
436
|
+
Made up embedder that sums two embeddings. Just to check if we can pass additional input to the embedder's
|
437
|
+
forward pass without breaking the model.
|
438
|
+
"""
|
439
|
+
def __init__(self, num_tokens, dim):
|
440
|
+
super().__init__()
|
441
|
+
self.embed_x = nn.Embedding(num_tokens, dim)
|
442
|
+
self.embed_y = nn.Embedding(num_tokens, dim)
|
443
|
+
|
444
|
+
def forward(self, x, y):
|
445
|
+
return self.embed_x(x) + self.embed_y(y)
|
446
|
+
|
447
|
+
def init_(self):
|
448
|
+
pass
|
449
|
+
|
450
|
+
embedder = CustomEmbedder(num_tokens, dim)
|
451
|
+
token_emb_kwargs['y'] = torch.randint(0, num_tokens, (2, 1024))
|
452
|
+
|
453
|
+
model = TransformerWrapper(
|
454
|
+
num_tokens = num_tokens,
|
455
|
+
max_seq_len = 1024,
|
456
|
+
attn_layers = Decoder(
|
457
|
+
dim = dim,
|
458
|
+
depth = 6,
|
459
|
+
heads = 8,
|
460
|
+
),
|
461
|
+
token_emb = embedder,
|
462
|
+
)
|
463
|
+
|
464
|
+
x = torch.randint(0, 20000, (2, 1024))
|
465
|
+
|
466
|
+
output = model(x, token_emb_kwargs=token_emb_kwargs)
|
467
|
+
assert output.shape == (2, 1024, 20000)
|
468
|
+
|
469
|
+
|
470
|
+
@pytest.mark.parametrize("to_logits", ('linear', 'none', 'pointer'))
|
471
|
+
def test_to_logits(to_logits):
|
472
|
+
num_tokens = 20000
|
473
|
+
dim = 128
|
474
|
+
|
475
|
+
to_logits_kwargs = {}
|
476
|
+
|
477
|
+
if to_logits == 'linear':
|
478
|
+
logit_mapper = LinearNoBias(dim, num_tokens)
|
479
|
+
elif to_logits == 'none':
|
480
|
+
logit_mapper = None
|
481
|
+
else:
|
482
|
+
class PointerNetworkLogits(Module):
|
483
|
+
def __init__(self, dim):
|
484
|
+
super().__init__()
|
485
|
+
self.proj_to_pointers = nn.Linear(dim, dim)
|
486
|
+
|
487
|
+
def forward(self, model_embeddings, input_embeddings):
|
488
|
+
pointers = self.proj_to_pointers(model_embeddings)
|
489
|
+
logits = torch.matmul(pointers, input_embeddings.permute(0, 2, 1))
|
490
|
+
return logits
|
491
|
+
|
492
|
+
logit_mapper = PointerNetworkLogits(dim)
|
493
|
+
to_logits_kwargs['input_embeddings'] = torch.randn(2, 20000, dim)
|
494
|
+
|
495
|
+
model = TransformerWrapper(
|
496
|
+
num_tokens = num_tokens,
|
497
|
+
max_seq_len = 1024,
|
498
|
+
attn_layers = Decoder(
|
499
|
+
dim = dim,
|
500
|
+
depth = 6,
|
501
|
+
heads = 8,
|
502
|
+
),
|
503
|
+
to_logits = logit_mapper,
|
504
|
+
)
|
505
|
+
|
506
|
+
x = torch.randint(0, num_tokens, (2, 1024))
|
507
|
+
|
508
|
+
output = model(x, to_logits_kwargs=to_logits_kwargs)
|
509
|
+
|
510
|
+
assert output.shape == (2, 1024, 20000)
|
@@ -238,6 +238,13 @@ class TokenEmbedding(Module):
|
|
238
238
|
token_emb = self.emb(x.long())
|
239
239
|
return l2norm(token_emb) if self.l2norm_embed else token_emb
|
240
240
|
|
241
|
+
def init_(self):
|
242
|
+
if self.l2norm_embed:
|
243
|
+
nn.init.normal_(self.emb.weight, std=1e-5)
|
244
|
+
return
|
245
|
+
nn.init.kaiming_normal_(self.emb.weight)
|
246
|
+
|
247
|
+
|
241
248
|
# positional embeddings
|
242
249
|
|
243
250
|
class AbsolutePositionalEmbedding(Module):
|
@@ -445,13 +452,20 @@ class DynamicPositionBias(Module):
|
|
445
452
|
return bias
|
446
453
|
|
447
454
|
class AlibiPositionalBias(Module):
|
448
|
-
def __init__(
|
455
|
+
def __init__(
|
456
|
+
self,
|
457
|
+
heads,
|
458
|
+
total_heads = None,
|
459
|
+
slopes: list[int] | None = None,
|
460
|
+
**kwargs
|
461
|
+
):
|
449
462
|
super().__init__()
|
450
463
|
self.heads = heads
|
451
464
|
self.total_heads = default(total_heads, heads)
|
452
465
|
|
453
|
-
slopes = Tensor(self._get_slopes(heads))
|
466
|
+
slopes = Tensor(default(slopes, self._get_slopes(heads)))
|
454
467
|
slopes = rearrange(slopes, 'h -> h 1 1')
|
468
|
+
|
455
469
|
self.register_buffer('slopes', slopes, persistent = False)
|
456
470
|
self.register_buffer('bias', None, persistent = False)
|
457
471
|
|
@@ -480,7 +494,10 @@ class AlibiPositionalBias(Module):
|
|
480
494
|
h, device = self.total_heads, self.device
|
481
495
|
|
482
496
|
pos_j = default(pos_j, pos_i)
|
483
|
-
bias = -einx.subtract('... j, ... i -> ...
|
497
|
+
bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()
|
498
|
+
|
499
|
+
if bias.ndim == 3:
|
500
|
+
bias = rearrange(bias, 'b i j -> b 1 i j')
|
484
501
|
|
485
502
|
bias = bias * self.slopes
|
486
503
|
num_heads_unalibied = h - bias.shape[-3]
|
@@ -1524,8 +1541,9 @@ class AttentionLayers(Module):
|
|
1524
1541
|
use_layerscale = False,
|
1525
1542
|
layerscale_init_value = 0.,
|
1526
1543
|
unet_skips = False,
|
1527
|
-
reinject_input = False,
|
1528
|
-
add_value_residual = False,
|
1544
|
+
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1545
|
+
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
|
1546
|
+
rel_pos_kwargs: dict = dict(),
|
1529
1547
|
**kwargs
|
1530
1548
|
):
|
1531
1549
|
super().__init__()
|
@@ -1566,14 +1584,14 @@ class AttentionLayers(Module):
|
|
1566
1584
|
|
1567
1585
|
if rel_pos_bias:
|
1568
1586
|
assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
|
1569
|
-
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
1587
|
+
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance, **rel_pos_kwargs)
|
1570
1588
|
elif dynamic_pos_bias:
|
1571
1589
|
assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
|
1572
|
-
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
|
1590
|
+
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm, **rel_pos_kwargs)
|
1573
1591
|
elif alibi_pos_bias:
|
1574
1592
|
alibi_num_heads = default(alibi_num_heads, heads)
|
1575
1593
|
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
1576
|
-
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
|
1594
|
+
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
|
1577
1595
|
|
1578
1596
|
assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
|
1579
1597
|
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
@@ -2261,7 +2279,8 @@ class TransformerWrapper(Module):
|
|
2261
2279
|
token_emb: TokenEmbedding | None = None,
|
2262
2280
|
mixture_of_softmax = False,
|
2263
2281
|
mixture_of_softmax_k = 4,
|
2264
|
-
sigsoftmax_logits = False
|
2282
|
+
sigsoftmax_logits = False,
|
2283
|
+
to_logits: Module | None = None,
|
2265
2284
|
):
|
2266
2285
|
super().__init__()
|
2267
2286
|
|
@@ -2363,11 +2382,12 @@ class TransformerWrapper(Module):
|
|
2363
2382
|
if return_only_embed:
|
2364
2383
|
self.to_logits = None
|
2365
2384
|
elif tie_embedding:
|
2385
|
+
assert isinstance(token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
|
2366
2386
|
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
2367
2387
|
elif num_output_heads > 1:
|
2368
2388
|
self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
|
2369
2389
|
else:
|
2370
|
-
self.to_logits = LinearNoBias(dim, logits_dim)
|
2390
|
+
self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
|
2371
2391
|
|
2372
2392
|
# memory tokens (like [cls]) from Memory Transformers paper
|
2373
2393
|
|
@@ -2388,13 +2408,12 @@ class TransformerWrapper(Module):
|
|
2388
2408
|
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
2389
2409
|
|
2390
2410
|
def init_(self):
|
2411
|
+
if hasattr(self.token_emb, 'init_'):
|
2412
|
+
self.token_emb.init_()
|
2413
|
+
|
2391
2414
|
if self.l2norm_embed:
|
2392
|
-
nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
|
2393
2415
|
if not isinstance(self.pos_emb, always):
|
2394
2416
|
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
|
2395
|
-
return
|
2396
|
-
|
2397
|
-
nn.init.kaiming_normal_(self.token_emb.emb.weight)
|
2398
2417
|
|
2399
2418
|
def forward(
|
2400
2419
|
self,
|
@@ -2417,7 +2436,9 @@ class TransformerWrapper(Module):
|
|
2417
2436
|
attn_z_loss_weight = 1e-4,
|
2418
2437
|
seq_start_pos = None,
|
2419
2438
|
cache: LayerIntermediates | None = None,
|
2420
|
-
|
2439
|
+
token_emb_kwargs = dict(),
|
2440
|
+
to_logits_kwargs = dict(),
|
2441
|
+
**kwargs,
|
2421
2442
|
):
|
2422
2443
|
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
2423
2444
|
|
@@ -2428,7 +2449,7 @@ class TransformerWrapper(Module):
|
|
2428
2449
|
|
2429
2450
|
external_pos_emb = exists(pos) and pos.dtype != torch.long
|
2430
2451
|
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
|
2431
|
-
x = self.token_emb(x) + pos_emb
|
2452
|
+
x = self.token_emb(x, **token_emb_kwargs) + pos_emb
|
2432
2453
|
|
2433
2454
|
# add additional embeddings
|
2434
2455
|
|
@@ -2583,9 +2604,9 @@ class TransformerWrapper(Module):
|
|
2583
2604
|
|
2584
2605
|
if not return_embeddings:
|
2585
2606
|
if self.has_multiple_heads:
|
2586
|
-
logits = tuple(fn(x) for fn in self.to_logits)
|
2607
|
+
logits = tuple(fn(x, **to_logits_kwargs) for fn in self.to_logits)
|
2587
2608
|
else:
|
2588
|
-
logits = self.to_logits(x)
|
2609
|
+
logits = self.to_logits(x, **to_logits_kwargs)
|
2589
2610
|
|
2590
2611
|
# maybe sig softmax
|
2591
2612
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|