rxnn 0.2.54__py3-none-any.whl → 0.2.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +96 -7
- rxnn/experimental/models.py +1 -1
- {rxnn-0.2.54.dist-info → rxnn-0.2.56.dist-info}/METADATA +1 -1
- {rxnn-0.2.54.dist-info → rxnn-0.2.56.dist-info}/RECORD +6 -6
- {rxnn-0.2.54.dist-info → rxnn-0.2.56.dist-info}/LICENSE +0 -0
- {rxnn-0.2.54.dist-info → rxnn-0.2.56.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -319,6 +319,8 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
319
319
|
|
320
320
|
# Others
|
321
321
|
class FlexAttention(MultiHeadAttention):
|
322
|
+
"""Flex attention layer, with RoPE support"""
|
323
|
+
|
322
324
|
def __init__(
|
323
325
|
self,
|
324
326
|
embed_dim: int,
|
@@ -355,8 +357,7 @@ class FlexAttention(MultiHeadAttention):
|
|
355
357
|
# Learnable global tokens
|
356
358
|
self.global_tokens = nn.Parameter(torch.randn(1, num_global_tokens, embed_dim))
|
357
359
|
|
358
|
-
|
359
|
-
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
360
|
+
def _split_heads(self, x: torch.Tensor, is_query: bool = False) -> torch.Tensor:
|
360
361
|
b, t, d = x.size()
|
361
362
|
return x.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
362
363
|
|
@@ -364,6 +365,10 @@ class FlexAttention(MultiHeadAttention):
|
|
364
365
|
b, h, t, d = x.size()
|
365
366
|
return self._transpose_output(x, b, t, h * d)
|
366
367
|
|
368
|
+
def _sdpa(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
369
|
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0,
|
370
|
+
is_causal=self.is_causal)
|
371
|
+
|
367
372
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
368
373
|
b, t, d = query.size()
|
369
374
|
|
@@ -372,7 +377,7 @@ class FlexAttention(MultiHeadAttention):
|
|
372
377
|
x = torch.cat([global_tokens, query], dim=1)
|
373
378
|
|
374
379
|
# Project Q, K, V
|
375
|
-
q = self._split_heads(self.q_proj(x))
|
380
|
+
q = self._split_heads(self.q_proj(x), is_query=True)
|
376
381
|
k = self._split_heads(self.k_proj(key))
|
377
382
|
v = self._split_heads(self.v_proj(value))
|
378
383
|
|
@@ -385,7 +390,7 @@ class FlexAttention(MultiHeadAttention):
|
|
385
390
|
local_q = q[:, :, self.num_global_tokens:] # (B, H, L, D)
|
386
391
|
|
387
392
|
# Global attention
|
388
|
-
global_attn =
|
393
|
+
global_attn = self._sdpa(global_q, k, v, mask=mask)
|
389
394
|
|
390
395
|
# Local attention with windowed slicing (no large masks)
|
391
396
|
L = local_q.size(2)
|
@@ -403,8 +408,7 @@ class FlexAttention(MultiHeadAttention):
|
|
403
408
|
window_k = k[:, :, k_window_start:k_window_end]
|
404
409
|
window_v = v[:, :, k_window_start:k_window_end]
|
405
410
|
|
406
|
-
|
407
|
-
window_attn = F.scaled_dot_product_attention(window_q, window_k, window_v, attn_mask=None, dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal)
|
411
|
+
window_attn = self._sdpa(window_q, window_k, window_v, mask=None)
|
408
412
|
|
409
413
|
windowed_attn.append(window_attn)
|
410
414
|
|
@@ -419,6 +423,74 @@ class FlexAttention(MultiHeadAttention):
|
|
419
423
|
return output[:, self.num_global_tokens:, :]
|
420
424
|
|
421
425
|
|
426
|
+
class FlexSparseQueryAttention(FlexAttention):
|
427
|
+
"""Combined Flex and Sparse Query attention layer, with RoPE support"""
|
428
|
+
def __init__(
|
429
|
+
self,
|
430
|
+
embed_dim: int,
|
431
|
+
num_heads: int,
|
432
|
+
num_groups: int,
|
433
|
+
num_query_groups: int,
|
434
|
+
dropout: float = 0.0,
|
435
|
+
rope: RotaryPositionalEmbedding = None,
|
436
|
+
rope_only_for_query: bool = False,
|
437
|
+
rope_only_for_keys: bool = False,
|
438
|
+
use_relative_embeddings: bool = False,
|
439
|
+
max_seq_len: int = 1024,
|
440
|
+
use_flash_attention: bool = True,
|
441
|
+
is_causal: bool = False,
|
442
|
+
use_bias: bool = False,
|
443
|
+
num_global_tokens: int = 16,
|
444
|
+
window_size: int = 128,
|
445
|
+
):
|
446
|
+
self.num_groups = num_groups
|
447
|
+
self.num_query_groups = num_query_groups
|
448
|
+
super(FlexSparseQueryAttention, self).__init__(
|
449
|
+
embed_dim,
|
450
|
+
num_heads,
|
451
|
+
dropout=dropout,
|
452
|
+
rope=rope,
|
453
|
+
rope_only_for_query=rope_only_for_query,
|
454
|
+
rope_only_for_keys=rope_only_for_keys,
|
455
|
+
use_relative_embeddings=use_relative_embeddings,
|
456
|
+
max_seq_len=max_seq_len,
|
457
|
+
use_flash_attention=use_flash_attention,
|
458
|
+
is_causal=is_causal,
|
459
|
+
use_bias=use_bias,
|
460
|
+
num_global_tokens=num_global_tokens,
|
461
|
+
window_size=window_size,
|
462
|
+
)
|
463
|
+
assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
|
464
|
+
|
465
|
+
def _init_kv(self, embed_dim: int):
|
466
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
|
467
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
|
468
|
+
|
469
|
+
def _init_q(self, embed_dim: int):
|
470
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_query_groups), bias=self.use_bias)
|
471
|
+
|
472
|
+
def _init_out(self, embed_dim: int):
|
473
|
+
"""Initialize output projection"""
|
474
|
+
self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_query_groups), embed_dim)
|
475
|
+
|
476
|
+
def _split_heads(self, x: torch.Tensor, is_query: bool = False) -> torch.Tensor:
|
477
|
+
b, t, d = x.size()
|
478
|
+
return x.view(b, t, self.num_query_groups if is_query else self.num_groups, self.head_dim).transpose(1, 2)
|
479
|
+
|
480
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
481
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
482
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, d // (self.num_heads // self.num_query_groups))
|
483
|
+
|
484
|
+
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
485
|
+
b, h, t, d = x.size()
|
486
|
+
return self._transpose_output(x, b, t, self.embed_dim)
|
487
|
+
|
488
|
+
def _sdpa(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
489
|
+
is_gqa = self.num_query_groups != self.num_groups
|
490
|
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0,
|
491
|
+
is_causal=self.is_causal, enable_gqa=is_gqa)
|
492
|
+
|
493
|
+
|
422
494
|
class InfiniteAttention(MultiHeadAttention):
|
423
495
|
def __init__(
|
424
496
|
self,
|
@@ -499,7 +571,7 @@ def init_experimental_attention(
|
|
499
571
|
num_global_tokens: int = 16,
|
500
572
|
window_size: int = 128,
|
501
573
|
) -> MultiHeadAttention:
|
502
|
-
assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
|
574
|
+
assert attention_type in ['gma', 'dma', 'sqa', 'flex', 'flex-sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex', 'flex-sqa"
|
503
575
|
|
504
576
|
if attention_type == "gma":
|
505
577
|
return GroupedMoeAttention(
|
@@ -550,6 +622,23 @@ def init_experimental_attention(
|
|
550
622
|
num_global_tokens=num_global_tokens,
|
551
623
|
window_size=window_size,
|
552
624
|
)
|
625
|
+
elif attention_type == "flex-sqa":
|
626
|
+
return FlexSparseQueryAttention(
|
627
|
+
embed_dim,
|
628
|
+
num_heads,
|
629
|
+
gqa_groups,
|
630
|
+
num_query_groups,
|
631
|
+
dropout=dropout,
|
632
|
+
rope=rope,
|
633
|
+
max_seq_len=max_seq_len,
|
634
|
+
rope_only_for_query=rope_only_for_query,
|
635
|
+
rope_only_for_keys=rope_only_for_keys,
|
636
|
+
use_flash_attention=use_flash_attention,
|
637
|
+
is_causal=is_causal,
|
638
|
+
use_bias=use_bias,
|
639
|
+
num_global_tokens=num_global_tokens,
|
640
|
+
window_size=window_size,
|
641
|
+
)
|
553
642
|
else:
|
554
643
|
return SparseQueryAttention(
|
555
644
|
embed_dim,
|
rxnn/experimental/models.py
CHANGED
@@ -73,7 +73,7 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
73
73
|
assert ff_activation in ['relu', 'gelu',
|
74
74
|
'swish', 'silu', 'linear',
|
75
75
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
76
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
|
76
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex', 'flex-sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex", "flex-sqa".'
|
77
77
|
|
78
78
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
79
79
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -1,8 +1,8 @@
|
|
1
1
|
rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
|
2
2
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
rxnn/experimental/attention.py,sha256=
|
5
|
-
rxnn/experimental/models.py,sha256=
|
4
|
+
rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhXg,28455
|
5
|
+
rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.56.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.56.dist-info/METADATA,sha256=qW9X-oP3LWHB0E6S0opHPzWjDmNBRy9DjOz5od4qutc,25997
|
38
|
+
rxnn-0.2.56.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.56.dist-info/RECORD,,
|
File without changes
|
File without changes
|