rxnn 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  import torch
2
- from torch import nn
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
3
4
  from ..transformers.attention import MultiHeadAttention, GroupedQueryAttention
4
5
  from ..transformers.positional import RotaryPositionalEmbedding
5
6
  from ..transformers.moe import MoeRouter
@@ -9,6 +10,7 @@ from ..transformers.moe import MoeRouter
9
10
  class GroupedMoeAttention(GroupedQueryAttention):
10
11
  """
11
12
  Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
13
+
12
14
  Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
13
15
  number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
14
16
  - with num_groups set to 1, it will be MoE MultiQueryAttention
@@ -20,8 +22,11 @@ class GroupedMoeAttention(GroupedQueryAttention):
20
22
 
21
23
  Optionally, it could use even more expert heads than attention heads - in example:
22
24
  - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
23
- 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
25
+ 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
26
+
27
+ © 2025 Adam Filipek
24
28
  """
29
+
25
30
  def __init__(
26
31
  self,
27
32
  embed_dim: int,
@@ -39,7 +44,7 @@ class GroupedMoeAttention(GroupedQueryAttention):
39
44
  *args,
40
45
  **kwargs,
41
46
  ):
42
- self.num_experts = num_experts if num_experts is not None else num_heads
47
+ self.num_experts = num_experts or num_heads
43
48
  super(GroupedMoeAttention, self).__init__(
44
49
  embed_dim,
45
50
  num_heads,
@@ -58,7 +63,228 @@ class GroupedMoeAttention(GroupedQueryAttention):
58
63
 
59
64
  def _init_kv(self, embed_dim: int):
60
65
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
61
- hidden_dim = embed_dim // (self.num_heads // self.num_groups)
66
+
67
+ hidden_dim = embed_dim // self.num_heads
68
+ self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
69
+ self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
70
+ self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
71
+ self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
72
+ self._init_experts()
73
+
74
+ def _init_experts(self):
75
+ nn.init.xavier_uniform_(self.wk)
76
+ nn.init.xavier_uniform_(self.wv)
77
+ if self.use_bias:
78
+ nn.init.zeros_(self.bk)
79
+ nn.init.zeros_(self.bv)
80
+
81
+ def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
82
+ B, S, G = indices.shape
83
+ x_flat = x.view(-1, x.size(-1))
84
+
85
+ # Flatten batch and sequence dimensions
86
+ indices_flat = indices.view(-1, G)
87
+ weights_flat = weights.view(-1, G, 1)
88
+
89
+ # Create expanded indices for expert processing
90
+ mask = torch.zeros(B * S, self.num_experts, device=x.device, dtype=torch.bool)
91
+ for g in range(G):
92
+ mask.scatter_(1, indices_flat[:, g].unsqueeze(1), True)
93
+
94
+ output = torch.zeros(B * S, G, w.size(2), device=x.device, dtype=x.dtype)
95
+
96
+ for e in range(self.num_experts):
97
+ token_mask = mask[:, e]
98
+ if not token_mask.any():
99
+ continue
100
+
101
+ # Get positions where expert e is used in any group
102
+ x_slice = x_flat[token_mask]
103
+ proj = F.linear(x_slice, w[e], b[e] if b is not None else None)
104
+
105
+ # Find which groups use this expert for selected tokens
106
+ group_mask = (indices_flat[token_mask] == e)
107
+
108
+ # Accumulate projections for relevant groups
109
+ weighted_proj = proj.unsqueeze(1) * weights_flat[token_mask] * group_mask.unsqueeze(-1).float()
110
+ output[token_mask] += weighted_proj.sum(dim=1)
111
+
112
+ return output.view(B, S, G, -1)
113
+
114
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
115
+ skip_query_processing: bool = False):
116
+ q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
117
+
118
+ # Key/Value processing
119
+ B, S, _ = key.shape
120
+ weights_k, indices_k = self.router(key)
121
+ k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
122
+ v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
123
+
124
+ # Expand to GQA format
125
+ k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
126
+ v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
127
+
128
+ if not self.use_flash_attention:
129
+ group_heads = self.num_heads // self.num_groups
130
+
131
+ k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
132
+ v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
133
+
134
+ k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
135
+ v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
136
+
137
+ return q, k, v
138
+
139
+
140
+ class DeepMoeAttention(GroupedMoeAttention):
141
+ """
142
+ Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
143
+
144
+ In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
145
+ query heads - with that approach, each token could attend to every other token, but only partially - only some part of
146
+ information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
147
+ sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
148
+
149
+ This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
150
+ a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
151
+
152
+ © 2025 Adam Filipek
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ embed_dim: int,
158
+ num_heads: int,
159
+ num_groups: int,
160
+ dropout: float = 0.0,
161
+ rope: RotaryPositionalEmbedding = None,
162
+ rope_only_for_query: bool = False,
163
+ use_relative_embeddings: bool = False,
164
+ max_seq_len: int = 1024,
165
+ use_flash_attention: bool = False,
166
+ is_causal: bool = False,
167
+ use_bias: bool = False,
168
+ num_experts: int = None,
169
+ num_query_experts: int = None,
170
+ num_query_groups: int = None,
171
+ *args,
172
+ **kwargs,
173
+ ):
174
+ self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
175
+ self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
176
+ super(DeepMoeAttention, self).__init__(
177
+ embed_dim,
178
+ num_heads,
179
+ num_groups=num_groups,
180
+ dropout=dropout,
181
+ rope=rope,
182
+ rope_only_for_query=rope_only_for_query,
183
+ use_relative_embeddings=use_relative_embeddings,
184
+ max_seq_len=max_seq_len,
185
+ use_flash_attention=use_flash_attention,
186
+ is_causal=is_causal,
187
+ use_bias=use_bias,
188
+ num_experts=num_experts,
189
+ *args,
190
+ **kwargs,
191
+ )
192
+
193
+ def _init_q(self, embed_dim: int):
194
+ self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
195
+
196
+ hidden_dim = embed_dim // self.num_heads
197
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
198
+ self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
199
+ self._init_query_experts()
200
+
201
+ def _init_query_experts(self):
202
+ nn.init.xavier_uniform_(self.wq)
203
+ if self.use_bias:
204
+ nn.init.zeros_(self.bq)
205
+
206
+ def _init_out(self, embed_dim: int):
207
+ """Initialize output projection"""
208
+ hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
209
+ self.out_proj = nn.Linear(hidden_dim, embed_dim)
210
+
211
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
212
+ # Query processing
213
+ B, T, _ = query.shape
214
+ weights_q, indices_q = self.query_router(query)
215
+ q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
216
+ q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
217
+
218
+ # Expand query groups to match head count
219
+ group_heads = self.num_heads // self.num_query_groups
220
+ q = q.unsqueeze(2).expand(-1, -1, group_heads, -1, -1).flatten(1, 2).transpose(1, 2)
221
+
222
+ # Key/Value processing
223
+ return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
224
+
225
+ # Vectorized
226
+
227
+ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
228
+ """
229
+ Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
230
+ for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
231
+ experts - it has to be tested.
232
+
233
+ Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
234
+
235
+ Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
236
+ number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
237
+ - with num_groups set to 1, it will be MoE MultiQueryAttention
238
+
239
+ Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
240
+ this approach - we are training the full number of keys/values heads, while using only a group.
241
+
242
+ In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
243
+
244
+ Optionally, it could use even more expert heads than attention heads - in example:
245
+ - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
246
+ 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
247
+
248
+ © 2025 Adam Filipek
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ embed_dim: int,
254
+ num_heads: int,
255
+ num_groups: int,
256
+ dropout: float = 0.0,
257
+ rope: RotaryPositionalEmbedding = None,
258
+ rope_only_for_query: bool = False,
259
+ use_relative_embeddings: bool = False,
260
+ max_seq_len: int = 1024,
261
+ use_flash_attention: bool = False,
262
+ is_causal: bool = False,
263
+ use_bias: bool = False,
264
+ num_experts: int = None,
265
+ *args,
266
+ **kwargs,
267
+ ):
268
+ self.num_experts = num_experts if num_experts is not None else num_heads
269
+ super(GroupedMoeAttentionVectorized, self).__init__(
270
+ embed_dim,
271
+ num_heads,
272
+ num_groups=num_groups,
273
+ dropout=dropout,
274
+ rope=rope,
275
+ rope_only_for_query=rope_only_for_query,
276
+ use_relative_embeddings=use_relative_embeddings,
277
+ max_seq_len=max_seq_len,
278
+ use_flash_attention=use_flash_attention,
279
+ is_causal=is_causal,
280
+ use_bias=use_bias,
281
+ *args,
282
+ **kwargs,
283
+ )
284
+
285
+ def _init_kv(self, embed_dim: int):
286
+ self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
287
+ hidden_dim = embed_dim // self.num_heads
62
288
  self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
63
289
  self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
64
290
  self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
@@ -72,47 +298,37 @@ class GroupedMoeAttention(GroupedQueryAttention):
72
298
  torch.nn.init.zeros_(self.bk)
73
299
  torch.nn.init.zeros_(self.bv)
74
300
 
75
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
301
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
302
+ skip_query_processing: bool = False):
303
+ # Indexed version may cause memory overflow
304
+ #
76
305
  # head_dim = d // self.num_heads
77
- # group_heads = self.num_heads // self.num_groups
78
306
  #
79
307
  # # Process Query as in GQA
80
- # q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
308
+ # q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
309
+ # 2) if not skip_query_processing else query
81
310
  #
82
311
  # # Process Key and Value with MoE routing
83
- # key_flat = key.view(-1, d)
84
- # weights, indices = self.router(key_flat)
85
- # weights = weights.view(b, key.size(1), self.num_groups, 1)
86
- # indices = indices.view(b, key.size(1), self.num_groups)
312
+ # key_flat = key.view(-1, d) # (B*S, d)
313
+ # value_flat = value.view(-1, d) # (B*S, d)
87
314
  #
88
- # # Compute all experts' K and V projections
89
- # # Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
90
- # k_all = torch.einsum(
91
- # 'be, ehd -> bedh',
92
- # key_flat,
93
- # self.wk.view(self.num_experts, d, -1)
94
- # ).view(b, key.size(1), self.num_experts, -1)
315
+ # # Get routing indices and weights for K
316
+ # weights_k, indices_k = self.router(key_flat)
317
+ # indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
318
+ # weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
95
319
  #
96
- # v_all = torch.einsum(
97
- # 'be, ehd -> bedh',
98
- # value.view(-1, d),
99
- # self.wv.view(self.num_experts, d, -1)
100
- # ).view(b, value.size(1), self.num_experts, -1)
320
+ # # Select and compute K projections for only the top_k experts
321
+ # selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
322
+ # k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
323
+ # selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
324
+ # selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
101
325
  #
102
- # # Select top_k experts and compute weighted sum
103
- # selected_k = torch.gather(
104
- # k_all,
105
- # 2,
106
- # indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
107
- # )
108
- # selected_v = torch.gather(
109
- # v_all,
110
- # 2,
111
- # indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
112
- # )
326
+ # # Compute V using the same indices as K (since they share the same router)
327
+ # selected_v_weights = self.v_experts[indices_k]
328
+ # v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
329
+ # selected_v = (v_proj * weights_k).sum(dim=1)
330
+ # selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
113
331
  #
