rxnn 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +451 -88
- rxnn/experimental/models.py +116 -0
- rxnn/experimental/moe.py +206 -0
- rxnn/transformers/moe.py +45 -99
- {rxnn-0.1.15.dist-info → rxnn-0.1.17.dist-info}/METADATA +24 -1
- {rxnn-0.1.15.dist-info → rxnn-0.1.17.dist-info}/RECORD +8 -6
- {rxnn-0.1.15.dist-info → rxnn-0.1.17.dist-info}/LICENSE +0 -0
- {rxnn-0.1.15.dist-info → rxnn-0.1.17.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import torch
|
2
|
-
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
3
4
|
from ..transformers.attention import MultiHeadAttention, GroupedQueryAttention
|
4
5
|
from ..transformers.positional import RotaryPositionalEmbedding
|
5
6
|
from ..transformers.moe import MoeRouter
|
@@ -9,6 +10,7 @@ from ..transformers.moe import MoeRouter
|
|
9
10
|
class GroupedMoeAttention(GroupedQueryAttention):
|
10
11
|
"""
|
11
12
|
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
13
|
+
|
12
14
|
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
13
15
|
number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
|
14
16
|
- with num_groups set to 1, it will be MoE MultiQueryAttention
|
@@ -20,8 +22,11 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
20
22
|
|
21
23
|
Optionally, it could use even more expert heads than attention heads - in example:
|
22
24
|
- 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
|
23
|
-
|
25
|
+
4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
|
26
|
+
|
27
|
+
© 2025 Adam Filipek
|
24
28
|
"""
|
29
|
+
|
25
30
|
def __init__(
|
26
31
|
self,
|
27
32
|
embed_dim: int,
|
@@ -39,7 +44,7 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
39
44
|
*args,
|
40
45
|
**kwargs,
|
41
46
|
):
|
42
|
-
self.num_experts = num_experts
|
47
|
+
self.num_experts = num_experts or num_heads
|
43
48
|
super(GroupedMoeAttention, self).__init__(
|
44
49
|
embed_dim,
|
45
50
|
num_heads,
|
@@ -58,7 +63,237 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
58
63
|
|
59
64
|
def _init_kv(self, embed_dim: int):
|
60
65
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
61
|
-
|
66
|
+
|
67
|
+
hidden_dim = embed_dim // self.num_heads
|
68
|
+
self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
69
|
+
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
70
|
+
self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
71
|
+
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
72
|
+
self._init_experts()
|
73
|
+
|
74
|
+
def _init_experts(self):
|
75
|
+
nn.init.xavier_uniform_(self.wk)
|
76
|
+
nn.init.xavier_uniform_(self.wv)
|
77
|
+
if self.use_bias:
|
78
|
+
nn.init.zeros_(self.bk)
|
79
|
+
nn.init.zeros_(self.bv)
|
80
|
+
|
81
|
+
def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
|
82
|
+
B, S, G = indices.shape
|
83
|
+
x_flat = x.view(-1, x.size(-1))
|
84
|
+
|
85
|
+
# Flatten batch and sequence dimensions
|
86
|
+
indices_flat = indices.view(-1, G)
|
87
|
+
weights_flat = weights.view(-1, G, 1)
|
88
|
+
|
89
|
+
# Create expanded indices for expert processing
|
90
|
+
mask = torch.zeros(B * S, self.num_experts, device=x.device, dtype=torch.bool)
|
91
|
+
for g in range(G):
|
92
|
+
mask.scatter_(1, indices_flat[:, g].unsqueeze(1), True)
|
93
|
+
|
94
|
+
output = torch.zeros(B * S, G, w.size(2), device=x.device, dtype=x.dtype)
|
95
|
+
|
96
|
+
for e in range(self.num_experts):
|
97
|
+
token_mask = mask[:, e]
|
98
|
+
if not token_mask.any():
|
99
|
+
continue
|
100
|
+
|
101
|
+
# Get positions where expert e is used in any group
|
102
|
+
x_slice = x_flat[token_mask]
|
103
|
+
proj = F.linear(x_slice, w[e], b[e] if b is not None else None)
|
104
|
+
|
105
|
+
# Find which groups use this expert for selected tokens
|
106
|
+
group_mask = (indices_flat[token_mask] == e)
|
107
|
+
|
108
|
+
# Accumulate projections for relevant groups
|
109
|
+
weighted_proj = proj.unsqueeze(1) * weights_flat[token_mask] * group_mask.unsqueeze(-1).float()
|
110
|
+
output[token_mask] += weighted_proj.sum(dim=1)
|
111
|
+
|
112
|
+
return output.view(B, S, G, -1)
|
113
|
+
|
114
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
115
|
+
skip_query_processing: bool = False):
|
116
|
+
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
|
117
|
+
|
118
|
+
# Key/Value processing
|
119
|
+
B, S, D = key.shape
|
120
|
+
key_flat = key.view(-1, D)
|
121
|
+
weights_k_flat, indices_k_flat = self.router(key_flat)
|
122
|
+
# Reshape back to original dimensions
|
123
|
+
weights_k = weights_k_flat.view(B, S, -1)
|
124
|
+
indices_k = indices_k_flat.view(B, S, -1)
|
125
|
+
k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
|
126
|
+
v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
|
127
|
+
|
128
|
+
# Expand to GQA format
|
129
|
+
k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
130
|
+
v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
131
|
+
|
132
|
+
if not self.use_flash_attention:
|
133
|
+
group_heads = self.num_heads // self.num_groups
|
134
|
+
|
135
|
+
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
136
|
+
v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
137
|
+
|
138
|
+
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
139
|
+
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
140
|
+
|
141
|
+
return q, k, v
|
142
|
+
|
143
|
+
|
144
|
+
class DeepMoeAttention(GroupedMoeAttention):
|
145
|
+
"""
|
146
|
+
Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
147
|
+
|
148
|
+
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
149
|
+
query heads - with that approach, each token could attend to every other token, but only partially - only some part of
|
150
|
+
information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
|
151
|
+
sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
|
152
|
+
|
153
|
+
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
|
154
|
+
a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
|
155
|
+
|
156
|
+
© 2025 Adam Filipek
|
157
|
+
"""
|
158
|
+
|
159
|
+
def __init__(
|
160
|
+
self,
|
161
|
+
embed_dim: int,
|
162
|
+
num_heads: int,
|
163
|
+
num_groups: int,
|
164
|
+
dropout: float = 0.0,
|
165
|
+
rope: RotaryPositionalEmbedding = None,
|
166
|
+
rope_only_for_query: bool = False,
|
167
|
+
use_relative_embeddings: bool = False,
|
168
|
+
max_seq_len: int = 1024,
|
169
|
+
use_flash_attention: bool = False,
|
170
|
+
is_causal: bool = False,
|
171
|
+
use_bias: bool = False,
|
172
|
+
num_experts: int = None,
|
173
|
+
num_query_experts: int = None,
|
174
|
+
num_query_groups: int = None,
|
175
|
+
*args,
|
176
|
+
**kwargs,
|
177
|
+
):
|
178
|
+
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
179
|
+
self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
|
180
|
+
super(DeepMoeAttention, self).__init__(
|
181
|
+
embed_dim,
|
182
|
+
num_heads,
|
183
|
+
num_groups=num_groups,
|
184
|
+
dropout=dropout,
|
185
|
+
rope=rope,
|
186
|
+
rope_only_for_query=rope_only_for_query,
|
187
|
+
use_relative_embeddings=use_relative_embeddings,
|
188
|
+
max_seq_len=max_seq_len,
|
189
|
+
use_flash_attention=use_flash_attention,
|
190
|
+
is_causal=is_causal,
|
191
|
+
use_bias=use_bias,
|
192
|
+
num_experts=num_experts,
|
193
|
+
*args,
|
194
|
+
**kwargs,
|
195
|
+
)
|
196
|
+
|
197
|
+
def _init_q(self, embed_dim: int):
|
198
|
+
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
199
|
+
|
200
|
+
hidden_dim = embed_dim // self.num_heads
|
201
|
+
self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
202
|
+
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
203
|
+
self._init_query_experts()
|
204
|
+
|
205
|
+
def _init_query_experts(self):
|
206
|
+
nn.init.xavier_uniform_(self.wq)
|
207
|
+
if self.use_bias:
|
208
|
+
nn.init.zeros_(self.bq)
|
209
|
+
|
210
|
+
def _init_out(self, embed_dim: int):
|
211
|
+
"""Initialize output projection"""
|
212
|
+
hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
|
213
|
+
self.out_proj = nn.Linear(hidden_dim, embed_dim)
|
214
|
+
|
215
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
|
216
|
+
# Query processing
|
217
|
+
B, T, D = query.shape
|
218
|
+
# Flatten for query routing
|
219
|
+
query_flat = query.view(B * T, D)
|
220
|
+
weights_q_flat, indices_q_flat = self.query_router(query_flat)
|
221
|
+
# Reshape back
|
222
|
+
weights_q = weights_q_flat.view(B, T, -1)
|
223
|
+
indices_q = indices_q_flat.view(B, T, -1)
|
224
|
+
q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
|
225
|
+
q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
|
226
|
+
|
227
|
+
# Expand query groups to match head count
|
228
|
+
group_heads = self.num_heads // self.num_query_groups
|
229
|
+
q = q.unsqueeze(2).expand(-1, -1, group_heads, -1, -1).flatten(1, 2).transpose(1, 2)
|
230
|
+
|
231
|
+
# Key/Value processing
|
232
|
+
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
233
|
+
|
234
|
+
# Vectorized
|
235
|
+
|
236
|
+
class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
237
|
+
"""
|
238
|
+
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
239
|
+
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
240
|
+
experts - it has to be tested.
|
241
|
+
|
242
|
+
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
243
|
+
|
244
|
+
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
245
|
+
number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
|
246
|
+
- with num_groups set to 1, it will be MoE MultiQueryAttention
|
247
|
+
|
248
|
+
Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
|
249
|
+
this approach - we are training the full number of keys/values heads, while using only a group.
|
250
|
+
|
251
|
+
In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
|
252
|
+
|
253
|
+
Optionally, it could use even more expert heads than attention heads - in example:
|
254
|
+
- 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
|
255
|
+
4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
|
256
|
+
|
257
|
+
© 2025 Adam Filipek
|
258
|
+
"""
|
259
|
+
|
260
|
+
def __init__(
|
261
|
+
self,
|
262
|
+
embed_dim: int,
|
263
|
+
num_heads: int,
|
264
|
+
num_groups: int,
|
265
|
+
dropout: float = 0.0,
|
266
|
+
rope: RotaryPositionalEmbedding = None,
|
267
|
+
rope_only_for_query: bool = False,
|
268
|
+
use_relative_embeddings: bool = False,
|
269
|
+
max_seq_len: int = 1024,
|
270
|
+
use_flash_attention: bool = False,
|
271
|
+
is_causal: bool = False,
|
272
|
+
use_bias: bool = False,
|
273
|
+
num_experts: int = None,
|
274
|
+
*args,
|
275
|
+
**kwargs,
|
276
|
+
):
|
277
|
+
self.num_experts = num_experts if num_experts is not None else num_heads
|
278
|
+
super(GroupedMoeAttentionVectorized, self).__init__(
|
279
|
+
embed_dim,
|
280
|
+
num_heads,
|
281
|
+
num_groups=num_groups,
|
282
|
+
dropout=dropout,
|
283
|
+
rope=rope,
|
284
|
+
rope_only_for_query=rope_only_for_query,
|
285
|
+
use_relative_embeddings=use_relative_embeddings,
|
286
|
+
max_seq_len=max_seq_len,
|
287
|
+
use_flash_attention=use_flash_attention,
|
288
|
+
is_causal=is_causal,
|
289
|
+
use_bias=use_bias,
|
290
|
+
*args,
|
291
|
+
**kwargs,
|
292
|
+
)
|
293
|
+
|
294
|
+
def _init_kv(self, embed_dim: int):
|
295
|
+
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
296
|
+
hidden_dim = embed_dim // self.num_heads
|
62
297
|
self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
63
298
|
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
64
299
|
self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
@@ -72,47 +307,37 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
72
307
|
torch.nn.init.zeros_(self.bk)
|
73
308
|
torch.nn.init.zeros_(self.bv)
|
74
309
|
|
75
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
310
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
311
|
+
skip_query_processing: bool = False):
|
312
|
+
# Indexed version may cause memory overflow
|
313
|
+
#
|
76
314
|
# head_dim = d // self.num_heads
|
77
|
-
# group_heads = self.num_heads // self.num_groups
|
78
315
|
#
|
79
316
|
# # Process Query as in GQA
|
80
|
-
# q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
|
317
|
+
# q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
|
318
|
+
# 2) if not skip_query_processing else query
|
81
319
|
#
|
82
320
|
# # Process Key and Value with MoE routing
|
83
|
-
# key_flat = key.view(-1, d)
|
84
|
-
#
|
85
|
-
# weights = weights.view(b, key.size(1), self.num_groups, 1)
|
86
|
-
# indices = indices.view(b, key.size(1), self.num_groups)
|
321
|
+
# key_flat = key.view(-1, d) # (B*S, d)
|
322
|
+
# value_flat = value.view(-1, d) # (B*S, d)
|
87
323
|
#
|
88
|
-
# #
|
89
|
-
#
|
90
|
-
#
|
91
|
-
#
|
92
|
-
# key_flat,
|
93
|
-
# self.wk.view(self.num_experts, d, -1)
|
94
|
-
# ).view(b, key.size(1), self.num_experts, -1)
|
324
|
+
# # Get routing indices and weights for K
|
325
|
+
# weights_k, indices_k = self.router(key_flat)
|
326
|
+
# indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
|
327
|
+
# weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
|
95
328
|
#
|
96
|
-
#
|
97
|
-
#
|
98
|
-
#
|
99
|
-
#
|
100
|
-
#
|
329
|
+
# # Select and compute K projections for only the top_k experts
|
330
|
+
# selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
|
331
|
+
# k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
|
332
|
+
# selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
|
333
|
+
# selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
|
101
334
|
#
|
102
|
-
# #
|
103
|
-
#
|
104
|
-
#
|
105
|
-
#
|
106
|
-
#
|
107
|
-
# )
|
108
|
-
# selected_v = torch.gather(
|
109
|
-
# v_all,
|
110
|
-
# 2,
|
111
|
-
# indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
|
112
|
-
# )
|
335
|
+
# # Compute V using the same indices as K (since they share the same router)
|
336
|
+
# selected_v_weights = self.v_experts[indices_k]
|
337
|
+
# v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
|
338
|
+
# selected_v = (v_proj * weights_k).sum(dim=1)
|
339
|
+
# selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
|
113
340
|
#
|
114
|
-
# selected_k = (selected_k * weights).sum(dim=2)
|
115
|
-
# selected_v = (selected_v * weights).sum(dim=2)
|
116
341
|
# # Reshape to GQA format: (B, G, S, head_dim)
|
117
342
|
# k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
118
343
|
# v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
@@ -127,32 +352,46 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
127
352
|
# v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
128
353
|
#
|
129
354
|
# return q, k, v
|
355
|
+
|
130
356
|
head_dim = d // self.num_heads
|
131
357
|
|
132
358
|
# Process Query as in GQA
|
133
|
-
q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
|
359
|
+
q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
|
134
360
|
|
135
361
|
# Process Key and Value with MoE routing
|
136
|
-
key_flat = key.view(-1, d)
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
362
|
+
key_flat = key.view(-1, d)
|
363
|
+
weights, indices = self.router(key_flat)
|
364
|
+
weights = weights.view(b, key.size(1), self.num_groups, 1)
|
365
|
+
indices = indices.view(b, key.size(1), self.num_groups)
|
366
|
+
|
367
|
+
# Compute all experts' K and V projections
|
368
|
+
# Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
|
369
|
+
k_all = torch.einsum(
|
370
|
+
'be, ehd -> bedh',
|
371
|
+
key_flat,
|
372
|
+
self.wk.view(self.num_experts, d, -1)
|
373
|
+
).view(b, key.size(1), self.num_experts, -1)
|
374
|
+
|
375
|
+
v_all = torch.einsum(
|
376
|
+
'be, ehd -> bedh',
|
377
|
+
value.view(-1, d),
|
378
|
+
self.wv.view(self.num_experts, d, -1)
|
379
|
+
).view(b, value.size(1), self.num_experts, -1)
|
380
|
+
|
381
|
+
# Select top_k experts and compute weighted sum
|
382
|
+
selected_k = torch.gather(
|
383
|
+
k_all,
|
384
|
+
2,
|
385
|
+
indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
386
|
+
)
|
387
|
+
selected_v = torch.gather(
|
388
|
+
v_all,
|
389
|
+
2,
|
390
|
+
indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
|
391
|
+
)
|
155
392
|
|
393
|
+
selected_k = (selected_k * weights).sum(dim=2)
|
394
|
+
selected_v = (selected_v * weights).sum(dim=2)
|
156
395
|
# Reshape to GQA format: (B, G, S, head_dim)
|
157
396
|
k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
158
397
|
v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
@@ -168,15 +407,26 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
168
407
|
|
169
408
|
return q, k, v
|
170
409
|
|
171
|
-
|
410
|
+
|
411
|
+
class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
172
412
|
"""
|
173
|
-
|
413
|
+
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
414
|
+
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
415
|
+
experts - it has to be tested.
|
416
|
+
|
417
|
+
Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
418
|
+
|
174
419
|
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
175
420
|
query heads - with that approach, each token could attend to every other token, but only partially - only some part of
|
176
|
-
information from each token is used to identify related information parts from other tokens.
|
421
|
+
information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
|
422
|
+
sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
|
177
423
|
|
178
|
-
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N))
|
424
|
+
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
|
425
|
+
a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
|
426
|
+
|
427
|
+
© 2025 Adam Filipek
|
179
428
|
"""
|
429
|
+
|
180
430
|
def __init__(
|
181
431
|
self,
|
182
432
|
embed_dim: int,
|
@@ -192,13 +442,13 @@ class SparseMoeAttention(GroupedMoeAttention):
|
|
192
442
|
use_bias: bool = False,
|
193
443
|
num_experts: int = None,
|
194
444
|
num_query_experts: int = None,
|
195
|
-
|
445
|
+
num_query_groups: int = None,
|
196
446
|
*args,
|
197
447
|
**kwargs,
|
198
448
|
):
|
199
449
|
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
200
|
-
self.
|
201
|
-
super(
|
450
|
+
self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
|
451
|
+
super(DeepMoeAttentionVectorized, self).__init__(
|
202
452
|
embed_dim,
|
203
453
|
num_heads,
|
204
454
|
num_groups=num_groups,
|
@@ -216,8 +466,8 @@ class SparseMoeAttention(GroupedMoeAttention):
|
|
216
466
|
)
|
217
467
|
|
218
468
|
def _init_q(self, embed_dim: int):
|
219
|
-
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.
|
220
|
-
hidden_dim = embed_dim //
|
469
|
+
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
470
|
+
hidden_dim = embed_dim // self.num_heads
|
221
471
|
self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
222
472
|
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
223
473
|
self._init_query_experts()
|
@@ -227,20 +477,47 @@ class SparseMoeAttention(GroupedMoeAttention):
|
|
227
477
|
if self.use_bias:
|
228
478
|
torch.nn.init.zeros_(self.bq)
|
229
479
|
|
480
|
+
def _init_out(self, embed_dim: int):
|
481
|
+
"""Initialize output projection"""
|
482
|
+
self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
|
483
|
+
|
230
484
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
485
|
+
# Indexed version may cause memory overflow
|
486
|
+
#
|
487
|
+
# head_dim = d // self.num_heads
|
488
|
+
#
|
489
|
+
# # Process Query with MoE routing
|
490
|
+
# query_flat = query.view(-1, d) # (B*T, d)
|
491
|
+
# weights_q, indices_q = self.query_router(query_flat)
|
492
|
+
# indices_q = indices_q.view(-1, self.num_query_groups) # (B*T, top_k_q)
|
493
|
+
# weights_q = weights_q.view(-1, self.num_query_groups, 1) # (B*T, top_k_q, 1)
|
494
|
+
#
|
495
|
+
# # Select and compute Q projections for top_k experts
|
496
|
+
# selected_q_weights = self.wq[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
|
497
|
+
# q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
|
498
|
+
# selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
|
499
|
+
# selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
|
231
500
|
head_dim = d // self.num_heads
|
232
501
|
|
233
502
|
# Process Query with MoE routing
|
234
|
-
query_flat = query.view(
|
235
|
-
weights_q, indices_q = self.
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
#
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
503
|
+
query_flat = query.view(b * t, d)
|
504
|
+
weights_q, indices_q = self.query_router(query_flat)
|
505
|
+
weights_q = weights_q.view(b, t, self.num_query_groups, 1)
|
506
|
+
indices_q = indices_q.view(b, t, self.num_query_groups)
|
507
|
+
|
508
|
+
# Compute all experts' Q projections
|
509
|
+
q_all = torch.einsum(
|
510
|
+
'be, ehd -> bedh',
|
511
|
+
query_flat,
|
512
|
+
self.wq.view(self.num_query_experts, d, -1)
|
513
|
+
).view(b, t, self.num_query_experts, -1)
|
514
|
+
|
515
|
+
selected_q = torch.gather(
|
516
|
+
q_all,
|
517
|
+
2,
|
518
|
+
indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.shape[-1])
|
519
|
+
)
|
520
|
+
selected_q = (selected_q * weights_q).sum(dim=2)
|
244
521
|
|
245
522
|
q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
|
246
523
|
|
@@ -251,12 +528,12 @@ class SparseMoeAttention(GroupedMoeAttention):
|
|
251
528
|
|
252
529
|
class FlexAttention(MultiHeadAttention):
|
253
530
|
def __init__(
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
531
|
+
self,
|
532
|
+
embed_dim: int,
|
533
|
+
num_heads: int,
|
534
|
+
num_global_tokens: int = 16,
|
535
|
+
window_size: int = 128,
|
536
|
+
**kwargs
|
260
537
|
):
|
261
538
|
super().__init__(embed_dim, num_heads, **kwargs)
|
262
539
|
self.num_global_tokens = num_global_tokens
|
@@ -319,14 +596,15 @@ class FlexAttention(MultiHeadAttention):
|
|
319
596
|
output = self._calculate_output(combined_attn, v, b, t, d)
|
320
597
|
return self.out_proj(output)
|
321
598
|
|
599
|
+
|
322
600
|
class InfiniteAttention(MultiHeadAttention):
|
323
601
|
def __init__(
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
602
|
+
self,
|
603
|
+
embed_dim: int,
|
604
|
+
num_heads: int,
|
605
|
+
kernel_size: int = 128,
|
606
|
+
use_rotary: bool = True,
|
607
|
+
**kwargs
|
330
608
|
):
|
331
609
|
super().__init__(embed_dim, num_heads, **kwargs)
|
332
610
|
self.kernel_size = kernel_size
|
@@ -377,4 +655,89 @@ class InfiniteAttention(MultiHeadAttention):
|
|
377
655
|
q = q / (q.shape[-1] ** 0.5)
|
378
656
|
attn = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
379
657
|
attn = torch.softmax(attn, dim=-1)
|
380
|
-
return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
658
|
+
return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
659
|
+
|
660
|
+
def init_moe_attention(
|
661
|
+
embed_dim: int,
|
662
|
+
num_heads: int,
|
663
|
+
attention_type: str,
|
664
|
+
gqa_groups: int = 1,
|
665
|
+
dropout: float = 0.0,
|
666
|
+
rope: RotaryPositionalEmbedding = None,
|
667
|
+
rope_only_for_query: bool = False,
|
668
|
+
use_relative_embeddings: bool = False,
|
669
|
+
max_seq_len: int = 1024,
|
670
|
+
use_flash_attention: bool = False,
|
671
|
+
is_causal: bool = False,
|
672
|
+
use_bias: bool = False,
|
673
|
+
num_experts: int = None,
|
674
|
+
num_query_experts: int = None,
|
675
|
+
num_query_groups: int = None,
|
676
|
+
) -> GroupedQueryAttention:
|
677
|
+
assert attention_type == 'gma' or attention_type == 'dma' or attention_type == 'gma_v' or attention_type == 'dma_v', \
|
678
|
+
"Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
|
679
|
+
|
680
|
+
if attention_type == "gma":
|
681
|
+
return GroupedMoeAttention(
|
682
|
+
embed_dim,
|
683
|
+
num_heads,
|
684
|
+
gqa_groups,
|
685
|
+
dropout=dropout,
|
686
|
+
rope=rope,
|
687
|
+
use_relative_embeddings=use_relative_embeddings,
|
688
|
+
max_seq_len=max_seq_len,
|
689
|
+
rope_only_for_query=rope_only_for_query,
|
690
|
+
use_flash_attention=use_flash_attention,
|
691
|
+
is_causal=is_causal,
|
692
|
+
use_bias=use_bias,
|
693
|
+
num_experts=num_experts,
|
694
|
+
)
|
695
|
+
elif attention_type == "dma":
|
696
|
+
return DeepMoeAttention(
|
697
|
+
embed_dim,
|
698
|
+
num_heads,
|
699
|
+
gqa_groups,
|
700
|
+
dropout=dropout,
|
701
|
+
rope=rope,
|
702
|
+
use_relative_embeddings=use_relative_embeddings,
|
703
|
+
max_seq_len=max_seq_len,
|
704
|
+
rope_only_for_query=rope_only_for_query,
|
705
|
+
use_flash_attention=use_flash_attention,
|
706
|
+
is_causal=is_causal,
|
707
|
+
use_bias=use_bias,
|
708
|
+
num_experts=num_experts,
|
709
|
+
num_query_experts=num_query_experts,
|
710
|
+
num_query_groups=num_query_groups,
|
711
|
+
)
|
712
|
+
elif attention_type == "gma_v":
|
713
|
+
return GroupedMoeAttentionVectorized(
|
714
|
+
embed_dim,
|
715
|
+
num_heads,
|
716
|
+
gqa_groups,
|
717
|
+
dropout=dropout,
|
718
|
+
rope=rope,
|
719
|
+
use_relative_embeddings=use_relative_embeddings,
|
720
|
+
max_seq_len=max_seq_len,
|
721
|
+
rope_only_for_query=rope_only_for_query,
|
722
|
+
use_flash_attention=use_flash_attention,
|
723
|
+
is_causal=is_causal,
|
724
|
+
use_bias=use_bias,
|
725
|
+
num_experts=num_experts,
|
726
|
+
)
|
727
|
+
else:
|
728
|
+
return DeepMoeAttentionVectorized(
|
729
|
+
embed_dim,
|
730
|
+
num_heads,
|
731
|
+
gqa_groups,
|
732
|
+
dropout=dropout,
|
733
|
+
rope=rope,
|
734
|
+
use_relative_embeddings=use_relative_embeddings,
|
735
|
+
max_seq_len=max_seq_len,
|
736
|
+
rope_only_for_query=rope_only_for_query,
|
737
|
+
use_flash_attention=use_flash_attention,
|
738
|
+
is_causal=is_causal,
|
739
|
+
use_bias=use_bias,
|
740
|
+
num_experts=num_experts,
|
741
|
+
num_query_experts=num_query_experts,
|
742
|
+
num_query_groups=num_query_groups,
|
743
|
+
)
|
@@ -0,0 +1,116 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
from typing import TypedDict, Union
|
4
|
+
from huggingface_hub import PyTorchModelHubMixin
|
5
|
+
from ..transformers.positional import RotaryPositionalEmbedding
|
6
|
+
from ..transformers.attention import init_attention
|
7
|
+
from ..transformers.layers import ClassicTransformerLayer
|
8
|
+
from ..transformers.models import ClassicTransformerDecoder
|
9
|
+
from ..transformers.ff import get_activation_layer
|
10
|
+
from ..utils import get_model_size
|
11
|
+
from .attention import init_moe_attention
|
12
|
+
|
13
|
+
|
14
|
+
class MoeAttentionTransformerConfig(TypedDict):
|
15
|
+
num_layers: int
|
16
|
+
vocab_size: int
|
17
|
+
embed_dim: int
|
18
|
+
ff_dim: int
|
19
|
+
att_heads: int
|
20
|
+
seq_len: int
|
21
|
+
use_flash_attention: bool
|
22
|
+
use_gated: bool
|
23
|
+
ff_activation: str
|
24
|
+
ff_dropout: float
|
25
|
+
att_dropout: float
|
26
|
+
use_rms_norm: bool
|
27
|
+
att_groups: int
|
28
|
+
use_moe_ff: bool
|
29
|
+
ff_num_experts: int
|
30
|
+
ff_moe_top_k: int
|
31
|
+
att_type: str
|
32
|
+
att_num_experts: int
|
33
|
+
att_num_query_experts: int
|
34
|
+
att_num_query_groups: int
|
35
|
+
|
36
|
+
|
37
|
+
class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
|
38
|
+
"""Research model for experiments with Mixture-of-Experts Attention"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
num_layers: int = 6,
|
43
|
+
vocab_size: int = 5000,
|
44
|
+
embed_dim: int = 128,
|
45
|
+
ff_dim: int = 384,
|
46
|
+
att_heads: int = 16,
|
47
|
+
seq_len: int = 256,
|
48
|
+
use_flash_attention: bool = True,
|
49
|
+
use_gated: bool = True,
|
50
|
+
ff_activation: str = "swish",
|
51
|
+
ff_dropout: float = 0.0,
|
52
|
+
att_dropout: float = 0.0,
|
53
|
+
use_rms_norm: bool = True,
|
54
|
+
att_groups: int = 1,
|
55
|
+
use_moe_ff: bool = False,
|
56
|
+
ff_num_experts: int = 1,
|
57
|
+
ff_moe_top_k: int = 1,
|
58
|
+
att_type: str = 'gma',
|
59
|
+
att_num_experts: int = None,
|
60
|
+
att_num_query_experts: int = None,
|
61
|
+
att_num_query_groups: int = None,
|
62
|
+
**kwargs
|
63
|
+
):
|
64
|
+
super(MoeAttentionTransformer, self).__init__(**kwargs)
|
65
|
+
assert ff_activation in ['relu', 'gelu',
|
66
|
+
'swish', 'silu', 'linear',
|
67
|
+
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
68
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v',
|
69
|
+
'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v"'
|
70
|
+
|
71
|
+
embedding = nn.Embedding(vocab_size, embed_dim)
|
72
|
+
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
73
|
+
|
74
|
+
ff_activation = get_activation_layer(ff_activation)
|
75
|
+
|
76
|
+
if att_type in ['mha', 'gqa', 'mqa']:
|
77
|
+
att_init = lambda: init_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
78
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
79
|
+
max_seq_len=seq_len, is_causal=True)
|
80
|
+
else:
|
81
|
+
att_init = lambda: init_moe_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
82
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
83
|
+
max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
|
84
|
+
num_query_experts=att_num_query_experts,
|
85
|
+
num_query_groups=att_num_query_groups)
|
86
|
+
|
87
|
+
self.model = ClassicTransformerDecoder(
|
88
|
+
embed_dim,
|
89
|
+
vocab_size,
|
90
|
+
embedding=embedding,
|
91
|
+
layers=nn.ModuleList([
|
92
|
+
ClassicTransformerLayer(
|
93
|
+
embed_dim,
|
94
|
+
ff_dim,
|
95
|
+
use_gated=use_gated,
|
96
|
+
use_moe=use_moe_ff,
|
97
|
+
num_experts=ff_num_experts,
|
98
|
+
moe_top_k=ff_moe_top_k,
|
99
|
+
ff_activation=ff_activation,
|
100
|
+
ff_dropout=ff_dropout,
|
101
|
+
use_rms_norm=use_rms_norm,
|
102
|
+
self_attention=att_init(),
|
103
|
+
) for _ in range(num_layers)
|
104
|
+
]),
|
105
|
+
use_flash_attention=use_flash_attention,
|
106
|
+
)
|
107
|
+
|
108
|
+
def params_count(self):
|
109
|
+
return get_model_size(self.model)
|
110
|
+
|
111
|
+
def load_shared_embedding(self, embedding: nn.Embedding):
|
112
|
+
self.model.embedding = embedding
|
113
|
+
|
114
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
|
115
|
+
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
116
|
+
return self.model(x, attention_mask=attention_mask)
|
rxnn/experimental/moe.py
ADDED
@@ -0,0 +1,206 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from ..transformers.moe import MoeRouter
|
5
|
+
|
6
|
+
class DynamicMoeRouter(nn.Module):
|
7
|
+
"""Dynamic Mixture-of-Experts Router layer - dynamically selects top-k experts for each token."""
|
8
|
+
|
9
|
+
def __init__(self, embed_dim: int, num_experts: int, top_ks: tuple[int] = (1, 2, 3), *args, **kwargs):
|
10
|
+
super(DynamicMoeRouter, self).__init__(*args, **kwargs)
|
11
|
+
self.top_ks = top_ks
|
12
|
+
self.num_options = len(top_ks)
|
13
|
+
self.num_experts = num_experts
|
14
|
+
self.gate = nn.Linear(embed_dim, num_experts + self.num_options, bias=False)
|
15
|
+
# For expert load balancing
|
16
|
+
self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
|
17
|
+
|
18
|
+
def calculate_aux_loss(self, top_k_indices: torch.Tensor, routing_probs: torch.Tensor) -> torch.Tensor:
|
19
|
+
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
20
|
+
expert_usage = expert_mask.sum(dim=0).mean(dim=0)
|
21
|
+
mean_probs = routing_probs.mean(dim=0)
|
22
|
+
return (expert_usage * mean_probs).sum() * self.num_experts
|
23
|
+
|
24
|
+
def forward(self, x: torch.Tensor):
|
25
|
+
# Input shape: [batch*seq_len, embed_dim]
|
26
|
+
all_logits = self.gate(x)
|
27
|
+
routing_logits = all_logits[:, :-self.num_options]
|
28
|
+
options_logits = all_logits[:, -self.num_options:]
|
29
|
+
|
30
|
+
routing_probs = F.softmax(routing_logits, dim=-1)
|
31
|
+
top_k_id = torch.argmax(options_logits, dim=-1).item()
|
32
|
+
|
33
|
+
top_k = self.top_ks[top_k_id]
|
34
|
+
|
35
|
+
# Get top-k experts for each token
|
36
|
+
top_k_weights, top_k_indices = routing_probs.topk(top_k, dim=-1)
|
37
|
+
|
38
|
+
# Normalize weights (sum to 1 for each token)
|
39
|
+
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
40
|
+
|
41
|
+
# Load Balance Loss
|
42
|
+
self.aux_loss = self.calculate_aux_loss(top_k_indices, routing_probs)
|
43
|
+
|
44
|
+
return top_k_weights, top_k_indices, top_k
|
45
|
+
|
46
|
+
class MoeFeedForwardVectorized(nn.Module):
|
47
|
+
"""
|
48
|
+
Vectorized MoE - current implementation is incorrect - it calculates all the experts, then selects the correct ones.
|
49
|
+
|
50
|
+
Commented out implementation is fixing this problem, but is causing memory overflows, because of experts weights
|
51
|
+
indexing - it's using ~15x more memory, than dense model of similar size, so it's currently not viable.
|
52
|
+
|
53
|
+
It's recommended to use standard MoE from rxnn.transformers.moe instead.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
embed_dim: int,
|
59
|
+
hidden_dim: int,
|
60
|
+
num_experts: int,
|
61
|
+
activation: nn.Module,
|
62
|
+
top_k: int = 1,
|
63
|
+
dropout: float = 0.0,
|
64
|
+
*args,
|
65
|
+
**kwargs
|
66
|
+
):
|
67
|
+
super(MoeFeedForwardVectorized, self).__init__(*args, **kwargs)
|
68
|
+
self.embed_dim = embed_dim
|
69
|
+
self.num_experts = num_experts
|
70
|
+
self.top_k = top_k
|
71
|
+
|
72
|
+
self.router = MoeRouter(embed_dim, num_experts, top_k)
|
73
|
+
|
74
|
+
# Batch all expert parameters together
|
75
|
+
self.w1 = nn.Parameter(torch.empty(num_experts, embed_dim, self._w1_dim_factor(hidden_dim)))
|
76
|
+
self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
|
77
|
+
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
|
78
|
+
self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
|
79
|
+
self.activation = activation
|
80
|
+
self.dropout = nn.Dropout(dropout)
|
81
|
+
|
82
|
+
# Initialize parameters
|
83
|
+
self._init_linear_parameters()
|
84
|
+
nn.init.zeros_(self.b1)
|
85
|
+
nn.init.zeros_(self.b2)
|
86
|
+
|
87
|
+
def _init_linear_parameters(self):
|
88
|
+
nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
|
89
|
+
nn.init.kaiming_normal_(self.w2, nonlinearity='relu')
|
90
|
+
|
91
|
+
def _w1_dim_factor(self, hidden_dim: int) -> int:
|
92
|
+
return hidden_dim
|
93
|
+
|
94
|
+
def _activate(self, h: torch.Tensor):
|
95
|
+
return self.activation(h)
|
96
|
+
|
97
|
+
def router_loss(self):
|
98
|
+
return self.router.aux_loss
|
99
|
+
|
100
|
+
def forward(self, x: torch.Tensor):
|
101
|
+
orig_shape = x.shape
|
102
|
+
x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
103
|
+
|
104
|
+
# Get routing weights and indices
|
105
|
+
weights, indices = self.router(x) # [batch*seq_len, top_k]
|
106
|
+
|
107
|
+
# Create expert masks and combine it with masks
|
108
|
+
mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
|
109
|
+
weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
|
110
|
+
|
111
|
+
# Expert computation
|
112
|
+
x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
|
113
|
+
|
114
|
+
# First linear layer
|
115
|
+
h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
|
116
|
+
h = self._activate(h)
|
117
|
+
h = self.dropout(h)
|
118
|
+
|
119
|
+
# Second linear layer (projection back to embed_dim)
|
120
|
+
out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
|
121
|
+
|
122
|
+
# Weighted sum of expert outputs
|
123
|
+
out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
|
124
|
+
|
125
|
+
return out.view(*orig_shape)
|
126
|
+
# orig_shape = x.shape
|
127
|
+
# x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
128
|
+
#
|
129
|
+
# # Get routing weights and indices
|
130
|
+
# weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
|
131
|
+
#
|
132
|
+
# # Flatten indices and weights
|
133
|
+
# batch_size = x.shape[0]
|
134
|
+
# top_k = indices.shape[1]
|
135
|
+
# indices_flat = indices.view(-1) # [B*T * top_k]
|
136
|
+
#
|
137
|
+
# # Compute contributions for selected experts without materializing large tensors
|
138
|
+
# # First Layer:
|
139
|
+
# # Compute all expert contributions first (but this may still be memory-heavy)
|
140
|
+
# # Alternative: Compute contributions for selected experts directly
|
141
|
+
# # ... (see detailed steps below)
|
142
|
+
#
|
143
|
+
# # Alternative approach using gather and batched operations
|
144
|
+
# x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [B*T*top_k, D]
|
145
|
+
#
|
146
|
+
# # Compute first layer contributions using gather
|
147
|
+
# # indices_flat has shape [B*T*top_k]
|
148
|
+
# # selected_w1 is self.w1[indices_flat], but we compute the product inline
|
149
|
+
# h = torch.einsum(
|
150
|
+
# 'be, eih -> bh',
|
151
|
+
# x_expanded,
|
152
|
+
# self.w1[indices_flat]
|
153
|
+
# ) + self.b1[indices_flat]
|
154
|
+
# h = self._activate(h)
|
155
|
+
# h = self.dropout(h)
|
156
|
+
#
|
157
|
+
# # Second layer:
|
158
|
+
# out = torch.einsum(
|
159
|
+
# 'bh, eho -> beo',
|
160
|
+
# h,
|
161
|
+
# self.w2[indices_flat]
|
162
|
+
# ).squeeze(-1) + self.b2[indices_flat]
|
163
|
+
#
|
164
|
+
# # Reshape and apply weights
|
165
|
+
# out = out.view(batch_size, top_k, -1)
|
166
|
+
# weights = weights.view(batch_size, top_k, 1)
|
167
|
+
# out = (out * weights).sum(dim=1)
|
168
|
+
#
|
169
|
+
# return out.view(*orig_shape)
|
170
|
+
|
171
|
+
|
172
|
+
class GatedMoeFeedForwardVectorized(MoeFeedForwardVectorized):
|
173
|
+
"""Gated Mixture-of-Experts Feed-Forward layer - enable GLU-based activations for MoE"""
|
174
|
+
|
175
|
+
def __init__(
|
176
|
+
self,
|
177
|
+
embed_dim: int,
|
178
|
+
hidden_dim: int,
|
179
|
+
num_experts: int,
|
180
|
+
activation: nn.Module = nn.SiLU(),
|
181
|
+
top_k: int = 1,
|
182
|
+
dropout: float = 0.1,
|
183
|
+
*args,
|
184
|
+
**kwargs
|
185
|
+
):
|
186
|
+
super(GatedMoeFeedForwardVectorized, self).__init__(
|
187
|
+
embed_dim=embed_dim,
|
188
|
+
hidden_dim=hidden_dim,
|
189
|
+
num_experts=num_experts,
|
190
|
+
activation=activation,
|
191
|
+
top_k=top_k,
|
192
|
+
dropout=dropout,
|
193
|
+
*args,
|
194
|
+
**kwargs
|
195
|
+
)
|
196
|
+
|
197
|
+
def _init_linear_parameters(self):
|
198
|
+
nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
|
199
|
+
nn.init.kaiming_normal_(self.w2, nonlinearity='linear')
|
200
|
+
|
201
|
+
def _w1_dim_factor(self, hidden_dim: int) -> int:
|
202
|
+
return 2 * hidden_dim
|
203
|
+
|
204
|
+
def _activate(self, h: torch.Tensor):
|
205
|
+
a, b = h.chunk(2, dim=-1)
|
206
|
+
return a * self.activation(b)
|
rxnn/transformers/moe.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
-
|
4
|
+
from .ff import FeedForward, GatedFeedForward
|
5
5
|
|
6
6
|
class MoeRouter(nn.Module):
|
7
7
|
"""Mixture-of-Experts Router layer - computes routing weights for each expert."""
|
@@ -14,18 +14,27 @@ class MoeRouter(nn.Module):
|
|
14
14
|
# For expert load balancing
|
15
15
|
self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
|
16
16
|
|
17
|
+
def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
|
18
|
+
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
19
|
+
expert_usage = expert_mask.sum(dim=0).mean(dim=0)
|
20
|
+
mean_probs = probs.mean(dim=0)
|
21
|
+
return (expert_usage * mean_probs).sum() * self.num_experts
|
22
|
+
|
23
|
+
|
17
24
|
def forward(self, x: torch.Tensor):
|
18
|
-
#
|
25
|
+
# Input shape: [batch*seq_len, embed_dim]
|
19
26
|
logits = self.gate(x)
|
20
27
|
probs = F.softmax(logits, dim=-1)
|
21
28
|
|
22
|
-
#
|
23
|
-
mean_probs = probs.mean(dim=0) # Mean probability per expert across batch
|
24
|
-
self.aux_loss = (mean_probs * torch.log(mean_probs + 1e-9)).sum() # Entropy-based loss
|
25
|
-
|
29
|
+
# Get top-k experts for each token
|
26
30
|
top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
|
31
|
+
|
32
|
+
# Normalize weights (sum to 1 for each token)
|
27
33
|
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
28
34
|
|
35
|
+
# Load Balance Loss
|
36
|
+
self.aux_loss = self.calculate_aux_loss(top_k_indices, probs)
|
37
|
+
|
29
38
|
return top_k_weights, top_k_indices
|
30
39
|
|
31
40
|
|
@@ -51,101 +60,43 @@ class MoeFeedForward(nn.Module):
|
|
51
60
|
self.router = MoeRouter(embed_dim, num_experts, top_k)
|
52
61
|
|
53
62
|
# Batch all expert parameters together
|
54
|
-
self.
|
55
|
-
self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
|
56
|
-
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
|
57
|
-
self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
|
58
|
-
self.activation = activation
|
59
|
-
self.dropout = nn.Dropout(dropout)
|
60
|
-
|
61
|
-
# Initialize parameters
|
62
|
-
self._init_linear_parameters()
|
63
|
-
nn.init.zeros_(self.b1)
|
64
|
-
nn.init.zeros_(self.b2)
|
65
|
-
|
66
|
-
def _init_linear_parameters(self):
|
67
|
-
nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
|
68
|
-
nn.init.kaiming_normal_(self.w2, nonlinearity='relu')
|
63
|
+
self._init_experts(num_experts, embed_dim, hidden_dim, activation, dropout)
|
69
64
|
|
70
|
-
def
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
65
|
+
def _init_experts(self, num_experts: int, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float):
|
66
|
+
self.experts = nn.ModuleList([
|
67
|
+
FeedForward(embed_dim, hidden_dim, activation, dropout)
|
68
|
+
for _ in range(num_experts)
|
69
|
+
])
|
75
70
|
|
76
71
|
def router_loss(self):
|
77
72
|
return self.router.aux_loss
|
78
73
|
|
79
74
|
def forward(self, x: torch.Tensor):
|
80
|
-
# orig_shape = x.shape
|
81
|
-
# x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
82
|
-
#
|
83
|
-
# # Get routing weights and indices
|
84
|
-
# weights, indices = self.router(x) # [batch*seq_len, top_k]
|
85
|
-
#
|
86
|
-
# # Create expert masks and combine it with masks
|
87
|
-
# mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
|
88
|
-
# weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
|
89
|
-
#
|
90
|
-
# # Expert computation
|
91
|
-
# x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
|
92
|
-
#
|
93
|
-
# # First linear layer
|
94
|
-
# h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
|
95
|
-
# h = self._activate(h)
|
96
|
-
# h = self.dropout(h)
|
97
|
-
#
|
98
|
-
# # Second linear layer (projection back to embed_dim)
|
99
|
-
# out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
|
100
|
-
#
|
101
|
-
# # Weighted sum of expert outputs
|
102
|
-
# out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
|
103
|
-
#
|
104
|
-
# return out.view(*orig_shape)
|
105
75
|
orig_shape = x.shape
|
106
76
|
x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
107
77
|
|
108
78
|
# Get routing weights and indices
|
109
79
|
weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
|
110
80
|
|
111
|
-
#
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
x_expanded,
|
131
|
-
self.w1[indices_flat]
|
132
|
-
) + self.b1[indices_flat]
|
133
|
-
h = self._activate(h)
|
134
|
-
h = self.dropout(h)
|
135
|
-
|
136
|
-
# Second layer:
|
137
|
-
out = torch.einsum(
|
138
|
-
'bh, eho -> beo',
|
139
|
-
h,
|
140
|
-
self.w2[indices_flat]
|
141
|
-
).squeeze(-1) + self.b2[indices_flat]
|
142
|
-
|
143
|
-
# Reshape and apply weights
|
144
|
-
out = out.view(batch_size, top_k, -1)
|
145
|
-
weights = weights.view(batch_size, top_k, 1)
|
146
|
-
out = (out * weights).sum(dim=1)
|
147
|
-
|
148
|
-
return out.view(*orig_shape)
|
81
|
+
# Create mask for expert contributions (B*T, num_experts)
|
82
|
+
expert_mask = F.one_hot(indices, self.num_experts).float() # [B*T, top_k, num_experts]
|
83
|
+
expert_weights = (weights.unsqueeze(-1) * expert_mask).sum(dim=1) # [B*T, num_experts]
|
84
|
+
|
85
|
+
output = torch.zeros_like(x)
|
86
|
+
for expert_idx in range(self.num_experts):
|
87
|
+
# Mask for tokens where this expert is in top_k
|
88
|
+
mask = expert_weights[:, expert_idx] > 0
|
89
|
+
if not mask.any():
|
90
|
+
continue
|
91
|
+
|
92
|
+
# Compute expert output for selected tokens
|
93
|
+
expert_input = x[mask]
|
94
|
+
expert_output = self.experts[expert_idx](expert_input)
|
95
|
+
|
96
|
+
# Apply combined weights for this expert
|
97
|
+
output[mask] += expert_output * expert_weights[mask, expert_idx].unsqueeze(-1)
|
98
|
+
|
99
|
+
return output.view(*orig_shape)
|
149
100
|
|
150
101
|
|
151
102
|
class GatedMoeFeedForward(MoeFeedForward):
|
@@ -173,13 +124,8 @@ class GatedMoeFeedForward(MoeFeedForward):
|
|
173
124
|
**kwargs
|
174
125
|
)
|
175
126
|
|
176
|
-
def
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
return 2 * hidden_dim
|
182
|
-
|
183
|
-
def _activate(self, h: torch.Tensor):
|
184
|
-
a, b = h.chunk(2, dim=-1)
|
185
|
-
return a * self.activation(b)
|
127
|
+
def _init_experts(self, num_experts: int, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float):
|
128
|
+
self.experts = nn.ModuleList([
|
129
|
+
GatedFeedForward(embed_dim, hidden_dim, activation, dropout)
|
130
|
+
for _ in range(num_experts)
|
131
|
+
])
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.17
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -53,6 +53,29 @@ that's generating Infinite Chain-of-Thoughts and is communicating in push-based
|
|
53
53
|
Reactive communication patterns in RxNN models are adapted to handle asynchronous nature of model - after it finish generating
|
54
54
|
sequence, it has to process it and save it in memory, but it could be done in background.
|
55
55
|
|
56
|
+
## Release plan
|
57
|
+
We are working on three new reactive architectures, that progressively advance from language models to awareness models:
|
58
|
+
- Reactive Transformer: Reactive Language Model (RLM) with Short-Term Memory
|
59
|
+
- Preactor: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
|
60
|
+
single message length is limited) and the ability to learn from interactions (Live Learning)
|
61
|
+
- Reactor: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
|
62
|
+
|
63
|
+
Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
|
64
|
+
released with next versions of **RxNN** framework:
|
65
|
+
- 0.1.x: Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
|
66
|
+
- 0.2.x: Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
|
67
|
+
- 0.3.x: Reinforcement Learning from Human Feedback for Reactive models (RxRLHF), basic Tensor Reactive
|
68
|
+
Extensions (TRX/Rust) for full Reactive Transformer, RxT-Alpha release (+following models - RxT-Beta, etc.)
|
69
|
+
- 0.4.x: Preactor base models, Tensor Database (TDB/Rust) for Long-Term Memory, mxRAG/revRAG subsystems
|
70
|
+
- 0.5.x: MRL for Long-Term Memory & Preactor, Live Learning for Preactor, PRx-Alpha release (+following models - PRx-Beta, etc.)
|
71
|
+
- 0.6.x: Reactor base models, TRX full implementation, Receptors & Effectors Reactive RNNs
|
72
|
+
- 0.7.x: Behavioral Reinforcement Learning (BRL) for Reactor's Infinite Chain-of-Thoughts, Continuous Live Learning for Reactor
|
73
|
+
- 0.8.x: Rx-Alpha release
|
74
|
+
- 0.9.x: Rx-Beta release
|
75
|
+
- 1.0.0: Reactor AGI official release (Expert, Assistant & Utility class models)
|
76
|
+
- 1.x.x: Multimodal reactive models (could be released earlier, depending on progress)
|
77
|
+
- 2.0.0: Real-Time Vision Reactor - Worker class models
|
78
|
+
- x.x.x: ...and more!
|
56
79
|
Apache License
|
57
80
|
Version 2.0, January 2004
|
58
81
|
http://www.apache.org/licenses/
|
@@ -1,6 +1,8 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=wjHrxfov3Ybg3iou8FlQtFvxNuHdcs_A7a6FTloosgA,32056
|
4
|
+
rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
|
5
|
+
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
4
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
6
8
|
rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
@@ -19,11 +21,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
|
19
21
|
rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
|
20
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
21
23
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
22
|
-
rxnn/transformers/moe.py,sha256=
|
24
|
+
rxnn/transformers/moe.py,sha256=FeaQR7hTX1dE74YdMOcuyZHSkGiV_0JwF8fw-GnfNOQ,4741
|
23
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
24
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
25
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
26
|
-
rxnn-0.1.
|
27
|
-
rxnn-0.1.
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.17.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.17.dist-info/METADATA,sha256=wId6o7JCcBjRD1plWzgJRmFAY5VlHN7-FIVySeVDqx8,16627
|
30
|
+
rxnn-0.1.17.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|