rxnn 0.2.52__py3-none-any.whl → 0.2.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +106 -60
- rxnn/experimental/models.py +7 -3
- rxnn/transformers/attention.py +3 -1
- {rxnn-0.2.52.dist-info → rxnn-0.2.54.dist-info}/METADATA +1 -1
- {rxnn-0.2.52.dist-info → rxnn-0.2.54.dist-info}/RECORD +7 -7
- {rxnn-0.2.52.dist-info → rxnn-0.2.54.dist-info}/LICENSE +0 -0
- {rxnn-0.2.52.dist-info → rxnn-0.2.54.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -318,76 +318,105 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
318
318
|
|
319
319
|
|
320
320
|
# Others
|
321
|
-
|
322
321
|
class FlexAttention(MultiHeadAttention):
|
323
322
|
def __init__(
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
323
|
+
self,
|
324
|
+
embed_dim: int,
|
325
|
+
num_heads: int,
|
326
|
+
dropout: float = 0.0,
|
327
|
+
rope: RotaryPositionalEmbedding = None,
|
328
|
+
rope_only_for_query: bool = False,
|
329
|
+
rope_only_for_keys: bool = False,
|
330
|
+
use_relative_embeddings: bool = False,
|
331
|
+
max_seq_len: int = 1024,
|
332
|
+
use_flash_attention: bool = True,
|
333
|
+
is_causal: bool = False,
|
334
|
+
use_bias: bool = False,
|
335
|
+
num_global_tokens: int = 16,
|
336
|
+
window_size: int = 128,
|
330
337
|
):
|
331
|
-
super().__init__(
|
338
|
+
super(FlexAttention, self).__init__(
|
339
|
+
embed_dim,
|
340
|
+
num_heads,
|
341
|
+
dropout=dropout,
|
342
|
+
rope=rope,
|
343
|
+
rope_only_for_query=rope_only_for_query,
|
344
|
+
rope_only_for_keys=rope_only_for_keys,
|
345
|
+
use_relative_embeddings=use_relative_embeddings,
|
346
|
+
max_seq_len=max_seq_len,
|
347
|
+
use_flash_attention=use_flash_attention,
|
348
|
+
is_causal=is_causal,
|
349
|
+
use_bias=use_bias,
|
350
|
+
)
|
351
|
+
self.head_dim = embed_dim // num_heads
|
332
352
|
self.num_global_tokens = num_global_tokens
|
333
353
|
self.window_size = window_size
|
334
|
-
self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, embed_dim))
|
335
354
|
|
336
|
-
|
355
|
+
# Learnable global tokens
|
356
|
+
self.global_tokens = nn.Parameter(torch.randn(1, num_global_tokens, embed_dim))
|
357
|
+
|
358
|
+
|
359
|
+
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
360
|
+
b, t, d = x.size()
|
361
|
+
return x.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
362
|
+
|
363
|
+
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
364
|
+
b, h, t, d = x.size()
|
365
|
+
return self._transpose_output(x, b, t, h * d)
|
366
|
+
|
367
|
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
337
368
|
b, t, d = query.size()
|
338
|
-
head_dim = d // self.num_heads
|
339
369
|
|
340
|
-
#
|
341
|
-
|
342
|
-
|
343
|
-
num_windows = (seq_len - self.num_global_tokens + self.window_size - 1) // self.window_size
|
370
|
+
# Prepend global tokens
|
371
|
+
global_tokens = self.global_tokens.expand(b, -1, -1)
|
372
|
+
x = torch.cat([global_tokens, query], dim=1)
|
344
373
|
|
345
374
|
# Project Q, K, V
|
346
|
-
q
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
global_v = v[:, :, :self.num_global_tokens]
|
352
|
-
global_attn = self._calculate_attn_weights(global_q, global_k, d) @ global_v
|
353
|
-
|
354
|
-
# Process Global-to-Local Attention
|
355
|
-
local_k = k[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
|
356
|
-
local_v = v[:, :, self.num_global_tokens:]
|
357
|
-
# Apply RoPE to local_k if needed
|
375
|
+
q = self._split_heads(self.q_proj(x))
|
376
|
+
k = self._split_heads(self.k_proj(key))
|
377
|
+
v = self._split_heads(self.v_proj(value))
|
378
|
+
|
379
|
+
# Apply RoPE
|
358
380
|
if self.rope:
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
381
|
+
q, k = self._apply_rope(q, k, separate=True)
|
382
|
+
|
383
|
+
# Split Q into global and local
|
384
|
+
global_q = q[:, :, :self.num_global_tokens] # (B, H, G, D)
|
385
|
+
local_q = q[:, :, self.num_global_tokens:] # (B, H, L, D)
|
386
|
+
|
387
|
+
# Global attention
|
388
|
+
global_attn = F.scaled_dot_product_attention(global_q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal)
|
389
|
+
|
390
|
+
# Local attention with windowed slicing (no large masks)
|
391
|
+
L = local_q.size(2)
|
392
|
+
S = k.size(2)
|
393
|
+
windowed_attn = []
|
394
|
+
|
395
|
+
for i in range(0, L, self.window_size):
|
396
|
+
start = i
|
397
|
+
end = min(i + self.window_size, L)
|
398
|
+
window_q = local_q[:, :, start:end] # (B, H, W, D)
|
399
|
+
|
400
|
+
# Use only relevant keys/values (same window)
|
401
|
+
k_window_start = max(0, start - self.window_size)
|
402
|
+
k_window_end = min(S, end + self.window_size)
|
403
|
+
window_k = k[:, :, k_window_start:k_window_end]
|
404
|
+
window_v = v[:, :, k_window_start:k_window_end]
|
405
|
+
|
406
|
+
|
407
|
+
window_attn = F.scaled_dot_product_attention(window_q, window_k, window_v, attn_mask=None, dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal)
|
408
|
+
|
409
|
+
windowed_attn.append(window_attn)
|
410
|
+
|
411
|
+
# Concat local attention
|
412
|
+
local_attn = torch.cat(windowed_attn, dim=2)
|
413
|
+
|
414
|
+
# Combine global and local
|
415
|
+
attn = torch.cat([global_attn, local_attn], dim=2)
|
416
|
+
output = self._merge_heads(attn)
|
417
|
+
output = self.out_proj(output)
|
418
|
+
|
419
|
+
return output[:, self.num_global_tokens:, :]
|
391
420
|
|
392
421
|
|
393
422
|
class InfiniteAttention(MultiHeadAttention):
|
@@ -467,8 +496,10 @@ def init_experimental_attention(
|
|
467
496
|
num_experts: int = None,
|
468
497
|
num_query_experts: int = None,
|
469
498
|
num_query_groups: int = None,
|
499
|
+
num_global_tokens: int = 16,
|
500
|
+
window_size: int = 128,
|
470
501
|
) -> MultiHeadAttention:
|
471
|
-
assert attention_type in ['gma', 'dma', 'sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa'"
|
502
|
+
assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
|
472
503
|
|
473
504
|
if attention_type == "gma":
|
474
505
|
return GroupedMoeAttention(
|
@@ -504,6 +535,21 @@ def init_experimental_attention(
|
|
504
535
|
num_query_experts=num_query_experts,
|
505
536
|
num_query_groups=num_query_groups,
|
506
537
|
)
|
538
|
+
elif attention_type == "flex":
|
539
|
+
return FlexAttention(
|
540
|
+
embed_dim,
|
541
|
+
num_heads,
|
542
|
+
dropout=dropout,
|
543
|
+
rope=rope,
|
544
|
+
max_seq_len=max_seq_len,
|
545
|
+
rope_only_for_query=rope_only_for_query,
|
546
|
+
rope_only_for_keys=rope_only_for_keys,
|
547
|
+
use_flash_attention=use_flash_attention,
|
548
|
+
is_causal=is_causal,
|
549
|
+
use_bias=use_bias,
|
550
|
+
num_global_tokens=num_global_tokens,
|
551
|
+
window_size=window_size,
|
552
|
+
)
|
507
553
|
else:
|
508
554
|
return SparseQueryAttention(
|
509
555
|
embed_dim,
|
rxnn/experimental/models.py
CHANGED
@@ -32,6 +32,8 @@ class ExperimentalAttentionTransformerConfig(TypedDict):
|
|
32
32
|
att_num_experts: int
|
33
33
|
att_num_query_experts: int
|
34
34
|
att_num_query_groups: int
|
35
|
+
att_num_global_tokens: int
|
36
|
+
att_window_size: int
|
35
37
|
|
36
38
|
|
37
39
|
class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
|
@@ -63,13 +65,15 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
63
65
|
att_num_experts: int = None,
|
64
66
|
att_num_query_experts: int = None,
|
65
67
|
att_num_query_groups: int = None,
|
68
|
+
att_num_global_tokens: int = 16,
|
69
|
+
att_window_size: int = 128,
|
66
70
|
**kwargs
|
67
71
|
):
|
68
72
|
super(ExperimentalAttentionTransformer, self).__init__(**kwargs)
|
69
73
|
assert ff_activation in ['relu', 'gelu',
|
70
74
|
'swish', 'silu', 'linear',
|
71
75
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
72
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
76
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
|
73
77
|
|
74
78
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
75
79
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -84,8 +88,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
84
88
|
att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
85
89
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
86
90
|
max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
|
87
|
-
num_query_experts=att_num_query_experts,
|
88
|
-
|
91
|
+
num_query_experts=att_num_query_experts, num_query_groups=att_num_query_groups,
|
92
|
+
num_global_tokens=att_num_global_tokens, window_size=att_window_size)
|
89
93
|
|
90
94
|
use_moe_att = att_type in ['gma', 'dma']
|
91
95
|
|
rxnn/transformers/attention.py
CHANGED
@@ -69,12 +69,14 @@ class MultiHeadAttention(nn.Module):
|
|
69
69
|
v = self.v_proj(value).view(b, -1, self.num_heads, d // self.num_heads).transpose(1, 2)
|
70
70
|
return q, k, v
|
71
71
|
|
72
|
-
def _apply_rope(self, q: torch.Tensor, k: torch.Tensor):
|
72
|
+
def _apply_rope(self, q: torch.Tensor, k: torch.Tensor, separate: bool = False):
|
73
73
|
if self.rope is not None:
|
74
74
|
if self.rope_only_for_query:
|
75
75
|
q = self.rope.forward_one(q)
|
76
76
|
elif self.rope_only_for_keys:
|
77
77
|
k = self.rope.forward_one(k)
|
78
|
+
elif separate:
|
79
|
+
q, k = self.rope.forward_one(q), self.rope.forward_one(k)
|
78
80
|
else:
|
79
81
|
q, k = self.rope(q, k)
|
80
82
|
return q, k
|
@@ -1,8 +1,8 @@
|
|
1
1
|
rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
|
2
2
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
rxnn/experimental/attention.py,sha256=
|
5
|
-
rxnn/experimental/models.py,sha256=
|
4
|
+
rxnn/experimental/attention.py,sha256=zQM6Og62IZVGogsDBReYrHSiRZmDaebl1FcH2e6sHyY,24589
|
5
|
+
rxnn/experimental/models.py,sha256=HPOIRpnX_oiI10wsVC4J6rzo3T6dj10aNWGYpa9S1UU,5115
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
|
@@ -24,7 +24,7 @@ rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,14
|
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
-
rxnn/transformers/attention.py,sha256=
|
27
|
+
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
29
|
rxnn/transformers/layers.py,sha256=wG8C9doafpLUsGtUTg-xdrHt7EQMEdB10vcSD-O1nVg,7999
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.54.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.54.dist-info/METADATA,sha256=t6l1VezLNpdpgaXaqB-YhrfAhEUlWZm9-wwzBZ_Xk34,25997
|
38
|
+
rxnn-0.2.54.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.54.dist-info/RECORD,,
|
File without changes
|
File without changes
|