114
- # selected_k = (selected_k * weights).sum(dim=2)
115
- # selected_v = (selected_v * weights).sum(dim=2)
116
332
  # # Reshape to GQA format: (B, G, S, head_dim)
117
333
  # k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
118
334
  # v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
@@ -127,32 +343,46 @@ class GroupedMoeAttention(GroupedQueryAttention):
127
343
  # v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
128
344
  #
129
345
  # return q, k, v
346
+
130
347
  head_dim = d // self.num_heads
131
348
 
132
349
  # Process Query as in GQA
133
- q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2) if not skip_query_processing else query
350
+ q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
134
351
 
135
352
  # Process Key and Value with MoE routing
136
- key_flat = key.view(-1, d) # (B*S, d)
137
- value_flat = value.view(-1, d) # (B*S, d)
138
-
139
- # Get routing indices and weights for K
140
- weights_k, indices_k = self.router(key_flat)
141
- indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
142
- weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
143
-
144
- # Select and compute K projections for only the top_k experts
145
- selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
146
- k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
147
- selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
148
- selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
149
-
150
- # Compute V using the same indices as K (since they share the same router)
151
- selected_v_weights = self.v_experts[indices_k]
152
- v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
153
- selected_v = (v_proj * weights_k).sum(dim=1)
154
- selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
353
+ key_flat = key.view(-1, d)
354
+ weights, indices = self.router(key_flat)
355
+ weights = weights.view(b, key.size(1), self.num_groups, 1)
356
+ indices = indices.view(b, key.size(1), self.num_groups)
357
+
358
+ # Compute all experts' K and V projections
359
+ # Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
360
+ k_all = torch.einsum(
361
+ 'be, ehd -> bedh',
362
+ key_flat,
363
+ self.wk.view(self.num_experts, d, -1)
364
+ ).view(b, key.size(1), self.num_experts, -1)
365
+
366
+ v_all = torch.einsum(
367
+ 'be, ehd -> bedh',
368
+ value.view(-1, d),
369
+ self.wv.view(self.num_experts, d, -1)
370
+ ).view(b, value.size(1), self.num_experts, -1)
371
+
372
+ # Select top_k experts and compute weighted sum
373
+ selected_k = torch.gather(
374
+ k_all,
375
+ 2,
376
+ indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
377
+ )
378
+ selected_v = torch.gather(
379
+ v_all,
380
+ 2,
381
+ indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
382
+ )
155
383
 
384
+ selected_k = (selected_k * weights).sum(dim=2)
385
+ selected_v = (selected_v * weights).sum(dim=2)
156
386
  # Reshape to GQA format: (B, G, S, head_dim)
157
387
  k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
158
388
  v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
@@ -168,15 +398,26 @@ class GroupedMoeAttention(GroupedQueryAttention):
168
398
 
169
399
  return q, k, v
170
400
 
171
- class SparseMoeAttention(GroupedMoeAttention):
401
+
402
+ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
172
403
  """
173
- Sparse MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
404
+ Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
405
+ for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
406
+ experts - it has to be tested.
407
+
408
+ Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
409
+
174
410
  In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
175
411
  query heads - with that approach, each token could attend to every other token, but only partially - only some part of
176
- information from each token is used to identify related information parts from other tokens.
412
+ information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
413
+ sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
414
+
415
+ This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
416
+ a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
177
417
 
178
- This solution could reduce the computational complexity of attention operation to sublinear level (<O(N))
418
+ © 2025 Adam Filipek
179
419
  """
420
+
180
421
  def __init__(
181
422
  self,
182
423
  embed_dim: int,
@@ -192,13 +433,13 @@ class SparseMoeAttention(GroupedMoeAttention):
192
433
  use_bias: bool = False,
193
434
  num_experts: int = None,
194
435
  num_query_experts: int = None,
195
- num_active_query_heads: int = None,
436
+ num_query_groups: int = None,
196
437
  *args,
197
438
  **kwargs,
198
439
  ):
199
440
  self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
200
- self.num_active_query_heads = num_active_query_heads if num_active_query_heads is not None else num_groups
201
- super(SparseMoeAttention, self).__init__(
441
+ self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
442
+ super(DeepMoeAttentionVectorized, self).__init__(
202
443
  embed_dim,
203
444
  num_heads,
204
445
  num_groups=num_groups,
@@ -216,8 +457,8 @@ class SparseMoeAttention(GroupedMoeAttention):
216
457
  )
217
458
 
218
459
  def _init_q(self, embed_dim: int):
219
- self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_active_query_heads)
220
- hidden_dim = embed_dim // (self.num_heads // self.num_groups)
460
+ self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
461
+ hidden_dim = embed_dim // self.num_heads
221
462
  self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
222
463
  self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
223
464
  self._init_query_experts()
@@ -227,20 +468,47 @@ class SparseMoeAttention(GroupedMoeAttention):
227
468
  if self.use_bias:
228
469
  torch.nn.init.zeros_(self.bq)
229
470
 
471
+ def _init_out(self, embed_dim: int):
472
+ """Initialize output projection"""
473
+ self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
474
+
230
475
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
476
+ # Indexed version may cause memory overflow
477
+ #
478
+ # head_dim = d // self.num_heads
479
+ #
480
+ # # Process Query with MoE routing
481
+ # query_flat = query.view(-1, d) # (B*T, d)
482
+ # weights_q, indices_q = self.query_router(query_flat)
483
+ # indices_q = indices_q.view(-1, self.num_query_groups) # (B*T, top_k_q)
484
+ # weights_q = weights_q.view(-1, self.num_query_groups, 1) # (B*T, top_k_q, 1)
485
+ #
486
+ # # Select and compute Q projections for top_k experts
487
+ # selected_q_weights = self.wq[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
488
+ # q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
489
+ # selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
490
+ # selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
231
491
  head_dim = d // self.num_heads
