rxnn 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -238,6 +238,231 @@ class DeepMoeAttention(GroupedMoeAttention):
238
238
  # Key/Value processing
239
239
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
240
240
 
241
+ # Vectorized
242
+
243
+ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
244
+ """
245
+ Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
246
+ for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
247
+ experts - it has to be tested.
248
+
249
+ Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
250
+
251
+ Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
252
+ number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
253
+ - with num_groups set to 1, it will be MoE MultiQueryAttention
254
+
255
+ Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
256
+ this approach - we are training the full number of keys/values heads, while using only a group.
257
+
258
+ In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
259
+
260
+ Optionally, it could use even more expert heads than attention heads - in example:
261
+ - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
262
+ 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
263
+
264
+ © 2025 Adam Filipek
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ embed_dim: int,
270
+ num_heads: int,
271
+ num_groups: int,
272
+ dropout: float = 0.0,
273
+ rope: RotaryPositionalEmbedding = None,
274
+ rope_only_for_query: bool = False,
275
+ use_relative_embeddings: bool = False,
276
+ max_seq_len: int = 1024,
277
+ use_flash_attention: bool = False,
278
+ is_causal: bool = False,
279
+ use_bias: bool = False,
280
+ num_experts: int = None,
281
+ *args,
282
+ **kwargs,
283
+ ):
284
+ self.num_experts = num_experts if num_experts is not None else num_heads
285
+ super(GroupedMoeAttentionVectorized, self).__init__(
286
+ embed_dim,
287
+ num_heads,
288
+ num_groups=num_groups,
289
+ dropout=dropout,
290
+ rope=rope,
291
+ rope_only_for_query=rope_only_for_query,
292
+ use_relative_embeddings=use_relative_embeddings,
293
+ max_seq_len=max_seq_len,
294
+ use_flash_attention=use_flash_attention,
295
+ is_causal=is_causal,
296
+ use_bias=use_bias,
297
+ *args,
298
+ **kwargs,
299
+ )
300
+
301
+ def _init_kv(self, embed_dim: int):
302
+ self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
303
+ hidden_dim = embed_dim // self.num_heads
304
+ self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
305
+ self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
306
+ self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
307
+ self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
308
+ self._init_experts()
309
+
310
+ def _init_experts(self):
311
+ torch.nn.init.xavier_uniform_(self.wk)
312
+ torch.nn.init.xavier_uniform_(self.wv)
313
+ if self.use_bias:
314
+ torch.nn.init.zeros_(self.bk)
315
+ torch.nn.init.zeros_(self.bv)
316
+
317
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
318
+ skip_query_processing: bool = False):
319
+ head_dim = d // self.num_heads
320
+
321
+ # Process Query as in GQA
322
+ q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2)
323
+
324
+ # Key/Value MoE routing
325
+ B, S, D = key.shape
326
+ key_flat = key.reshape(B * S, D)
327
+ weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
328
+ weights = weights.view(B, S, self.num_groups, 1)
329
+ indices = indices.view(B, S, self.num_groups)
330
+
331
+ # Compute all experts' projections
332
+ # Shape: (B*S, num_experts, head_dim)
333
+ k_all = torch.einsum('be,ehd->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
334
+ v_all = torch.einsum('be,ehd->beh', value.view(B*S, D), self.wv)
335
+
336
+ # Reshape to [B, S, num_experts, head_dim]
337
+ k_all = k_all.view(B, S, self.num_experts, -1)
338
+ v_all = v_all.view(B, S, self.num_experts, -1)
339
+
340
+ # Gather top-k experts and weights
341
+ # Expand indices to [B, S, num_groups, head_dim]
342
+ expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
343
+ selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
344
+ selected_v = torch.gather(v_all, 2, expanded_indices)
345
+
346
+ # Weighted sum
347
+ weighted_k = (selected_k * weights).sum(dim=2) # [B, S, head_dim]
348
+ weighted_v = (selected_v * weights).sum(dim=2)
349
+
350
+ # Reshape to GQA format
351
+ k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, G, S, head_dim]
352
+ v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
353
+
354
+ if not self.use_flash_attention:
355
+ group_heads = self.num_heads // self.num_groups
356
+
357
+ k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
358
+ v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
359
+
360
+ k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
361
+ v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
362
+
363
+ return q, k, v
364
+
365
+
366
+ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
367
+ """
368
+ Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
369
+ for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
370
+ experts - it has to be tested.
371
+
372
+ Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
373
+
374
+ In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
375
+ query heads - with that approach, each token could attend to every other token, but only partially - only some part of
376
+ information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
377
+ sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
378
+
379
+ This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
380
+ a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
381
+
382
+ © 2025 Adam Filipek
383
+ """
384
+
385
+ def __init__(
386
+ self,
387
+ embed_dim: int,
388
+ num_heads: int,
389
+ num_groups: int,
390
+ dropout: float = 0.0,
391
+ rope: RotaryPositionalEmbedding = None,
392
+ rope_only_for_query: bool = False,
393
+ use_relative_embeddings: bool = False,
394
+ max_seq_len: int = 1024,
395
+ use_flash_attention: bool = False,
396
+ is_causal: bool = False,
397
+ use_bias: bool = False,
398
+ num_experts: int = None,
399
+ num_query_experts: int = None,
400
+ num_query_groups: int = None,
401
+ *args,
402
+ **kwargs,
403
+ ):
404
+ self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
405
+ self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
406
+ super(DeepMoeAttentionVectorized, self).__init__(
407
+ embed_dim,
408
+ num_heads,
409
+ num_groups=num_groups,
410
+ dropout=dropout,
411
+ rope=rope,
412
+ rope_only_for_query=rope_only_for_query,
413
+ use_relative_embeddings=use_relative_embeddings,
414
+ max_seq_len=max_seq_len,
415
+ use_flash_attention=use_flash_attention,
416
+ is_causal=is_causal,
417
+ use_bias=use_bias,
418
+ num_experts=num_experts,
419
+ *args,
420
+ **kwargs,
421
+ )
422
+
423
+ def _init_q(self, embed_dim: int):
424
+ self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
425
+ hidden_dim = embed_dim // self.num_heads
426
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
427
+ self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
428
+ self._init_query_experts()
429
+
430
+ def _init_query_experts(self):
431
+ torch.nn.init.xavier_uniform_(self.wq)
432
+ if self.use_bias:
433
+ torch.nn.init.zeros_(self.bq)
434
+
435
+ def _init_out(self, embed_dim: int):
436
+ """Initialize output projection"""
437
+ out_hidden_dim = embed_dim // self.num_heads * self.num_query_groups
438
+ self.out_proj = nn.Linear(out_hidden_dim, embed_dim)
439
+
440
+ def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
441
+ """Transpose attention output back to (B, T, D) shape"""
442
+ out_hidden_dim = d // self.num_heads * self.num_query_groups
443
+ return attn_output.transpose(1, 2).contiguous().view(b, t, out_hidden_dim)
444
+
445
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
446
+ B, T, D = query.shape
447
+ query_flat = query.reshape(B * T, D)
448
+ weights_q, indices_q = self.query_router(query_flat)
449
+ weights_q = weights_q.view(B, T, self.num_query_groups, 1)
450
+ indices_q = indices_q.view(B, T, self.num_query_groups)
451
+
452
+ # Compute all query experts
453
+ q_all = torch.einsum('be,ehd->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
454
+ q_all = q_all.view(B, T, self.num_query_experts, -1)
455
+
456
+ # Gather top-k experts
457
+ expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
458
+ selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
459
+
460
+ # Weighted sum
461
+ q = (selected_q * weights_q).sum(dim=2) # [B, T, head_dim]
462
+ q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, H_q, T, head_dim]
463
+
464
+ return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
465
+
241
466
 
242
467
  # Others
243
468
 
@@ -389,7 +614,7 @@ def init_moe_attention(
389
614
  num_query_experts: int = None,
390
615
  num_query_groups: int = None,
391
616
  ) -> GroupedQueryAttention:
392
- assert attention_type in ['gma', 'dma'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
617
+ assert attention_type in ['gma', 'dma', 'gma_v', 'dma_v'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
393
618
 
394
619
  if attention_type == "gma":
395
620
  return GroupedMoeAttention(
@@ -406,7 +631,7 @@ def init_moe_attention(
406
631
  use_bias=use_bias,
407
632
  num_experts=num_experts,
408
633
  )
409
- else:
634
+ elif attention_type == "dma":
410
635
  return DeepMoeAttention(
411
636
  embed_dim,
412
637
  num_heads,
@@ -423,3 +648,35 @@ def init_moe_attention(
423
648
  num_query_experts=num_query_experts,
424
649
  num_query_groups=num_query_groups,
425
650
  )
651
+ elif attention_type == "gma_v":
652
+ return GroupedMoeAttentionVectorized(
653
+ embed_dim,
654
+ num_heads,
655
+ gqa_groups,
656
+ dropout=dropout,
657
+ rope=rope,
658
+ use_relative_embeddings=use_relative_embeddings,
659
+ max_seq_len=max_seq_len,
660
+ rope_only_for_query=rope_only_for_query,
661
+ use_flash_attention=use_flash_attention,
662
+ is_causal=is_causal,
663
+ use_bias=use_bias,
664
+ num_experts=num_experts,
665
+ )
666
+ else:
667
+ return DeepMoeAttentionVectorized(
668
+ embed_dim,
669
+ num_heads,
670
+ gqa_groups,
671
+ dropout=dropout,
672
+ rope=rope,
673
+ use_relative_embeddings=use_relative_embeddings,
674
+ max_seq_len=max_seq_len,
675
+ rope_only_for_query=rope_only_for_query,
676
+ use_flash_attention=use_flash_attention,
677
+ is_causal=is_causal,
678
+ use_bias=use_bias,
679
+ num_experts=num_experts,
680
+ num_query_experts=num_query_experts,
681
+ num_query_groups=num_query_groups,
682
+ )
@@ -65,7 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
65
65
  assert ff_activation in ['relu', 'gelu',
66
66
  'swish', 'silu', 'linear',
67
67
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
68
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma"'
68
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v', 'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v".'
69
69
 
70
70
  embedding = nn.Embedding(vocab_size, embed_dim)
71
71
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.26
3
+ Version: 0.1.28
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,7 +1,7 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=rJSQjA7_9YqcM4Y8SyJSuZQjyz8j4XPhC5jcrcrRK2M,17891
4
- rxnn/experimental/models.py,sha256=8KAo7BtRkke9qRlzGRtQa9-EZ34roGWrn0N_T6L-6Wc,4561
3
+ rxnn/experimental/attention.py,sha256=GnK_J7o_4fJ5O50ETx4oG-p7dOCsPRMwVGv3BIbUIbg,29439
4
+ rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
5
5
  rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.26.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.26.dist-info/METADATA,sha256=qKx2OFl-ca7OCZoSf99l5lOmg2czSCN8UN_Esqmfco0,16627
30
- rxnn-0.1.26.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.26.dist-info/RECORD,,
28
+ rxnn-0.1.28.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.28.dist-info/METADATA,sha256=zpVetjl-0pFz7Z4e4GUlybS-rBHKFk3AYIS6fM46diU,16627
30
+ rxnn-0.1.28.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.28.dist-info/RECORD,,
File without changes
File without changes