rxnn 0.1.51__py3-none-any.whl → 0.1.52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/transformers/models.py +6 -2
- rxnn/transformers/moe.py +30 -4
- {rxnn-0.1.51.dist-info → rxnn-0.1.52.dist-info}/METADATA +1 -1
- {rxnn-0.1.51.dist-info → rxnn-0.1.52.dist-info}/RECORD +6 -6
- {rxnn-0.1.51.dist-info → rxnn-0.1.52.dist-info}/LICENSE +0 -0
- {rxnn-0.1.51.dist-info → rxnn-0.1.52.dist-info}/WHEEL +0 -0
rxnn/transformers/models.py
CHANGED
@@ -16,6 +16,7 @@ class ReactiveTransformerBase(nn.Module):
|
|
16
16
|
shared_layers: nn.ModuleList = None,
|
17
17
|
absolute_embedding: AbsolutePositionalEmbedding = None,
|
18
18
|
use_flash_attention: bool = False,
|
19
|
+
use_relative_embedding: bool = False,
|
19
20
|
*args,
|
20
21
|
**kwargs,
|
21
22
|
):
|
@@ -25,6 +26,7 @@ class ReactiveTransformerBase(nn.Module):
|
|
25
26
|
self.stm = stm
|
26
27
|
self.pos_embedding = absolute_embedding
|
27
28
|
self.use_flash_attention = use_flash_attention
|
29
|
+
self.use_relative_embedding = use_relative_embedding
|
28
30
|
|
29
31
|
self.shared_layers = shared_layers
|
30
32
|
self.layers = own_layers
|
@@ -59,7 +61,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
59
61
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
60
62
|
x = super().forward(x) # apply embeddings
|
61
63
|
seq_len = x.size(1)
|
62
|
-
if not self.use_flash_attention:
|
64
|
+
if not self.use_flash_attention and self.use_relative_embedding:
|
63
65
|
mask = create_causal_mask(seq_len, device=x.device)
|
64
66
|
if attention_mask is not None:
|
65
67
|
mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
@@ -111,6 +113,7 @@ class ClassicTransformerBase(nn.Module):
|
|
111
113
|
layers: nn.ModuleList,
|
112
114
|
absolute_embedding: AbsolutePositionalEmbedding = None,
|
113
115
|
use_flash_attention: bool = False,
|
116
|
+
use_relative_embedding: bool = False,
|
114
117
|
*args,
|
115
118
|
**kwargs,
|
116
119
|
):
|
@@ -119,6 +122,7 @@ class ClassicTransformerBase(nn.Module):
|
|
119
122
|
self.embedding = embedding
|
120
123
|
self.pos_embedding = absolute_embedding
|
121
124
|
self.use_flash_attention = use_flash_attention
|
125
|
+
self.use_relative_embedding = use_relative_embedding
|
122
126
|
|
123
127
|
self.layers = layers
|
124
128
|
self.num_layers = len(layers) if layers else 0
|
@@ -144,7 +148,7 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
|
|
144
148
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
145
149
|
x = super().forward(x) # apply embeddings
|
146
150
|
seq_len = x.size(1)
|
147
|
-
if not self.use_flash_attention:
|
151
|
+
if not self.use_flash_attention and self.use_relative_embedding:
|
148
152
|
mask = create_causal_mask(seq_len, device=x.device)
|
149
153
|
if attention_mask is not None:
|
150
154
|
mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
rxnn/transformers/moe.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
+
|
4
5
|
from .ff import FeedForward, GatedFeedForward
|
5
6
|
|
6
7
|
class MoeRouter(nn.Module):
|
@@ -14,11 +15,36 @@ class MoeRouter(nn.Module):
|
|
14
15
|
# For expert load balancing
|
15
16
|
self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
|
16
17
|
|
18
|
+
# def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
|
19
|
+
# expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
20
|
+
# expert_usage = expert_mask.sum(dim=0).mean(dim=0)
|
21
|
+
# mean_probs = probs.mean(dim=0)
|
22
|
+
# return (expert_usage * mean_probs).sum() * self.num_experts
|
23
|
+
|
17
24
|
def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
25
|
+
# Get shapes
|
26
|
+
B, S, K = top_k_indices.shape # Batch, Sequence length, Top-K
|
27
|
+
|
28
|
+
# 1. Compute expert selection mask (one-hot encoded)
|
29
|
+
expert_mask = F.one_hot(top_k_indices, self.num_experts).float() # (B, S, K, E)
|
30
|
+
|
31
|
+
# 2. Total number of times each expert is selected
|
32
|
+
expert_usage = expert_mask.sum(dim=(0, 1, 2)) # (E,)
|
33
|
+
|
34
|
+
# 3. Fraction of tokens assigned to each expert
|
35
|
+
total_tokens = B * S * K
|
36
|
+
fraction_expert = expert_usage / total_tokens # (E,)
|
37
|
+
|
38
|
+
# 4. Sum of probabilities for each expert's selected tokens
|
39
|
+
sum_probs = (probs.unsqueeze(-1) * expert_mask).sum(dim=(0, 1, 2)) # (E,)
|
40
|
+
|
41
|
+
# 5. Average probability per expert (avoid division by zero)
|
42
|
+
avg_probs = sum_probs / expert_usage.clamp(min=1e-6) # (E,)
|
43
|
+
|
44
|
+
# 6. Compute load balancing loss
|
45
|
+
loss = (fraction_expert * avg_probs).sum() * self.num_experts
|
46
|
+
|
47
|
+
return loss
|
22
48
|
|
23
49
|
def forward(self, x: torch.Tensor):
|
24
50
|
# Input shape: [batch*seq_len, embed_dim]
|
@@ -20,12 +20,12 @@ rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI
|
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=OX8CsFY9A7uqH1SLwyexR_5BNlwheYrJHCGXjF8Q7HU,7186
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
|
-
rxnn/transformers/models.py,sha256=
|
24
|
-
rxnn/transformers/moe.py,sha256=
|
23
|
+
rxnn/transformers/models.py,sha256=QFzBrOR7tDp9d_T0HoIukBMfEbLxsCictV5p3e2ilxg,7552
|
24
|
+
rxnn/transformers/moe.py,sha256=88-w4cQhYNcebdq4zBsdkaoFa4VxJi1LFXDKAAkfVLk,5791
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.52.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.52.dist-info/METADATA,sha256=aae9Bt0SpsDgugeHY-7Bi6SN3wWhXneD3Kbz1NMtxJo,16627
|
30
|
+
rxnn-0.1.52.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.52.dist-info/RECORD,,
|
File without changes
|
File without changes
|