rxnn 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -238,6 +238,230 @@ class DeepMoeAttention(GroupedMoeAttention):
238
238
  # Key/Value processing
239
239
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
240
240
 
241
+ # Vectorized
242
+
243
+ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
244
+ """
245
+ Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
246
+ for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
247
+ experts - it has to be tested.
248
+
249
+ Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
250
+
251
+ Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
252
+ number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
253
+ - with num_groups set to 1, it will be MoE MultiQueryAttention
254
+
255
+ Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
256
+ this approach - we are training the full number of keys/values heads, while using only a group.
257
+
258
+ In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
259
+
260
+ Optionally, it could use even more expert heads than attention heads - in example:
261
+ - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
262
+ 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
263
+
264
+ © 2025 Adam Filipek
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ embed_dim: int,
270
+ num_heads: int,
271
+ num_groups: int,
272
+ dropout: float = 0.0,
273
+ rope: RotaryPositionalEmbedding = None,
274
+ rope_only_for_query: bool = False,
275
+ use_relative_embeddings: bool = False,
276
+ max_seq_len: int = 1024,
277
+ use_flash_attention: bool = False,
278
+ is_causal: bool = False,
279
+ use_bias: bool = False,
280
+ num_experts: int = None,
281
+ *args,
282
+ **kwargs,
283
+ ):
284
+ self.num_experts = num_experts if num_experts is not None else num_heads
285
+ super(GroupedMoeAttentionVectorized, self).__init__(
286
+ embed_dim,
287
+ num_heads,
288
+ num_groups=num_groups,
289
+ dropout=dropout,
290
+ rope=rope,
291
+ rope_only_for_query=rope_only_for_query,
292
+ use_relative_embeddings=use_relative_embeddings,
293
+ max_seq_len=max_seq_len,
294
+ use_flash_attention=use_flash_attention,
295
+ is_causal=is_causal,
296
+ use_bias=use_bias,
297
+ *args,
298
+ **kwargs,
299
+ )
300
+
301
+ def _init_kv(self, embed_dim: int):
302
+ self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
303
+ hidden_dim = embed_dim // self.num_heads
304
+ self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
305
+ self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
306
+ self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
307
+ self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
308
+ self._init_experts()
309
+
310
+ def _init_experts(self):
311
+ torch.nn.init.xavier_uniform_(self.wk)
312
+ torch.nn.init.xavier_uniform_(self.wv)
313
+ if self.use_bias:
314
+ torch.nn.init.zeros_(self.bk)
315
+ torch.nn.init.zeros_(self.bv)
316
+
317
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
318
+ skip_query_processing: bool = False):
319
+ head_dim = d // self.num_heads
320
+
321
+ # Process Query as in GQA
322
+ q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2)
323
+
324
+ # Key/Value MoE routing
325
+ B, S, D = key.shape
326
+ key_flat = key.reshape(B * S, D)
327
+ weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
328
+ weights = weights.view(B, S, self.num_groups, 1)
329
+ indices = indices.view(B, S, self.num_groups)
330
+
331
+ # Compute all experts' projections
332
+ # Shape: (B*S, num_experts, head_dim)
333
+ k_all = torch.einsum('be,ehd->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
334
+ v_all = torch.einsum('be,ehd->beh', value.view(B*S, D), self.wv)
335
+
336
+ # Reshape to [B, S, num_experts, head_dim]
337
+ k_all = k_all.view(B, S, self.num_experts, -1)
338
+ v_all = v_all.view(B, S, self.num_experts, -1)
339
+
340
+ # Gather top-k experts and weights
341
+ # Expand indices to [B, S, num_groups, head_dim]
342
+ expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
343
+ selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
344
+ selected_v = torch.gather(v_all, 2, expanded_indices)
345
+
346
+ # Weighted sum
347
+ weighted_k = (selected_k * weights).sum(dim=2) # [B, S, head_dim]
348
+ weighted_v = (selected_v * weights).sum(dim=2)
349
+
350
+ # Reshape to GQA format
351
+ k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, G, S, head_dim]
352
+ v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
353
+
354
+ if not self.use_flash_attention:
355
+ group_heads = self.num_heads // self.num_groups
356
+
357
+ k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
358
+ v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
359
+
360
+ k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
361
+ v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
362
+
363
+ return q, k, v
364
+
365
+
366
+ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
367
+ """
368
+ Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
369
+ for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
370
+ experts - it has to be tested.
371
+
372
+ Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
373
+
374
+ In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
375
+ query heads - with that approach, each token could attend to every other token, but only partially - only some part of
376
+ information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
377
+ sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
378
+
379
+ This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
380
+ a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
381
+
382
+ © 2025 Adam Filipek
383
+ """
384
+
385
+ def __init__(
386
+ self,
387
+ embed_dim: int,
388
+ num_heads: int,
389
+ num_groups: int,
390
+ dropout: float = 0.0,
391
+ rope: RotaryPositionalEmbedding = None,
392
+ rope_only_for_query: bool = False,
393
+ use_relative_embeddings: bool = False,
394
+ max_seq_len: int = 1024,
395
+ use_flash_attention: bool = False,
396
+ is_causal: bool = False,
397
+ use_bias: bool = False,
398
+ num_experts: int = None,
399
+ num_query_experts: int = None,
400
+ num_query_groups: int = None,
401
+ *args,
402
+ **kwargs,
403
+ ):
404
+ self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
405
+ self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
406
+ super(DeepMoeAttentionVectorized, self).__init__(
407
+ embed_dim,
408
+ num_heads,
409
+ num_groups=num_groups,
410
+ dropout=dropout,
411
+ rope=rope,
412
+ rope_only_for_query=rope_only_for_query,
413
+ use_relative_embeddings=use_relative_embeddings,
414
+ max_seq_len=max_seq_len,
415
+ use_flash_attention=use_flash_attention,
416
+ is_causal=is_causal,
417
+ use_bias=use_bias,
418
+ num_experts=num_experts,
419
+ *args,
420
+ **kwargs,
421
+ )
422
+
423
+ def _init_q(self, embed_dim: int):
424
+ self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
425
+ hidden_dim = embed_dim // self.num_heads
426
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
427
+ self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
428
+ self._init_query_experts()
429
+
430
+ def _init_query_experts(self):
431
+ torch.nn.init.xavier_uniform_(self.wq)
432
+ if self.use_bias:
433
+ torch.nn.init.zeros_(self.bq)
434
+
435
+ def _init_out(self, embed_dim: int):
436
+ """Initialize output projection"""
437
+ self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
438
+
439
+ def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
440
+ """Transpose attention output back to (B, T, D) shape"""
441
+ hidden_dim = d // self.num_heads * self.num_query_groups
442
+ return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
443
+
444
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
445
+ B, T, D = query.shape
446
+ query_flat = query.reshape(B * T, D)
447
+ weights_q, indices_q = self.query_router(query_flat)
448
+ weights_q = weights_q.view(B, T, self.num_query_groups, 1)
449
+ indices_q = indices_q.view(B, T, self.num_query_groups)
450
+
451
+ # Compute all query experts
452
+ q_all = torch.einsum('be,ehd->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
453
+ q_all = q_all.view(B, T, self.num_query_experts, -1)
454
+
455
+ # Gather top-k experts
456
+ expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
457
+ selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
458
+
459
+ # Weighted sum
460
+ q = (selected_q * weights_q).sum(dim=2) # [B, T, head_dim]
461
+ q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, H_q, T, head_dim]
462
+
463
+ return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
464
+
241
465
 
242
466
  # Others
243
467
 
@@ -389,7 +613,7 @@ def init_moe_attention(
389
613
  num_query_experts: int = None,
390
614
  num_query_groups: int = None,
391
615
  ) -> GroupedQueryAttention:
392
- assert attention_type in ['gma', 'dma'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
616
+ assert attention_type in ['gma', 'dma', 'gma_v', 'dma_v'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
393
617
 
394
618
  if attention_type == "gma":
395
619
  return GroupedMoeAttention(
@@ -406,7 +630,7 @@ def init_moe_attention(
406
630
  use_bias=use_bias,
407
631
  num_experts=num_experts,
408
632
  )
409
- else:
633
+ elif attention_type == "dma":
410
634
  return DeepMoeAttention(
411
635
  embed_dim,
412
636
  num_heads,
@@ -423,3 +647,35 @@ def init_moe_attention(
423
647
  num_query_experts=num_query_experts,
424
648
  num_query_groups=num_query_groups,
425
649
  )
650
+ elif attention_type == "gma_v":
651
+ return GroupedMoeAttentionVectorized(
652
+ embed_dim,
653
+ num_heads,
654
+ gqa_groups,
655
+ dropout=dropout,
656
+ rope=rope,
657
+ use_relative_embeddings=use_relative_embeddings,
658
+ max_seq_len=max_seq_len,
659
+ rope_only_for_query=rope_only_for_query,
660
+ use_flash_attention=use_flash_attention,
661
+ is_causal=is_causal,
662
+ use_bias=use_bias,
663
+ num_experts=num_experts,
664
+ )
665
+ else:
666
+ return DeepMoeAttentionVectorized(
667
+ embed_dim,
668
+ num_heads,
669
+ gqa_groups,
670
+ dropout=dropout,
671
+ rope=rope,
672
+ use_relative_embeddings=use_relative_embeddings,
673
+ max_seq_len=max_seq_len,
674
+ rope_only_for_query=rope_only_for_query,
675
+ use_flash_attention=use_flash_attention,
676
+ is_causal=is_causal,
677
+ use_bias=use_bias,
678
+ num_experts=num_experts,
679
+ num_query_experts=num_query_experts,
680
+ num_query_groups=num_query_groups,
681
+ )
@@ -65,7 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
65
65
  assert ff_activation in ['relu', 'gelu',
66
66
  'swish', 'silu', 'linear',
67
67
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
68
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma"'
68
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v', 'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v".'
69
69
 
70
70
  embedding = nn.Embedding(vocab_size, embed_dim)
71
71
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.26
3
+ Version: 0.1.27
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,7 +1,7 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=rJSQjA7_9YqcM4Y8SyJSuZQjyz8j4XPhC5jcrcrRK2M,17891
4
- rxnn/experimental/models.py,sha256=8KAo7BtRkke9qRlzGRtQa9-EZ34roGWrn0N_T6L-6Wc,4561
3
+ rxnn/experimental/attention.py,sha256=csasMRxL4nq2dS7pc9WdS4bvCB70ZVgsR7LTHV2jEJ0,29388
4
+ rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
5
5
  rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.26.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.26.dist-info/METADATA,sha256=qKx2OFl-ca7OCZoSf99l5lOmg2czSCN8UN_Esqmfco0,16627
30
- rxnn-0.1.26.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.26.dist-info/RECORD,,
28
+ rxnn-0.1.27.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.27.dist-info/METADATA,sha256=XjcqSdhjTRsCvvP-o981Ihp4k5PFRCQUSLVsPZ_NVPw,16627
30
+ rxnn-0.1.27.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.27.dist-info/RECORD,,
File without changes
File without changes