rxnn 0.1.22__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +12 -13
- rxnn/transformers/attention.py +6 -4
- rxnn/transformers/layers.py +28 -2
- rxnn/transformers/moe.py +0 -1
- {rxnn-0.1.22.dist-info → rxnn-0.1.24.dist-info}/METADATA +1 -1
- {rxnn-0.1.22.dist-info → rxnn-0.1.24.dist-info}/RECORD +8 -8
- {rxnn-0.1.22.dist-info → rxnn-0.1.24.dist-info}/LICENSE +0 -0
- {rxnn-0.1.22.dist-info → rxnn-0.1.24.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -61,6 +61,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
61
61
|
**kwargs,
|
62
62
|
)
|
63
63
|
|
64
|
+
def router_loss(self):
|
65
|
+
return self.router.aux_loss
|
66
|
+
|
64
67
|
def _init_kv(self, embed_dim: int):
|
65
68
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
66
69
|
|
@@ -125,9 +128,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
125
128
|
k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
|
126
129
|
v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
|
127
130
|
|
128
|
-
print('processed k', k.size())
|
129
|
-
print('processed v', v.size())
|
130
|
-
|
131
131
|
# Expand to GQA format
|
132
132
|
k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
133
133
|
v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
@@ -141,10 +141,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
141
141
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
142
142
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
143
143
|
|
144
|
-
print('q', q.size())
|
145
|
-
print('k', k.size())
|
146
|
-
print('v', v.size())
|
147
|
-
|
148
144
|
return q, k, v
|
149
145
|
|
150
146
|
|
@@ -201,6 +197,9 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
201
197
|
**kwargs,
|
202
198
|
)
|
203
199
|
|
200
|
+
def router_loss(self):
|
201
|
+
return (self.router.aux_loss + self.query_router.aux_loss) / 2
|
202
|
+
|
204
203
|
def _init_q(self, embed_dim: int):
|
205
204
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
206
205
|
|
@@ -219,6 +218,11 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
219
218
|
hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
|
220
219
|
self.out_proj = nn.Linear(hidden_dim, embed_dim)
|
221
220
|
|
221
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
222
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
223
|
+
hidden_dim = d // self.num_heads * self.num_query_groups
|
224
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
|
225
|
+
|
222
226
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
|
223
227
|
# Query processing
|
224
228
|
B, T, D = query.shape
|
@@ -229,13 +233,8 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
229
233
|
weights_q = weights_q_flat.view(B, T, -1)
|
230
234
|
indices_q = indices_q_flat.view(B, T, -1)
|
231
235
|
q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
|
232
|
-
print('processed q', q.size())
|
233
|
-
q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
|
234
|
-
|
235
|
-
# Expand query groups to match head count
|
236
|
-
group_heads = self.num_heads // self.num_query_groups
|
237
|
-
q = q.unsqueeze(2).expand(-1, -1, group_heads, -1, -1).flatten(1, 2)
|
238
236
|
|
237
|
+
q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
|
239
238
|
# Key/Value processing
|
240
239
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
241
240
|
|
rxnn/transformers/attention.py
CHANGED
@@ -91,9 +91,13 @@ class MultiHeadAttention(nn.Module):
|
|
91
91
|
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
|
92
92
|
return F.softmax(attn_logits, dim=-1)
|
93
93
|
|
94
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
95
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
96
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, d)
|
97
|
+
|
94
98
|
def _calculate_output(self, attn_weights: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int):
|
95
99
|
"""Calculate the output by multiplying attention weights with values and concatenating heads"""
|
96
|
-
return torch.matmul(attn_weights, v)
|
100
|
+
return self._transpose_output(torch.matmul(attn_weights, v), b, t, d)
|
97
101
|
|
98
102
|
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
99
103
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
@@ -104,9 +108,7 @@ class MultiHeadAttention(nn.Module):
|
|
104
108
|
is_causal=self.is_causal,
|
105
109
|
enable_gqa=enable_gqa,
|
106
110
|
)
|
107
|
-
|
108
|
-
# Reshape back to (B, T, D)
|
109
|
-
return attn_output.transpose(1, 2).contiguous().view(b, t, d)
|
111
|
+
return self._transpose_output(attn_output, b, t, d)
|
110
112
|
|
111
113
|
def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
112
114
|
mask: torch.Tensor = None):
|
rxnn/transformers/layers.py
CHANGED
@@ -60,7 +60,23 @@ class ReactiveTransformerLayer(nn.Module):
|
|
60
60
|
param.requires_grad_(is_trainable)
|
61
61
|
|
62
62
|
def moe_router_loss(self):
|
63
|
-
|
63
|
+
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
64
|
+
att_router_loss = None
|
65
|
+
if self.attention.router_loss is not None and self.memory_cross_attention.router_loss is not None:
|
66
|
+
att_router_loss = (self.attention.router_loss() + self.memory_cross_attention.router_loss()) / 2
|
67
|
+
elif self.attention.router_loss is not None:
|
68
|
+
att_router_loss = self.attention.router_loss()
|
69
|
+
elif self.memory_cross_attention.router_loss is not None:
|
70
|
+
att_router_loss = self.memory_cross_attention.router_loss()
|
71
|
+
|
72
|
+
if ff_router_loss is not None and att_router_loss is not None:
|
73
|
+
return (ff_router_loss + att_router_loss) / 2
|
74
|
+
elif ff_router_loss is not None:
|
75
|
+
return ff_router_loss
|
76
|
+
elif att_router_loss is not None:
|
77
|
+
return att_router_loss
|
78
|
+
else:
|
79
|
+
return None
|
64
80
|
|
65
81
|
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
66
82
|
# First step, self-attention
|
@@ -136,7 +152,17 @@ class ClassicTransformerLayer(nn.Module):
|
|
136
152
|
self.use_moe = use_moe
|
137
153
|
|
138
154
|
def moe_router_loss(self):
|
139
|
-
|
155
|
+
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
156
|
+
att_router_loss = self.attention.router_loss() if self.attention.router_loss is not None else None
|
157
|
+
|
158
|
+
if ff_router_loss is not None and att_router_loss is not None:
|
159
|
+
return (ff_router_loss + att_router_loss) / 2
|
160
|
+
elif ff_router_loss is not None:
|
161
|
+
return ff_router_loss
|
162
|
+
elif att_router_loss is not None:
|
163
|
+
return att_router_loss
|
164
|
+
else:
|
165
|
+
return None
|
140
166
|
|
141
167
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
142
168
|
# First step, self-attention
|
rxnn/transformers/moe.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=r3y0_BaweONrkm6Z-9zn56U9jlBMfYBs8NlWNm7rR90,32424
|
4
4
|
rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,16 +16,16 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=FHATZVf_kt3OHnG02zEeG9QdUXLncKDjrhyT28Pk0E4,14185
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
|
-
rxnn/transformers/layers.py,sha256=
|
21
|
+
rxnn/transformers/layers.py,sha256=ZJfNdgCv9dzrKqsWIMf99Ryzgs494ZhkwK4zSBYLvQ4,6880
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
23
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
24
|
-
rxnn/transformers/moe.py,sha256=
|
24
|
+
rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.24.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.24.dist-info/METADATA,sha256=_h4mqmSKPEr0mxc2CaMn-yzvmZ5Lqlk_H4parGt-eHk,16627
|
30
|
+
rxnn-0.1.24.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.24.dist-info/RECORD,,
|
File without changes
|
File without changes
|