rxnn 0.1.78__tar.gz → 0.1.80__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {rxnn-0.1.78 → rxnn-0.1.80}/PKG-INFO +1 -1
  2. {rxnn-0.1.78 → rxnn-0.1.80}/pyproject.toml +1 -1
  3. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/transformers/positional.py +10 -17
  4. {rxnn-0.1.78 → rxnn-0.1.80}/LICENSE +0 -0
  5. {rxnn-0.1.78 → rxnn-0.1.80}/README.md +0 -0
  6. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/__init__.py +0 -0
  7. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/experimental/__init__.py +0 -0
  8. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/experimental/attention.py +0 -0
  9. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/experimental/models.py +0 -0
  10. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/experimental/moe.py +0 -0
  11. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/memory/__init__.py +0 -0
  12. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/memory/norm.py +0 -0
  13. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/memory/stm.py +0 -0
  14. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/rxt/__init__.py +0 -0
  15. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/rxt/models.py +0 -0
  16. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/training/__init__.py +0 -0
  17. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/training/base.py +0 -0
  18. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/training/bml.py +0 -0
  19. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/training/callbacks.py +0 -0
  20. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/training/dataset.py +0 -0
  21. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/training/scheduler.py +0 -0
  22. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/training/tokenizer.py +0 -0
  23. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/transformers/__init__.py +0 -0
  24. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/transformers/attention.py +0 -0
  25. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/transformers/ff.py +0 -0
  26. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/transformers/layers.py +0 -0
  27. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/transformers/mask.py +0 -0
  28. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/transformers/models.py +0 -0
  29. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/transformers/moe.py +0 -0
  30. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/transformers/sampler.py +0 -0
  31. {rxnn-0.1.78 → rxnn-0.1.80}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.78
3
+ Version: 0.1.80
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.78"
7
+ version = "0.1.80"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -12,14 +12,17 @@ class RotaryPositionalEmbedding(nn.Module):
12
12
  self.max_seq_len = max_seq_len
13
13
  self.base = base
14
14
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
15
- self.register_buffer('inv_freq', inv_freq)
16
- self.register_buffer('cache', None, persistent=False)
15
+ self.register_buffer('inv_freq', inv_freq) # must stay for models compatibility
16
+ # Pre-cache freqs for max_len
17
+ t = torch.arange(max_seq_len).type_as(self.inv_freq)
18
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
19
+ self.register_buffer('cache', freqs)
20
+
17
21
 
18
22
  def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
19
- device = q.device
20
23
  seq_len = q.size(-2)
21
24
  # Prepare RoPE Frequencies
22
- freqs = self._prepare_freqs(seq_len, device)
25
+ freqs = self._prepare_freqs(seq_len)
23
26
 
24
27
  # Apply the rotation to the queries
25
28
  q_embed = self._rotate(q, freqs)
@@ -29,27 +32,17 @@ class RotaryPositionalEmbedding(nn.Module):
29
32
  return q_embed, k_embed
30
33
 
31
34
  def forward_one(self, q: torch.Tensor) -> torch.Tensor:
32
- device = q.device
33
35
  seq_len = q.size(-2)
34
36
  # Prepare RoPE Frequencies
35
- freqs = self._prepare_freqs(seq_len, device)
37
+ freqs = self._prepare_freqs(seq_len)
36
38
 
37
39
  # Apply the rotation to the queries
38
40
  q_embed = self._rotate(q, freqs)
39
41
 
40
42
  return q_embed
41
43
 
42
- def _prepare_freqs(self, seq_len: int, device: torch.device) -> torch.Tensor:
43
- cache_len = self.cache.size(1)
44
- if self.cache is None or cache_len < seq_len:
45
- t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
46
- freqs = torch.einsum('i,j->ij', t, self.inv_freq)
47
- self.cache = freqs
48
- return freqs[None, None, :, :]
49
- elif cache_len == seq_len:
50
- return self.cache[None, None, :, :]
51
- else:
52
- return self.cache[:seq_len][None, None, :, :]
44
+ def _prepare_freqs(self, seq_len: int) -> torch.Tensor:
45
+ return self.cache[:seq_len][None, None, :, :]
53
46
 
54
47
  def _rotate(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
55
48
  x1 = x[..., 0::2]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes