rxnn 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  import torch
2
- from torch import nn
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
3
4
  from ..transformers.attention import MultiHeadAttention, GroupedQueryAttention
4
5
  from ..transformers.positional import RotaryPositionalEmbedding
5
6
  from ..transformers.moe import MoeRouter
@@ -9,6 +10,7 @@ from ..transformers.moe import MoeRouter
9
10
  class GroupedMoeAttention(GroupedQueryAttention):
10
11
  """
11
12
  Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
13
+
12
14
  Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
13
15
  number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
14
16
  - with num_groups set to 1, it will be MoE MultiQueryAttention
@@ -20,8 +22,11 @@ class GroupedMoeAttention(GroupedQueryAttention):
20
22
 
21
23
  Optionally, it could use even more expert heads than attention heads - in example:
22
24
  - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
23
- 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
25
+ 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
26
+
27
+ © 2025 Adam Filipek
24
28
  """
29
+
25
30
  def __init__(
26
31
  self,
27
32
  embed_dim: int,
@@ -39,7 +44,7 @@ class GroupedMoeAttention(GroupedQueryAttention):
39
44
  *args,
40
45
  **kwargs,
41
46
  ):
42
- self.num_experts = num_experts if num_experts is not None else num_heads
47
+ self.num_experts = num_experts or num_heads
43
48
  super(GroupedMoeAttention, self).__init__(
44
49
  embed_dim,
45
50
  num_heads,
@@ -58,7 +63,237 @@ class GroupedMoeAttention(GroupedQueryAttention):
58
63
 
59
64
  def _init_kv(self, embed_dim: int):
60
65
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
61
- hidden_dim = embed_dim // (self.num_heads // self.num_groups)
66
+
67
+ hidden_dim = embed_dim // self.num_heads
68
+ self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
69
+ self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
70
+ self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
71
+ self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
72
+ self._init_experts()
73
+
74
+ def _init_experts(self):
75
+ nn.init.xavier_uniform_(self.wk)
76
+ nn.init.xavier_uniform_(self.wv)
77
+ if self.use_bias:
78
+ nn.init.zeros_(self.bk)
79
+ nn.init.zeros_(self.bv)
80
+
81
+ def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
82
+ B, S, G = indices.shape
83
+ x_flat = x.view(-1, x.size(-1))
84
+
85
+ # Flatten batch and sequence dimensions
86
+ indices_flat = indices.view(-1, G)
87
+ weights_flat = weights.view(-1, G, 1)
88
+
89
+ # Create expanded indices for expert processing
90
+ mask = torch.zeros(B * S, self.num_experts, device=x.device, dtype=torch.bool)
91
+ for g in range(G):
92
+ mask.scatter_(1, indices_flat[:, g].unsqueeze(1), True)
93
+
94
+ output = torch.zeros(B * S, G, w.size(2), device=x.device, dtype=x.dtype)
95
+
96
+ for e in range(self.num_experts):
97
+ token_mask = mask[:, e]
98
+ if not token_mask.any():
99
+ continue
100
+
101
+ # Get positions where expert e is used in any group
102
+ x_slice = x_flat[token_mask]
103
+ proj = F.linear(x_slice, w[e], b[e] if b is not None else None)
104
+
105
+ # Find which groups use this expert for selected tokens
106
+ group_mask = (indices_flat[token_mask] == e)
107
+
108
+ # Accumulate projections for relevant groups
109
+ weighted_proj = proj.unsqueeze(1) * weights_flat[token_mask] * group_mask.unsqueeze(-1).float()
110
+ output[token_mask] += weighted_proj.sum(dim=1)
111
+
112
+ return output.view(B, S, G, -1)
113
+
114
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
115
+ skip_query_processing: bool = False):
116
+ q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
117
+
118
+ # Key/Value processing
119
+ B, S, D = key.shape
120
+ key_flat = key.view(-1, D)
121
+ weights_k_flat, indices_k_flat = self.router(key_flat)
122
+ # Reshape back to original dimensions
123
+ weights_k = weights_k_flat.view(B, S, -1)
124
+ indices_k = indices_k_flat.view(B, S, -1)
125
+ k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
126
+ v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
127
+
128
+ # Expand to GQA format
129
+ k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
130
+ v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
131
+
132
+ if not self.use_flash_attention:
133
+ group_heads = self.num_heads // self.num_groups
134
+
135
+ k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
136
+ v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
137
+
138
+ k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
139
+ v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
140
+
141
+ return q, k, v
142
+
143
+
144
+ class DeepMoeAttention(GroupedMoeAttention):
145
+ """
146
+ Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
147
+
148
+ In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
149
+ query heads - with that approach, each token could attend to every other token, but only partially - only some part of
150
+ information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
151
+ sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
152
+
153
+ This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
154
+ a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
155
+
156
+ © 2025 Adam Filipek
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ embed_dim: int,
162
+ num_heads: int,
163
+ num_groups: int,
164
+ dropout: float = 0.0,
165
+ rope: RotaryPositionalEmbedding = None,
166
+ rope_only_for_query: bool = False,
167
+ use_relative_embeddings: bool = False,
168
+ max_seq_len: int = 1024,
169
+ use_flash_attention: bool = False,
170
+ is_causal: bool = False,
171
+ use_bias: bool = False,
172
+ num_experts: int = None,
173
+ num_query_experts: int = None,
174
+ num_query_groups: int = None,
175
+ *args,
176
+ **kwargs,
177
+ ):
178
+ self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
179
+ self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
180
+ super(DeepMoeAttention, self).__init__(
181
+ embed_dim,
182
+ num_heads,
183
+ num_groups=num_groups,
184
+ dropout=dropout,
185
+ rope=rope,
186
+ rope_only_for_query=rope_only_for_query,
187
+ use_relative_embeddings=use_relative_embeddings,
188
+ max_seq_len=max_seq_len,
189
+ use_flash_attention=use_flash_attention,
190
+ is_causal=is_causal,
191
+ use_bias=use_bias,
192
+ num_experts=num_experts,
193
+ *args,
194
+ **kwargs,
195
+ )
196
+
197
+ def _init_q(self, embed_dim: int):
198
+ self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
199
+
200
+ hidden_dim = embed_dim // self.num_heads
201
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
202
+ self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
203
+ self._init_query_experts()
204
+
205
+ def _init_query_experts(self):
206
+ nn.init.xavier_uniform_(self.wq)
207
+ if self.use_bias:
208
+ nn.init.zeros_(self.bq)
209
+
210
+ def _init_out(self, embed_dim: int):
211
+ """Initialize output projection"""
212
+ hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
213
+ self.out_proj = nn.Linear(hidden_dim, embed_dim)
214
+
215
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
216
+ # Query processing
217
+ B, T, D = query.shape
218
+ # Flatten for query routing
219
+ query_flat = query.view(B * T, D)
220
+ weights_q_flat, indices_q_flat = self.query_router(query_flat)
221
+ # Reshape back
222
+ weights_q = weights_q_flat.view(B, T, -1)
223
+ indices_q = indices_q_flat.view(B, T, -1)
224
+ q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
225
+ q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
226
+
227
+ # Expand query groups to match head count
228
+ group_heads = self.num_heads // self.num_query_groups
229
+ q = q.unsqueeze(2).expand(-1, -1, group_heads, -1, -1).flatten(1, 2).transpose(1, 2)
230
+
231
+ # Key/Value processing
232
+ return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
233
+
234
+ # Vectorized
235
+
236
+ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
237
+ """
238
+ Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
239
+ for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
240
+ experts - it has to be tested.
241
+
242
+ Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
243
+
244
+ Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
245
+ number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
246
+ - with num_groups set to 1, it will be MoE MultiQueryAttention
247
+
248
+ Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
249
+ this approach - we are training the full number of keys/values heads, while using only a group.
250
+
251
+ In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
252
+
253
+ Optionally, it could use even more expert heads than attention heads - in example:
254
+ - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
255
+ 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
256
+
257
+ © 2025 Adam Filipek
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ embed_dim: int,
263
+ num_heads: int,
264
+ num_groups: int,
265
+ dropout: float = 0.0,
266
+ rope: RotaryPositionalEmbedding = None,
267
+ rope_only_for_query: bool = False,
268
+ use_relative_embeddings: bool = False,
269
+ max_seq_len: int = 1024,
270
+ use_flash_attention: bool = False,
271
+ is_causal: bool = False,
272
+ use_bias: bool = False,
273
+ num_experts: int = None,
274
+ *args,
275
+ **kwargs,
276
+ ):
277
+ self.num_experts = num_experts if num_experts is not None else num_heads
278
+ super(GroupedMoeAttentionVectorized, self).__init__(
279
+ embed_dim,
280
+ num_heads,
281
+ num_groups=num_groups,
282
+ dropout=dropout,
283
+ rope=rope,
284
+ rope_only_for_query=rope_only_for_query,
285
+ use_relative_embeddings=use_relative_embeddings,
286
+ max_seq_len=max_seq_len,
287
+ use_flash_attention=use_flash_attention,
288
+ is_causal=is_causal,
289
+ use_bias=use_bias,
290
+ *args,
291
+ **kwargs,
292
+ )
293
+
294
+ def _init_kv(self, embed_dim: int):
295
+ self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
296
+ hidden_dim = embed_dim // self.num_heads
62
297
  self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
63
298
  self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
64
299
  self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
@@ -72,47 +307,37 @@ class GroupedMoeAttention(GroupedQueryAttention):
72
307
  torch.nn.init.zeros_(self.bk)
73
308
  torch.nn.init.zeros_(self.bv)
74
309
 
75
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
310
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
311
+ skip_query_processing: bool = False):
312
+ # Indexed version may cause memory overflow
313
+ #
76
314
  # head_dim = d // self.num_heads
77
- # group_heads = self.num_heads // self.num_groups
78
315
  #
79
316
  # # Process Query as in GQA
80
- # q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
317
+ # q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
318
+ # 2) if not skip_query_processing else query
81
319
  #
82
320
  # # Process Key and Value with MoE routing
83
- # key_flat = key.view(-1, d)
84
- # weights, indices = self.router(key_flat)
85
- # weights = weights.view(b, key.size(1), self.num_groups, 1)
86
- # indices = indices.view(b, key.size(1), self.num_groups)
321
+ # key_flat = key.view(-1, d) # (B*S, d)
322
+ # value_flat = value.view(-1, d) # (B*S, d)
87
323
  #
88
- # # Compute all experts' K and V projections
89
- # # Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
90
- # k_all = torch.einsum(
91
- # 'be, ehd -> bedh',
92
- # key_flat,
93
- # self.wk.view(self.num_experts, d, -1)
94
- # ).view(b, key.size(1), self.num_experts, -1)
324
+ # # Get routing indices and weights for K
325
+ # weights_k, indices_k = self.router(key_flat)
326
+ # indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
327
+ # weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
95
328
  #
96
- # v_all = torch.einsum(
97
- # 'be, ehd -> bedh',
98
- # value.view(-1, d),
99
- # self.wv.view(self.num_experts, d, -1)
100
- # ).view(b, value.size(1), self.num_experts, -1)
329
+ # # Select and compute K projections for only the top_k experts
330
+ # selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
331
+ # k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
332
+ # selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
333
+ # selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
101
334
  #
102
- # # Select top_k experts and compute weighted sum
103
- # selected_k = torch.gather(
104
- # k_all,
105
- # 2,
106
- # indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
107
- # )
108
- # selected_v = torch.gather(
109
- # v_all,
110
- # 2,
111
- # indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
112
- # )
335
+ # # Compute V using the same indices as K (since they share the same router)
336
+ # selected_v_weights = self.v_experts[indices_k]
337
+ # v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
338
+ # selected_v = (v_proj * weights_k).sum(dim=1)
339
+ # selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
113
340
  #
114
- # selected_k = (selected_k * weights).sum(dim=2)
115
- # selected_v = (selected_v * weights).sum(dim=2)
116
341
  # # Reshape to GQA format: (B, G, S, head_dim)
117
342
  # k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
118
343
  # v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
@@ -127,32 +352,46 @@ class GroupedMoeAttention(GroupedQueryAttention):
127
352
  # v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
128
353
  #
129
354
  # return q, k, v
355
+
130
356
  head_dim = d // self.num_heads
131
357
 
132
358
  # Process Query as in GQA
133
- q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2) if not skip_query_processing else query
359
+ q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
134
360
 
135
361
  # Process Key and Value with MoE routing
136
- key_flat = key.view(-1, d) # (B*S, d)
137
- value_flat = value.view(-1, d) # (B*S, d)
138
-
139
- # Get routing indices and weights for K
140
- weights_k, indices_k = self.router(key_flat)
141
- indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
142
- weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
143
-
144
- # Select and compute K projections for only the top_k experts
145
- selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
146
- k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
147
- selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
148
- selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
149
-
150
- # Compute V using the same indices as K (since they share the same router)
151
- selected_v_weights = self.v_experts[indices_k]
152
- v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
153
- selected_v = (v_proj * weights_k).sum(dim=1)
154
- selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
362
+ key_flat = key.view(-1, d)
363
+ weights, indices = self.router(key_flat)
364
+ weights = weights.view(b, key.size(1), self.num_groups, 1)
365
+ indices = indices.view(b, key.size(1), self.num_groups)
366
+
367
+ # Compute all experts' K and V projections
368
+ # Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
369
+ k_all = torch.einsum(
370
+ 'be, ehd -> bedh',
371
+ key_flat,
372
+ self.wk.view(self.num_experts, d, -1)
373
+ ).view(b, key.size(1), self.num_experts, -1)
374
+
375
+ v_all = torch.einsum(
376
+ 'be, ehd -> bedh',
377
+ value.view(-1, d),
378
+ self.wv.view(self.num_experts, d, -1)
379
+ ).view(b, value.size(1), self.num_experts, -1)
380
+
381
+ # Select top_k experts and compute weighted sum
382
+ selected_k = torch.gather(
383
+ k_all,
384
+ 2,
385
+ indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
386
+ )
387
+ selected_v = torch.gather(
388
+ v_all,
389
+ 2,
390
+ indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
391
+ )
155
392
 
393
+ selected_k = (selected_k * weights).sum(dim=2)
394
+ selected_v = (selected_v * weights).sum(dim=2)
156
395
  # Reshape to GQA format: (B, G, S, head_dim)
157
396
  k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
158
397
  v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
@@ -168,15 +407,26 @@ class GroupedMoeAttention(GroupedQueryAttention):
168
407
 
169
408
  return q, k, v
170
409
 
171
- class SparseMoeAttention(GroupedMoeAttention):
410
+
411
+ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
172
412
  """
173
- Sparse MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
413
+ Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
414
+ for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
415
+ experts - it has to be tested.
416
+
417
+ Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
418
+
174
419
  In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
175
420
  query heads - with that approach, each token could attend to every other token, but only partially - only some part of
176
- information from each token is used to identify related information parts from other tokens.
421
+ information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
422
+ sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
177
423
 
178
- This solution could reduce the computational complexity of attention operation to sublinear level (<O(N))
424
+ This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
425
+ a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
426
+
427
+ © 2025 Adam Filipek
179
428
  """
429
+
180
430
  def __init__(
181
431
  self,
182
432
  embed_dim: int,
@@ -192,13 +442,13 @@ class SparseMoeAttention(GroupedMoeAttention):
192
442
  use_bias: bool = False,
193
443
  num_experts: int = None,
194
444
  num_query_experts: int = None,
195
- num_active_query_heads: int = None,
445
+ num_query_groups: int = None,
196
446
  *args,
197
447
  **kwargs,
198
448
  ):
199
449
  self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
200
- self.num_active_query_heads = num_active_query_heads if num_active_query_heads is not None else num_groups
201
- super(SparseMoeAttention, self).__init__(
450
+ self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
451
+ super(DeepMoeAttentionVectorized, self).__init__(
202
452
  embed_dim,
203
453
  num_heads,
204
454
  num_groups=num_groups,
@@ -216,8 +466,8 @@ class SparseMoeAttention(GroupedMoeAttention):
216
466
  )
217
467
 
218
468
  def _init_q(self, embed_dim: int):
219
- self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_active_query_heads)
220
- hidden_dim = embed_dim // (self.num_heads // self.num_groups)
469
+ self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
470
+ hidden_dim = embed_dim // self.num_heads
221
471
  self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
222
472
  self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
223
473
  self._init_query_experts()
@@ -227,20 +477,47 @@ class SparseMoeAttention(GroupedMoeAttention):
227
477
  if self.use_bias:
228
478
  torch.nn.init.zeros_(self.bq)
229
479
 
480
+ def _init_out(self, embed_dim: int):
481
+ """Initialize output projection"""
482
+ self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
483
+
230
484
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
485
+ # Indexed version may cause memory overflow
486
+ #
487
+ # head_dim = d // self.num_heads
488
+ #
489
+ # # Process Query with MoE routing
490
+ # query_flat = query.view(-1, d) # (B*T, d)
491
+ # weights_q, indices_q = self.query_router(query_flat)
492
+ # indices_q = indices_q.view(-1, self.num_query_groups) # (B*T, top_k_q)
493
+ # weights_q = weights_q.view(-1, self.num_query_groups, 1) # (B*T, top_k_q, 1)
494
+ #
495
+ # # Select and compute Q projections for top_k experts
496
+ # selected_q_weights = self.wq[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
497
+ # q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
498
+ # selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
499
+ # selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
231
500
  head_dim = d // self.num_heads
232
501
 
233
502
  # Process Query with MoE routing
234
- query_flat = query.view(-1, d) # (B*T, d)
235
- weights_q, indices_q = self.router_q(query_flat)
236
- indices_q = indices_q.view(-1, self.top_k_q) # (B*T, top_k_q)
237
- weights_q = weights_q.view(-1, self.top_k_q, 1) # (B*T, top_k_q, 1)
238
-
239
- # Select and compute Q projections for top_k experts
240
- selected_q_weights = self.q_experts[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
241
- q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
242
- selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
243
- selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
503
+ query_flat = query.view(b * t, d)
504
+ weights_q, indices_q = self.query_router(query_flat)
505
+ weights_q = weights_q.view(b, t, self.num_query_groups, 1)
506
+ indices_q = indices_q.view(b, t, self.num_query_groups)
507
+
508
+ # Compute all experts' Q projections
509
+ q_all = torch.einsum(
510
+ 'be, ehd -> bedh',
511
+ query_flat,
512
+ self.wq.view(self.num_query_experts, d, -1)
513
+ ).view(b, t, self.num_query_experts, -1)
514
+
515
+ selected_q = torch.gather(
516
+ q_all,
517
+ 2,
518
+ indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.shape[-1])
519
+ )
520
+ selected_q = (selected_q * weights_q).sum(dim=2)
244
521
 
245
522
  q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
246
523
 
@@ -251,12 +528,12 @@ class SparseMoeAttention(GroupedMoeAttention):
251
528
 
252
529
  class FlexAttention(MultiHeadAttention):
253
530
  def __init__(
254
- self,
255
- embed_dim: int,
256
- num_heads: int,
257
- num_global_tokens: int = 16,
258
- window_size: int = 128,
259
- **kwargs
531
+ self,
532
+ embed_dim: int,
533
+ num_heads: int,
534
+ num_global_tokens: int = 16,
535
+ window_size: int = 128,
536
+ **kwargs
260
537
  ):
261
538
  super().__init__(embed_dim, num_heads, **kwargs)
262
539
  self.num_global_tokens = num_global_tokens
@@ -319,14 +596,15 @@ class FlexAttention(MultiHeadAttention):
319
596
  output = self._calculate_output(combined_attn, v, b, t, d)
320
597
  return self.out_proj(output)
321
598
 
599
+
322
600
  class InfiniteAttention(MultiHeadAttention):
323
601
  def __init__(
324
- self,
325
- embed_dim: int,
326
- num_heads: int,
327
- kernel_size: int = 128,
328
- use_rotary: bool = True,
329
- **kwargs
602
+ self,
603
+ embed_dim: int,
604
+ num_heads: int,
605
+ kernel_size: int = 128,
606
+ use_rotary: bool = True,
607
+ **kwargs
330
608
  ):
331
609
  super().__init__(embed_dim, num_heads, **kwargs)
332
610
  self.kernel_size = kernel_size
@@ -377,4 +655,89 @@ class InfiniteAttention(MultiHeadAttention):
377
655
  q = q / (q.shape[-1] ** 0.5)
378
656
  attn = torch.einsum('b h i d, b h j d -> b h i j', q, k)
379
657
  attn = torch.softmax(attn, dim=-1)
380
- return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
658
+ return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
659
+
660
+ def init_moe_attention(
661
+ embed_dim: int,
662
+ num_heads: int,
663
+ attention_type: str,
664
+ gqa_groups: int = 1,
665
+ dropout: float = 0.0,
666
+ rope: RotaryPositionalEmbedding = None,
667
+ rope_only_for_query: bool = False,
668
+ use_relative_embeddings: bool = False,
669
+ max_seq_len: int = 1024,
670
+ use_flash_attention: bool = False,
671
+ is_causal: bool = False,
672
+ use_bias: bool = False,
673
+ num_experts: int = None,
674
+ num_query_experts: int = None,
675
+ num_query_groups: int = None,
676
+ ) -> GroupedQueryAttention:
677
+ assert attention_type == 'gma' or attention_type == 'dma' or attention_type == 'gma_v' or attention_type == 'dma_v', \
678
+ "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
679
+
680
+ if attention_type == "gma":
681
+ return GroupedMoeAttention(
682
+ embed_dim,
683
+ num_heads,
684
+ gqa_groups,
685
+ dropout=dropout,
686
+ rope=rope,
687
+ use_relative_embeddings=use_relative_embeddings,
688
+ max_seq_len=max_seq_len,
689
+ rope_only_for_query=rope_only_for_query,
690
+ use_flash_attention=use_flash_attention,
691
+ is_causal=is_causal,
692
+ use_bias=use_bias,
693
+ num_experts=num_experts,
694
+ )
695
+ elif attention_type == "dma":
696
+ return DeepMoeAttention(
697
+ embed_dim,
698
+ num_heads,
699
+ gqa_groups,
700
+ dropout=dropout,
701
+ rope=rope,
702
+ use_relative_embeddings=use_relative_embeddings,
703
+ max_seq_len=max_seq_len,
704
+ rope_only_for_query=rope_only_for_query,
705
+ use_flash_attention=use_flash_attention,
706
+ is_causal=is_causal,
707
+ use_bias=use_bias,
708
+ num_experts=num_experts,
709
+ num_query_experts=num_query_experts,
710
+ num_query_groups=num_query_groups,
711
+ )
712
+ elif attention_type == "gma_v":
713
+ return GroupedMoeAttentionVectorized(
714
+ embed_dim,
715
+ num_heads,
716
+ gqa_groups,
717
+ dropout=dropout,
718
+ rope=rope,
719
+ use_relative_embeddings=use_relative_embeddings,
720
+ max_seq_len=max_seq_len,
721
+ rope_only_for_query=rope_only_for_query,
722
+ use_flash_attention=use_flash_attention,
723
+ is_causal=is_causal,
724
+ use_bias=use_bias,
725
+ num_experts=num_experts,
726
+ )
727
+ else:
728
+ return DeepMoeAttentionVectorized(
729
+ embed_dim,
730
+ num_heads,
731
+ gqa_groups,
732
+ dropout=dropout,
733
+ rope=rope,
734
+ use_relative_embeddings=use_relative_embeddings,
735
+ max_seq_len=max_seq_len,
736
+ rope_only_for_query=rope_only_for_query,
737
+ use_flash_attention=use_flash_attention,
738
+ is_causal=is_causal,
739
+ use_bias=use_bias,
740
+ num_experts=num_experts,
741
+ num_query_experts=num_query_experts,
742
+ num_query_groups=num_query_groups,
743
+ )
@@ -0,0 +1,116 @@
1
+ import torch
2
+ from torch import nn
3
+ from typing import TypedDict, Union
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from ..transformers.positional import RotaryPositionalEmbedding
6
+ from ..transformers.attention import init_attention
7
+ from ..transformers.layers import ClassicTransformerLayer
8
+ from ..transformers.models import ClassicTransformerDecoder
9
+ from ..transformers.ff import get_activation_layer
10
+ from ..utils import get_model_size
11
+ from .attention import init_moe_attention
12
+
13
+
14
+ class MoeAttentionTransformerConfig(TypedDict):
15
+ num_layers: int
16
+ vocab_size: int
17
+ embed_dim: int
18
+ ff_dim: int
19
+ att_heads: int
20
+ seq_len: int
21
+ use_flash_attention: bool
22
+ use_gated: bool
23
+ ff_activation: str
24
+ ff_dropout: float
25
+ att_dropout: float
26
+ use_rms_norm: bool
27
+ att_groups: int
28
+ use_moe_ff: bool
29
+ ff_num_experts: int
30
+ ff_moe_top_k: int
31
+ att_type: str
32
+ att_num_experts: int
33
+ att_num_query_experts: int
34
+ att_num_query_groups: int
35
+
36
+
37
+ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
38
+ """Research model for experiments with Mixture-of-Experts Attention"""
39
+
40
+ def __init__(
41
+ self,
42
+ num_layers: int = 6,
43
+ vocab_size: int = 5000,
44
+ embed_dim: int = 128,
45
+ ff_dim: int = 384,
46
+ att_heads: int = 16,
47
+ seq_len: int = 256,
48
+ use_flash_attention: bool = True,
49
+ use_gated: bool = True,
50
+ ff_activation: str = "swish",
51
+ ff_dropout: float = 0.0,
52
+ att_dropout: float = 0.0,
53
+ use_rms_norm: bool = True,
54
+ att_groups: int = 1,
55
+ use_moe_ff: bool = False,
56
+ ff_num_experts: int = 1,
57
+ ff_moe_top_k: int = 1,
58
+ att_type: str = 'gma',
59
+ att_num_experts: int = None,
60
+ att_num_query_experts: int = None,
61
+ att_num_query_groups: int = None,
62
+ **kwargs
63
+ ):
64
+ super(MoeAttentionTransformer, self).__init__(**kwargs)
65
+ assert ff_activation in ['relu', 'gelu',
66
+ 'swish', 'silu', 'linear',
67
+ 'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
68
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v',
69
+ 'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v"'
70
+
71
+ embedding = nn.Embedding(vocab_size, embed_dim)
72
+ rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
73
+
74
+ ff_activation = get_activation_layer(ff_activation)
75
+
76
+ if att_type in ['mha', 'gqa', 'mqa']:
77
+ att_init = lambda: init_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
78
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
79
+ max_seq_len=seq_len, is_causal=True)
80
+ else:
81
+ att_init = lambda: init_moe_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
82
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
83
+ max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
84
+ num_query_experts=att_num_query_experts,
85
+ num_query_groups=att_num_query_groups)
86
+
87
+ self.model = ClassicTransformerDecoder(
88
+ embed_dim,
89
+ vocab_size,
90
+ embedding=embedding,
91
+ layers=nn.ModuleList([
92
+ ClassicTransformerLayer(
93
+ embed_dim,
94
+ ff_dim,
95
+ use_gated=use_gated,
96
+ use_moe=use_moe_ff,
97
+ num_experts=ff_num_experts,
98
+ moe_top_k=ff_moe_top_k,
99
+ ff_activation=ff_activation,
100
+ ff_dropout=ff_dropout,
101
+ use_rms_norm=use_rms_norm,
102
+ self_attention=att_init(),
103
+ ) for _ in range(num_layers)
104
+ ]),
105
+ use_flash_attention=use_flash_attention,
106
+ )
107
+
108
+ def params_count(self):
109
+ return get_model_size(self.model)
110
+
111
+ def load_shared_embedding(self, embedding: nn.Embedding):
112
+ self.model.embedding = embedding
113
+
114
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
115
+ torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
116
+ return self.model(x, attention_mask=attention_mask)
@@ -0,0 +1,206 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ..transformers.moe import MoeRouter
5
+
6
+ class DynamicMoeRouter(nn.Module):
7
+ """Dynamic Mixture-of-Experts Router layer - dynamically selects top-k experts for each token."""
8
+
9
+ def __init__(self, embed_dim: int, num_experts: int, top_ks: tuple[int] = (1, 2, 3), *args, **kwargs):
10
+ super(DynamicMoeRouter, self).__init__(*args, **kwargs)
11
+ self.top_ks = top_ks
12
+ self.num_options = len(top_ks)
13
+ self.num_experts = num_experts
14
+ self.gate = nn.Linear(embed_dim, num_experts + self.num_options, bias=False)
15
+ # For expert load balancing
16
+ self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
17
+
18
+ def calculate_aux_loss(self, top_k_indices: torch.Tensor, routing_probs: torch.Tensor) -> torch.Tensor:
19
+ expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
20
+ expert_usage = expert_mask.sum(dim=0).mean(dim=0)
21
+ mean_probs = routing_probs.mean(dim=0)
22
+ return (expert_usage * mean_probs).sum() * self.num_experts
23
+
24
+ def forward(self, x: torch.Tensor):
25
+ # Input shape: [batch*seq_len, embed_dim]
26
+ all_logits = self.gate(x)
27
+ routing_logits = all_logits[:, :-self.num_options]
28
+ options_logits = all_logits[:, -self.num_options:]
29
+
30
+ routing_probs = F.softmax(routing_logits, dim=-1)
31
+ top_k_id = torch.argmax(options_logits, dim=-1).item()
32
+
33
+ top_k = self.top_ks[top_k_id]
34
+
35
+ # Get top-k experts for each token
36
+ top_k_weights, top_k_indices = routing_probs.topk(top_k, dim=-1)
37
+
38
+ # Normalize weights (sum to 1 for each token)
39
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
40
+
41
+ # Load Balance Loss
42
+ self.aux_loss = self.calculate_aux_loss(top_k_indices, routing_probs)
43
+
44
+ return top_k_weights, top_k_indices, top_k
45
+
46
+ class MoeFeedForwardVectorized(nn.Module):
47
+ """
48
+ Vectorized MoE - current implementation is incorrect - it calculates all the experts, then selects the correct ones.
49
+
50
+ Commented out implementation is fixing this problem, but is causing memory overflows, because of experts weights
51
+ indexing - it's using ~15x more memory, than dense model of similar size, so it's currently not viable.
52
+
53
+ It's recommended to use standard MoE from rxnn.transformers.moe instead.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ embed_dim: int,
59
+ hidden_dim: int,
60
+ num_experts: int,
61
+ activation: nn.Module,
62
+ top_k: int = 1,
63
+ dropout: float = 0.0,
64
+ *args,
65
+ **kwargs
66
+ ):
67
+ super(MoeFeedForwardVectorized, self).__init__(*args, **kwargs)
68
+ self.embed_dim = embed_dim
69
+ self.num_experts = num_experts
70
+ self.top_k = top_k
71
+
72
+ self.router = MoeRouter(embed_dim, num_experts, top_k)
73
+
74
+ # Batch all expert parameters together
75
+ self.w1 = nn.Parameter(torch.empty(num_experts, embed_dim, self._w1_dim_factor(hidden_dim)))
76
+ self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
77
+ self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
78
+ self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
79
+ self.activation = activation
80
+ self.dropout = nn.Dropout(dropout)
81
+
82
+ # Initialize parameters
83
+ self._init_linear_parameters()
84
+ nn.init.zeros_(self.b1)
85
+ nn.init.zeros_(self.b2)
86
+
87
+ def _init_linear_parameters(self):
88
+ nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
89
+ nn.init.kaiming_normal_(self.w2, nonlinearity='relu')
90
+
91
+ def _w1_dim_factor(self, hidden_dim: int) -> int:
92
+ return hidden_dim
93
+
94
+ def _activate(self, h: torch.Tensor):
95
+ return self.activation(h)
96
+
97
+ def router_loss(self):
98
+ return self.router.aux_loss
99
+
100
+ def forward(self, x: torch.Tensor):
101
+ orig_shape = x.shape
102
+ x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
103
+
104
+ # Get routing weights and indices
105
+ weights, indices = self.router(x) # [batch*seq_len, top_k]
106
+
107
+ # Create expert masks and combine it with masks
108
+ mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
109
+ weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
110
+
111
+ # Expert computation
112
+ x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
113
+
114
+ # First linear layer
115
+ h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
116
+ h = self._activate(h)
117
+ h = self.dropout(h)
118
+
119
+ # Second linear layer (projection back to embed_dim)
120
+ out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
121
+
122
+ # Weighted sum of expert outputs
123
+ out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
124
+
125
+ return out.view(*orig_shape)
126
+ # orig_shape = x.shape
127
+ # x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
128
+ #
129
+ # # Get routing weights and indices
130
+ # weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
131
+ #
132
+ # # Flatten indices and weights
133
+ # batch_size = x.shape[0]
134
+ # top_k = indices.shape[1]
135
+ # indices_flat = indices.view(-1) # [B*T * top_k]
136
+ #
137
+ # # Compute contributions for selected experts without materializing large tensors
138
+ # # First Layer:
139
+ # # Compute all expert contributions first (but this may still be memory-heavy)
140
+ # # Alternative: Compute contributions for selected experts directly
141
+ # # ... (see detailed steps below)
142
+ #
143
+ # # Alternative approach using gather and batched operations
144
+ # x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [B*T*top_k, D]
145
+ #
146
+ # # Compute first layer contributions using gather
147
+ # # indices_flat has shape [B*T*top_k]
148
+ # # selected_w1 is self.w1[indices_flat], but we compute the product inline
149
+ # h = torch.einsum(
150
+ # 'be, eih -> bh',
151
+ # x_expanded,
152
+ # self.w1[indices_flat]
153
+ # ) + self.b1[indices_flat]
154
+ # h = self._activate(h)
155
+ # h = self.dropout(h)
156
+ #
157
+ # # Second layer:
158
+ # out = torch.einsum(
159
+ # 'bh, eho -> beo',
160
+ # h,
161
+ # self.w2[indices_flat]
162
+ # ).squeeze(-1) + self.b2[indices_flat]
163
+ #
164
+ # # Reshape and apply weights
165
+ # out = out.view(batch_size, top_k, -1)
166
+ # weights = weights.view(batch_size, top_k, 1)
167
+ # out = (out * weights).sum(dim=1)
168
+ #
169
+ # return out.view(*orig_shape)
170
+
171
+
172
+ class GatedMoeFeedForwardVectorized(MoeFeedForwardVectorized):
173
+ """Gated Mixture-of-Experts Feed-Forward layer - enable GLU-based activations for MoE"""
174
+
175
+ def __init__(
176
+ self,
177
+ embed_dim: int,
178
+ hidden_dim: int,
179
+ num_experts: int,
180
+ activation: nn.Module = nn.SiLU(),
181
+ top_k: int = 1,
182
+ dropout: float = 0.1,
183
+ *args,
184
+ **kwargs
185
+ ):
186
+ super(GatedMoeFeedForwardVectorized, self).__init__(
187
+ embed_dim=embed_dim,
188
+ hidden_dim=hidden_dim,
189
+ num_experts=num_experts,
190
+ activation=activation,
191
+ top_k=top_k,
192
+ dropout=dropout,
193
+ *args,
194
+ **kwargs
195
+ )
196
+
197
+ def _init_linear_parameters(self):
198
+ nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
199
+ nn.init.kaiming_normal_(self.w2, nonlinearity='linear')
200
+
201
+ def _w1_dim_factor(self, hidden_dim: int) -> int:
202
+ return 2 * hidden_dim
203
+
204
+ def _activate(self, h: torch.Tensor):
205
+ a, b = h.chunk(2, dim=-1)
206
+ return a * self.activation(b)
rxnn/transformers/moe.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
-
4
+ from .ff import FeedForward, GatedFeedForward
5
5
 
6
6
  class MoeRouter(nn.Module):
7
7
  """Mixture-of-Experts Router layer - computes routing weights for each expert."""
@@ -14,18 +14,27 @@ class MoeRouter(nn.Module):
14
14
  # For expert load balancing
15
15
  self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
16
16
 
17
+ def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
18
+ expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
19
+ expert_usage = expert_mask.sum(dim=0).mean(dim=0)
20
+ mean_probs = probs.mean(dim=0)
21
+ return (expert_usage * mean_probs).sum() * self.num_experts
22
+
23
+
17
24
  def forward(self, x: torch.Tensor):
18
- # x shape: [batch_size*seq_len, embed_dim]
25
+ # Input shape: [batch*seq_len, embed_dim]
19
26
  logits = self.gate(x)
20
27
  probs = F.softmax(logits, dim=-1)
21
28
 
22
- # Expert load balancing loss
23
- mean_probs = probs.mean(dim=0) # Mean probability per expert across batch
24
- self.aux_loss = (mean_probs * torch.log(mean_probs + 1e-9)).sum() # Entropy-based loss
25
-
29
+ # Get top-k experts for each token
26
30
  top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
31
+
32
+ # Normalize weights (sum to 1 for each token)
27
33
  top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
28
34
 
35
+ # Load Balance Loss
36
+ self.aux_loss = self.calculate_aux_loss(top_k_indices, probs)
37
+
29
38
  return top_k_weights, top_k_indices
30
39
 
31
40
 
@@ -51,101 +60,43 @@ class MoeFeedForward(nn.Module):
51
60
  self.router = MoeRouter(embed_dim, num_experts, top_k)
52
61
 
53
62
  # Batch all expert parameters together
54
- self.w1 = nn.Parameter(torch.empty(num_experts, embed_dim, self._w1_dim_factor(hidden_dim)))
55
- self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
56
- self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
57
- self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
58
- self.activation = activation
59
- self.dropout = nn.Dropout(dropout)
60
-
61
- # Initialize parameters
62
- self._init_linear_parameters()
63
- nn.init.zeros_(self.b1)
64
- nn.init.zeros_(self.b2)
65
-
66
- def _init_linear_parameters(self):
67
- nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
68
- nn.init.kaiming_normal_(self.w2, nonlinearity='relu')
63
+ self._init_experts(num_experts, embed_dim, hidden_dim, activation, dropout)
69
64
 
70
- def _w1_dim_factor(self, hidden_dim: int) -> int:
71
- return hidden_dim
72
-
73
- def _activate(self, h: torch.Tensor):
74
- return self.activation(h)
65
+ def _init_experts(self, num_experts: int, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float):
66
+ self.experts = nn.ModuleList([
67
+ FeedForward(embed_dim, hidden_dim, activation, dropout)
68
+ for _ in range(num_experts)
69
+ ])
75
70
 
76
71
  def router_loss(self):
77
72
  return self.router.aux_loss
78
73
 
79
74
  def forward(self, x: torch.Tensor):
80
- # orig_shape = x.shape
81
- # x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
82
- #
83
- # # Get routing weights and indices
84
- # weights, indices = self.router(x) # [batch*seq_len, top_k]
85
- #
86
- # # Create expert masks and combine it with masks
87
- # mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
88
- # weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
89
- #
90
- # # Expert computation
91
- # x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
92
- #
93
- # # First linear layer
94
- # h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
95
- # h = self._activate(h)
96
- # h = self.dropout(h)
97
- #
98
- # # Second linear layer (projection back to embed_dim)
99
- # out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
100
- #
101
- # # Weighted sum of expert outputs
102
- # out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
103
- #
104
- # return out.view(*orig_shape)
105
75
  orig_shape = x.shape
106
76
  x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
107
77
 
108
78
  # Get routing weights and indices
109
79
  weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
110
80
 
111
- # Flatten indices and weights
112
- batch_size = x.shape[0]
113
- top_k = indices.shape[1]
114
- indices_flat = indices.view(-1) # [B*T * top_k]
115
-
116
- # Compute contributions for selected experts without materializing large tensors
117
- # First Layer:
118
- # Compute all expert contributions first (but this may still be memory-heavy)
119
- # Alternative: Compute contributions for selected experts directly
120
- # ... (see detailed steps below)
121
-
122
- # Alternative approach using gather and batched operations
123
- x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [B*T*top_k, D]
124
-
125
- # Compute first layer contributions using gather
126
- # indices_flat has shape [B*T*top_k]
127
- # selected_w1 is self.w1[indices_flat], but we compute the product inline
128
- h = torch.einsum(
129
- 'be, eih -> bh',
130
- x_expanded,
131
- self.w1[indices_flat]
132
- ) + self.b1[indices_flat]
133
- h = self._activate(h)
134
- h = self.dropout(h)
135
-
136
- # Second layer:
137
- out = torch.einsum(
138
- 'bh, eho -> beo',
139
- h,
140
- self.w2[indices_flat]
141
- ).squeeze(-1) + self.b2[indices_flat]
142
-
143
- # Reshape and apply weights
144
- out = out.view(batch_size, top_k, -1)
145
- weights = weights.view(batch_size, top_k, 1)
146
- out = (out * weights).sum(dim=1)
147
-
148
- return out.view(*orig_shape)
81
+ # Create mask for expert contributions (B*T, num_experts)
82
+ expert_mask = F.one_hot(indices, self.num_experts).float() # [B*T, top_k, num_experts]
83
+ expert_weights = (weights.unsqueeze(-1) * expert_mask).sum(dim=1) # [B*T, num_experts]
84
+
85
+ output = torch.zeros_like(x)
86
+ for expert_idx in range(self.num_experts):
87
+ # Mask for tokens where this expert is in top_k
88
+ mask = expert_weights[:, expert_idx] > 0
89
+ if not mask.any():
90
+ continue
91
+
92
+ # Compute expert output for selected tokens
93
+ expert_input = x[mask]
94
+ expert_output = self.experts[expert_idx](expert_input)
95
+
96
+ # Apply combined weights for this expert
97
+ output[mask] += expert_output * expert_weights[mask, expert_idx].unsqueeze(-1)
98
+
99
+ return output.view(*orig_shape)
149
100
 
150
101
 
151
102
  class GatedMoeFeedForward(MoeFeedForward):
@@ -173,13 +124,8 @@ class GatedMoeFeedForward(MoeFeedForward):
173
124
  **kwargs
174
125
  )
175
126
 
176
- def _init_linear_parameters(self):
177
- nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
178
- nn.init.kaiming_normal_(self.w2, nonlinearity='linear')
179
-
180
- def _w1_dim_factor(self, hidden_dim: int) -> int:
181
- return 2 * hidden_dim
182
-
183
- def _activate(self, h: torch.Tensor):
184
- a, b = h.chunk(2, dim=-1)
185
- return a * self.activation(b)
127
+ def _init_experts(self, num_experts: int, embed_dim: int, hidden_dim: int, activation: nn.Module, dropout: float):
128
+ self.experts = nn.ModuleList([
129
+ GatedFeedForward(embed_dim, hidden_dim, activation, dropout)
130
+ for _ in range(num_experts)
131
+ ])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.15
3
+ Version: 0.1.17
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -53,6 +53,29 @@ that's generating Infinite Chain-of-Thoughts and is communicating in push-based
53
53
  Reactive communication patterns in RxNN models are adapted to handle asynchronous nature of model - after it finish generating
54
54
  sequence, it has to process it and save it in memory, but it could be done in background.
55
55
 
56
+ ## Release plan
57
+ We are working on three new reactive architectures, that progressively advance from language models to awareness models:
58
+ - Reactive Transformer: Reactive Language Model (RLM) with Short-Term Memory
59
+ - Preactor: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
60
+ single message length is limited) and the ability to learn from interactions (Live Learning)
61
+ - Reactor: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
62
+
63
+ Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
64
+ released with next versions of **RxNN** framework:
65
+ - 0.1.x: Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
66
+ - 0.2.x: Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
67
+ - 0.3.x: Reinforcement Learning from Human Feedback for Reactive models (RxRLHF), basic Tensor Reactive
68
+ Extensions (TRX/Rust) for full Reactive Transformer, RxT-Alpha release (+following models - RxT-Beta, etc.)
69
+ - 0.4.x: Preactor base models, Tensor Database (TDB/Rust) for Long-Term Memory, mxRAG/revRAG subsystems
70
+ - 0.5.x: MRL for Long-Term Memory & Preactor, Live Learning for Preactor, PRx-Alpha release (+following models - PRx-Beta, etc.)
71
+ - 0.6.x: Reactor base models, TRX full implementation, Receptors & Effectors Reactive RNNs
72
+ - 0.7.x: Behavioral Reinforcement Learning (BRL) for Reactor's Infinite Chain-of-Thoughts, Continuous Live Learning for Reactor
73
+ - 0.8.x: Rx-Alpha release
74
+ - 0.9.x: Rx-Beta release
75
+ - 1.0.0: Reactor AGI official release (Expert, Assistant & Utility class models)
76
+ - 1.x.x: Multimodal reactive models (could be released earlier, depending on progress)
77
+ - 2.0.0: Real-Time Vision Reactor - Worker class models
78
+ - x.x.x: ...and more!
56
79
  Apache License
57
80
  Version 2.0, January 2004
58
81
  http://www.apache.org/licenses/
@@ -1,6 +1,8 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=HahcWU37FTfW8kwSTW8z_l7EtAVkJgvDDxLU8k3miHo,17101
3
+ rxnn/experimental/attention.py,sha256=wjHrxfov3Ybg3iou8FlQtFvxNuHdcs_A7a6FTloosgA,32056
4
+ rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
5
+ rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
4
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
6
8
  rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
@@ -19,11 +21,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
19
21
  rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
20
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
21
23
  rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
22
- rxnn/transformers/moe.py,sha256=s2yeBsAg-JIqKp7tLlXPdLNar9FXZ14LgbHyXlUKk6o,6758
24
+ rxnn/transformers/moe.py,sha256=FeaQR7hTX1dE74YdMOcuyZHSkGiV_0JwF8fw-GnfNOQ,4741
23
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
24
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
25
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
26
- rxnn-0.1.15.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
- rxnn-0.1.15.dist-info/METADATA,sha256=r3sjBGoGAsIcNqrNEC1tDuG6blEuNRVrQ_3fyy-yWJY,14629
28
- rxnn-0.1.15.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
- rxnn-0.1.15.dist-info/RECORD,,
28
+ rxnn-0.1.17.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.17.dist-info/METADATA,sha256=wId6o7JCcBjRD1plWzgJRmFAY5VlHN7-FIVySeVDqx8,16627
30
+ rxnn-0.1.17.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.17.dist-info/RECORD,,
File without changes
File without changes