rxnn 0.1.78__py3-none-any.whl → 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,14 +12,17 @@ class RotaryPositionalEmbedding(nn.Module):
12
12
  self.max_seq_len = max_seq_len
13
13
  self.base = base
14
14
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
15
- self.register_buffer('inv_freq', inv_freq)
16
- self.register_buffer('cache', None, persistent=False)
15
+ self.register_buffer('inv_freq', inv_freq) # must stay for models compatibility
16
+ # Pre-cache freqs for max_len
17
+ t = torch.arange(max_seq_len).type_as(self.inv_freq)
18
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
19
+ self.register_buffer('cache', freqs)
20
+
17
21
 
18
22
  def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
19
- device = q.device
20
23
  seq_len = q.size(-2)
21
24
  # Prepare RoPE Frequencies
22
- freqs = self._prepare_freqs(seq_len, device)
25
+ freqs = self._prepare_freqs(seq_len)
23
26
 
24
27
  # Apply the rotation to the queries
25
28
  q_embed = self._rotate(q, freqs)
@@ -29,27 +32,17 @@ class RotaryPositionalEmbedding(nn.Module):
29
32
  return q_embed, k_embed
30
33
 
31
34
  def forward_one(self, q: torch.Tensor) -> torch.Tensor:
32
- device = q.device
33
35
  seq_len = q.size(-2)
34
36
  # Prepare RoPE Frequencies
35
- freqs = self._prepare_freqs(seq_len, device)
37
+ freqs = self._prepare_freqs(seq_len)
36
38
 
37
39
  # Apply the rotation to the queries
38
40
  q_embed = self._rotate(q, freqs)
39
41
 
40
42
  return q_embed
41
43
 
42
- def _prepare_freqs(self, seq_len: int, device: torch.device) -> torch.Tensor:
43
- cache_len = self.cache.size(1)
44
- if self.cache is None or cache_len < seq_len:
45
- t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
46
- freqs = torch.einsum('i,j->ij', t, self.inv_freq)
47
- self.cache = freqs
48
- return freqs[None, None, :, :]
49
- elif cache_len == seq_len:
50
- return self.cache[None, None, :, :]
51
- else:
52
- return self.cache[:seq_len][None, None, :, :]
44
+ def _prepare_freqs(self, seq_len: int) -> torch.Tensor:
45
+ return self.cache[:seq_len][None, None, :, :]
53
46
 
54
47
  def _rotate(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
55
48
  x1 = x[..., 0::2]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.78
3
+ Version: 0.1.80
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -22,10 +22,10 @@ rxnn/transformers/layers.py,sha256=OX8CsFY9A7uqH1SLwyexR_5BNlwheYrJHCGXjF8Q7HU,7
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
23
23
  rxnn/transformers/models.py,sha256=xbnn3FTNZFhaqq9A0XEM12ie_WL_58pPeq0qFXIgve0,7656
24
24
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
25
- rxnn/transformers/positional.py,sha256=DE1TP3D6ikBPg3Ym0sP9F666LHuE70H0w-JEH5DfKPw,4415
25
+ rxnn/transformers/positional.py,sha256=ge-kaS6WnWnPGnWVp25ZK5bVkmhBUNCaELaN2rN_fSY,4097
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.78.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.78.dist-info/METADATA,sha256=559E3b22oEiu6vXNnsi7xLCw0GeuYQmcdmOgHkcdlL0,16589
30
- rxnn-0.1.78.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.78.dist-info/RECORD,,
28
+ rxnn-0.1.80.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.80.dist-info/METADATA,sha256=Voy_a7EI9nC1kEqzxHcLCYpZRJntWnoKaFpF7XyiKCE,16589
30
+ rxnn-0.1.80.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.80.dist-info/RECORD,,
File without changes
File without changes