rxnn 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,10 @@ from ..transformers.moe import MoeRouter
9
9
 
10
10
  class GroupedMoeAttention(GroupedQueryAttention):
11
11
  """
12
+ Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
13
+ for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
14
+ experts - it has to be tested.
15
+
12
16
  Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
13
17
 
14
18
  Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
@@ -44,7 +48,7 @@ class GroupedMoeAttention(GroupedQueryAttention):
44
48
  *args,
45
49
  **kwargs,
46
50
  ):
47
- self.num_experts = num_experts or num_heads
51
+ self.num_experts = num_experts if num_experts is not None else num_heads
48
52
  super(GroupedMoeAttention, self).__init__(
49
53
  embed_dim,
50
54
  num_heads,
@@ -61,78 +65,63 @@ class GroupedMoeAttention(GroupedQueryAttention):
61
65
  **kwargs,
62
66
  )
63
67
 
64
- def router_loss(self):
65
- return self.router.aux_loss
66
-
67
68
  def _init_kv(self, embed_dim: int):
68
69
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
69
-
70
70
  hidden_dim = embed_dim // self.num_heads
71
- self.wk = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
71
+ self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
72
72
  self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
73
- self.wv = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
73
+ self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
74
74
  self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
75
75
  self._init_experts()
76
76
 
77
77
  def _init_experts(self):
78
- nn.init.xavier_uniform_(self.wk)
79
- nn.init.xavier_uniform_(self.wv)
78
+ torch.nn.init.xavier_uniform_(self.wk)
79
+ torch.nn.init.xavier_uniform_(self.wv)
80
80
  if self.use_bias:
81
- nn.init.zeros_(self.bk)
82
- nn.init.zeros_(self.bv)
83
-
84
- def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
85
- B, S, G = indices.shape
86
- x_flat = x.view(-1, x.size(-1)) # [B*S, D]
87
-
88
- indices_flat = indices.view(-1, G) # [B*S, G]
89
- weights_flat = weights.view(-1, G) # [B*S, G]
81
+ torch.nn.init.zeros_(self.bk)
82
+ torch.nn.init.zeros_(self.bv)
90
83
 
91
- output = torch.zeros(B * S, G, w.size(1), device=x.device, dtype=x.dtype) # [B*S, G, hidden_dim]
84
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
85
+ skip_query_processing: bool = False):
86
+ head_dim = d // self.num_heads
92
87
 
93
- for e in range(self.num_experts):
94
- # 1. Find tokens where expert `e` is used in ANY group
95
- expert_mask = (indices_flat == e).any(dim=1) # [B*S]
96
- if not expert_mask.any():
97
- continue
88
+ # Process Query as in GQA
89
+ q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
98
90
 
99
- # 2. Project tokens using expert `e`
100
- x_slice = x_flat[expert_mask] # [num_selected, D]
101
- proj = F.linear(x_slice, w[e], b[e] if b is not None else None) # [num_selected, hidden_dim]
91
+ # Key/Value MoE routing
92
+ B, S, D = key.shape
93
+ key_flat = key.reshape(-1, D)
94
+ weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
95
+ weights = weights.view(B, S, self.num_groups, 1)
96
+ indices = indices.view(B, S, self.num_groups)
102
97
 
103
- # 3. Scatter projections into correct groups
104
- for g in range(G):
105
- group_mask = indices_flat[expert_mask, g] == e # [num_selected]
106
- if not group_mask.any():
107
- continue
98
+ # Compute all experts' projections
99
+ # Shape: (B*S, num_experts, head_dim)
100
+ k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
101
+ v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
108
102
 
109
- # Get tokens in this group using expert `e`
110
- group_tokens = expert_mask.nonzero()[group_mask].squeeze(1)
111
- # Weight and scatter
112
- weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
113
- output[group_tokens, g] += weighted_proj
103
+ if self.use_bias:
104
+ k_all += self.bk
105
+ v_all += self.bv
114
106
 
115
- return output.view(B, S, G, -1)
107
+ # Get results for all heads
108
+ k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
109
+ v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
116
110
 
117
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
118
- skip_query_processing: bool = False):
119
- q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
111
+ # Gather top-k experts using expanded indices
112
+ expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
113
+ selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
114
+ selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
120
115
 
121
- # Key/Value processing
122
- B, S, D = key.shape
123
- key_flat = key.view(-1, D)
124
- weights_k_flat, indices_k_flat = self.router(key_flat)
125
- # Reshape back to original dimensions
126
- weights_k = weights_k_flat.view(B, S, -1)
127
- indices_k = indices_k_flat.view(B, S, -1)
128
- k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
129
- v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
116
+ # Weighted
117
+ weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
118
+ weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
130
119
 
131
- # Expand to GQA format
132
- k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
133
- v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
120
+ # Reshape to GQA format
121
+ k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
122
+ v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
134
123
 
135
- if not self.use_flash_attention:
124
+ if not self.rel_embed:
136
125
  group_heads = self.num_heads // self.num_groups
137
126
 
138
127
  k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
@@ -197,55 +186,54 @@ class DeepMoeAttention(GroupedMoeAttention):
197
186
  **kwargs,
198
187
  )
199
188
 
200
- def router_loss(self):
201
- return (self.router.aux_loss + self.query_router.aux_loss) / 2
202
-
203
189
  def _init_q(self, embed_dim: int):
204
190
  self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
205
-
206
191
  hidden_dim = embed_dim // self.num_heads
207
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, hidden_dim, embed_dim))
192
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
208
193
  self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
209
194
  self._init_query_experts()
210
195
 
211
196
  def _init_query_experts(self):
212
- nn.init.xavier_uniform_(self.wq)
197
+ torch.nn.init.xavier_uniform_(self.wq)
213
198
  if self.use_bias:
214
- nn.init.zeros_(self.bq)
199
+ torch.nn.init.zeros_(self.bq)
215
200
 
216
201
  def _init_out(self, embed_dim: int):
217
202
  """Initialize output projection"""
218
- hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
219
- self.out_proj = nn.Linear(hidden_dim, embed_dim)
203
+ out_hidden_dim = embed_dim // self.num_heads * self.num_query_groups
204
+ self.out_proj = nn.Linear(out_hidden_dim, embed_dim)
220
205
 
221
206
  def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
222
207
  """Transpose attention output back to (B, T, D) shape"""
223
- hidden_dim = d // self.num_heads * self.num_query_groups
224
- return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
208
+ out_hidden_dim = d // self.num_heads * self.num_query_groups
209
+ return attn_output.transpose(1, 2).contiguous().view(b, t, out_hidden_dim)
225
210
 
226
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
227
- # Query processing
211
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
228
212
  B, T, D = query.shape
229
- # Flatten for query routing
230
- query_flat = query.view(-1, D)
231
- weights_q_flat, indices_q_flat = self.query_router(query_flat)
232
- # Reshape back
233
- weights_q = weights_q_flat.view(B, T, -1)
234
- indices_q = indices_q_flat.view(B, T, -1)
235
- q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
213
+ query_flat = query.reshape(-1, D)
214
+ weights_q, indices_q = self.query_router(query_flat)
215
+ weights_q = weights_q.view(B, T, self.num_query_groups, 1)
216
+ indices_q = indices_q.view(B, T, self.num_query_groups)
236
217
 
237
- q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
238
- # Key/Value processing
239
- return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
218
+ # Compute all query experts
219
+ q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
220
+ if self.use_bias:
221
+ q_all += self.bq
240
222
 
241
- # Vectorized
223
+ q_all = q_all.view(B, T, self.num_query_experts, -1)
242
224
 
243
- class GroupedMoeAttentionVectorized(GroupedQueryAttention):
244
- """
245
- Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
246
- for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
247
- experts - it has to be tested.
225
+ # Gather top-k experts
226
+ expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
227
+ selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
228
+
229
+ # Weighted sum
230
+ q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
231
+ q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
232
+
233
+ return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
248
234
 