232
492
 
233
493
  # Process Query with MoE routing
234
- query_flat = query.view(-1, d) # (B*T, d)
235
- weights_q, indices_q = self.router_q(query_flat)
236
- indices_q = indices_q.view(-1, self.top_k_q) # (B*T, top_k_q)
237
- weights_q = weights_q.view(-1, self.top_k_q, 1) # (B*T, top_k_q, 1)
238
-
239
- # Select and compute Q projections for top_k experts
240
- selected_q_weights = self.q_experts[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
241
- q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
242
- selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
243
- selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
494
+ query_flat = query.view(b * t, d)
495
+ weights_q, indices_q = self.query_router(query_flat)
496
+ weights_q = weights_q.view(b, t, self.num_query_groups, 1)
497
+ indices_q = indices_q.view(b, t, self.num_query_groups)
498
+
499
+ # Compute all experts' Q projections
500
+ q_all = torch.einsum(
501
+ 'be, ehd -> bedh',
502
+ query_flat,
503
+ self.wq.view(self.num_query_experts, d, -1)
504
+ ).view(b, t, self.num_query_experts, -1)
505
+
506
+ selected_q = torch.gather(
507
+ q_all,
508
+ 2,
509
+ indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.shape[-1])
510
+ )
511
+ selected_q = (selected_q * weights_q).sum(dim=2)
244
512
 
245
513
  q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
246
514
 
@@ -251,12 +519,12 @@ class SparseMoeAttention(GroupedMoeAttention):
251
519
 
252
520
  class FlexAttention(MultiHeadAttention):
253
521
  def __init__(
254
- self,
255
- embed_dim: int,
256
- num_heads: int,
257
- num_global_tokens: int = 16,
258
- window_size: int = 128,
259
- **kwargs
522
+ self,
523
+ embed_dim: int,
524
+ num_heads: int,
525
+ num_global_tokens: int = 16,
526
+ window_size: int = 128,
527
+ **kwargs
260
528
  ):
261
529
  super().__init__(embed_dim, num_heads, **kwargs)
262
530
  self.num_global_tokens = num_global_tokens
@@ -319,14 +587,15 @@ class FlexAttention(MultiHeadAttention):
319
587
  output = self._calculate_output(combined_attn, v, b, t, d)
320
588
  return self.out_proj(output)
321
589
 
590
+
322
591
  class InfiniteAttention(MultiHeadAttention):
323
592
  def __init__(
324
- self,
325
- embed_dim: int,
326
- num_heads: int,
327
- kernel_size: int = 128,
328
- use_rotary: bool = True,
329
- **kwargs
593
+ self,
594
+ embed_dim: int,
595
+ num_heads: int,
596
+ kernel_size: int = 128,
597
+ use_rotary: bool = True,
598
+ **kwargs
330
599
  ):
331
600
  super().__init__(embed_dim, num_heads, **kwargs)
332
601
  self.kernel_size = kernel_size
@@ -377,4 +646,89 @@ class InfiniteAttention(MultiHeadAttention):
377
646
  q = q / (q.shape[-1] ** 0.5)
378
647
  attn = torch.einsum('b h i d, b h j d -> b h i j', q, k)
379
648
  attn = torch.softmax(attn, dim=-1)
380
- return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
649
+ return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
650
+
651
+ def init_moe_attention(
652
+ embed_dim: int,
653
+ num_heads: int,
654
+ attention_type: str,
655
+ gqa_groups: int = 1,
656
+ dropout: float = 0.0,
657
+ rope: RotaryPositionalEmbedding = None,
658
+ rope_only_for_query: bool = False,
659
+ use_relative_embeddings: bool = False,
660
+ max_seq_len: int = 1024,
661
+ use_flash_attention: bool = False,
662
+ is_causal: bool = False,
663
+ use_bias: bool = False,
664
+ num_experts: int = None,
665
+ num_query_experts: int = None,
666
+ num_query_groups: int = None,
667
+ ) -> GroupedQueryAttention:
668
+ assert attention_type == 'gma' or attention_type == 'dma' or attention_type == 'gma_v' or attention_type == 'dma_v', \
669
+ "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
670
+
671
+ if attention_type == "gma":
672
+ return GroupedMoeAttention(
673
+ embed_dim,
674
+ num_heads,
675
+ gqa_groups,
676
+ dropout=dropout,
677
+ rope=rope,
678
+ use_relative_embeddings=use_relative_embeddings,
679
+ max_seq_len=max_seq_len,
680
+ rope_only_for_query=rope_only_for_query,
681
+ use_flash_attention=use_flash_attention,
682
+ is_causal=is_causal,
683
+ use_bias=use_bias,
684
+ num_experts=num_experts,
685
+ )
686
+ elif attention_type == "dma":
687
+ return DeepMoeAttention(
688
+ embed_dim,
689
+ num_heads,
690
+ gqa_groups,
691
+ dropout=dropout,
692
+ rope=rope,
693
+ use_relative_embeddings=use_relative_embeddings,
694
+ max_seq_len=max_seq_len,
695
+ rope_only_for_query=rope_only_for_query,
696
+ use_flash_attention=use_flash_attention,
697
+ is_causal=is_causal,
698
+ use_bias=use_bias,
699
+ num_experts=num_experts,
700
+ num_query_experts=num_query_experts,
701
+ num_query_groups=num_query_groups,
702
+ )
703
+ elif attention_type == "gma_v":
704
+ return GroupedMoeAttentionVectorized(
705
+ embed_dim,
706
+ num_heads,
707
+ gqa_groups,
708
+ dropout=dropout,
709
+ rope=rope,
710
+ use_relative_embeddings=use_relative_embeddings,
711
+ max_seq_len=max_seq_len,
712
+ rope_only_for_query=rope_only_for_query,
713
+ use_flash_attention=use_flash_attention,
714
+ is_causal=is_causal,
715
+ use_bias=use_bias,
716
+ num_experts=num_experts,
717
+ )
718
+ else:
719
+ return DeepMoeAttentionVectorized(
720
+ embed_dim,
721
+ num_heads,
722
+ gqa_groups,
723
+ dropout=dropout,
724
+ rope=rope,
725
+ use_relative_embeddings=use_relative_embeddings,
726
+ max_seq_len=max_seq_len,
727
+ rope_only_for_query=rope_only_for_query,
728
+ use_flash_attention=use_flash_attention,
729
+ is_causal=is_causal,
730
+ use_bias=use_bias,
731
+ num_experts=num_experts,
732
+ num_query_experts=num_query_experts,
733
+ num_query_groups=num_query_groups,
734
+ )
@@ -0,0 +1,117 @@
1
+ import torch
2
+ from torch import nn
3
+ from typing import TypedDict, Union
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from ..transformers.positional import RotaryPositionalEmbedding
6
+ from ..transformers.attention import init_attention
7
+ from ..transformers.layers import ClassicTransformerLayer
8
+ from ..transformers.models import ClassicTransformerDecoder
9
+ from ..transformers.ff import get_activation_layer
10
+ from ..memory.stm import ShortTermMemory
11
+ from ..utils import get_model_size
12
+ from .attention import init_moe_attention
13
+
14
+
15
+ class MoeAttentionTransformerConfig(TypedDict):
16
+ num_layers: int
17
+ vocab_size: int
18
+ embed_dim: int
19
+ ff_dim: int
20
+ att_heads: int
21
+ seq_len: int
22
+ use_flash_attention: bool
23
+ use_gated: bool
24
+ ff_activation: str
25
+ ff_dropout: float
26
+ att_dropout: float
27
+ use_rms_norm: bool
28
+ att_groups: int
29
+ use_moe_ff: bool
30
+ ff_num_experts: int
31
+ ff_moe_top_k: int
32
+ att_type: str
33
+ att_num_experts: int
34
+ att_num_query_experts: int
35
+ att_num_query_groups: int
36
+
37
+
38
+ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
39
+ """Research model for experiments with Mixture-of-Experts Attention"""
40
+
41
+ def __init__(
42
+ self,
43
+ num_layers: int = 6,
44
+ vocab_size: int = 5000,
45
+ embed_dim: int = 128,
46
+ ff_dim: int = 384,
47
+ att_heads: int = 16,
48
+ seq_len: int = 256,
49
+ use_flash_attention: bool = True,
50
+ use_gated: bool = True,
51
+ ff_activation: str = "swish",
52
+ ff_dropout: float = 0.0,
53
+ att_dropout: float = 0.0,
54
+ use_rms_norm: bool = True,
55
+ att_groups: int = 1,
56
+ use_moe_ff: bool = False,
57
+ ff_num_experts: int = 1,
58
+ ff_moe_top_k: int = 1,
59
+ att_type: str = 'gma',
60
+ att_num_experts: int = None,
61
+ att_num_query_experts: int = None,
62
+ att_num_query_groups: int = None,
63
+ **kwargs
64
+ ):
65
+ super(MoeAttentionTransformer, self).__init__(**kwargs)
66
+ assert ff_activation in ['relu', 'gelu',
67
+ 'swish', 'silu', 'linear',
68
+ 'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
69
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v',
70
+ 'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v"'
71
+
72
+ embedding = nn.Embedding(vocab_size, embed_dim)
73
+ rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
74
+
75
+ ff_activation = get_activation_layer(ff_activation)
76
+
77
+ if att_type in ['mha', 'gqa', 'mqa']:
78
+ att_init = lambda: init_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
79
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
80
+ max_seq_len=seq_len, is_causal=True)
81
+ else:
82
+ att_init = lambda: init_moe_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
83
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
84
+ max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
85
+ num_query_experts=att_num_query_experts,
86
+ num_query_groups=att_num_query_groups)
87
+
88
+ self.model = ClassicTransformerDecoder(
89
+ embed_dim,
90
+ vocab_size,
91
+ embedding=embedding,
92
+ layers=nn.ModuleList([
93
+ ClassicTransformerLayer(
94
+ embed_dim,
95
+ ff_dim,
96
+ use_gated=use_gated,
97
+ use_moe=use_moe_ff,
98
+ num_experts=ff_num_experts,
99
+ moe_top_k=ff_moe_top_k,
100
+ ff_activation=ff_activation,
101
+ ff_dropout=ff_dropout,
102
+ use_rms_norm=use_rms_norm,
103
+ self_attention=att_init(),
104
+ ) for _ in range(num_layers)
105
+ ]),
106
+ use_flash_attention=use_flash_attention,
107
+ )
108
+
109
+ def params_count(self):
110
+ return get_model_size(self.model)
111
+
112
+ def load_shared_embedding(self, embedding: nn.Embedding):
113
+ self.model.embedding = embedding
114
+
115
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
116
+ torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
117
+ return self.model(x, attention_mask=attention_mask)
@@ -0,0 +1,206 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ..transformers.moe import MoeRouter
5
+
6
+ class DynamicMoeRouter(nn.Module):
7
+ """Dynamic Mixture-of-Experts Router layer - dynamically selects top-k experts for each token."""
8
+
9
+ def __init__(self, embed_dim: int, num_experts: int, top_ks: tuple[int] = (1, 2, 3), *args, **kwargs):
10
+ super(DynamicMoeRouter, self).__init__(*args, **kwargs)
11
+ self.top_ks = top_ks
12
+ self.num_options = len(top_ks)
13
+ self.num_experts = num_experts
14
+ self.gate = nn.Linear(embed_dim, num_experts + self.num_options, bias=False)
15
+ # For expert load balancing
16
+ self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
17
+
18
+ def calculate_aux_loss(self, top_k_indices: torch.Tensor, routing_probs: torch.Tensor) -> torch.Tensor:
19
+ expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
20
+ expert_usage = expert_mask.sum(dim=0).mean(dim=0)
21
+ mean_probs = routing_probs.mean(dim=0)
22
+ return (expert_usage * mean_probs).sum() * self.num_experts
23
+
24
+ def forward(self, x: torch.Tensor):
25
+ # Input shape: [batch*seq_len, embed_dim]
26
+ all_logits = self.gate(x)
27
+ routing_logits = all_logits[:, :-self.num_options]
28
+ options_logits = all_logits[:, -self.num_options:]
29
+
30
+ routing_probs = F.softmax(routing_logits, dim=-1)
31
+ top_k_id = torch.argmax(options_logits, dim=-1).item()
32
+
33
+ top_k = self.top_ks[top_k_id]
34
+
35
+ # Get top-k experts for each token
36
+ top_k_weights, top_k_indices = routing_probs.topk(top_k, dim=-1)
37
+
38
+ # Normalize weights (sum to 1 for each token)
39
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
40
+
41
+ # Load Balance Loss
42
+ self.aux_loss = self.calculate_aux_loss(top_k_indices, routing_probs)
43
+
44
+ return top_k_weights, top_k_indices, top_k
45
+
46
+ class MoeFeedForwardVectorized(nn.Module):
47
+ """
48
+ Vectorized MoE - current implementation is incorrect - it calculates all the experts, then selects the correct ones.
49
+
50
+ Commented out implementation is fixing this problem, but is causing memory overflows, because of experts weights
51
+ indexing - it's using ~15x more memory, than dense model of similar size, so it's currently not viable.
52
+
53
+ It's recommended to use standard MoE from rxnn.transformers.moe instead.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ embed_dim: int,
59
+ hidden_dim: int,
60
+ num_experts: int,
61
+ activation: nn.Module,
62
+ top_k: int = 1,
63
+ dropout: float = 0.0,
64
+ *args,
65
+ **kwargs
66
+ ):
67
+ super(MoeFeedForwardVectorized, self).__init__(*args, **kwargs)
68
+ self.embed_dim = embed_dim
69
+ self.num_experts = num_experts
70
+ self.top_k = top_k
71
+
72
+ self.router = MoeRouter(embed_dim, num_experts, top_k)
73
+
74
+ # Batch all expert parameters together
75
+ self.w1 = nn.Parameter(torch.empty(num_experts, embed_dim, self._w1_dim_factor(hidden_dim)))
76
+ self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
77
+ self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
78
+ self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
79
+ self.activation = activation
80
+ self.dropout = nn.Dropout(dropout)
81
+
82
+ # Initialize parameters
83
+ self._init_linear_parameters()
84
+ nn.init.zeros_(self.b1)
85
+ nn.init.zeros_(self.b2)
86
+
87
+ def _init_linear_parameters(self):
88
+ nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
89
+ nn.init.kaiming_normal_(self.w2, nonlinearity='relu')
90
+
91
+ def _w1_dim_factor(self, hidden_dim: int) -> int:
92
+ return hidden_dim
93
+
94
+ def _activate(self, h: torch.Tensor):
95
+ return self.activation(h)
96
+
97
+ def router_loss(self):
98
+ return self.router.aux_loss
99
+
100
+ def forward(self, x: torch.Tensor):
101
+ orig_shape = x.shape
102
+ x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
103
+
104
+ # Get routing weights and indices
105
+ weights, indices = self.router(x) # [batch*seq_len, top_k]
106
+
107
+ # Create expert masks and combine it with masks
108
+ mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
109
+ weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
110
+
111
+ # Expert computation
112
+ x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
113
+
114
+ # First linear layer
115
+ h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
116
+ h = self._activate(h)
117
+ h = self.dropout(h)
118
+
119
+ # Second linear layer (projection back to embed_dim)
120
+ out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
121
+
122
+ # Weighted sum of expert outputs
123
+ out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
124
+
125
+ return out.view(*orig_shape)
126
+ # orig_shape = x.shape
127
+ # x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
128
+ #
129
+ # # Get routing weights and indices
130
+ # weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
131
+ #
132
+ # # Flatten indices and weights
133
+ # batch_size = x.shape[0]
134
+ # top_k = indices.shape[1]
135
+ # indices_flat = indices.view(-1) # [B*T * top_k]
136
+ #
137
+ # # Compute contributions for selected experts without materializing large tensors
138
+ # # First Layer:
139
+ # # Compute all expert contributions first (but this may still be memory-heavy)
140
+ # # Alternative: Compute contributions for selected experts directly
141
+ # # ... (see detailed steps below)
142
+ #
143
+ # # Alternative approach using gather and batched operations
144
+ # x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [B*T*top_k, D]
145
+ #
146
+ # # Compute first layer contributions using gather
147
+ # # indices_flat has shape [B*T*top_k]
148
+ # # selected_w1 is self.w1[indices_flat], but we compute the product inline
149
+ # h = torch.einsum(
150
+ # 'be, eih -> bh',
151
+ # x_expanded,
152
+ # self.w1[indices_flat]
153
+ # ) + self.b1[indices_flat]
154
+ # h = self._activate(h)
155
+ # h = self.dropout(h)
156
+ #
157
+ # # Second layer:
158
+ # out = torch.einsum(
159
+ # 'bh, eho -> beo',
160
+ # h,
161
+ # self.w2[indices_flat]
162
+ # ).squeeze(-1) + self.b2[indices_flat]
163
+ #
164
+ # # Reshape and apply weights
165
+ # out = out.view(batch_size, top_k, -1)
166
+ # weights = weights.view(batch_size, top_k, 1)
167
+ # out = (out * weights).sum(dim=1)
168
+ #
169
+ # return out.view(*orig_shape)
170
+
171
+
172
+ class GatedMoeFeedForwardVectorized(MoeFeedForwardVectorized):
173
+ """Gated Mixture-of-Experts Feed-Forward layer - enable GLU-based activations for MoE"""
174
+
175
+ def __init__(
176
+ self,
177
+ embed_dim: int,
178
+ hidden_dim: int,
179
+ num_experts: int,
180
+ activation: nn.Module = nn.SiLU(),
181
+ top_k: int = 1,
182
+ dropout: float = 0.1,
183
+ *args,
184
+ **kwargs
185
+ ):
186
+ super(GatedMoeFeedForwardVectorized, self).__init__(
187
+ embed_dim=embed_dim,
188
+ hidden_dim=hidden_dim,
189
+ num_experts=num_experts,
190
+ activation=activation,
191
+ top_k=top_k,
192
+ dropout=dropout,
193
+ *args,
194
+ **kwargs
195
+ )
196
+
197
+ def _init_linear_parameters(self):
198
+ nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
199
+ nn.init.kaiming_normal_(self.w2, nonlinearity='linear')
200
+
201
+ def _w1_dim_factor(self, hidden_dim: int) -> int:
202
+ return 2 * hidden_dim
203
+
204
+ def _activate(self, h: torch.Tensor):
205
+ a, b = h.chunk(2, dim=-1)
206
+ return a * self.activation(b)
rxnn/transformers/moe.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
-
4
+ from .ff import FeedForward, GatedFeedForward
5
5
 
6
6
  class MoeRouter(nn.Module):
7
7
  """Mixture-of-Experts Router layer - computes routing weights for each expert."""
@@ -14,18 +14,27 @@ class MoeRouter(nn.Module):
14
14
  # For expert load balancing
15
15
  self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
16
16
 
17
+ def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
18
+ expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
19
+ expert_usage = expert_mask.sum(dim=0).mean(dim=0)
20
+ mean_probs = probs.mean(dim=0)
21
+ return (expert_usage * mean_probs).sum() * self.num_experts
22
+
23
+
17
24
  def forward(self, x: torch.Tensor):
18
- # x shape: [batch_size*seq_len, embed_dim]
25
+ # Input shape: [batch*seq_len, embed_dim]
19
26
  logits = self.gate(x)
20
27
  probs = F.softmax(logits, dim=-1)
21
28
 
22
- # Expert load balancing loss
23
- mean_probs = probs.mean(dim=0) # Mean probability per expert across batch
24
- self.aux_loss = (mean_probs * torch.log(mean_probs + 1e-9)).sum() # Entropy-based loss
25
-
29
+ # Get top-k experts for each token
26
30
  top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
31
+
32
+ # Normalize weights (sum to 1 for each token)
27
33
  top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
28
34
 
35
+ # Load Balance Loss
36
+ self.aux_loss = self.calculate_aux_loss(top_k_indices, probs)
37
+
29
38
  return top_k_weights, top_k_indices
30
39
 
31
40
 
@@ -51,91 +60,43 @@ class MoeFeedForward(nn.Module):
51
60
  self.router = MoeRouter(embed_dim, num_experts, top_k)
52
61
 
53
62
  # Batch all expert parameters together
54
- self.w1 = nn.Parameter(torch.empty(num_experts, embed_dim, self._w1_dim_factor(hidden_dim)))
55
- self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
56
- self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
57
- self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
58
- self.activation = activation
59
- self.dropout = nn.Dropout(dropout)
60
-
61
- # Initialize parameters
62
- self._init_linear_parameters()
63
- nn.init.zeros_(self.b1)
64
- nn.init.zeros_(self.b2)
63
+ self._init_experts(num_experts, embed_dim, hidden_dim, activation, dropout)
65
64
 
66
- def _init_linear_parameters(self):
67
- nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
68
- nn.init.kaiming_normal_(self.w2, nonlinearity='relu')
69
-
70
- def _w1_dim_factor(self, hidden_dim: int) -> int:
71
- return hidden_dim
72
-
73
- def _activate(self, h: torch.Tensor):
74
- return self.activation(h)
65
+ def _init_experts(self, num_experts: int, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float):
66
+ self.experts = nn.ModuleList([
67
+ FeedForward(embed_dim, hidden_dim, activation, dropout)
68
+ for _ in range(num_experts)
69
+ ])
75
70
 
76
71
  def router_loss(self):
77
72
  return self.router.aux_loss
78
73
 
79
74
  def forward(self, x: torch.Tensor):
80
- # orig_shape = x.shape
81
- # x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
82
- #
83
- # # Get routing weights and indices
84
- # weights, indices = self.router(x) # [batch*seq_len, top_k]
85
- #
86
- # # Create expert masks and combine it with masks
87
- # mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
88
- # weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
89
- #
90
- # # Expert computation
91
- # x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
92
- #
93
- # # First linear layer
94
- # h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
95
- # h = self._activate(h)
96
- # h = self.dropout(h)
97
- #
98
- # # Second linear layer (projection back to embed_dim)
99
- # out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
100
- #
101
- # # Weighted sum of expert outputs
102
- # out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
103
- #
104
- # return out.view(*orig_shape)
105
75
  orig_shape = x.shape
106
76
  x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
107
77
 
108
78
  # Get routing weights and indices
109
- weights, indices = self.router(x) # [batch*seq_len, top_k], [batch*seq_len, top_k]
110
-
111
- # Flatten indices and weights
112
- batch_size = x.size(0)
113
- top_k = indices.size(1)
114
- indices = indices.view(-1) # [batch*seq_len * top_k]
115
- weights = weights.view(-1, 1) # [batch*seq_len * top_k, 1]
79
+ weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
116
80
 
117
- # Select only the relevant experts for each token
118
- selected_w1 = self.w1[indices] # [batch*seq_len * top_k, embed_dim, hidden_dim]
119
- selected_b1 = self.b1[indices] # [batch*seq_len * top_k, hidden_dim]
120
- selected_w2 = self.w2[indices] # [batch*seq_len * top_k, hidden_dim, embed_dim]
121
- selected_b2 = self.b2[indices] # [batch*seq_len * top_k, embed_dim]
81
+ # Create mask for expert contributions (B*T, num_experts)
82
+ expert_mask = F.one_hot(indices, self.num_experts).float() # [B*T, top_k, num_experts]
83
+ expert_weights = (weights.unsqueeze(-1) * expert_mask).sum(dim=1) # [B*T, num_experts]
122
84
 
123
- # Reshape x for batched computation
124
- x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [batch*seq_len * top_k, embed_dim]
85
+ output = torch.zeros_like(x)
86
+ for expert_idx in range(self.num_experts):
87
+ # Mask for tokens where this expert is in top_k
88
+ mask = expert_weights[:, expert_idx] > 0
89
+ if not mask.any():
90
+ continue
125
91
 
126
- # Compute only the selected experts
127
- h = torch.einsum('be, beh -> bh', x_expanded, selected_w1) + selected_b1
128
- h = self._activate(h)
129
- h = self.dropout(h)
92
+ # Compute expert output for selected tokens
93
+ expert_input = x[mask]
94
+ expert_output = self.experts[expert_idx](expert_input)
130
95
 
131
- out = torch.einsum('bh, bhe -> be', h, selected_w2) + selected_b2
96
+ # Apply combined weights for this expert
97
+ output[mask] += expert_output * expert_weights[mask, expert_idx].unsqueeze(-1)
132
98
 
133
- # Reshape back and apply weights
134
- out = out.view(batch_size, top_k, -1) # [batch*seq_len, top_k, embed_dim]
135
- weights = weights.view(batch_size, top_k, 1) # [batch*seq_len, top_k, 1]
136
- out = (out * weights).sum(dim=1) # Weighted sum over top_k experts
137
-
138
- return out.view(*orig_shape)
99
+ return output.view(*orig_shape)
139
100
 
140
101
 
141
102
  class GatedMoeFeedForward(MoeFeedForward):
@@ -163,13 +124,8 @@ class GatedMoeFeedForward(MoeFeedForward):
163
124
  **kwargs
164
125
  )
165
126
 
166
- def _init_linear_parameters(self):
167
- nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
168
- nn.init.kaiming_normal_(self.w2, nonlinearity='linear')
169
-
170
- def _w1_dim_factor(self, hidden_dim: int) -> int:
171
- return 2 * hidden_dim
172
-
173
- def _activate(self, h: torch.Tensor):
174
- a, b = h.chunk(2, dim=-1)
175
- return a * self.activation(b)
127
+ def _init_experts(self, num_experts: int, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float):
128
+ self.experts = nn.ModuleList([
129
+ GatedFeedForward(embed_dim, hidden_dim, activation, dropout)
130
+ for _ in range(num_experts)
131
+ ])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.14
3
+ Version: 0.1.16
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,8 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=HahcWU37FTfW8kwSTW8z_l7EtAVkJgvDDxLU8k3miHo,17101
3
+ rxnn/experimental/attention.py,sha256=qly-Lf9UsYC9JB945JcLnt27ZbF0vFvfyS5iUm-Rsak,31644
4
+ rxnn/experimental/models.py,sha256=ioYtbJDxJ4zASiKs9dFY4WvAJn7eVqFf7zid-65pbUU,4709
5
+ rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
4
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
6
8
  rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
@@ -19,11 +21,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
19
21
  rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
20
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
21
23
  rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
22
- rxnn/transformers/moe.py,sha256=fFPTRcctCSc9OwHd0PhNb0nwHgNJY7dXfUtGreXtaho,6720
24
+ rxnn/transformers/moe.py,sha256=FeaQR7hTX1dE74YdMOcuyZHSkGiV_0JwF8fw-GnfNOQ,4741
23
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
24
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
25
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
26
- rxnn-0.1.14.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
- rxnn-0.1.14.dist-info/METADATA,sha256=YQDNMaHDrfVdOk44qEUczgLaNcrXApoqVmNX50yQDdM,14629
28
- rxnn-0.1.14.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
- rxnn-0.1.14.dist-info/RECORD,,
28
+ rxnn-0.1.16.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.16.dist-info/METADATA,sha256=Cr_8OPHWlf2LHYlZEmc_NaUkIiE3ShJ01Z5B5ZhI6G8,14629
30
+ rxnn-0.1.16.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.16.dist-info/RECORD,,
File without changes
File without changes