rxnn 0.2.54__py3-none-any.whl → 0.2.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -319,6 +319,8 @@ class SparseQueryAttention(MultiHeadAttention):
319
319
 
320
320
  # Others
321
321
  class FlexAttention(MultiHeadAttention):
322
+ """Flex attention layer, with RoPE support"""
323
+
322
324
  def __init__(
323
325
  self,
324
326
  embed_dim: int,
@@ -355,8 +357,7 @@ class FlexAttention(MultiHeadAttention):
355
357
  # Learnable global tokens
356
358
  self.global_tokens = nn.Parameter(torch.randn(1, num_global_tokens, embed_dim))
357
359
 
358
-
359
- def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
360
+ def _split_heads(self, x: torch.Tensor, is_query: bool = False) -> torch.Tensor:
360
361
  b, t, d = x.size()
361
362
  return x.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
362
363
 
@@ -364,6 +365,10 @@ class FlexAttention(MultiHeadAttention):
364
365
  b, h, t, d = x.size()
365
366
  return self._transpose_output(x, b, t, h * d)
366
367
 
368
+ def _sdpa(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
369
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0,
370
+ is_causal=self.is_causal)
371
+
367
372
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
368
373
  b, t, d = query.size()
369
374
 
@@ -372,7 +377,7 @@ class FlexAttention(MultiHeadAttention):
372
377
  x = torch.cat([global_tokens, query], dim=1)
373
378
 
374
379
  # Project Q, K, V
375
- q = self._split_heads(self.q_proj(x))
380
+ q = self._split_heads(self.q_proj(x), is_query=True)
376
381
  k = self._split_heads(self.k_proj(key))
377
382
  v = self._split_heads(self.v_proj(value))
378
383
 
@@ -385,7 +390,7 @@ class FlexAttention(MultiHeadAttention):
385
390
  local_q = q[:, :, self.num_global_tokens:] # (B, H, L, D)
386
391
 
387
392
  # Global attention
388
- global_attn = F.scaled_dot_product_attention(global_q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal)
393
+ global_attn = self._sdpa(global_q, k, v, mask=mask)
389
394
 
390
395
  # Local attention with windowed slicing (no large masks)
391
396
  L = local_q.size(2)
@@ -403,8 +408,7 @@ class FlexAttention(MultiHeadAttention):
403
408
  window_k = k[:, :, k_window_start:k_window_end]
404
409
  window_v = v[:, :, k_window_start:k_window_end]
405
410
 
406
-
407
- window_attn = F.scaled_dot_product_attention(window_q, window_k, window_v, attn_mask=None, dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal)
411
+ window_attn = self._sdpa(window_q, window_k, window_v, mask=None)
408
412
 
409
413
  windowed_attn.append(window_attn)
410
414
 
@@ -419,6 +423,74 @@ class FlexAttention(MultiHeadAttention):
419
423
  return output[:, self.num_global_tokens:, :]
420
424
 
421
425
 
426
+ class FlexSparseQueryAttention(FlexAttention):
427
+ """Combined Flex and Sparse Query attention layer, with RoPE support"""
428
+ def __init__(
429
+ self,
430
+ embed_dim: int,
431
+ num_heads: int,
432
+ num_groups: int,
433
+ num_query_groups: int,
434
+ dropout: float = 0.0,
435
+ rope: RotaryPositionalEmbedding = None,
436
+ rope_only_for_query: bool = False,
437
+ rope_only_for_keys: bool = False,
438
+ use_relative_embeddings: bool = False,
439
+ max_seq_len: int = 1024,
440
+ use_flash_attention: bool = True,
441
+ is_causal: bool = False,
442
+ use_bias: bool = False,
443
+ num_global_tokens: int = 16,
444
+ window_size: int = 128,
445
+ ):
446
+ self.num_groups = num_groups
447
+ self.num_query_groups = num_query_groups
448
+ super(FlexSparseQueryAttention, self).__init__(
449
+ embed_dim,
450
+ num_heads,
451
+ dropout=dropout,
452
+ rope=rope,
453
+ rope_only_for_query=rope_only_for_query,
454
+ rope_only_for_keys=rope_only_for_keys,
455
+ use_relative_embeddings=use_relative_embeddings,
456
+ max_seq_len=max_seq_len,
457
+ use_flash_attention=use_flash_attention,
458
+ is_causal=is_causal,
459
+ use_bias=use_bias,
460
+ num_global_tokens=num_global_tokens,
461
+ window_size=window_size,
462
+ )
463
+ assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
464
+
465
+ def _init_kv(self, embed_dim: int):
466
+ self.k_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
467
+ self.v_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
468
+
469
+ def _init_q(self, embed_dim: int):
470
+ self.q_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_query_groups), bias=self.use_bias)
471
+
472
+ def _init_out(self, embed_dim: int):
473
+ """Initialize output projection"""
474
+ self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_query_groups), embed_dim)
475
+
476
+ def _split_heads(self, x: torch.Tensor, is_query: bool = False) -> torch.Tensor:
477
+ b, t, d = x.size()
478
+ return x.view(b, t, self.num_query_groups if is_query else self.num_groups, self.head_dim).transpose(1, 2)
479
+
480
+ def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
481
+ """Transpose attention output back to (B, T, D) shape"""
482
+ return attn_output.transpose(1, 2).contiguous().view(b, t, d // (self.num_heads // self.num_query_groups))
483
+
484
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
485
+ b, h, t, d = x.size()
486
+ return self._transpose_output(x, b, t, self.embed_dim)
487
+
488
+ def _sdpa(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
489
+ is_gqa = self.num_query_groups != self.num_groups
490
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0,
491
+ is_causal=self.is_causal, enable_gqa=is_gqa)
492
+
493
+
422
494
  class InfiniteAttention(MultiHeadAttention):
423
495
  def __init__(
424
496
  self,
@@ -499,7 +571,7 @@ def init_experimental_attention(
499
571
  num_global_tokens: int = 16,
500
572
  window_size: int = 128,
501
573
  ) -> MultiHeadAttention:
502
- assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
574
+ assert attention_type in ['gma', 'dma', 'sqa', 'flex', 'flex-sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex', 'flex-sqa"
503
575
 
504
576
  if attention_type == "gma":
505
577
  return GroupedMoeAttention(
@@ -550,6 +622,23 @@ def init_experimental_attention(
550
622
  num_global_tokens=num_global_tokens,
551
623
  window_size=window_size,
552
624
  )
625
+ elif attention_type == "flex-sqa":
626
+ return FlexSparseQueryAttention(
627
+ embed_dim,
628
+ num_heads,
629
+ gqa_groups,
630
+ num_query_groups,
631
+ dropout=dropout,
632
+ rope=rope,
633
+ max_seq_len=max_seq_len,
634
+ rope_only_for_query=rope_only_for_query,
635
+ rope_only_for_keys=rope_only_for_keys,
636
+ use_flash_attention=use_flash_attention,
637
+ is_causal=is_causal,
638
+ use_bias=use_bias,
639
+ num_global_tokens=num_global_tokens,
640
+ window_size=window_size,
641
+ )
553
642
  else:
554
643
  return SparseQueryAttention(
555
644
  embed_dim,
@@ -73,7 +73,7 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
73
73
  assert ff_activation in ['relu', 'gelu',
74
74
  'swish', 'silu', 'linear',
75
75
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
76
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
76
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex', 'flex-sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex", "flex-sqa".'
77
77
 
78
78
  embedding = nn.Embedding(vocab_size, embed_dim)
79
79
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.54
3
+ Version: 0.2.56
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,8 +1,8 @@
1
1
  rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
2
2
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- rxnn/experimental/attention.py,sha256=zQM6Og62IZVGogsDBReYrHSiRZmDaebl1FcH2e6sHyY,24589
5
- rxnn/experimental/models.py,sha256=HPOIRpnX_oiI10wsVC4J6rzo3T6dj10aNWGYpa9S1UU,5115
4
+ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhXg,28455
5
+ rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.54.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.54.dist-info/METADATA,sha256=t6l1VezLNpdpgaXaqB-YhrfAhEUlWZm9-wwzBZ_Xk34,25997
38
- rxnn-0.2.54.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.54.dist-info/RECORD,,
36
+ rxnn-0.2.56.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.56.dist-info/METADATA,sha256=qW9X-oP3LWHB0E6S0opHPzWjDmNBRy9DjOz5od4qutc,25997
38
+ rxnn-0.2.56.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.56.dist-info/RECORD,,
File without changes
File without changes