x-transformers 1.30.0__tar.gz → 1.30.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {x_transformers-1.30.0/x_transformers.egg-info → x_transformers-1.30.2}/PKG-INFO +1 -1
  2. {x_transformers-1.30.0 → x_transformers-1.30.2}/README.md +1 -1
  3. {x_transformers-1.30.0 → x_transformers-1.30.2}/setup.py +1 -1
  4. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers/x_transformers.py +14 -8
  5. {x_transformers-1.30.0 → x_transformers-1.30.2/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.30.0 → x_transformers-1.30.2}/LICENSE +0 -0
  7. {x_transformers-1.30.0 → x_transformers-1.30.2}/setup.cfg +0 -0
  8. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers/__init__.py +0 -0
  9. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers/attend.py +0 -0
  10. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers/autoregressive_wrapper.py +0 -0
  11. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers/continuous.py +0 -0
  12. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers/dpo.py +0 -0
  13. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  14. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  15. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers/xval.py +0 -0
  16. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers.egg-info/SOURCES.txt +0 -0
  17. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers.egg-info/dependency_links.txt +0 -0
  18. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers.egg-info/requires.txt +0 -0
  19. {x_transformers-1.30.0 → x_transformers-1.30.2}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.0
3
+ Version: 1.30.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -693,7 +693,7 @@ model = TransformerWrapper(
693
693
  )
694
694
  ```
695
695
 
696
- If you wish to do something more sophisticated, say 3 layers, with each layer recurrent 4 times before onto the next, that is possible as well. Be aware the `layers_execute_order` is 0-indexed
696
+ If you wish to do something more sophisticated, say 3 layers, with each layer recurrent 4 times before onto the next (similar to <a href="https://arxiv.org/abs/2405.15071">this paper</a>), that is possible as well. Be aware the `layers_execute_order` is 0-indexed
697
697
 
698
698
  ```python
699
699
  import torch
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.30.0',
6
+ version = '1.30.2',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -444,28 +444,32 @@ class RotaryEmbedding(Module):
444
444
 
445
445
  @autocast(enabled = False)
446
446
  def forward(self, t):
447
- max_pos = t.max()+1
447
+ max_pos = t.max() + 1
448
448
 
449
449
  freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
450
- freqs = torch.cat((freqs, freqs), dim = -1)
450
+ freqs = torch.stack((freqs, freqs), dim = -1)
451
+ freqs = rearrange(freqs, '... d r -> ... (d r)')
451
452
 
452
453
  if not exists(self.scale):
453
454
  return freqs, 1.
454
455
 
455
456
  power = (t - (max_pos // 2)) / self.scale_base
456
457
  scale = self.scale ** rearrange(power, 'n -> n 1')
457
- scale = torch.cat((scale, scale), dim = -1)
458
+ scale = torch.stack((scale, scale), dim = -1)
459
+ scale = rearrange(scale, '... d r -> ... (d r)')
458
460
 
459
461
  return freqs, scale
460
462
 
461
463
  def rotate_half(x):
462
- x = rearrange(x, '... (j d) -> ... j d', j = 2)
463
- x1, x2 = x.unbind(dim = -2)
464
- return torch.cat((-x2, x1), dim = -1)
464
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
465
+ x1, x2 = x.unbind(dim = -1)
466
+ x = torch.stack((-x2, x1), dim = -1)
467
+ return rearrange(x, '... d r -> ... (d r)')
465
468
 
466
469
  @autocast(enabled = False)
467
470
  def apply_rotary_pos_emb(t, freqs, scale = 1):
468
- rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
471
+ rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
472
+
469
473
  freqs = freqs[-seq_len:, :]
470
474
  scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
471
475
 
@@ -475,7 +479,9 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
475
479
  # partial rotary embeddings, Wang et al. GPT-J
476
480
  t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
477
481
  t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
478
- return torch.cat((t, t_unrotated), dim = -1)
482
+ out = torch.cat((t, t_unrotated), dim = -1)
483
+
484
+ return out.type(orig_dtype)
479
485
 
480
486
  # norms
481
487
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.0
3
+ Version: 1.30.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes