x-transformers 1.42.6__tar.gz → 1.42.8__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.6/x_transformers.egg-info → x_transformers-1.42.8}/PKG-INFO +1 -1
  2. {x_transformers-1.42.6 → x_transformers-1.42.8}/setup.py +1 -1
  3. {x_transformers-1.42.6 → x_transformers-1.42.8}/tests/test_x_transformers.py +112 -1
  4. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/x_transformers.py +34 -13
  5. {x_transformers-1.42.6 → x_transformers-1.42.8/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.42.6 → x_transformers-1.42.8}/LICENSE +0 -0
  7. {x_transformers-1.42.6 → x_transformers-1.42.8}/README.md +0 -0
  8. {x_transformers-1.42.6 → x_transformers-1.42.8}/setup.cfg +0 -0
  9. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.42.6 → x_transformers-1.42.8}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.6
3
+ Version: 1.42.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.6',
6
+ version = '1.42.8',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -1,12 +1,15 @@
1
1
  import pytest
2
+
2
3
  import torch
4
+ from torch import nn
5
+ from torch.nn import Module
3
6
 
4
7
  from x_transformers.x_transformers import (
5
8
  XTransformer,
6
9
  TransformerWrapper,
7
10
  Encoder,
8
11
  Decoder,
9
- AutoregressiveWrapper,
12
+ LinearNoBias,
10
13
  )
11
14
 
12
15
  from x_transformers.neo_mlp import (
@@ -376,3 +379,111 @@ def test_neo_mlp():
376
379
 
377
380
  out = mlp(x)
378
381
  assert out.shape == (3, 7)
382
+
383
+ def test_custom_alibi():
384
+ model = TransformerWrapper(
385
+ num_tokens = 20_000,
386
+ max_seq_len = 1024,
387
+ attn_layers = Decoder(
388
+ dim = 512,
389
+ depth = 2,
390
+ heads = 8,
391
+ alibi_pos_bias = True
392
+ )
393
+ )
394
+
395
+ x = torch.randint(0, 20000, (2, 4))
396
+
397
+ pos = torch.tensor([[0, 1, 2, 4], [1, 3, 5, 7]])
398
+
399
+ logits = model(x, pos = pos)
400
+
401
+
402
+ @pytest.mark.parametrize('embedder_type', ('embedding', 'none', 'custom'))
403
+ def test_embedder(embedder_type):
404
+ num_tokens = 20000
405
+ dim = 128
406
+ token_emb_kwargs = {}
407
+
408
+ if embedder_type == 'embedding':
409
+ embedder = nn.Embedding(num_tokens, dim)
410
+ elif embedder_type == 'none':
411
+ embedder = None
412
+ else:
413
+ class CustomEmbedder(Module):
414
+ """
415
+ Made up embedder that sums two embeddings. Just to check if we can pass additional input to the embedder's
416
+ forward pass without breaking the model.
417
+ """
418
+ def __init__(self, num_tokens, dim):
419
+ super().__init__()
420
+ self.embed_x = nn.Embedding(num_tokens, dim)
421
+ self.embed_y = nn.Embedding(num_tokens, dim)
422
+
423
+ def forward(self, x, y):
424
+ return self.embed_x(x) + self.embed_y(y)
425
+
426
+ def init_(self):
427
+ pass
428
+
429
+ embedder = CustomEmbedder(num_tokens, dim)
430
+ token_emb_kwargs['y'] = torch.randint(0, num_tokens, (2, 1024))
431
+
432
+ model = TransformerWrapper(
433
+ num_tokens = num_tokens,
434
+ max_seq_len = 1024,
435
+ attn_layers = Decoder(
436
+ dim = dim,
437
+ depth = 6,
438
+ heads = 8,
439
+ ),
440
+ token_emb = embedder,
441
+ )
442
+
443
+ x = torch.randint(0, 20000, (2, 1024))
444
+
445
+ output = model(x, token_emb_kwargs=token_emb_kwargs)
446
+ assert output.shape == (2, 1024, 20000)
447
+
448
+
449
+ @pytest.mark.parametrize("to_logits", ('linear', 'none', 'pointer'))
450
+ def test_to_logits(to_logits):
451
+ num_tokens = 20000
452
+ dim = 128
453
+
454
+ to_logits_kwargs = {}
455
+
456
+ if to_logits == 'linear':
457
+ logit_mapper = LinearNoBias(dim, num_tokens)
458
+ elif to_logits == 'none':
459
+ logit_mapper = None
460
+ else:
461
+ class PointerNetworkLogits(Module):
462
+ def __init__(self, dim):
463
+ super().__init__()
464
+ self.proj_to_pointers = nn.Linear(dim, dim)
465
+
466
+ def forward(self, model_embeddings, input_embeddings):
467
+ pointers = self.proj_to_pointers(model_embeddings)
468
+ logits = torch.matmul(pointers, input_embeddings.permute(0, 2, 1))
469
+ return logits
470
+
471
+ logit_mapper = PointerNetworkLogits(dim)
472
+ to_logits_kwargs['input_embeddings'] = torch.randn(2, 20000, dim)
473
+
474
+ model = TransformerWrapper(
475
+ num_tokens = num_tokens,
476
+ max_seq_len = 1024,
477
+ attn_layers = Decoder(
478
+ dim = dim,
479
+ depth = 6,
480
+ heads = 8,
481
+ ),
482
+ to_logits = logit_mapper,
483
+ )
484
+
485
+ x = torch.randint(0, num_tokens, (2, 1024))
486
+
487
+ output = model(x, to_logits_kwargs=to_logits_kwargs)
488
+
489
+ assert output.shape == (2, 1024, 20000)
@@ -238,6 +238,13 @@ class TokenEmbedding(Module):
238
238
  token_emb = self.emb(x.long())
239
239
  return l2norm(token_emb) if self.l2norm_embed else token_emb
240
240
 
241
+ def init_(self):
242
+ if self.l2norm_embed:
243
+ nn.init.normal_(self.emb.weight, std=1e-5)
244
+ return
245
+ nn.init.kaiming_normal_(self.emb.weight)
246
+
247
+
241
248
  # positional embeddings
242
249
 
243
250
  class AbsolutePositionalEmbedding(Module):
@@ -1246,6 +1253,7 @@ class Attention(Module):
1246
1253
  rel_pos = None,
1247
1254
  attn_bias = None,
1248
1255
  rotary_pos_emb = None,
1256
+ pos = None, # for custom alibi positions
1249
1257
  prev_attn = None,
1250
1258
  mem = None,
1251
1259
  mem_mask = None,
@@ -1392,7 +1400,14 @@ class Attention(Module):
1392
1400
 
1393
1401
  if exists(rel_pos):
1394
1402
  assert not exists(attn_bias)
1395
- attn_bias = rel_pos(i, j)
1403
+
1404
+ if exists(pos):
1405
+ assert isinstance(rel_pos, AlibiPositionalBias), 'only alibi allowed for custom positions at the moment'
1406
+ # allow for custom positions to be passed in
1407
+ attn_bias = rel_pos.forward_custom_pos(pos)
1408
+ else:
1409
+ attn_bias = rel_pos(i, j)
1410
+
1396
1411
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1397
1412
 
1398
1413
  # prepare data dependent alibi from forgetting transformers paper, if needed
@@ -1843,6 +1858,7 @@ class AttentionLayers(Module):
1843
1858
  cache_age = 1,
1844
1859
  return_hiddens = False,
1845
1860
  rotary_pos_emb = None,
1861
+ pos = None,
1846
1862
  attn_bias = None,
1847
1863
  condition = None,
1848
1864
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
@@ -1906,7 +1922,9 @@ class AttentionLayers(Module):
1906
1922
  maybe_mem = mems[0] # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
1907
1923
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
1908
1924
 
1909
- pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1925
+ if not exists(pos):
1926
+ pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
1927
+
1910
1928
  rotary_pos_emb = self.rotary_pos_emb(pos)
1911
1929
 
1912
1930
  # assume cached key / values
@@ -2030,7 +2048,7 @@ class AttentionLayers(Module):
2030
2048
  # forward depending on layer type
2031
2049
 
2032
2050
  if layer_type == 'a':
2033
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2051
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2034
2052
  elif layer_type == 'c':
2035
2053
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
2036
2054
  elif layer_type == 'f':
@@ -2250,7 +2268,8 @@ class TransformerWrapper(Module):
2250
2268
  token_emb: TokenEmbedding | None = None,
2251
2269
  mixture_of_softmax = False,
2252
2270
  mixture_of_softmax_k = 4,
2253
- sigsoftmax_logits = False
2271
+ sigsoftmax_logits = False,
2272
+ to_logits: Module | None = None,
2254
2273
  ):
