x-transformers 1.42.20__tar.gz → 1.42.21__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.20/x_transformers.egg-info → x_transformers-1.42.21}/PKG-INFO +1 -1
  2. {x_transformers-1.42.20 → x_transformers-1.42.21}/setup.py +1 -1
  3. {x_transformers-1.42.20 → x_transformers-1.42.21}/tests/test_x_transformers.py +22 -1
  4. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/x_transformers.py +6 -3
  5. {x_transformers-1.42.20 → x_transformers-1.42.21/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.42.20 → x_transformers-1.42.21}/LICENSE +0 -0
  7. {x_transformers-1.42.20 → x_transformers-1.42.21}/README.md +0 -0
  8. {x_transformers-1.42.20 → x_transformers-1.42.21}/setup.cfg +0 -0
  9. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.42.20 → x_transformers-1.42.21}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.20
3
+ Version: 1.42.21
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.20',
6
+ version = '1.42.21',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -409,9 +409,30 @@ def test_custom_alibi(flash: bool):
409
409
 
410
410
  logits = model(x, pos = pos)
411
411
 
412
+ def test_custom_rotary_pos_emb():
413
+ from einops import repeat
414
+
415
+ model = TransformerWrapper(
416
+ num_tokens = 20_000,
417
+ max_seq_len = 1024,
418
+ attn_layers = Decoder(
419
+ dim = 512,
420
+ depth = 2,
421
+ heads = 8,
422
+ rotary_pos_emb = True
423
+ )
424
+ )
425
+
426
+ x = torch.randint(0, 20000, (4, 4))
427
+
428
+ pos = repeat(torch.arange(0, 4), "n -> b n", b=4)
429
+
430
+ logits1 = model(x, pos = pos)
431
+ logits2 = model(x)
432
+ assert torch.allclose(logits1, logits2)
433
+
412
434
  @pytest.mark.parametrize('flash', (True, False))
413
435
  def test_custom_alibi_across_heads(flash: bool):
414
-
415
436
  model = Decoder(
416
437
  dim = 512,
417
438
  depth = 2,
@@ -655,7 +655,10 @@ class RotaryEmbedding(Module):
655
655
  def forward(self, t):
656
656
  max_pos = t.max() + 1
657
657
 
658
- freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
658
+ if t.ndim == 1:
659
+ t = rearrange(t, 'n -> 1 n')
660
+
661
+ freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
659
662
  freqs = torch.stack((freqs, freqs), dim = -1)
660
663
  freqs = rearrange(freqs, '... d r -> ... (d r)')
661
664
 
@@ -679,8 +682,8 @@ def rotate_half(x):
679
682
  def apply_rotary_pos_emb(t, freqs, scale = 1):
680
683
  rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
681
684
 
682
- freqs = freqs[-seq_len:, :]
683
- scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
685
+ freqs = freqs[:, -seq_len:, :]
686
+ scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
684
687
 
685
688
  if t.ndim == 4 and freqs.ndim == 3:
686
689
  freqs = rearrange(freqs, 'b n d -> b 1 n d')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.20
3
+ Version: 1.42.21
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang