x-transformers 1.42.7__tar.gz → 1.42.9__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.7/x_transformers.egg-info → x_transformers-1.42.9}/PKG-INFO +1 -1
  2. {x_transformers-1.42.7 → x_transformers-1.42.9}/setup.py +1 -1
  3. {x_transformers-1.42.7 → x_transformers-1.42.9}/tests/test_x_transformers.py +115 -1
  4. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/x_transformers.py +39 -18
  5. {x_transformers-1.42.7 → x_transformers-1.42.9/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.42.7 → x_transformers-1.42.9}/LICENSE +0 -0
  7. {x_transformers-1.42.7 → x_transformers-1.42.9}/README.md +0 -0
  8. {x_transformers-1.42.7 → x_transformers-1.42.9}/setup.cfg +0 -0
  9. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.42.7 → x_transformers-1.42.9}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.7
3
+ Version: 1.42.9
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.7',
6
+ version = '1.42.9',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -1,12 +1,15 @@
1
1
  import pytest
2
+
2
3
  import torch
4
+ from torch import nn
5
+ from torch.nn import Module
3
6
 
4
7
  from x_transformers.x_transformers import (
5
8
  XTransformer,
6
9
  TransformerWrapper,
7
10
  Encoder,
8
11
  Decoder,
9
- AutoregressiveWrapper,
12
+ LinearNoBias,
10
13
  )
11
14
 
12
15
  from x_transformers.neo_mlp import (
@@ -378,6 +381,7 @@ def test_neo_mlp():
378
381
  assert out.shape == (3, 7)
379
382
 
380
383
  def test_custom_alibi():
384
+
381
385
  model = TransformerWrapper(
382
386
  num_tokens = 20_000,
383
387
  max_seq_len = 1024,
@@ -394,3 +398,113 @@ def test_custom_alibi():
394
398
  pos = torch.tensor([[0, 1, 2, 4], [1, 3, 5, 7]])
395
399
 
396
400
  logits = model(x, pos = pos)
401
+
402
+ def test_custom_alibi_across_heads():
403
+
404
+ model = Decoder(
405
+ dim = 512,
406
+ depth = 2,
407
+ heads = 2,
408
+ alibi_pos_bias = True,
409
+ rel_pos_kwargs = dict(
410
+ slopes = [1, 1]
411
+ ),
412
+ )
413
+
414
+ x = torch.randn(2, 4, 512)
415
+
416
+ pos = torch.tensor([
417
+ [[0, 1, 2, 4], [1, 3, 5, 7]],
418
+ [[2, 3, 4, 5], [6, 8, 9, 10]]
419
+ ])
420
+
421
+ embed = model(x, pos = pos)
422
+
423
+ @pytest.mark.parametrize('embedder_type', ('embedding', 'none', 'custom'))
424
+ def test_embedder(embedder_type):
425
+ num_tokens = 20000
426
+ dim = 128
427
+ token_emb_kwargs = {}
428
+
429
+ if embedder_type == 'embedding':
430
+ embedder = nn.Embedding(num_tokens, dim)
431
+ elif embedder_type == 'none':
432
+ embedder = None
433
+ else:
434
+ class CustomEmbedder(Module):
435
+ """
436
+ Made up embedder that sums two embeddings. Just to check if we can pass additional input to the embedder's
437
+ forward pass without breaking the model.
438
+ """
439
+ def __init__(self, num_tokens, dim):
440
+ super().__init__()
441
+ self.embed_x = nn.Embedding(num_tokens, dim)
442
+ self.embed_y = nn.Embedding(num_tokens, dim)
443
+
444
+ def forward(self, x, y):
445
+ return self.embed_x(x) + self.embed_y(y)
446
+
447
+ def init_(self):
448
+ pass
449
+
450
+ embedder = CustomEmbedder(num_tokens, dim)
451
+ token_emb_kwargs['y'] = torch.randint(0, num_tokens, (2, 1024))
452
+
453
+ model = TransformerWrapper(
454
+ num_tokens = num_tokens,
455
+ max_seq_len = 1024,
456
+ attn_layers = Decoder(
457
+ dim = dim,
458
+ depth = 6,
459
+ heads = 8,
460
+ ),
461
+ token_emb = embedder,
462
+ )
463
+
464
+ x = torch.randint(0, 20000, (2, 1024))
465
+
466
+ output = model(x, token_emb_kwargs=token_emb_kwargs)
467
+ assert output.shape == (2, 1024, 20000)
468
+
469
+
470
+ @pytest.mark.parametrize("to_logits", ('linear', 'none', 'pointer'))
471
+ def test_to_logits(to_logits):
472
+ num_tokens = 20000
473
+ dim = 128
474
+
475
+ to_logits_kwargs = {}
476
+
477
+ if to_logits == 'linear':
478
+ logit_mapper = LinearNoBias(dim, num_tokens)
479
+ elif to_logits == 'none':
480
+ logit_mapper = None
481
+ else:
482
+ class PointerNetworkLogits(Module):
483
+ def __init__(self, dim):
484
+ super().__init__()
485
+ self.proj_to_pointers = nn.Linear(dim, dim)
486
+
487
+ def forward(self, model_embeddings, input_embeddings):
488
+ pointers = self.proj_to_pointers(model_embeddings)
489
+ logits = torch.matmul(pointers, input_embeddings.permute(0, 2, 1))
490
+ return logits
491
+
492
+ logit_mapper = PointerNetworkLogits(dim)
493
+ to_logits_kwargs['input_embeddings'] = torch.randn(2, 20000, dim)
494
+
495
+ model = TransformerWrapper(
496
+ num_tokens = num_tokens,
497
+ max_seq_len = 1024,
498
+ attn_layers = Decoder(
499
+ dim = dim,
500
+ depth = 6,
501
+ heads = 8,
502
+ ),
503
+ to_logits = logit_mapper,
504
+ )
505
+
506
+ x = torch.randint(0, num_tokens, (2, 1024))
507
+
508
+ output = model(x, to_logits_kwargs=to_logits_kwargs)
509
+
510
+ assert output.shape == (2, 1024, 20000)
@@ -238,6 +238,13 @@ class TokenEmbedding(Module):
238
238
  token_emb = self.emb(x.long())
239
239
  return l2norm(token_emb) if self.l2norm_embed else token_emb
240
240
 
241
+ def init_(self):
242
+ if self.l2norm_embed:
243
+ nn.init.normal_(self.emb.weight, std=1e-5)
244
+ return
245
+ nn.init.kaiming_normal_(self.emb.weight)
246
+
247
+
241
248
  # positional embeddings
242
249
 
243
250
  class AbsolutePositionalEmbedding(Module):
@@ -445,13 +452,20 @@ class DynamicPositionBias(Module):
445
452
  return bias
446
453
 
447
454
  class AlibiPositionalBias(Module):
448
- def __init__(self, heads, total_heads = None, **kwargs):
455
+ def __init__(
456
+ self,
457
+ heads,
458
+ total_heads = None,
459
+ slopes: list[int] | None = None,
460
+ **kwargs
461
+ ):
449
462
  super().__init__()
450
463
  self.heads = heads
451
464
  self.total_heads = default(total_heads, heads)
452
465
 
453
- slopes = Tensor(self._get_slopes(heads))
466
+ slopes = Tensor(default(slopes, self._get_slopes(heads)))
454
467
  slopes = rearrange(slopes, 'h -> h 1 1')
468
+
455
469
  self.register_buffer('slopes', slopes, persistent = False)
456
470
  self.register_buffer('bias', None, persistent = False)
457
471
 
@@ -480,7 +494,10 @@ class AlibiPositionalBias(Module):
480
494
  h, device = self.total_heads, self.device
481
495
 
482
496
  pos_j = default(pos_j, pos_i)
483
- bias = -einx.subtract('... j, ... i -> ... 1 i j', pos_j, pos_i).abs()
497
+ bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()
498
+
499
+ if bias.ndim == 3:
500
+ bias = rearrange(bias, 'b i j -> b 1 i j')
484
501
 
485
502
  bias = bias * self.slopes
486
503
  num_heads_unalibied = h - bias.shape[-3]
@@ -1524,8 +1541,9 @@ class AttentionLayers(Module):
1524
1541
  use_layerscale = False,
1525
1542
  layerscale_init_value = 0.,
1526
1543
  unet_skips = False,
1527
- reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1528
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1544
+ reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1545
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1546
+ rel_pos_kwargs: dict = dict(),
1529
1547
  **kwargs
1530
1548
  ):
1531
1549
  super().__init__()
@@ -1566,14 +1584,14 @@ class AttentionLayers(Module):
1566
1584
 
1567
1585
  if rel_pos_bias:
1568
1586
  assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1569
- self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1587
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance, **rel_pos_kwargs)
1570
1588
  elif dynamic_pos_bias:
1571
1589
  assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1572
- self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1590
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm, **rel_pos_kwargs)
1573
1591
  elif alibi_pos_bias:
1574
1592
  alibi_num_heads = default(alibi_num_heads, heads)
1575
1593
  assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1576
- self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1594
+ self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
1577
1595
 
1578
1596
  assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
1579
1597
  assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
@@ -2261,7 +2279,8 @@ class TransformerWrapper(Module):
2261
2279
  token_emb: TokenEmbedding | None = None,
2262
2280
  mixture_of_softmax = False,
2263
2281
  mixture_of_softmax_k = 4,
2264
- sigsoftmax_logits = False
2282
+ sigsoftmax_logits = False,
2283
+ to_logits: Module | None = None,
2265
2284
  ):
2266
2285
  super().__init__()
2267
2286
 
@@ -2363,11 +2382,12 @@ class TransformerWrapper(Module):
2363
2382
  if return_only_embed:
2364
2383
  self.to_logits = None
2365
2384
  elif tie_embedding:
2385
+ assert isinstance(token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
2366
2386
  self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
2367
2387
  elif num_output_heads > 1:
2368
2388
  self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
2369
2389
  else:
2370
- self.to_logits = LinearNoBias(dim, logits_dim)
2390
+ self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
2371
2391
 
2372
2392
  # memory tokens (like [cls]) from Memory Transformers paper
2373
2393
 
@@ -2388,13 +2408,12 @@ class TransformerWrapper(Module):
2388
2408
  self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
2389
2409
 
2390
2410
  def init_(self):
2411
+ if hasattr(self.token_emb, 'init_'):
2412
+ self.token_emb.init_()
2413
+
2391
2414
  if self.l2norm_embed:
2392
- nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2393
2415
  if not isinstance(self.pos_emb, always):
2394
2416
  nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2395
- return
2396
-
2397
- nn.init.kaiming_normal_(self.token_emb.emb.weight)
2398
2417
 
2399
2418
  def forward(
2400
2419
  self,
@@ -2417,7 +2436,9 @@ class TransformerWrapper(Module):
2417
2436
  attn_z_loss_weight = 1e-4,
2418
2437
  seq_start_pos = None,
2419
2438
  cache: LayerIntermediates | None = None,
2420
- **kwargs
2439
+ token_emb_kwargs = dict(),
2440
+ to_logits_kwargs = dict(),
2441
+ **kwargs,
2421
2442
  ):
2422
2443
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2423
2444
 
@@ -2428,7 +2449,7 @@ class TransformerWrapper(Module):
2428
2449
 
2429
2450
  external_pos_emb = exists(pos) and pos.dtype != torch.long
2430
2451
  pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2431
- x = self.token_emb(x) + pos_emb
2452
+ x = self.token_emb(x, **token_emb_kwargs) + pos_emb
2432
2453
 
2433
2454
  # add additional embeddings
2434
2455
 
@@ -2583,9 +2604,9 @@ class TransformerWrapper(Module):
2583
2604
 
2584
2605
  if not return_embeddings:
2585
2606
  if self.has_multiple_heads:
2586
- logits = tuple(fn(x) for fn in self.to_logits)
2607
+ logits = tuple(fn(x, **to_logits_kwargs) for fn in self.to_logits)
2587
2608
  else:
2588
- logits = self.to_logits(x)
2609
+ logits = self.to_logits(x, **to_logits_kwargs)
2589
2610
 
2590
2611
  # maybe sig softmax
2591
2612
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.7
3
+ Version: 1.42.9
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes