rxnn 0.2.52__py3-none-any.whl → 0.2.53__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +103 -60
- rxnn/experimental/models.py +7 -3
- rxnn/transformers/attention.py +3 -1
- {rxnn-0.2.52.dist-info → rxnn-0.2.53.dist-info}/METADATA +1 -1
- {rxnn-0.2.52.dist-info → rxnn-0.2.53.dist-info}/RECORD +7 -7
- {rxnn-0.2.52.dist-info → rxnn-0.2.53.dist-info}/LICENSE +0 -0
- {rxnn-0.2.52.dist-info → rxnn-0.2.53.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -318,76 +318,102 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
318
318
|
|
319
319
|
|
320
320
|
# Others
|
321
|
-
|
322
321
|
class FlexAttention(MultiHeadAttention):
|
323
322
|
def __init__(
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
323
|
+
self,
|
324
|
+
embed_dim: int,
|
325
|
+
num_heads: int,
|
326
|
+
dropout: float = 0.0,
|
327
|
+
rope: RotaryPositionalEmbedding = None,
|
328
|
+
rope_only_for_query: bool = False,
|
329
|
+
rope_only_for_keys: bool = False,
|
330
|
+
use_relative_embeddings: bool = False,
|
331
|
+
max_seq_len: int = 1024,
|
332
|
+
use_flash_attention: bool = True,
|
333
|
+
is_causal: bool = False,
|
334
|
+
use_bias: bool = False,
|
335
|
+
num_global_tokens: int = 16,
|
336
|
+
window_size: int = 128,
|
330
337
|
):
|
331
|
-
super().__init__(
|
338
|
+
super(FlexAttention, self).__init__(
|
339
|
+
embed_dim,
|
340
|
+
num_heads,
|
341
|
+
dropout=dropout,
|
342
|
+
rope=rope,
|
343
|
+
rope_only_for_query=rope_only_for_query,
|
344
|
+
rope_only_for_keys=rope_only_for_keys,
|
345
|
+
use_relative_embeddings=use_relative_embeddings,
|
346
|
+
max_seq_len=max_seq_len,
|
347
|
+
use_flash_attention=use_flash_attention,
|
348
|
+
is_causal=is_causal,
|
349
|
+
use_bias=use_bias,
|
350
|
+
)
|
351
|
+
self.head_dim = embed_dim // num_heads
|
332
352
|
self.num_global_tokens = num_global_tokens
|
333
353
|
self.window_size = window_size
|
334
|
-
self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, embed_dim))
|
335
354
|
|
336
|
-
|
355
|
+
# Learnable global tokens
|
356
|
+
self.global_tokens = nn.Parameter(torch.randn(1, num_global_tokens, embed_dim))
|
357
|
+
|
358
|
+
|
359
|
+
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
360
|
+
b, t, d = x.size()
|
361
|
+
return x.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
362
|
+
|
363
|
+
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
364
|
+
b, h, t, d = x.size()
|
365
|
+
return self._transpose_output(x, b, t, h * d)
|
366
|
+
|
367
|
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
337
368
|
b, t, d = query.size()
|
338
|
-
head_dim = d // self.num_heads
|
339
369
|
|
340
|
-
#
|
341
|
-
|
342
|
-
|
343
|
-
num_windows = (seq_len - self.num_global_tokens + self.window_size - 1) // self.window_size
|
370
|
+
# Prepend global tokens to the input query
|
371
|
+
global_tokens = self.global_tokens.expand(b, -1, -1)
|
372
|
+
x = torch.cat([global_tokens, query], dim=1)
|
344
373
|
|
345
374
|
# Project Q, K, V
|
346
|
-
q
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
global_v = v[:, :, :self.num_global_tokens]
|
352
|
-
global_attn = self._calculate_attn_weights(global_q, global_k, d) @ global_v
|
353
|
-
|
354
|
-
# Process Global-to-Local Attention
|
355
|
-
local_k = k[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
|
356
|
-
local_v = v[:, :, self.num_global_tokens:]
|
357
|
-
# Apply RoPE to local_k if needed
|
375
|
+
q = self._split_heads(self.q_proj(x))
|
376
|
+
k = self._split_heads(self.k_proj(key))
|
377
|
+
v = self._split_heads(self.v_proj(value))
|
378
|
+
|
379
|
+
# Apply RoPE
|
358
380
|
if self.rope:
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
#
|
388
|
-
|
389
|
-
|
390
|
-
|
381
|
+
q, k = self._apply_rope(q, k, separate=True)
|
382
|
+
|
383
|
+
# Split Q into global and local parts
|
384
|
+
global_q = q[:, :, :self.num_global_tokens] # (B, H, G, head_dim)
|
385
|
+
local_q = q[:, :, self.num_global_tokens:] # (B, H, L, head_dim)
|
386
|
+
L = local_q.size(2)
|
387
|
+
S = k.size(2)
|
388
|
+
|
389
|
+
# Global attention: global_q attends to all K/V
|
390
|
+
global_attn = F.scaled_dot_product_attention(
|
391
|
+
global_q, k, v, attn_mask=mask if not self.is_causal else None,
|
392
|
+
dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
|
393
|
+
)
|
394
|
+
|
395
|
+
# Local attention: local_q attends to windowed K/V
|
396
|
+
# Vectorized window mask
|
397
|
+
indices = torch.arange(S, device=local_q.device)
|
398
|
+
local_pos = torch.arange(L, device=local_q.device)
|
399
|
+
local_window = (local_pos // self.window_size).unsqueeze(-1) # (L, 1)
|
400
|
+
key_window = (indices // self.window_size).expand(L, -1) # (L, S)
|
401
|
+
window_mask = (local_window == key_window).to(device=local_q.device)
|
402
|
+
|
403
|
+
local_attn = F.scaled_dot_product_attention(
|
404
|
+
local_q, k, v, attn_mask=window_mask if not self.is_causal else None,
|
405
|
+
dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
|
406
|
+
)
|
407
|
+
|
408
|
+
# Combine global and local attention outputs
|
409
|
+
attn = torch.cat([global_attn, local_attn], dim=2) # (B, H, G+L, head_dim)
|
410
|
+
|
411
|
+
# Merge heads and project back
|
412
|
+
output = self._merge_heads(attn) # (B, G+L, D)
|
413
|
+
output = self.out_proj(output) # (B, G+L, D)
|
414
|
+
|
415
|
+
# Return only the local tokens (original query tokens)
|
416
|
+
return output[:, self.num_global_tokens:, :]
|
391
417
|
|
392
418
|
|
393
419
|
class InfiniteAttention(MultiHeadAttention):
|
@@ -467,8 +493,10 @@ def init_experimental_attention(
|
|
467
493
|
num_experts: int = None,
|
468
494
|
num_query_experts: int = None,
|
469
495
|
num_query_groups: int = None,
|
496
|
+
num_global_tokens: int = 16,
|
497
|
+
window_size: int = 128,
|
470
498
|
) -> MultiHeadAttention:
|
471
|
-
assert attention_type in ['gma', 'dma', 'sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa'"
|
499
|
+
assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
|
472
500
|
|
473
501
|
if attention_type == "gma":
|
474
502
|
return GroupedMoeAttention(
|
@@ -504,6 +532,21 @@ def init_experimental_attention(
|
|
504
532
|
num_query_experts=num_query_experts,
|
505
533
|
num_query_groups=num_query_groups,
|
506
534
|
)
|
535
|
+
elif attention_type == "flex":
|
536
|
+
return FlexAttention(
|
537
|
+
embed_dim,
|
538
|
+
num_heads,
|
539
|
+
dropout=dropout,
|
540
|
+
rope=rope,
|
541
|
+
max_seq_len=max_seq_len,
|
542
|
+
rope_only_for_query=rope_only_for_query,
|
543
|
+
rope_only_for_keys=rope_only_for_keys,
|
544
|
+
use_flash_attention=use_flash_attention,
|
545
|
+
is_causal=is_causal,
|
546
|
+
use_bias=use_bias,
|
547
|
+
num_global_tokens=num_global_tokens,
|
548
|
+
window_size=window_size,
|
549
|
+
)
|
507
550
|
else:
|
508
551
|
return SparseQueryAttention(
|
509
552
|
embed_dim,
|
rxnn/experimental/models.py
CHANGED
@@ -32,6 +32,8 @@ class ExperimentalAttentionTransformerConfig(TypedDict):
|
|
32
32
|
att_num_experts: int
|
33
33
|
att_num_query_experts: int
|
34
34
|
att_num_query_groups: int
|
35
|
+
att_num_global_tokens: int
|
36
|
+
att_window_size: int
|
35
37
|
|
36
38
|
|
37
39
|
class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
|
@@ -63,13 +65,15 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
63
65
|
att_num_experts: int = None,
|
64
66
|
att_num_query_experts: int = None,
|
65
67
|
att_num_query_groups: int = None,
|
68
|
+
att_num_global_tokens: int = 16,
|
69
|
+
att_window_size: int = 128,
|
66
70
|
**kwargs
|
67
71
|
):
|
68
72
|
super(ExperimentalAttentionTransformer, self).__init__(**kwargs)
|
69
73
|
assert ff_activation in ['relu', 'gelu',
|
70
74
|
'swish', 'silu', 'linear',
|
71
75
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
72
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
76
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
|
73
77
|
|
74
78
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
75
79
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -84,8 +88,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
84
88
|
att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
85
89
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
86
90
|
max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
|
87
|
-
num_query_experts=att_num_query_experts,
|
88
|
-
|
91
|
+
num_query_experts=att_num_query_experts, num_query_groups=att_num_query_groups,
|
92
|
+
num_global_tokens=att_num_global_tokens, window_size=att_window_size)
|
89
93
|
|
90
94
|
use_moe_att = att_type in ['gma', 'dma']
|
91
95
|
|
rxnn/transformers/attention.py
CHANGED
@@ -69,12 +69,14 @@ class MultiHeadAttention(nn.Module):
|
|
69
69
|
v = self.v_proj(value).view(b, -1, self.num_heads, d // self.num_heads).transpose(1, 2)
|
70
70
|
return q, k, v
|
71
71
|
|
72
|
-
def _apply_rope(self, q: torch.Tensor, k: torch.Tensor):
|
72
|
+
def _apply_rope(self, q: torch.Tensor, k: torch.Tensor, separate: bool = False):
|
73
73
|
if self.rope is not None:
|
74
74
|
if self.rope_only_for_query:
|
75
75
|
q = self.rope.forward_one(q)
|
76
76
|
elif self.rope_only_for_keys:
|
77
77
|
k = self.rope.forward_one(k)
|
78
|
+
elif separate:
|
79
|
+
q, k = self.rope.forward_one(q), self.rope.forward_one(k)
|
78
80
|
else:
|
79
81
|
q, k = self.rope(q, k)
|
80
82
|
return q, k
|
@@ -1,8 +1,8 @@
|
|
1
1
|
rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
|
2
2
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
rxnn/experimental/attention.py,sha256=
|
5
|
-
rxnn/experimental/models.py,sha256=
|
4
|
+
rxnn/experimental/attention.py,sha256=z9zahN1KTDcVabUB0RwWsT6oK8MTuz65haVwYRHAzy4,24689
|
5
|
+
rxnn/experimental/models.py,sha256=HPOIRpnX_oiI10wsVC4J6rzo3T6dj10aNWGYpa9S1UU,5115
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
|
@@ -24,7 +24,7 @@ rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,14
|
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
-
rxnn/transformers/attention.py,sha256=
|
27
|
+
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
29
|
rxnn/transformers/layers.py,sha256=wG8C9doafpLUsGtUTg-xdrHt7EQMEdB10vcSD-O1nVg,7999
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.53.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.53.dist-info/METADATA,sha256=CqOV8zz8qcGYyZXbeRv7jLMd40AeF8iU7vyfRsMuslE,25997
|
38
|
+
rxnn-0.2.53.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.53.dist-info/RECORD,,
|
File without changes
|
File without changes
|