rxnn 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +442 -88
- rxnn/experimental/models.py +117 -0
- rxnn/experimental/moe.py +206 -0
- rxnn/transformers/moe.py +42 -86
- {rxnn-0.1.14.dist-info → rxnn-0.1.16.dist-info}/METADATA +1 -1
- {rxnn-0.1.14.dist-info → rxnn-0.1.16.dist-info}/RECORD +8 -6
- {rxnn-0.1.14.dist-info → rxnn-0.1.16.dist-info}/LICENSE +0 -0
- {rxnn-0.1.14.dist-info → rxnn-0.1.16.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import torch
|
2
|
-
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
3
4
|
from ..transformers.attention import MultiHeadAttention, GroupedQueryAttention
|
4
5
|
from ..transformers.positional import RotaryPositionalEmbedding
|
5
6
|
from ..transformers.moe import MoeRouter
|
@@ -9,6 +10,7 @@ from ..transformers.moe import MoeRouter
|
|
9
10
|
class GroupedMoeAttention(GroupedQueryAttention):
|
10
11
|
"""
|
11
12
|
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
13
|
+
|
12
14
|
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
13
15
|
number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
|
14
16
|
- with num_groups set to 1, it will be MoE MultiQueryAttention
|
@@ -20,8 +22,11 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
20
22
|
|
21
23
|
Optionally, it could use even more expert heads than attention heads - in example:
|
22
24
|
- 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
|
23
|
-
|
25
|
+
4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
|
26
|
+
|
27
|
+
© 2025 Adam Filipek
|
24
28
|
"""
|
29
|
+
|
25
30
|
def __init__(
|
26
31
|
self,
|
27
32
|
embed_dim: int,
|
@@ -39,7 +44,7 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
39
44
|
*args,
|
40
45
|
**kwargs,
|
41
46
|
):
|
42
|
-
self.num_experts = num_experts
|
47
|
+
self.num_experts = num_experts or num_heads
|
43
48
|
super(GroupedMoeAttention, self).__init__(
|
44
49
|
embed_dim,
|
45
50
|
num_heads,
|
@@ -58,7 +63,228 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
58
63
|
|
59
64
|
def _init_kv(self, embed_dim: int):
|
60
65
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
61
|
-
|
66
|
+
|
67
|
+
hidden_dim = embed_dim // self.num_heads
|
68
|
+
self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
69
|
+
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
70
|
+
self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
71
|
+
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
72
|
+
self._init_experts()
|
73
|
+
|
74
|
+
def _init_experts(self):
|
75
|
+
nn.init.xavier_uniform_(self.wk)
|
76
|
+
nn.init.xavier_uniform_(self.wv)
|
77
|
+
if self.use_bias:
|
78
|
+
nn.init.zeros_(self.bk)
|
79
|
+
nn.init.zeros_(self.bv)
|
80
|
+
|
81
|
+
def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
|
82
|
+
B, S, G = indices.shape
|
83
|
+
x_flat = x.view(-1, x.size(-1))
|
84
|
+
|
85
|
+
# Flatten batch and sequence dimensions
|
86
|
+
indices_flat = indices.view(-1, G)
|
87
|
+
weights_flat = weights.view(-1, G, 1)
|
88
|
+
|
89
|
+
# Create expanded indices for expert processing
|
90
|
+
mask = torch.zeros(B * S, self.num_experts, device=x.device, dtype=torch.bool)
|
91
|
+
for g in range(G):
|
92
|
+
mask.scatter_(1, indices_flat[:, g].unsqueeze(1), True)
|
93
|
+
|
94
|
+
output = torch.zeros(B * S, G, w.size(2), device=x.device, dtype=x.dtype)
|
95
|
+
|
96
|
+
for e in range(self.num_experts):
|
97
|
+
token_mask = mask[:, e]
|
98
|
+
if not token_mask.any():
|
99
|
+
continue
|
100
|
+
|
101
|
+
# Get positions where expert e is used in any group
|
102
|
+
x_slice = x_flat[token_mask]
|
103
|
+
proj = F.linear(x_slice, w[e], b[e] if b is not None else None)
|
104
|
+
|
105
|
+
# Find which groups use this expert for selected tokens
|
106
|
+
group_mask = (indices_flat[token_mask] == e)
|
107
|
+
|
108
|
+
# Accumulate projections for relevant groups
|
109
|
+
weighted_proj = proj.unsqueeze(1) * weights_flat[token_mask] * group_mask.unsqueeze(-1).float()
|
110
|
+
output[token_mask] += weighted_proj.sum(dim=1)
|
111
|
+
|
112
|
+
return output.view(B, S, G, -1)
|
113
|
+
|
114
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
115
|
+
skip_query_processing: bool = False):
|
116
|
+
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
|
117
|
+
|
118
|
+
# Key/Value processing
|
119
|
+
B, S, _ = key.shape
|
120
|
+
weights_k, indices_k = self.router(key)
|
121
|
+
k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
|
122
|
+
v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
|
123
|
+
|
124
|
+
# Expand to GQA format
|
125
|
+
k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
126
|
+
v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
127
|
+
|
128
|
+
if not self.use_flash_attention:
|
129
|
+
group_heads = self.num_heads // self.num_groups
|
130
|
+
|
131
|
+
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
132
|
+
v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
133
|
+
|
134
|
+
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
135
|
+
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
136
|
+
|
137
|
+
return q, k, v
|
138
|
+
|
139
|
+
|
140
|
+
class DeepMoeAttention(GroupedMoeAttention):
|
141
|
+
"""
|
142
|
+
Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
143
|
+
|
144
|
+
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
145
|
+
query heads - with that approach, each token could attend to every other token, but only partially - only some part of
|
146
|
+
information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
|
147
|
+
sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
|
148
|
+
|
149
|
+
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
|
150
|
+
a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
|
151
|
+
|
152
|
+
© 2025 Adam Filipek
|
153
|
+
"""
|
154
|
+
|
155
|
+
def __init__(
|
156
|
+
self,
|
157
|
+
embed_dim: int,
|
158
|
+
num_heads: int,
|
159
|
+
num_groups: int,
|
160
|
+
dropout: float = 0.0,
|
161
|
+
rope: RotaryPositionalEmbedding = None,
|
162
|
+
rope_only_for_query: bool = False,
|
163
|
+
use_relative_embeddings: bool = False,
|
164
|
+
max_seq_len: int = 1024,
|
165
|
+
use_flash_attention: bool = False,
|
166
|
+
is_causal: bool = False,
|
167
|
+
use_bias: bool = False,
|
168
|
+
num_experts: int = None,
|
169
|
+
num_query_experts: int = None,
|
170
|
+
num_query_groups: int = None,
|
171
|
+
*args,
|
172
|
+
**kwargs,
|
173
|
+
):
|
174
|
+
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
175
|
+
self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
|
176
|
+
super(DeepMoeAttention, self).__init__(
|
177
|
+
embed_dim,
|
178
|
+
num_heads,
|
179
|
+
num_groups=num_groups,
|
180
|
+
dropout=dropout,
|
181
|
+
rope=rope,
|
182
|
+
rope_only_for_query=rope_only_for_query,
|
183
|
+
use_relative_embeddings=use_relative_embeddings,
|
184
|
+
max_seq_len=max_seq_len,
|
185
|
+
use_flash_attention=use_flash_attention,
|
186
|
+
is_causal=is_causal,
|
187
|
+
use_bias=use_bias,
|
188
|
+
num_experts=num_experts,
|
189
|
+
*args,
|
190
|
+
**kwargs,
|
191
|
+
)
|
192
|
+
|
193
|
+
def _init_q(self, embed_dim: int):
|
194
|
+
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
195
|
+
|
196
|
+
hidden_dim = embed_dim // self.num_heads
|
197
|
+
self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
198
|
+
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
199
|
+
self._init_query_experts()
|
200
|
+
|
201
|
+
def _init_query_experts(self):
|
202
|
+
nn.init.xavier_uniform_(self.wq)
|
203
|
+
if self.use_bias:
|
204
|
+
nn.init.zeros_(self.bq)
|
205
|
+
|
206
|
+
def _init_out(self, embed_dim: int):
|
207
|
+
"""Initialize output projection"""
|
208
|
+
hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
|
209
|
+
self.out_proj = nn.Linear(hidden_dim, embed_dim)
|
210
|
+
|
211
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
|
212
|
+
# Query processing
|
213
|
+
B, T, _ = query.shape
|
214
|
+
weights_q, indices_q = self.query_router(query)
|
215
|
+
q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
|
216
|
+
q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
|
217
|
+
|
218
|
+
# Expand query groups to match head count
|
219
|
+
group_heads = self.num_heads // self.num_query_groups
|
220
|
+
q = q.unsqueeze(2).expand(-1, -1, group_heads, -1, -1).flatten(1, 2).transpose(1, 2)
|
221
|
+
|
222
|
+
# Key/Value processing
|
223
|
+
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
224
|
+
|
225
|
+
# Vectorized
|
226
|
+
|
227
|
+
class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
228
|
+
"""
|
229
|
+
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
230
|
+
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
231
|
+
experts - it has to be tested.
|
232
|
+
|
233
|
+
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
234
|
+
|
235
|
+
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
236
|
+
number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
|
237
|
+
- with num_groups set to 1, it will be MoE MultiQueryAttention
|
238
|
+
|
239
|
+
Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
|
240
|
+
this approach - we are training the full number of keys/values heads, while using only a group.
|
241
|
+
|
242
|
+
In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
|
243
|
+
|
244
|
+
Optionally, it could use even more expert heads than attention heads - in example:
|
245
|
+
- 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
|
246
|
+
4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
|
247
|
+
|
248
|
+
© 2025 Adam Filipek
|
249
|
+
"""
|
250
|
+
|
251
|
+
def __init__(
|
252
|
+
self,
|
253
|
+
embed_dim: int,
|
254
|
+
num_heads: int,
|
255
|
+
num_groups: int,
|
256
|
+
dropout: float = 0.0,
|
257
|
+
rope: RotaryPositionalEmbedding = None,
|
258
|
+
rope_only_for_query: bool = False,
|
259
|
+
use_relative_embeddings: bool = False,
|
260
|
+
max_seq_len: int = 1024,
|
261
|
+
use_flash_attention: bool = False,
|
262
|
+
is_causal: bool = False,
|
263
|
+
use_bias: bool = False,
|
264
|
+
num_experts: int = None,
|
265
|
+
*args,
|
266
|
+
**kwargs,
|
267
|
+
):
|
268
|
+
self.num_experts = num_experts if num_experts is not None else num_heads
|
269
|
+
super(GroupedMoeAttentionVectorized, self).__init__(
|
270
|
+
embed_dim,
|
271
|
+
num_heads,
|
272
|
+
num_groups=num_groups,
|
273
|
+
dropout=dropout,
|
274
|
+
rope=rope,
|
275
|
+
rope_only_for_query=rope_only_for_query,
|
276
|
+
use_relative_embeddings=use_relative_embeddings,
|
277
|
+
max_seq_len=max_seq_len,
|
278
|
+
use_flash_attention=use_flash_attention,
|
279
|
+
is_causal=is_causal,
|
280
|
+
use_bias=use_bias,
|
281
|
+
*args,
|
282
|
+
**kwargs,
|
283
|
+
)
|
284
|
+
|
285
|
+
def _init_kv(self, embed_dim: int):
|
286
|
+
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
287
|
+
hidden_dim = embed_dim // self.num_heads
|
62
288
|
self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
63
289
|
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
64
290
|
self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
@@ -72,47 +298,37 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
72
298
|
torch.nn.init.zeros_(self.bk)
|
73
299
|
torch.nn.init.zeros_(self.bv)
|
74
300
|
|
75
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
301
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
302
|
+
skip_query_processing: bool = False):
|
303
|
+
# Indexed version may cause memory overflow
|
304
|
+
#
|
76
305
|
# head_dim = d // self.num_heads
|
77
|
-
# group_heads = self.num_heads // self.num_groups
|
78
306
|
#
|
79
307
|
# # Process Query as in GQA
|
80
|
-
# q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
|
308
|
+
# q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
|
309
|
+
# 2) if not skip_query_processing else query
|
81
310
|
#
|
82
311
|
# # Process Key and Value with MoE routing
|
83
|
-
# key_flat = key.view(-1, d)
|
84
|
-
#
|
85
|
-
# weights = weights.view(b, key.size(1), self.num_groups, 1)
|
86
|
-
# indices = indices.view(b, key.size(1), self.num_groups)
|
312
|
+
# key_flat = key.view(-1, d) # (B*S, d)
|
313
|
+
# value_flat = value.view(-1, d) # (B*S, d)
|
87
314
|
#
|
88
|
-
# #
|
89
|
-
#
|
90
|
-
#
|
91
|
-
#
|
92
|
-
# key_flat,
|
93
|
-
# self.wk.view(self.num_experts, d, -1)
|
94
|
-
# ).view(b, key.size(1), self.num_experts, -1)
|
315
|
+
# # Get routing indices and weights for K
|
316
|
+
# weights_k, indices_k = self.router(key_flat)
|
317
|
+
# indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
|
318
|
+
# weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
|
95
319
|
#
|
96
|
-
#
|
97
|
-
#
|
98
|
-
#
|
99
|
-
#
|
100
|
-
#
|
320
|
+
# # Select and compute K projections for only the top_k experts
|
321
|
+
# selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
|
322
|
+
# k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
|
323
|
+
# selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
|
324
|
+
# selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
|
101
325
|
#
|
102
|
-
# #
|
103
|
-
#
|
104
|
-
#
|
105
|
-
#
|
106
|
-
#
|
107
|
-
# )
|
108
|
-
# selected_v = torch.gather(
|
109
|
-
# v_all,
|
110
|
-
# 2,
|
111
|
-
# indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
|
112
|
-
# )
|
326
|
+
# # Compute V using the same indices as K (since they share the same router)
|
327
|
+
# selected_v_weights = self.v_experts[indices_k]
|
328
|
+
# v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
|
329
|
+
# selected_v = (v_proj * weights_k).sum(dim=1)
|
330
|
+
# selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
|
113
331
|
#
|
114
|
-
# selected_k = (selected_k * weights).sum(dim=2)
|
115
|
-
# selected_v = (selected_v * weights).sum(dim=2)
|
116
332
|
# # Reshape to GQA format: (B, G, S, head_dim)
|
117
333
|
# k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
118
334
|
# v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
@@ -127,32 +343,46 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
127
343
|
# v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
128
344
|
#
|
129
345
|
# return q, k, v
|
346
|
+
|
130
347
|
head_dim = d // self.num_heads
|
131
348
|
|
132
349
|
# Process Query as in GQA
|
133
|
-
q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
|
350
|
+
q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
|
134
351
|
|
135
352
|
# Process Key and Value with MoE routing
|
136
|
-
key_flat = key.view(-1, d)
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
353
|
+
key_flat = key.view(-1, d)
|
354
|
+
weights, indices = self.router(key_flat)
|
355
|
+
weights = weights.view(b, key.size(1), self.num_groups, 1)
|
356
|
+
indices = indices.view(b, key.size(1), self.num_groups)
|
357
|
+
|
358
|
+
# Compute all experts' K and V projections
|
359
|
+
# Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
|
360
|
+
k_all = torch.einsum(
|
361
|
+
'be, ehd -> bedh',
|
362
|
+
key_flat,
|
363
|
+
self.wk.view(self.num_experts, d, -1)
|
364
|
+
).view(b, key.size(1), self.num_experts, -1)
|
365
|
+
|
366
|
+
v_all = torch.einsum(
|
367
|
+
'be, ehd -> bedh',
|
368
|
+
value.view(-1, d),
|
369
|
+
self.wv.view(self.num_experts, d, -1)
|
370
|
+
).view(b, value.size(1), self.num_experts, -1)
|
371
|
+
|
372
|
+
# Select top_k experts and compute weighted sum
|
373
|
+
selected_k = torch.gather(
|
374
|
+
k_all,
|
375
|
+
2,
|
376
|
+
indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
377
|
+
)
|
378
|
+
selected_v = torch.gather(
|
379
|
+
v_all,
|
380
|
+
2,
|
381
|
+
indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
|
382
|
+
)
|
155
383
|
|
384
|
+
selected_k = (selected_k * weights).sum(dim=2)
|
385
|
+
selected_v = (selected_v * weights).sum(dim=2)
|
156
386
|
# Reshape to GQA format: (B, G, S, head_dim)
|
157
387
|
k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
158
388
|
v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
@@ -168,15 +398,26 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
168
398
|
|
169
399
|
return q, k, v
|
170
400
|
|
171
|
-
|
401
|
+
|
402
|
+
class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
172
403
|
"""
|
173
|
-
|
404
|
+
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
405
|
+
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
406
|
+
experts - it has to be tested.
|
407
|
+
|
408
|
+
Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
409
|
+
|
174
410
|
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
175
411
|
query heads - with that approach, each token could attend to every other token, but only partially - only some part of
|
176
|
-
information from each token is used to identify related information parts from other tokens.
|
412
|
+
information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
|
413
|
+
sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
|
414
|
+
|
415
|
+
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
|
416
|
+
a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
|
177
417
|
|
178
|
-
|
418
|
+
© 2025 Adam Filipek
|
179
419
|
"""
|
420
|
+
|
180
421
|
def __init__(
|
181
422
|
self,
|
182
423
|
embed_dim: int,
|
@@ -192,13 +433,13 @@ class SparseMoeAttention(GroupedMoeAttention):
|
|
192
433
|
use_bias: bool = False,
|
193
434
|
num_experts: int = None,
|
194
435
|
num_query_experts: int = None,
|
195
|
-
|
436
|
+
num_query_groups: int = None,
|
196
437
|
*args,
|
197
438
|
**kwargs,
|
198
439
|
):
|
199
440
|
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
200
|
-
self.
|
201
|
-
super(
|
441
|
+
self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
|
442
|
+
super(DeepMoeAttentionVectorized, self).__init__(
|
202
443
|
embed_dim,
|
203
444
|
num_heads,
|
204
445
|
num_groups=num_groups,
|
@@ -216,8 +457,8 @@ class SparseMoeAttention(GroupedMoeAttention):
|
|
216
457
|
)
|
217
458
|
|
218
459
|
def _init_q(self, embed_dim: int):
|
219
|
-
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.
|
220
|
-
hidden_dim = embed_dim //
|
460
|
+
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
461
|
+
hidden_dim = embed_dim // self.num_heads
|
221
462
|
self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
222
463
|
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
223
464
|
self._init_query_experts()
|
@@ -227,20 +468,47 @@ class SparseMoeAttention(GroupedMoeAttention):
|
|
227
468
|
if self.use_bias:
|
228
469
|
torch.nn.init.zeros_(self.bq)
|
229
470
|
|
471
|
+
def _init_out(self, embed_dim: int):
|
472
|
+
"""Initialize output projection"""
|
473
|
+
self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
|
474
|
+
|
230
475
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
476
|
+
# Indexed version may cause memory overflow
|
477
|
+
#
|
478
|
+
# head_dim = d // self.num_heads
|
479
|
+
#
|
480
|
+
# # Process Query with MoE routing
|
481
|
+
# query_flat = query.view(-1, d) # (B*T, d)
|
482
|
+
# weights_q, indices_q = self.query_router(query_flat)
|
483
|
+
# indices_q = indices_q.view(-1, self.num_query_groups) # (B*T, top_k_q)
|
484
|
+
# weights_q = weights_q.view(-1, self.num_query_groups, 1) # (B*T, top_k_q, 1)
|
485
|
+
#
|
486
|
+
# # Select and compute Q projections for top_k experts
|
487
|
+
# selected_q_weights = self.wq[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
|
488
|
+
# q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
|
489
|
+
# selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
|
490
|
+
# selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
|
231
491
|
head_dim = d // self.num_heads
|
232
492
|
|
233
493
|
# Process Query with MoE routing
|
234
|
-
query_flat = query.view(
|
235
|
-
weights_q, indices_q = self.
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
#
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
494
|
+
query_flat = query.view(b * t, d)
|
495
|
+
weights_q, indices_q = self.query_router(query_flat)
|
496
|
+
weights_q = weights_q.view(b, t, self.num_query_groups, 1)
|
497
|
+
indices_q = indices_q.view(b, t, self.num_query_groups)
|
498
|
+
|
499
|
+
# Compute all experts' Q projections
|
500
|
+
q_all = torch.einsum(
|
501
|
+
'be, ehd -> bedh',
|
502
|
+
query_flat,
|
503
|
+
self.wq.view(self.num_query_experts, d, -1)
|
504
|
+
).view(b, t, self.num_query_experts, -1)
|
505
|
+
|
506
|
+
selected_q = torch.gather(
|
507
|
+
q_all,
|
508
|
+
2,
|
509
|
+
indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.shape[-1])
|
510
|
+
)
|
511
|
+
selected_q = (selected_q * weights_q).sum(dim=2)
|
244
512
|
|
245
513
|
q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
|
246
514
|
|
@@ -251,12 +519,12 @@ class SparseMoeAttention(GroupedMoeAttention):
|
|
251
519
|
|
252
520
|
class FlexAttention(MultiHeadAttention):
|
253
521
|
def __init__(
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
522
|
+
self,
|
523
|
+
embed_dim: int,
|
524
|
+
num_heads: int,
|
525
|
+
num_global_tokens: int = 16,
|
526
|
+
window_size: int = 128,
|
527
|
+
**kwargs
|
260
528
|
):
|
261
529
|
super().__init__(embed_dim, num_heads, **kwargs)
|
262
530
|
self.num_global_tokens = num_global_tokens
|
@@ -319,14 +587,15 @@ class FlexAttention(MultiHeadAttention):
|
|
319
587
|
output = self._calculate_output(combined_attn, v, b, t, d)
|
320
588
|
return self.out_proj(output)
|
321
589
|
|
590
|
+
|
322
591
|
class InfiniteAttention(MultiHeadAttention):
|
323
592
|
def __init__(
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
593
|
+
self,
|
594
|
+
embed_dim: int,
|
595
|
+
num_heads: int,
|
596
|
+
kernel_size: int = 128,
|
597
|
+
use_rotary: bool = True,
|
598
|
+
**kwargs
|
330
599
|
):
|
331
600
|
super().__init__(embed_dim, num_heads, **kwargs)
|
332
601
|
self.kernel_size = kernel_size
|
@@ -377,4 +646,89 @@ class InfiniteAttention(MultiHeadAttention):
|
|
377
646
|
q = q / (q.shape[-1] ** 0.5)
|
378
647
|
attn = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
379
648
|
attn = torch.softmax(attn, dim=-1)
|
380
|
-
return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
649
|
+
return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
650
|
+
|
651
|
+
def init_moe_attention(
|
652
|
+
embed_dim: int,
|
653
|
+
num_heads: int,
|
654
|
+
attention_type: str,
|
655
|
+
gqa_groups: int = 1,
|
656
|
+
dropout: float = 0.0,
|
657
|
+
rope: RotaryPositionalEmbedding = None,
|
658
|
+
rope_only_for_query: bool = False,
|
659
|
+
use_relative_embeddings: bool = False,
|
660
|
+
max_seq_len: int = 1024,
|
661
|
+
use_flash_attention: bool = False,
|
662
|
+
is_causal: bool = False,
|
663
|
+
use_bias: bool = False,
|
664
|
+
num_experts: int = None,
|
665
|
+
num_query_experts: int = None,
|
666
|
+
num_query_groups: int = None,
|
667
|
+
) -> GroupedQueryAttention:
|
668
|
+
assert attention_type == 'gma' or attention_type == 'dma' or attention_type == 'gma_v' or attention_type == 'dma_v', \
|
669
|
+
"Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
|
670
|
+
|
671
|
+
if attention_type == "gma":
|
672
|
+
return GroupedMoeAttention(
|
673
|
+
embed_dim,
|
674
|
+
num_heads,
|
675
|
+
gqa_groups,
|
676
|
+
dropout=dropout,
|
677
|
+
rope=rope,
|
678
|
+
use_relative_embeddings=use_relative_embeddings,
|
679
|
+
max_seq_len=max_seq_len,
|
680
|
+
rope_only_for_query=rope_only_for_query,
|
681
|
+
use_flash_attention=use_flash_attention,
|
682
|
+
is_causal=is_causal,
|
683
|
+
use_bias=use_bias,
|
684
|
+
num_experts=num_experts,
|
685
|
+
)
|
686
|
+
elif attention_type == "dma":
|
687
|
+
return DeepMoeAttention(
|
688
|
+
embed_dim,
|
689
|
+
num_heads,
|
690
|
+
gqa_groups,
|
691
|
+
dropout=dropout,
|
692
|
+
rope=rope,
|
693
|
+
use_relative_embeddings=use_relative_embeddings,
|
694
|
+
max_seq_len=max_seq_len,
|
695
|
+
rope_only_for_query=rope_only_for_query,
|
696
|
+
use_flash_attention=use_flash_attention,
|
697
|
+
is_causal=is_causal,
|
698
|
+
use_bias=use_bias,
|
699
|
+
num_experts=num_experts,
|
700
|
+
num_query_experts=num_query_experts,
|
701
|
+
num_query_groups=num_query_groups,
|
702
|
+
)
|
703
|
+
elif attention_type == "gma_v":
|
704
|
+
return GroupedMoeAttentionVectorized(
|
705
|
+
embed_dim,
|
706
|
+
num_heads,
|
707
|
+
gqa_groups,
|
708
|
+
dropout=dropout,
|
709
|
+
rope=rope,
|
710
|
+
use_relative_embeddings=use_relative_embeddings,
|
711
|
+
max_seq_len=max_seq_len,
|
712
|
+
rope_only_for_query=rope_only_for_query,
|
713
|
+
use_flash_attention=use_flash_attention,
|
714
|
+
is_causal=is_causal,
|
715
|
+
use_bias=use_bias,
|
716
|
+
num_experts=num_experts,
|
717
|
+
)
|
718
|
+
else:
|
719
|
+
return DeepMoeAttentionVectorized(
|
720
|
+
embed_dim,
|
721
|
+
num_heads,
|
722
|
+
gqa_groups,
|
723
|
+
dropout=dropout,
|
724
|
+
rope=rope,
|
725
|
+
use_relative_embeddings=use_relative_embeddings,
|
726
|
+
max_seq_len=max_seq_len,
|
727
|
+
rope_only_for_query=rope_only_for_query,
|
728
|
+
use_flash_attention=use_flash_attention,
|
729
|
+
is_causal=is_causal,
|
730
|
+
use_bias=use_bias,
|
731
|
+
num_experts=num_experts,
|
732
|
+
num_query_experts=num_query_experts,
|
733
|
+
num_query_groups=num_query_groups,
|
734
|
+
)
|
@@ -0,0 +1,117 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
from typing import TypedDict, Union
|
4
|
+
from huggingface_hub import PyTorchModelHubMixin
|
5
|
+
from ..transformers.positional import RotaryPositionalEmbedding
|
6
|
+
from ..transformers.attention import init_attention
|
7
|
+
from ..transformers.layers import ClassicTransformerLayer
|
8
|
+
from ..transformers.models import ClassicTransformerDecoder
|
9
|
+
from ..transformers.ff import get_activation_layer
|
10
|
+
from ..memory.stm import ShortTermMemory
|
11
|
+
from ..utils import get_model_size
|
12
|
+
from .attention import init_moe_attention
|
13
|
+
|
14
|
+
|
15
|
+
class MoeAttentionTransformerConfig(TypedDict):
|
16
|
+
num_layers: int
|
17
|
+
vocab_size: int
|
18
|
+
embed_dim: int
|
19
|
+
ff_dim: int
|
20
|
+
att_heads: int
|
21
|
+
seq_len: int
|
22
|
+
use_flash_attention: bool
|
23
|
+
use_gated: bool
|
24
|
+
ff_activation: str
|
25
|
+
ff_dropout: float
|
26
|
+
att_dropout: float
|
27
|
+
use_rms_norm: bool
|
28
|
+
att_groups: int
|
29
|
+
use_moe_ff: bool
|
30
|
+
ff_num_experts: int
|
31
|
+
ff_moe_top_k: int
|
32
|
+
att_type: str
|
33
|
+
att_num_experts: int
|
34
|
+
att_num_query_experts: int
|
35
|
+
att_num_query_groups: int
|
36
|
+
|
37
|
+
|
38
|
+
class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
|
39
|
+
"""Research model for experiments with Mixture-of-Experts Attention"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
num_layers: int = 6,
|
44
|
+
vocab_size: int = 5000,
|
45
|
+
embed_dim: int = 128,
|
46
|
+
ff_dim: int = 384,
|
47
|
+
att_heads: int = 16,
|
48
|
+
seq_len: int = 256,
|
49
|
+
use_flash_attention: bool = True,
|
50
|
+
use_gated: bool = True,
|
51
|
+
ff_activation: str = "swish",
|
52
|
+
ff_dropout: float = 0.0,
|
53
|
+
att_dropout: float = 0.0,
|
54
|
+
use_rms_norm: bool = True,
|
55
|
+
att_groups: int = 1,
|
56
|
+
use_moe_ff: bool = False,
|
57
|
+
ff_num_experts: int = 1,
|
58
|
+
ff_moe_top_k: int = 1,
|
59
|
+
att_type: str = 'gma',
|
60
|
+
att_num_experts: int = None,
|
61
|
+
att_num_query_experts: int = None,
|
62
|
+
att_num_query_groups: int = None,
|
63
|
+
**kwargs
|
64
|
+
):
|
65
|
+
super(MoeAttentionTransformer, self).__init__(**kwargs)
|
66
|
+
assert ff_activation in ['relu', 'gelu',
|
67
|
+
'swish', 'silu', 'linear',
|
68
|
+
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
69
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v',
|
70
|
+
'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v"'
|
71
|
+
|
72
|
+
embedding = nn.Embedding(vocab_size, embed_dim)
|
73
|
+
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
74
|
+
|
75
|
+
ff_activation = get_activation_layer(ff_activation)
|
76
|
+
|
77
|
+
if att_type in ['mha', 'gqa', 'mqa']:
|
78
|
+
att_init = lambda: init_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
79
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
80
|
+
max_seq_len=seq_len, is_causal=True)
|
81
|
+
else:
|
82
|
+
att_init = lambda: init_moe_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
83
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
84
|
+
max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
|
85
|
+
num_query_experts=att_num_query_experts,
|
86
|
+
num_query_groups=att_num_query_groups)
|
87
|
+
|
88
|
+
self.model = ClassicTransformerDecoder(
|
89
|
+
embed_dim,
|
90
|
+
vocab_size,
|
91
|
+
embedding=embedding,
|
92
|
+
layers=nn.ModuleList([
|
93
|
+
ClassicTransformerLayer(
|
94
|
+
embed_dim,
|
95
|
+
ff_dim,
|
96
|
+
use_gated=use_gated,
|
97
|
+
use_moe=use_moe_ff,
|
98
|
+
num_experts=ff_num_experts,
|
99
|
+
moe_top_k=ff_moe_top_k,
|
100
|
+
ff_activation=ff_activation,
|
101
|
+
ff_dropout=ff_dropout,
|
102
|
+
use_rms_norm=use_rms_norm,
|
103
|
+
self_attention=att_init(),
|
104
|
+
) for _ in range(num_layers)
|
105
|
+
]),
|
106
|
+
use_flash_attention=use_flash_attention,
|
107
|
+
)
|
108
|
+
|
109
|
+
def params_count(self):
|
110
|
+
return get_model_size(self.model)
|
111
|
+
|
112
|
+
def load_shared_embedding(self, embedding: nn.Embedding):
|
113
|
+
self.model.embedding = embedding
|
114
|
+
|
115
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
|
116
|
+
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
117
|
+
return self.model(x, attention_mask=attention_mask)
|
rxnn/experimental/moe.py
ADDED
@@ -0,0 +1,206 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from ..transformers.moe import MoeRouter
|
5
|
+
|
6
|
+
class DynamicMoeRouter(nn.Module):
|
7
|
+
"""Dynamic Mixture-of-Experts Router layer - dynamically selects top-k experts for each token."""
|
8
|
+
|
9
|
+
def __init__(self, embed_dim: int, num_experts: int, top_ks: tuple[int] = (1, 2, 3), *args, **kwargs):
|
10
|
+
super(DynamicMoeRouter, self).__init__(*args, **kwargs)
|
11
|
+
self.top_ks = top_ks
|
12
|
+
self.num_options = len(top_ks)
|
13
|
+
self.num_experts = num_experts
|
14
|
+
self.gate = nn.Linear(embed_dim, num_experts + self.num_options, bias=False)
|
15
|
+
# For expert load balancing
|
16
|
+
self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
|
17
|
+
|
18
|
+
def calculate_aux_loss(self, top_k_indices: torch.Tensor, routing_probs: torch.Tensor) -> torch.Tensor:
|
19
|
+
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
20
|
+
expert_usage = expert_mask.sum(dim=0).mean(dim=0)
|
21
|
+
mean_probs = routing_probs.mean(dim=0)
|
22
|
+
return (expert_usage * mean_probs).sum() * self.num_experts
|
23
|
+
|
24
|
+
def forward(self, x: torch.Tensor):
|
25
|
+
# Input shape: [batch*seq_len, embed_dim]
|
26
|
+
all_logits = self.gate(x)
|
27
|
+
routing_logits = all_logits[:, :-self.num_options]
|
28
|
+
options_logits = all_logits[:, -self.num_options:]
|
29
|
+
|
30
|
+
routing_probs = F.softmax(routing_logits, dim=-1)
|
31
|
+
top_k_id = torch.argmax(options_logits, dim=-1).item()
|
32
|
+
|
33
|
+
top_k = self.top_ks[top_k_id]
|
34
|
+
|
35
|
+
# Get top-k experts for each token
|
36
|
+
top_k_weights, top_k_indices = routing_probs.topk(top_k, dim=-1)
|
37
|
+
|
38
|
+
# Normalize weights (sum to 1 for each token)
|
39
|
+
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
40
|
+
|
41
|
+
# Load Balance Loss
|
42
|
+
self.aux_loss = self.calculate_aux_loss(top_k_indices, routing_probs)
|
43
|
+
|
44
|
+
return top_k_weights, top_k_indices, top_k
|
45
|
+
|
46
|
+
class MoeFeedForwardVectorized(nn.Module):
|
47
|
+
"""
|
48
|
+
Vectorized MoE - current implementation is incorrect - it calculates all the experts, then selects the correct ones.
|
49
|
+
|
50
|
+
Commented out implementation is fixing this problem, but is causing memory overflows, because of experts weights
|
51
|
+
indexing - it's using ~15x more memory, than dense model of similar size, so it's currently not viable.
|
52
|
+
|
53
|
+
It's recommended to use standard MoE from rxnn.transformers.moe instead.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
embed_dim: int,
|
59
|
+
hidden_dim: int,
|
60
|
+
num_experts: int,
|
61
|
+
activation: nn.Module,
|
62
|
+
top_k: int = 1,
|
63
|
+
dropout: float = 0.0,
|
64
|
+
*args,
|
65
|
+
**kwargs
|
66
|
+
):
|
67
|
+
super(MoeFeedForwardVectorized, self).__init__(*args, **kwargs)
|
68
|
+
self.embed_dim = embed_dim
|
69
|
+
self.num_experts = num_experts
|
70
|
+
self.top_k = top_k
|
71
|
+
|
72
|
+
self.router = MoeRouter(embed_dim, num_experts, top_k)
|
73
|
+
|
74
|
+
# Batch all expert parameters together
|
75
|
+
self.w1 = nn.Parameter(torch.empty(num_experts, embed_dim, self._w1_dim_factor(hidden_dim)))
|
76
|
+
self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
|
77
|
+
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
|
78
|
+
self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
|
79
|
+
self.activation = activation
|
80
|
+
self.dropout = nn.Dropout(dropout)
|
81
|
+
|
82
|
+
# Initialize parameters
|
83
|
+
self._init_linear_parameters()
|
84
|
+
nn.init.zeros_(self.b1)
|
85
|
+
nn.init.zeros_(self.b2)
|
86
|
+
|
87
|
+
def _init_linear_parameters(self):
|
88
|
+
nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
|
89
|
+
nn.init.kaiming_normal_(self.w2, nonlinearity='relu')
|
90
|
+
|
91
|
+
def _w1_dim_factor(self, hidden_dim: int) -> int:
|
92
|
+
return hidden_dim
|
93
|
+
|
94
|
+
def _activate(self, h: torch.Tensor):
|
95
|
+
return self.activation(h)
|
96
|
+
|
97
|
+
def router_loss(self):
|
98
|
+
return self.router.aux_loss
|
99
|
+
|
100
|
+
def forward(self, x: torch.Tensor):
|
101
|
+
orig_shape = x.shape
|
102
|
+
x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
103
|
+
|
104
|
+
# Get routing weights and indices
|
105
|
+
weights, indices = self.router(x) # [batch*seq_len, top_k]
|
106
|
+
|
107
|
+
# Create expert masks and combine it with masks
|
108
|
+
mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
|
109
|
+
weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
|
110
|
+
|
111
|
+
# Expert computation
|
112
|
+
x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
|
113
|
+
|
114
|
+
# First linear layer
|
115
|
+
h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
|
116
|
+
h = self._activate(h)
|
117
|
+
h = self.dropout(h)
|
118
|
+
|
119
|
+
# Second linear layer (projection back to embed_dim)
|
120
|
+
out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
|
121
|
+
|
122
|
+
# Weighted sum of expert outputs
|
123
|
+
out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
|
124
|
+
|
125
|
+
return out.view(*orig_shape)
|
126
|
+
# orig_shape = x.shape
|
127
|
+
# x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
128
|
+
#
|
129
|
+
# # Get routing weights and indices
|
130
|
+
# weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
|
131
|
+
#
|
132
|
+
# # Flatten indices and weights
|
133
|
+
# batch_size = x.shape[0]
|
134
|
+
# top_k = indices.shape[1]
|
135
|
+
# indices_flat = indices.view(-1) # [B*T * top_k]
|
136
|
+
#
|
137
|
+
# # Compute contributions for selected experts without materializing large tensors
|
138
|
+
# # First Layer:
|
139
|
+
# # Compute all expert contributions first (but this may still be memory-heavy)
|
140
|
+
# # Alternative: Compute contributions for selected experts directly
|
141
|
+
# # ... (see detailed steps below)
|
142
|
+
#
|
143
|
+
# # Alternative approach using gather and batched operations
|
144
|
+
# x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [B*T*top_k, D]
|
145
|
+
#
|
146
|
+
# # Compute first layer contributions using gather
|
147
|
+
# # indices_flat has shape [B*T*top_k]
|
148
|
+
# # selected_w1 is self.w1[indices_flat], but we compute the product inline
|
149
|
+
# h = torch.einsum(
|
150
|
+
# 'be, eih -> bh',
|
151
|
+
# x_expanded,
|
152
|
+
# self.w1[indices_flat]
|
153
|
+
# ) + self.b1[indices_flat]
|
154
|
+
# h = self._activate(h)
|
155
|
+
# h = self.dropout(h)
|
156
|
+
#
|
157
|
+
# # Second layer:
|
158
|
+
# out = torch.einsum(
|
159
|
+
# 'bh, eho -> beo',
|
160
|
+
# h,
|
161
|
+
# self.w2[indices_flat]
|
162
|
+
# ).squeeze(-1) + self.b2[indices_flat]
|
163
|
+
#
|
164
|
+
# # Reshape and apply weights
|
165
|
+
# out = out.view(batch_size, top_k, -1)
|
166
|
+
# weights = weights.view(batch_size, top_k, 1)
|
167
|
+
# out = (out * weights).sum(dim=1)
|
168
|
+
#
|
169
|
+
# return out.view(*orig_shape)
|
170
|
+
|
171
|
+
|
172
|
+
class GatedMoeFeedForwardVectorized(MoeFeedForwardVectorized):
|
173
|
+
"""Gated Mixture-of-Experts Feed-Forward layer - enable GLU-based activations for MoE"""
|
174
|
+
|
175
|
+
def __init__(
|
176
|
+
self,
|
177
|
+
embed_dim: int,
|
178
|
+
hidden_dim: int,
|
179
|
+
num_experts: int,
|
180
|
+
activation: nn.Module = nn.SiLU(),
|
181
|
+
top_k: int = 1,
|
182
|
+
dropout: float = 0.1,
|
183
|
+
*args,
|
184
|
+
**kwargs
|
185
|
+
):
|
186
|
+
super(GatedMoeFeedForwardVectorized, self).__init__(
|
187
|
+
embed_dim=embed_dim,
|
188
|
+
hidden_dim=hidden_dim,
|
189
|
+
num_experts=num_experts,
|
190
|
+
activation=activation,
|
191
|
+
top_k=top_k,
|
192
|
+
dropout=dropout,
|
193
|
+
*args,
|
194
|
+
**kwargs
|
195
|
+
)
|
196
|
+
|
197
|
+
def _init_linear_parameters(self):
|
198
|
+
nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
|
199
|
+
nn.init.kaiming_normal_(self.w2, nonlinearity='linear')
|
200
|
+
|
201
|
+
def _w1_dim_factor(self, hidden_dim: int) -> int:
|
202
|
+
return 2 * hidden_dim
|
203
|
+
|
204
|
+
def _activate(self, h: torch.Tensor):
|
205
|
+
a, b = h.chunk(2, dim=-1)
|
206
|
+
return a * self.activation(b)
|
rxnn/transformers/moe.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
-
|
4
|
+
from .ff import FeedForward, GatedFeedForward
|
5
5
|
|
6
6
|
class MoeRouter(nn.Module):
|
7
7
|
"""Mixture-of-Experts Router layer - computes routing weights for each expert."""
|
@@ -14,18 +14,27 @@ class MoeRouter(nn.Module):
|
|
14
14
|
# For expert load balancing
|
15
15
|
self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
|
16
16
|
|
17
|
+
def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
|
18
|
+
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
19
|
+
expert_usage = expert_mask.sum(dim=0).mean(dim=0)
|
20
|
+
mean_probs = probs.mean(dim=0)
|
21
|
+
return (expert_usage * mean_probs).sum() * self.num_experts
|
22
|
+
|
23
|
+
|
17
24
|
def forward(self, x: torch.Tensor):
|
18
|
-
#
|
25
|
+
# Input shape: [batch*seq_len, embed_dim]
|
19
26
|
logits = self.gate(x)
|
20
27
|
probs = F.softmax(logits, dim=-1)
|
21
28
|
|
22
|
-
#
|
23
|
-
mean_probs = probs.mean(dim=0) # Mean probability per expert across batch
|
24
|
-
self.aux_loss = (mean_probs * torch.log(mean_probs + 1e-9)).sum() # Entropy-based loss
|
25
|
-
|
29
|
+
# Get top-k experts for each token
|
26
30
|
top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
|
31
|
+
|
32
|
+
# Normalize weights (sum to 1 for each token)
|
27
33
|
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
28
34
|
|
35
|
+
# Load Balance Loss
|
36
|
+
self.aux_loss = self.calculate_aux_loss(top_k_indices, probs)
|
37
|
+
|
29
38
|
return top_k_weights, top_k_indices
|
30
39
|
|
31
40
|
|
@@ -51,91 +60,43 @@ class MoeFeedForward(nn.Module):
|
|
51
60
|
self.router = MoeRouter(embed_dim, num_experts, top_k)
|
52
61
|
|
53
62
|
# Batch all expert parameters together
|
54
|
-
self.
|
55
|
-
self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
|
56
|
-
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
|
57
|
-
self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
|
58
|
-
self.activation = activation
|
59
|
-
self.dropout = nn.Dropout(dropout)
|
60
|
-
|
61
|
-
# Initialize parameters
|
62
|
-
self._init_linear_parameters()
|
63
|
-
nn.init.zeros_(self.b1)
|
64
|
-
nn.init.zeros_(self.b2)
|
63
|
+
self._init_experts(num_experts, embed_dim, hidden_dim, activation, dropout)
|
65
64
|
|
66
|
-
def
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
return hidden_dim
|
72
|
-
|
73
|
-
def _activate(self, h: torch.Tensor):
|
74
|
-
return self.activation(h)
|
65
|
+
def _init_experts(self, num_experts: int, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float):
|
66
|
+
self.experts = nn.ModuleList([
|
67
|
+
FeedForward(embed_dim, hidden_dim, activation, dropout)
|
68
|
+
for _ in range(num_experts)
|
69
|
+
])
|
75
70
|
|
76
71
|
def router_loss(self):
|
77
72
|
return self.router.aux_loss
|
78
73
|
|
79
74
|
def forward(self, x: torch.Tensor):
|
80
|
-
# orig_shape = x.shape
|
81
|
-
# x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
82
|
-
#
|
83
|
-
# # Get routing weights and indices
|
84
|
-
# weights, indices = self.router(x) # [batch*seq_len, top_k]
|
85
|
-
#
|
86
|
-
# # Create expert masks and combine it with masks
|
87
|
-
# mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
|
88
|
-
# weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
|
89
|
-
#
|
90
|
-
# # Expert computation
|
91
|
-
# x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
|
92
|
-
#
|
93
|
-
# # First linear layer
|
94
|
-
# h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
|
95
|
-
# h = self._activate(h)
|
96
|
-
# h = self.dropout(h)
|
97
|
-
#
|
98
|
-
# # Second linear layer (projection back to embed_dim)
|
99
|
-
# out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
|
100
|
-
#
|
101
|
-
# # Weighted sum of expert outputs
|
102
|
-
# out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
|
103
|
-
#
|
104
|
-
# return out.view(*orig_shape)
|
105
75
|
orig_shape = x.shape
|
106
76
|
x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
107
77
|
|
108
78
|
# Get routing weights and indices
|
109
|
-
weights, indices = self.router(x) # [
|
110
|
-
|
111
|
-
# Flatten indices and weights
|
112
|
-
batch_size = x.size(0)
|
113
|
-
top_k = indices.size(1)
|
114
|
-
indices = indices.view(-1) # [batch*seq_len * top_k]
|
115
|
-
weights = weights.view(-1, 1) # [batch*seq_len * top_k, 1]
|
79
|
+
weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
|
116
80
|
|
117
|
-
#
|
118
|
-
|
119
|
-
|
120
|
-
selected_w2 = self.w2[indices] # [batch*seq_len * top_k, hidden_dim, embed_dim]
|
121
|
-
selected_b2 = self.b2[indices] # [batch*seq_len * top_k, embed_dim]
|
81
|
+
# Create mask for expert contributions (B*T, num_experts)
|
82
|
+
expert_mask = F.one_hot(indices, self.num_experts).float() # [B*T, top_k, num_experts]
|
83
|
+
expert_weights = (weights.unsqueeze(-1) * expert_mask).sum(dim=1) # [B*T, num_experts]
|
122
84
|
|
123
|
-
|
124
|
-
|
85
|
+
output = torch.zeros_like(x)
|
86
|
+
for expert_idx in range(self.num_experts):
|
87
|
+
# Mask for tokens where this expert is in top_k
|
88
|
+
mask = expert_weights[:, expert_idx] > 0
|
89
|
+
if not mask.any():
|
90
|
+
continue
|
125
91
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
h = self.dropout(h)
|
92
|
+
# Compute expert output for selected tokens
|
93
|
+
expert_input = x[mask]
|
94
|
+
expert_output = self.experts[expert_idx](expert_input)
|
130
95
|
|
131
|
-
|
96
|
+
# Apply combined weights for this expert
|
97
|
+
output[mask] += expert_output * expert_weights[mask, expert_idx].unsqueeze(-1)
|
132
98
|
|
133
|
-
|
134
|
-
out = out.view(batch_size, top_k, -1) # [batch*seq_len, top_k, embed_dim]
|
135
|
-
weights = weights.view(batch_size, top_k, 1) # [batch*seq_len, top_k, 1]
|
136
|
-
out = (out * weights).sum(dim=1) # Weighted sum over top_k experts
|
137
|
-
|
138
|
-
return out.view(*orig_shape)
|
99
|
+
return output.view(*orig_shape)
|
139
100
|
|
140
101
|
|
141
102
|
class GatedMoeFeedForward(MoeFeedForward):
|
@@ -163,13 +124,8 @@ class GatedMoeFeedForward(MoeFeedForward):
|
|
163
124
|
**kwargs
|
164
125
|
)
|
165
126
|
|
166
|
-
def
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
return 2 * hidden_dim
|
172
|
-
|
173
|
-
def _activate(self, h: torch.Tensor):
|
174
|
-
a, b = h.chunk(2, dim=-1)
|
175
|
-
return a * self.activation(b)
|
127
|
+
def _init_experts(self, num_experts: int, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float):
|
128
|
+
self.experts = nn.ModuleList([
|
129
|
+
GatedFeedForward(embed_dim, hidden_dim, activation, dropout)
|
130
|
+
for _ in range(num_experts)
|
131
|
+
])
|
@@ -1,6 +1,8 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=qly-Lf9UsYC9JB945JcLnt27ZbF0vFvfyS5iUm-Rsak,31644
|
4
|
+
rxnn/experimental/models.py,sha256=ioYtbJDxJ4zASiKs9dFY4WvAJn7eVqFf7zid-65pbUU,4709
|
5
|
+
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
4
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
6
8
|
rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
@@ -19,11 +21,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
|
19
21
|
rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
|
20
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
21
23
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
22
|
-
rxnn/transformers/moe.py,sha256=
|
24
|
+
rxnn/transformers/moe.py,sha256=FeaQR7hTX1dE74YdMOcuyZHSkGiV_0JwF8fw-GnfNOQ,4741
|
23
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
24
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
25
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
26
|
-
rxnn-0.1.
|
27
|
-
rxnn-0.1.
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.16.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.16.dist-info/METADATA,sha256=Cr_8OPHWlf2LHYlZEmc_NaUkIiE3ShJ01Z5B5ZhI6G8,14629
|
30
|
+
rxnn-0.1.16.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|