rxnn 0.1.23__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +11 -0
- rxnn/transformers/attention.py +6 -4
- rxnn/transformers/layers.py +28 -2
- rxnn/transformers/moe.py +0 -1
- {rxnn-0.1.23.dist-info → rxnn-0.1.24.dist-info}/METADATA +1 -1
- {rxnn-0.1.23.dist-info → rxnn-0.1.24.dist-info}/RECORD +8 -8
- {rxnn-0.1.23.dist-info → rxnn-0.1.24.dist-info}/LICENSE +0 -0
- {rxnn-0.1.23.dist-info → rxnn-0.1.24.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -61,6 +61,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
61
61
|
**kwargs,
|
62
62
|
)
|
63
63
|
|
64
|
+
def router_loss(self):
|
65
|
+
return self.router.aux_loss
|
66
|
+
|
64
67
|
def _init_kv(self, embed_dim: int):
|
65
68
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
66
69
|
|
@@ -194,6 +197,9 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
194
197
|
**kwargs,
|
195
198
|
)
|
196
199
|
|
200
|
+
def router_loss(self):
|
201
|
+
return (self.router.aux_loss + self.query_router.aux_loss) / 2
|
202
|
+
|
197
203
|
def _init_q(self, embed_dim: int):
|
198
204
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
199
205
|
|
@@ -212,6 +218,11 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
212
218
|
hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
|
213
219
|
self.out_proj = nn.Linear(hidden_dim, embed_dim)
|
214
220
|
|
221
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
222
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
223
|
+
hidden_dim = d // self.num_heads * self.num_query_groups
|
224
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
|
225
|
+
|
215
226
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
|
216
227
|
# Query processing
|
217
228
|
B, T, D = query.shape
|
rxnn/transformers/attention.py
CHANGED
@@ -91,9 +91,13 @@ class MultiHeadAttention(nn.Module):
|
|
91
91
|
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
|
92
92
|
return F.softmax(attn_logits, dim=-1)
|
93
93
|
|
94
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
95
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
96
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, d)
|
97
|
+
|
94
98
|
def _calculate_output(self, attn_weights: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int):
|
95
99
|
"""Calculate the output by multiplying attention weights with values and concatenating heads"""
|
96
|
-
return torch.matmul(attn_weights, v)
|
100
|
+
return self._transpose_output(torch.matmul(attn_weights, v), b, t, d)
|
97
101
|
|
98
102
|
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
99
103
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
@@ -104,9 +108,7 @@ class MultiHeadAttention(nn.Module):
|
|
104
108
|
is_causal=self.is_causal,
|
105
109
|
enable_gqa=enable_gqa,
|
106
110
|
)
|
107
|
-
|
108
|
-
# Reshape back to (B, T, D)
|
109
|
-
return attn_output.transpose(1, 2).contiguous().view(b, t, d)
|
111
|
+
return self._transpose_output(attn_output, b, t, d)
|
110
112
|
|
111
113
|
def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
112
114
|
mask: torch.Tensor = None):
|
rxnn/transformers/layers.py
CHANGED
@@ -60,7 +60,23 @@ class ReactiveTransformerLayer(nn.Module):
|
|
60
60
|
param.requires_grad_(is_trainable)
|
61
61
|
|
62
62
|
def moe_router_loss(self):
|
63
|
-
|
63
|
+
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
64
|
+
att_router_loss = None
|
65
|
+
if self.attention.router_loss is not None and self.memory_cross_attention.router_loss is not None:
|
66
|
+
att_router_loss = (self.attention.router_loss() + self.memory_cross_attention.router_loss()) / 2
|
67
|
+
elif self.attention.router_loss is not None:
|
68
|
+
att_router_loss = self.attention.router_loss()
|
69
|
+
elif self.memory_cross_attention.router_loss is not None:
|
70
|
+
att_router_loss = self.memory_cross_attention.router_loss()
|
71
|
+
|
72
|
+
if ff_router_loss is not None and att_router_loss is not None:
|
73
|
+
return (ff_router_loss + att_router_loss) / 2
|
74
|
+
elif ff_router_loss is not None:
|
75
|
+
return ff_router_loss
|
76
|
+
elif att_router_loss is not None:
|
77
|
+
return att_router_loss
|
78
|
+
else:
|
79
|
+
return None
|
64
80
|
|
65
81
|
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
66
82
|
# First step, self-attention
|
@@ -136,7 +152,17 @@ class ClassicTransformerLayer(nn.Module):
|
|
136
152
|
self.use_moe = use_moe
|
137
153
|
|
138
154
|
def moe_router_loss(self):
|
139
|
-
|
155
|
+
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
156
|
+
att_router_loss = self.attention.router_loss() if self.attention.router_loss is not None else None
|
157
|
+
|
158
|
+
if ff_router_loss is not None and att_router_loss is not None:
|
159
|
+
return (ff_router_loss + att_router_loss) / 2
|
160
|
+
elif ff_router_loss is not None:
|
161
|
+
return ff_router_loss
|
162
|
+
elif att_router_loss is not None:
|
163
|
+
return att_router_loss
|
164
|
+
else:
|
165
|
+
return None
|
140
166
|
|
141
167
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
142
168
|
# First step, self-attention
|
rxnn/transformers/moe.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=r3y0_BaweONrkm6Z-9zn56U9jlBMfYBs8NlWNm7rR90,32424
|
4
4
|
rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,16 +16,16 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=FHATZVf_kt3OHnG02zEeG9QdUXLncKDjrhyT28Pk0E4,14185
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
|
-
rxnn/transformers/layers.py,sha256=
|
21
|
+
rxnn/transformers/layers.py,sha256=ZJfNdgCv9dzrKqsWIMf99Ryzgs494ZhkwK4zSBYLvQ4,6880
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
23
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
24
|
-
rxnn/transformers/moe.py,sha256=
|
24
|
+
rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.24.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.24.dist-info/METADATA,sha256=_h4mqmSKPEr0mxc2CaMn-yzvmZ5Lqlk_H4parGt-eHk,16627
|
30
|
+
rxnn-0.1.24.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.24.dist-info/RECORD,,
|
File without changes
|
File without changes
|