2255
2274
  super().__init__()
2256
2275
 
@@ -2352,11 +2371,12 @@ class TransformerWrapper(Module):
2352
2371
  if return_only_embed:
2353
2372
  self.to_logits = None
2354
2373
  elif tie_embedding:
2374
+ assert isinstance(token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
2355
2375
  self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
2356
2376
  elif num_output_heads > 1:
2357
2377
  self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
2358
2378
  else:
2359
- self.to_logits = LinearNoBias(dim, logits_dim)
2379
+ self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
2360
2380
 
2361
2381
  # memory tokens (like [cls]) from Memory Transformers paper
2362
2382
 
@@ -2377,13 +2397,12 @@ class TransformerWrapper(Module):
2377
2397
  self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
2378
2398
 
2379
2399
  def init_(self):
2400
+ if hasattr(self.token_emb, 'init_'):
2401
+ self.token_emb.init_()
2402
+
2380
2403
  if self.l2norm_embed:
2381
- nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2382
2404
  if not isinstance(self.pos_emb, always):
2383
2405
  nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2384
- return
2385
-
2386
- nn.init.kaiming_normal_(self.token_emb.emb.weight)
2387
2406
 
2388
2407
  def forward(
2389
2408
  self,
@@ -2406,7 +2425,9 @@ class TransformerWrapper(Module):
2406
2425
  attn_z_loss_weight = 1e-4,
2407
2426
  seq_start_pos = None,
2408
2427
  cache: LayerIntermediates | None = None,
2409
- **kwargs
2428
+ token_emb_kwargs = dict(),
2429
+ to_logits_kwargs = dict(),
2430
+ **kwargs,
2410
2431
  ):
2411
2432
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2412
2433
 
@@ -2417,7 +2438,7 @@ class TransformerWrapper(Module):
2417
2438
 
2418
2439
  external_pos_emb = exists(pos) and pos.dtype != torch.long
2419
2440
  pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2420
- x = self.token_emb(x) + pos_emb
2441
+ x = self.token_emb(x, **token_emb_kwargs) + pos_emb
2421
2442
 
2422
2443
  # add additional embeddings
2423
2444
 
@@ -2572,9 +2593,9 @@ class TransformerWrapper(Module):
2572
2593
 
2573
2594
  if not return_embeddings:
2574
2595
  if self.has_multiple_heads:
2575
- logits = tuple(fn(x) for fn in self.to_logits)
2596
+ logits = tuple(fn(x, **to_logits_kwargs) for fn in self.to_logits)
2576
2597
  else:
2577
- logits = self.to_logits(x)
2598
+ logits = self.to_logits(x, **to_logits_kwargs)
2578
2599
 
2579
2600
  # maybe sig softmax
2580
2601
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.6
3
+ Version: 1.42.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes