x-transformers 1.42.19__tar.gz → 1.42.21__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.19/x_transformers.egg-info → x_transformers-1.42.21}/PKG-INFO +1 -1
  2. {x_transformers-1.42.19 → x_transformers-1.42.21}/setup.py +1 -1
  3. {x_transformers-1.42.19 → x_transformers-1.42.21}/tests/test_x_transformers.py +28 -3
  4. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/attend.py +1 -1
  5. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/x_transformers.py +6 -3
  6. {x_transformers-1.42.19 → x_transformers-1.42.21/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.42.19 → x_transformers-1.42.21}/LICENSE +0 -0
  8. {x_transformers-1.42.19 → x_transformers-1.42.21}/README.md +0 -0
  9. {x_transformers-1.42.19 → x_transformers-1.42.21}/setup.cfg +0 -0
  10. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/__init__.py +0 -0
  11. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.42.19 → x_transformers-1.42.21}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.19
3
+ Version: 1.42.21
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.19',
6
+ version = '1.42.21',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -388,7 +388,8 @@ def test_neo_mlp():
388
388
  out = mlp(x)
389
389
  assert out.shape == (3, 7)
390
390
 
391
- def test_custom_alibi():
391
+ @pytest.mark.parametrize('flash', (True, False))
392
+ def test_custom_alibi(flash: bool):
392
393
 
393
394
  model = TransformerWrapper(
394
395
  num_tokens = 20_000,
@@ -397,7 +398,8 @@ def test_custom_alibi():
397
398
  dim = 512,
398
399
  depth = 2,
399
400
  heads = 8,
400
- alibi_pos_bias = True
401
+ alibi_pos_bias = True,
402
+ attn_flash = flash
401
403
  )
402
404
  )
403
405
 
@@ -407,8 +409,30 @@ def test_custom_alibi():
407
409
 
408
410
  logits = model(x, pos = pos)
409
411
 
410
- def test_custom_alibi_across_heads():
412
+ def test_custom_rotary_pos_emb():
413
+ from einops import repeat
411
414
 
415
+ model = TransformerWrapper(
416
+ num_tokens = 20_000,
417
+ max_seq_len = 1024,
418
+ attn_layers = Decoder(
419
+ dim = 512,
420
+ depth = 2,
421
+ heads = 8,
422
+ rotary_pos_emb = True
423
+ )
424
+ )
425
+
426
+ x = torch.randint(0, 20000, (4, 4))
427
+
428
+ pos = repeat(torch.arange(0, 4), "n -> b n", b=4)
429
+
430
+ logits1 = model(x, pos = pos)
431
+ logits2 = model(x)
432
+ assert torch.allclose(logits1, logits2)
433
+
434
+ @pytest.mark.parametrize('flash', (True, False))
435
+ def test_custom_alibi_across_heads(flash: bool):
412
436
  model = Decoder(
413
437
  dim = 512,
414
438
  depth = 2,
@@ -417,6 +441,7 @@ def test_custom_alibi_across_heads():
417
441
  rel_pos_kwargs = dict(
418
442
  slopes = [1, 1]
419
443
  ),
444
+ attn_flash = flash
420
445
  )
421
446
 
422
447
  x = torch.randn(2, 4, 512)
@@ -370,7 +370,7 @@ class Attend(Module):
370
370
  # convert from bool to float
371
371
 
372
372
  if exists(attn_bias):
373
- attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
373
+ attn_bias = attn_bias.expand(batch, heads, -1, -1)
374
374
 
375
375
  # if mask given, the mask would already contain the causal mask from above logic
376
376
  # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
@@ -655,7 +655,10 @@ class RotaryEmbedding(Module):
655
655
  def forward(self, t):
656
656
  max_pos = t.max() + 1
657
657
 
658
- freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
658
+ if t.ndim == 1:
659
+ t = rearrange(t, 'n -> 1 n')
660
+
661
+ freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
659
662
  freqs = torch.stack((freqs, freqs), dim = -1)
660
663
  freqs = rearrange(freqs, '... d r -> ... (d r)')
661
664
 
@@ -679,8 +682,8 @@ def rotate_half(x):
679
682
  def apply_rotary_pos_emb(t, freqs, scale = 1):
680
683
  rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
681
684
 
682
- freqs = freqs[-seq_len:, :]
683
- scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
685
+ freqs = freqs[:, -seq_len:, :]
686
+ scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
684
687
 
685
688
  if t.ndim == 4 and freqs.ndim == 3:
686
689
  freqs = rearrange(freqs, 'b n d -> b 1 n d')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.19
3
+ Version: 1.42.21
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang