rxnn 0.2.53__py3-none-any.whl → 0.2.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -319,6 +319,8 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
319
319
|
|
320
320
|
# Others
|
321
321
|
class FlexAttention(MultiHeadAttention):
|
322
|
+
"""Flex attention layer, with RoPE support"""
|
323
|
+
|
322
324
|
def __init__(
|
323
325
|
self,
|
324
326
|
embed_dim: int,
|
@@ -355,8 +357,7 @@ class FlexAttention(MultiHeadAttention):
|
|
355
357
|
# Learnable global tokens
|
356
358
|
self.global_tokens = nn.Parameter(torch.randn(1, num_global_tokens, embed_dim))
|
357
359
|
|
358
|
-
|
359
|
-
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
360
|
+
def _split_heads(self, x: torch.Tensor, is_query: bool = False) -> torch.Tensor:
|
360
361
|
b, t, d = x.size()
|
361
362
|
return x.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
362
363
|
|
@@ -364,15 +365,19 @@ class FlexAttention(MultiHeadAttention):
|
|
364
365
|
b, h, t, d = x.size()
|
365
366
|
return self._transpose_output(x, b, t, h * d)
|
366
367
|
|
368
|
+
def _sdpa(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
369
|
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0,
|
370
|
+
is_causal=self.is_causal)
|
371
|
+
|
367
372
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
368
373
|
b, t, d = query.size()
|
369
374
|
|
370
|
-
# Prepend global tokens
|
375
|
+
# Prepend global tokens
|
371
376
|
global_tokens = self.global_tokens.expand(b, -1, -1)
|
372
377
|
x = torch.cat([global_tokens, query], dim=1)
|
373
378
|
|
374
379
|
# Project Q, K, V
|
375
|
-
q = self._split_heads(self.q_proj(x))
|
380
|
+
q = self._split_heads(self.q_proj(x), is_query=True)
|
376
381
|
k = self._split_heads(self.k_proj(key))
|
377
382
|
v = self._split_heads(self.v_proj(value))
|
378
383
|
|
@@ -380,42 +385,112 @@ class FlexAttention(MultiHeadAttention):
|
|
380
385
|
if self.rope:
|
381
386
|
q, k = self._apply_rope(q, k, separate=True)
|
382
387
|
|
383
|
-
# Split Q into global and local
|
384
|
-
global_q = q[:, :, :self.num_global_tokens] # (B, H, G,
|
385
|
-
local_q = q[:, :, self.num_global_tokens:] # (B, H, L,
|
388
|
+
# Split Q into global and local
|
389
|
+
global_q = q[:, :, :self.num_global_tokens] # (B, H, G, D)
|
390
|
+
local_q = q[:, :, self.num_global_tokens:] # (B, H, L, D)
|
391
|
+
|
392
|
+
# Global attention
|
393
|
+
global_attn = self._sdpa(global_q, k, v, mask=mask)
|
394
|
+
|
395
|
+
# Local attention with windowed slicing (no large masks)
|
386
396
|
L = local_q.size(2)
|
387
397
|
S = k.size(2)
|
398
|
+
windowed_attn = []
|
388
399
|
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
)
|
400
|
+
for i in range(0, L, self.window_size):
|
401
|
+
start = i
|
402
|
+
end = min(i + self.window_size, L)
|
403
|
+
window_q = local_q[:, :, start:end] # (B, H, W, D)
|
394
404
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
local_q, k, v, attn_mask=window_mask if not self.is_causal else None,
|
405
|
-
dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
|
406
|
-
)
|
405
|
+
# Use only relevant keys/values (same window)
|
406
|
+
k_window_start = max(0, start - self.window_size)
|
407
|
+
k_window_end = min(S, end + self.window_size)
|
408
|
+
window_k = k[:, :, k_window_start:k_window_end]
|
409
|
+
window_v = v[:, :, k_window_start:k_window_end]
|
410
|
+
|
411
|
+
window_attn = self._sdpa(window_q, window_k, window_v, mask=None)
|
412
|
+
|
413
|
+
windowed_attn.append(window_attn)
|
407
414
|
|
408
|
-
#
|
409
|
-
|
415
|
+
# Concat local attention
|
416
|
+
local_attn = torch.cat(windowed_attn, dim=2)
|
410
417
|
|
411
|
-
#
|
412
|
-
|
413
|
-
output = self.
|
418
|
+
# Combine global and local
|
419
|
+
attn = torch.cat([global_attn, local_attn], dim=2)
|
420
|
+
output = self._merge_heads(attn)
|
421
|
+
output = self.out_proj(output)
|
414
422
|
|
415
|
-
# Return only the local tokens (original query tokens)
|
416
423
|
return output[:, self.num_global_tokens:, :]
|
417
424
|
|
418
425
|
|
426
|
+
class FlexSparseQueryAttention(FlexAttention):
|
427
|
+
"""Combined Flex and Sparse Query attention layer, with RoPE support"""
|
428
|
+
def __init__(
|
429
|
+
self,
|
430
|
+
embed_dim: int,
|
431
|
+
num_heads: int,
|
432
|
+
num_groups: int,
|
433
|
+
num_query_groups: int,
|
434
|
+
dropout: float = 0.0,
|
435
|
+
rope: RotaryPositionalEmbedding = None,
|
436
|
+
rope_only_for_query: bool = False,
|
437
|
+
rope_only_for_keys: bool = False,
|
438
|
+
use_relative_embeddings: bool = False,
|
439
|
+
max_seq_len: int = 1024,
|
440
|
+
use_flash_attention: bool = True,
|
441
|
+
is_causal: bool = False,
|
442
|
+
use_bias: bool = False,
|
443
|
+
num_global_tokens: int = 16,
|
444
|
+
window_size: int = 128,
|
445
|
+
):
|
446
|
+
self.num_groups = num_groups
|
447
|
+
self.num_query_groups = num_query_groups
|
448
|
+
super(FlexSparseQueryAttention, self).__init__(
|
449
|
+
embed_dim,
|
450
|
+
num_heads,
|
451
|
+
dropout=dropout,
|
452
|
+
rope=rope,
|
453
|
+
rope_only_for_query=rope_only_for_query,
|
454
|
+
rope_only_for_keys=rope_only_for_keys,
|
455
|
+
use_relative_embeddings=use_relative_embeddings,
|
456
|
+
max_seq_len=max_seq_len,
|
457
|
+
use_flash_attention=use_flash_attention,
|
458
|
+
is_causal=is_causal,
|
459
|
+
use_bias=use_bias,
|
460
|
+
num_global_tokens=num_global_tokens,
|
461
|
+
window_size=window_size,
|
462
|
+
)
|
463
|
+
assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
|
464
|
+
|
465
|
+
def _init_kv(self, embed_dim: int):
|
466
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
|
467
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
|
468
|
+
|
469
|
+
def _init_q(self, embed_dim: int):
|
470
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_query_groups), bias=self.use_bias)
|
471
|
+
|
472
|
+
def _init_out(self, embed_dim: int):
|
473
|
+
"""Initialize output projection"""
|
474
|
+
self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_query_groups), embed_dim)
|
475
|
+
|
476
|
+
def _split_heads(self, x: torch.Tensor, is_query: bool = False) -> torch.Tensor:
|
477
|
+
b, t, d = x.size()
|
478
|
+
return x.view(b, t, self.num_query_groups if is_query else self.num_groups, self.head_dim).transpose(1, 2)
|
479
|
+
|
480
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
481
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
482
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, d // (self.num_heads // self.num_query_groups))
|
483
|
+
|
484
|
+
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
485
|
+
b, h, t, d = x.size()
|
486
|
+
return self._transpose_output(x, b, t, self.embed_dim)
|
487
|
+
|
488
|
+
def _sdpa(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
489
|
+
is_gqa = self.num_query_groups != self.num_groups
|
490
|
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0,
|
491
|
+
is_causal=self.is_causal, enable_gqa=is_gqa)
|
492
|
+
|
493
|
+
|
419
494
|
class InfiniteAttention(MultiHeadAttention):
|
420
495
|
def __init__(
|
421
496
|
self,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
|
2
2
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
+
rxnn/experimental/attention.py,sha256=JMs6Wr2rRe5J5m0ULhudmhBrzPicGuOOyg5hO8aLFiQ,27846
|
5
5
|
rxnn/experimental/models.py,sha256=HPOIRpnX_oiI10wsVC4J6rzo3T6dj10aNWGYpa9S1UU,5115
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.55.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.55.dist-info/METADATA,sha256=4XCpsJFv9dpetex6uDRLrzKlMYlQZFLK2H2j---WZmA,25997
|
38
|
+
rxnn-0.2.55.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.55.dist-info/RECORD,,
|
File without changes
|
File without changes
|