235
+ class GroupedMoeAttentionSimplified(GroupedQueryAttention):
236
+ """
249
237
  Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
250
238
 
251
239
  Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
@@ -281,8 +269,8 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
281
269
  *args,
282
270
  **kwargs,
283
271
  ):
284
- self.num_experts = num_experts if num_experts is not None else num_heads
285
- super(GroupedMoeAttentionVectorized, self).__init__(
272
+ self.num_experts = num_experts or num_heads
273
+ super(GroupedMoeAttentionSimplified, self).__init__(
286
274
  embed_dim,
287
275
  num_heads,
288
276
  num_groups=num_groups,
@@ -298,63 +286,78 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
298
286
  **kwargs,
299
287
  )
300
288
 
289
+ def router_loss(self):
290
+ return self.router.aux_loss
291
+
301
292
  def _init_kv(self, embed_dim: int):
302
293
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
294
+
303
295
  hidden_dim = embed_dim // self.num_heads
304
- self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
296
+ self.wk = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
305
297
  self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
306
- self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
298
+ self.wv = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
307
299
  self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
308
300
  self._init_experts()
309
301
 
310
302
  def _init_experts(self):
311
- torch.nn.init.xavier_uniform_(self.wk)
312
- torch.nn.init.xavier_uniform_(self.wv)
303
+ nn.init.xavier_uniform_(self.wk)
304
+ nn.init.xavier_uniform_(self.wv)
313
305
  if self.use_bias:
314
- torch.nn.init.zeros_(self.bk)
315
- torch.nn.init.zeros_(self.bv)
306
+ nn.init.zeros_(self.bk)
307
+ nn.init.zeros_(self.bv)
316
308
 
317
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
318
- skip_query_processing: bool = False):
319
- head_dim = d // self.num_heads
309
+ def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
310
+ B, S, G = indices.shape
311
+ x_flat = x.view(-1, x.size(-1)) # [B*S, D]
320
312
 
321
- # Process Query as in GQA
322
- q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
313
+ indices_flat = indices.view(-1, G) # [B*S, G]
314
+ weights_flat = weights.view(-1, G) # [B*S, G]
323
315
 
324
- # Key/Value MoE routing
325
- B, S, D = key.shape
326
- key_flat = key.reshape(-1, D)
327
- weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
328
- weights = weights.view(B, S, self.num_groups, 1)
329
- indices = indices.view(B, S, self.num_groups)
316
+ output = torch.zeros(B * S, G, w.size(1), device=x.device, dtype=x.dtype) # [B*S, G, hidden_dim]
330
317
 
331
- # Compute all experts' projections
332
- # Shape: (B*S, num_experts, head_dim)
333
- k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
334
- v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
318
+ for e in range(self.num_experts):
319
+ # 1. Find tokens where expert `e` is used in ANY group
320
+ expert_mask = (indices_flat == e).any(dim=1) # [B*S]
321
+ if not expert_mask.any():
322
+ continue
335
323
 
336
- if self.use_bias:
337
- k_all += self.bk
338
- v_all += self.bv
324
+ # 2. Project tokens using expert `e`
325
+ x_slice = x_flat[expert_mask] # [num_selected, D]
326
+ proj = F.linear(x_slice, w[e], b[e] if b is not None else None) # [num_selected, hidden_dim]
339
327
 
340
- # Get results for all heads
341
- k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
342
- v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
328
+ # 3. Scatter projections into correct groups
329
+ for g in range(G):
330
+ group_mask = indices_flat[expert_mask, g] == e # [num_selected]
331
+ if not group_mask.any():
332
+ continue
343
333
 
344
- # Gather top-k experts using expanded indices
345
- expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
346
- selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
347
- selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
334
+ # Get tokens in this group using expert `e`
335
+ group_tokens = expert_mask.nonzero()[group_mask].squeeze(1)
336
+ # Weight and scatter
337
+ weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
338
+ output[group_tokens, g] += weighted_proj
348
339
 
349
- # Weighted
350
- weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
351
- weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
340
+ return output.view(B, S, G, -1)
352
341
 
353
- # Reshape to GQA format
354
- k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
355
- v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
342
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
343
+ skip_query_processing: bool = False):
344
+ q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
345
+
346
+ # Key/Value processing
347
+ B, S, D = key.shape
348
+ key_flat = key.view(-1, D)
349
+ weights_k_flat, indices_k_flat = self.router(key_flat)
350
+ # Reshape back to original dimensions
351
+ weights_k = weights_k_flat.view(B, S, -1)
352
+ indices_k = indices_k_flat.view(B, S, -1)
353
+ k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
354
+ v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
355
+
356
+ # Expand to GQA format
357
+ k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
358
+ v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
356
359
 
357
- if not self.use_flash_attention:
360
+ if not self.rel_embed:
358
361
  group_heads = self.num_heads // self.num_groups
359
362
 
360
363
  k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
@@ -366,12 +369,8 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
366
369
  return q, k, v
367
370
 
368
371
 
369
- class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
372
+ class DeepMoeAttentionSimplified(GroupedMoeAttentionSimplified):
370
373
  """
371
- Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
372
- for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
373
- experts - it has to be tested.
374
-
375
374
  Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
376
375
 
377
376
  In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
@@ -406,7 +405,7 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
406
405
  ):
407
406
  self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
408
407
  self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
409
- super(DeepMoeAttentionVectorized, self).__init__(
408
+ super(DeepMoeAttentionSimplified, self).__init__(
410
409
  embed_dim,
411
410
  num_heads,
412
411
  num_groups=num_groups,
@@ -423,52 +422,48 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
423
422
  **kwargs,
424
423
  )
425
424
 
425
+ def router_loss(self):
426
+ return (self.router.aux_loss + self.query_router.aux_loss) / 2
427
+
426
428
  def _init_q(self, embed_dim: int):
427
429
  self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
430
+
428
431
  hidden_dim = embed_dim // self.num_heads
429
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
432
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, hidden_dim, embed_dim))
430
433
  self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
431
434
  self._init_query_experts()
432
435
 
433
436
  def _init_query_experts(self):
434
- torch.nn.init.xavier_uniform_(self.wq)
437
+ nn.init.xavier_uniform_(self.wq)
435
438
  if self.use_bias:
436
- torch.nn.init.zeros_(self.bq)
439
+ nn.init.zeros_(self.bq)
437
440
 
438
441
  def _init_out(self, embed_dim: int):
439
442
  """Initialize output projection"""
440
- out_hidden_dim = embed_dim // self.num_heads * self.num_query_groups
441
- self.out_proj = nn.Linear(out_hidden_dim, embed_dim)
443
+ hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
444
+ self.out_proj = nn.Linear(hidden_dim, embed_dim)
442
445
 
443
446
  def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
444
447
  """Transpose attention output back to (B, T, D) shape"""
445
- out_hidden_dim = d // self.num_heads * self.num_query_groups
446
- return attn_output.transpose(1, 2).contiguous().view(b, t, out_hidden_dim)
448
+ hidden_dim = d // self.num_heads * self.num_query_groups
449
+ return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
447
450
 
448
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
451
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
452
+ # Query processing
449
453
  B, T, D = query.shape
450
- query_flat = query.reshape(-1, D)
451
- weights_q, indices_q = self.query_router(query_flat)
452
- weights_q = weights_q.view(B, T, self.num_query_groups, 1)
453
- indices_q = indices_q.view(B, T, self.num_query_groups)
454
-
455
- # Compute all query experts
456
- q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
457
- if self.use_bias:
458
- q_all += self.bq
459
-
460
- q_all = q_all.view(B, T, self.num_query_experts, -1)
461
-
462
- # Gather top-k experts
463
- expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
464
- selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
465
-
466
- # Weighted sum
467
- q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
468
- q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
454
+ # Flatten for query routing
455
+ query_flat = query.view(-1, D)
456
+ weights_q_flat, indices_q_flat = self.query_router(query_flat)
457
+ # Reshape back
458
+ weights_q = weights_q_flat.view(B, T, -1)
459
+ indices_q = indices_q_flat.view(B, T, -1)
460
+ q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
469
461
 
462
+ q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
463
+ # Key/Value processing
470
464
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
471
465
 
466
+
472
467
  # Others
473
468
 
474
469
  class FlexAttention(MultiHeadAttention):
rxnn/experimental/moe.py CHANGED
@@ -3,46 +3,6 @@ import torch.nn as nn
3
3
  import torch.nn.functional as F
4
4
  from ..transformers.moe import MoeRouter
5
5
 
6
- class DynamicMoeRouter(nn.Module):
7
- """Dynamic Mixture-of-Experts Router layer - dynamically selects top-k experts for each token."""
8
-
9
- def __init__(self, embed_dim: int, num_experts: int, top_ks: tuple[int] = (1, 2, 3), *args, **kwargs):
10
- super(DynamicMoeRouter, self).__init__(*args, **kwargs)
11
- self.top_ks = top_ks
12
- self.num_options = len(top_ks)
13
- self.num_experts = num_experts
14
- self.gate = nn.Linear(embed_dim, num_experts + self.num_options, bias=False)
15
- # For expert load balancing
16
- self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
17
-
18
- def calculate_aux_loss(self, top_k_indices: torch.Tensor, routing_probs: torch.Tensor) -> torch.Tensor:
19
- expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
20
- expert_usage = expert_mask.sum(dim=0).mean(dim=0)
21
- mean_probs = routing_probs.mean(dim=0)
22
- return (expert_usage * mean_probs).sum() * self.num_experts
23
-
24
- def forward(self, x: torch.Tensor):
25
- # Input shape: [batch*seq_len, embed_dim]
26
- all_logits = self.gate(x)
27
- routing_logits = all_logits[:, :-self.num_options]
28
- options_logits = all_logits[:, -self.num_options:]
29
-
30
- routing_probs = F.softmax(routing_logits, dim=-1)
31
- top_k_id = torch.argmax(options_logits, dim=-1).item()
32
-
33
- top_k = self.top_ks[top_k_id]
34
-
35
- # Get top-k experts for each token
36
- top_k_weights, top_k_indices = routing_probs.topk(top_k, dim=-1)
37
-
38
- # Normalize weights (sum to 1 for each token)
39
- top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
40
-
41
- # Load Balance Loss
42
- self.aux_loss = self.calculate_aux_loss(top_k_indices, routing_probs)
43
-
44
- return top_k_weights, top_k_indices, top_k
45
-
46
6
  class MoeFeedForwardVectorized(nn.Module):
47
7
  """
48
8
  Vectorized MoE - current implementation is incorrect - it calculates all the experts, then selects the correct ones.
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
- from torch.backends.cuda import sdp_kernel, SDPBackend
4
+ from torch.nn.attention import sdpa_kernel, SDPBackend
5
5
  import math
6
6
  from .positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
7
7
 
@@ -102,36 +102,41 @@ class MultiHeadAttention(nn.Module):
102
102
 
103
103
  def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
104
104
  mask: torch.Tensor = None, enable_gqa: bool = False):
105
- with sdp_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
106
- attn_output = F.scaled_dot_product_attention(
107
- q, k, v,
108
- attn_mask=mask if not self.is_causal else None,
109
- dropout_p=self.dropout.p if self.training else 0.0,
110
- is_causal=self.is_causal,
111
- enable_gqa=enable_gqa,
112
- )
105
+ with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
106
+ return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
107
+
108
+ def _torch_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
109
+ mask: torch.Tensor = None, enable_gqa: bool = False):
110
+ attn_output = F.scaled_dot_product_attention(
111
+ q, k, v,
112
+ attn_mask=mask if not self.is_causal else None,
113
+ dropout_p=self.dropout.p if self.training else 0.0,
114
+ is_causal=self.is_causal,
115
+ enable_gqa=enable_gqa,
116
+ )
113
117
  return self._transpose_output(attn_output, b, t, d)
114
118
 
115
- def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
116
- mask: torch.Tensor = None):
117
- # Compute attention with FlashAttention
118
- return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask)
119
+ def _calculate_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
120
+ if self.use_flash_attention:
121
+ # Compute attention with FlashAttention
122
+ return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask)
123
+ else:
124
+ # Compute attention using optimized PyTorch implementation
125
+ return self._torch_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask)
126
+
127
+ def _calculate_attention_with_relative_embedding(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
128
+ attn_weights = self._calculate_attn_weight_with_relative_embeddings(q, k, mask=mask)
129
+ attn_weights = self.dropout(attn_weights)
130
+ return self._calculate_output(attn_weights, v, b, t, d)
119
131
 
120
132
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
121
133
  b, t, d = query.size()
122
134
  q, k, v = self._forward_qkv(query, key, value, b, t, d)
123
- if self.use_flash_attention:
135
+ if not self.rel_embed:
124
136
  q, k = self._apply_rope(q, k)
125
- attn_output = self._calculate_flash_attention(q, k, v, b, t, d, mask=mask)
137
+ attn_output = self._calculate_attention(q, k, v, b, t, d, mask=mask)
126
138
  else:
127
- if not self.rel_embed:
128
- attn_weights = self._calculate_attn_weights(q, k, d, mask=mask)
129
- else:
130
- attn_weights = self._calculate_attn_weight_with_relative_embeddings(q, k, mask=mask)
131
-
132
- attn_weights = self.dropout(attn_weights)
133
-
134
- attn_output = self._calculate_output(attn_weights, v, b, t, d)
139
+ attn_output = self._calculate_attention_with_relative_embedding(q, k, v, b, t, d, mask=mask)
135
140
  return self.out_proj(attn_output)
136
141
 
137
142
 
@@ -178,7 +183,7 @@ class GroupedQueryAttention(MultiHeadAttention):
178
183
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
179
184
  """Override query, key, and value projections for GQA case - split data into heads and groups"""
180
185
  head_dim = d // self.num_heads
181
- if self.use_flash_attention:
186
+ if not self.rel_embed:
182
187
  q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
183
188
  k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
184
189
  v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
@@ -202,12 +207,14 @@ class GroupedQueryAttention(MultiHeadAttention):
202
207
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
203
208
  return q, k, v
204
209
 
205
- def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
206
- mask: torch.Tensor = None):
207
- return self._flash_attention(
208
- q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask,
209
- enable_gqa=(self.num_heads != self.num_groups)
210
- )
210
+ def _calculate_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
211
+ is_gqa = self.num_heads != self.num_groups
212
+ if self.use_flash_attention:
213
+ # Compute attention with FlashAttention
214
+ return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=is_gqa)
215
+ else:
216
+ # Compute attention using optimized PyTorch implementation
217
+ return self._torch_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=is_gqa)
211
218
 
212
219
 
213
220
  class MultiQueryAttention(MultiHeadAttention):
@@ -251,7 +258,7 @@ class MultiQueryAttention(MultiHeadAttention):
251
258
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
252
259
  """Override query, key, and value projections for GQA case - use multiple heads
253
260
  for query and single for key/values"""
254
- if self.use_flash_attention:
261
+ if not self.rel_embed:
255
262
  q = self.q_proj(query).view(b, t, self.num_heads, d // self.num_heads).transpose(1, 2)
256
263
  k = self.k_proj(key).view(b, -1, 1, d // self.num_heads).transpose(1, 2)
257
264
  v = self.v_proj(value).view(b, -1, 1, d // self.num_heads).transpose(1, 2)
@@ -261,12 +268,13 @@ class MultiQueryAttention(MultiHeadAttention):
261
268
  v = self.v_proj(value).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
262
269
  return q, k, v
263
270
 
264
- def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
265
- mask: torch.Tensor = None):
266
- return self._flash_attention(
267
- q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask,
268
- enable_gqa=True
269
- )
271
+ def _calculate_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
272
+ if self.use_flash_attention:
273
+ # Compute attention with FlashAttention
274
+ return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=True)
275
+ else:
276
+ # Compute attention using optimized PyTorch implementation
277
+ return self._torch_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=True)
270
278
 
271
279
 
272
280
  def init_attention(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.33
3
+ Version: 0.1.35
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,8 +1,8 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=h9pEv_70NKpD5KOQOFP3h-IJKzh7Wbnaxka4Bd3rdd8,29745
3
+ rxnn/experimental/attention.py,sha256=GxbLmOTBvUiYU0Rc_0ju1n_ocJciHC6i3neDGe-rZZc,29426
4
4
  rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
5
- rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
5
+ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
8
8
  rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
16
16
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
17
17
  rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- rxnn/transformers/attention.py,sha256=Nox986BH9qq4rDYLiYmfj1DeMeULF3akexIl99MPccM,14331
19
+ rxnn/transformers/attention.py,sha256=zv0uH3_L39tVmpiwNdmEf6Cp602uqdbr3UQj8Z3hIIk,15349
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
21
  rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.33.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.33.dist-info/METADATA,sha256=m3DWDnTu7Lx1kHYPIAQCdKU8t4QZBdqG0QcSIFvB924,16627
30
- rxnn-0.1.33.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.33.dist-info/RECORD,,
28
+ rxnn-0.1.35.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.35.dist-info/METADATA,sha256=aziCzqOeetdE3gMV2i15QoB5O31bGpiZgzcpGM97QPk,16627
30
+ rxnn-0.1.35.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.35.dist-info/RECORD,,
File without changes
File without changes