rxnn 0.1.50__py3-none-any.whl → 0.1.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -65,6 +65,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
65
65
  **kwargs,
66
66
  )
67
67
 
68
+ def router_loss(self):
69
+ return self.router.aux_loss
70
+
68
71
  def _init_kv(self, embed_dim: int):
69
72
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
70
73
  hidden_dim = embed_dim // self.num_heads
@@ -16,6 +16,7 @@ class ReactiveTransformerBase(nn.Module):
16
16
  shared_layers: nn.ModuleList = None,
17
17
  absolute_embedding: AbsolutePositionalEmbedding = None,
18
18
  use_flash_attention: bool = False,
19
+ use_relative_embedding: bool = False,
19
20
  *args,
20
21
  **kwargs,
21
22
  ):
@@ -25,6 +26,7 @@ class ReactiveTransformerBase(nn.Module):
25
26
  self.stm = stm
26
27
  self.pos_embedding = absolute_embedding
27
28
  self.use_flash_attention = use_flash_attention
29
+ self.use_relative_embedding = use_relative_embedding
28
30
 
29
31
  self.shared_layers = shared_layers
30
32
  self.layers = own_layers
@@ -59,7 +61,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
59
61
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
60
62
  x = super().forward(x) # apply embeddings
61
63
  seq_len = x.size(1)
62
- if not self.use_flash_attention:
64
+ if not self.use_flash_attention and self.use_relative_embedding:
63
65
  mask = create_causal_mask(seq_len, device=x.device)
64
66
  if attention_mask is not None:
65
67
  mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
@@ -111,6 +113,7 @@ class ClassicTransformerBase(nn.Module):
111
113
  layers: nn.ModuleList,
112
114
  absolute_embedding: AbsolutePositionalEmbedding = None,
113
115
  use_flash_attention: bool = False,
116
+ use_relative_embedding: bool = False,
114
117
  *args,
115
118
  **kwargs,
116
119
  ):
@@ -119,6 +122,7 @@ class ClassicTransformerBase(nn.Module):
119
122
  self.embedding = embedding
120
123
  self.pos_embedding = absolute_embedding
121
124
  self.use_flash_attention = use_flash_attention
125
+ self.use_relative_embedding = use_relative_embedding
122
126
 
123
127
  self.layers = layers
124
128
  self.num_layers = len(layers) if layers else 0
@@ -144,7 +148,7 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
144
148
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
145
149
  x = super().forward(x) # apply embeddings
146
150
  seq_len = x.size(1)
147
- if not self.use_flash_attention:
151
+ if not self.use_flash_attention and self.use_relative_embedding:
148
152
  mask = create_causal_mask(seq_len, device=x.device)
149
153
  if attention_mask is not None:
150
154
  mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
rxnn/transformers/moe.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
+
4
5
  from .ff import FeedForward, GatedFeedForward
5
6
 
6
7
  class MoeRouter(nn.Module):
@@ -14,11 +15,36 @@ class MoeRouter(nn.Module):
14
15
  # For expert load balancing
15
16
  self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
16
17
 
18
+ # def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
19
+ # expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
20
+ # expert_usage = expert_mask.sum(dim=0).mean(dim=0)
21
+ # mean_probs = probs.mean(dim=0)
22
+ # return (expert_usage * mean_probs).sum() * self.num_experts
23
+
17
24
  def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
18
- expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
19
- expert_usage = expert_mask.sum(dim=0).mean(dim=0)
20
- mean_probs = probs.mean(dim=0)
21
- return (expert_usage * mean_probs).sum() * self.num_experts
25
+ # Get shapes
26
+ B, S, K = top_k_indices.shape # Batch, Sequence length, Top-K
27
+
28
+ # 1. Compute expert selection mask (one-hot encoded)
29
+ expert_mask = F.one_hot(top_k_indices, self.num_experts).float() # (B, S, K, E)
30
+
31
+ # 2. Total number of times each expert is selected
32
+ expert_usage = expert_mask.sum(dim=(0, 1, 2)) # (E,)
33
+
34
+ # 3. Fraction of tokens assigned to each expert
35
+ total_tokens = B * S * K
36
+ fraction_expert = expert_usage / total_tokens # (E,)
37
+
38
+ # 4. Sum of probabilities for each expert's selected tokens
39
+ sum_probs = (probs.unsqueeze(-1) * expert_mask).sum(dim=(0, 1, 2)) # (E,)
40
+
41
+ # 5. Average probability per expert (avoid division by zero)
42
+ avg_probs = sum_probs / expert_usage.clamp(min=1e-6) # (E,)
43
+
44
+ # 6. Compute load balancing loss
45
+ loss = (fraction_expert * avg_probs).sum() * self.num_experts
46
+
47
+ return loss
22
48
 
23
49
  def forward(self, x: torch.Tensor):
24
50
  # Input shape: [batch*seq_len, embed_dim]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.50
3
+ Version: 0.1.52
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=22Qb4jYN6QaqibTU8bwD8x2FaOKCxvWglM2eK9EuOlo,29468
3
+ rxnn/experimental/attention.py,sha256=ZYdRxz4ik7knk3VS_9Opzy6ZqVF98FIhSNjsmIUhGfk,29532
4
4
  rxnn/experimental/models.py,sha256=-BQn7gWlSHLpkAQdthPW5L9ZNzIBqSJS9tkm2N88jgw,4711
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -20,12 +20,12 @@ rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
21
  rxnn/transformers/layers.py,sha256=OX8CsFY9A7uqH1SLwyexR_5BNlwheYrJHCGXjF8Q7HU,7186
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
23
- rxnn/transformers/models.py,sha256=_w5C7xvjT4-BFeMfzi57BQ51_fgaYZ4UK0SqUDE5Ooo,7266
24
- rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
23
+ rxnn/transformers/models.py,sha256=QFzBrOR7tDp9d_T0HoIukBMfEbLxsCictV5p3e2ilxg,7552
24
+ rxnn/transformers/moe.py,sha256=88-w4cQhYNcebdq4zBsdkaoFa4VxJi1LFXDKAAkfVLk,5791
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.50.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.50.dist-info/METADATA,sha256=bIeDbrlcclSfD9oHf26i_sYepOTvTkpcwQMWpOm2jWc,16627
30
- rxnn-0.1.50.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.50.dist-info/RECORD,,
28
+ rxnn-0.1.52.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.52.dist-info/METADATA,sha256=aae9Bt0SpsDgugeHY-7Bi6SN3wWhXneD3Kbz1NMtxJo,16627
30
+ rxnn-0.1.52.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.52.dist-info/RECORD,,
File without changes
File without changes