rxnn 0.1.56__py3-none-any.whl → 0.1.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +4 -270
- rxnn/experimental/models.py +6 -6
- rxnn/rxt/models.py +33 -10
- {rxnn-0.1.56.dist-info → rxnn-0.1.58.dist-info}/METADATA +1 -1
- {rxnn-0.1.56.dist-info → rxnn-0.1.58.dist-info}/RECORD +7 -7
- {rxnn-0.1.56.dist-info → rxnn-0.1.58.dist-info}/LICENSE +0 -0
- {rxnn-0.1.56.dist-info → rxnn-0.1.58.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -269,7 +269,7 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
269
269
|
self.v_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
|
270
270
|
|
271
271
|
def _init_q(self, embed_dim: int):
|
272
|
-
self.q_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_query_groups))
|
272
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_query_groups), bias=self.use_bias)
|
273
273
|
|
274
274
|
def _init_out(self, embed_dim: int):
|
275
275
|
"""Initialize output projection"""
|
@@ -316,240 +316,6 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
316
316
|
return self._torch_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=is_gqa)
|
317
317
|
|
318
318
|
|
319
|
-
|
320
|
-
|
321
|
-
class GroupedMoeAttentionSimplified(GroupedQueryAttention):
|
322
|
-
"""
|
323
|
-
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
324
|
-
|
325
|
-
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
326
|
-
number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
|
327
|
-
- with num_groups set to 1, it will be MoE MultiQueryAttention
|
328
|
-
|
329
|
-
Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
|
330
|
-
this approach - we are training the full number of keys/values heads, while using only a group.
|
331
|
-
|
332
|
-
In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
|
333
|
-
|
334
|
-
Optionally, it could use even more expert heads than attention heads - in example:
|
335
|
-
- 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
|
336
|
-
4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
|
337
|
-
|
338
|
-
© 2025 Adam Filipek
|
339
|
-
"""
|
340
|
-
|
341
|
-
def __init__(
|
342
|
-
self,
|
343
|
-
embed_dim: int,
|
344
|
-
num_heads: int,
|
345
|
-
num_groups: int,
|
346
|
-
dropout: float = 0.0,
|
347
|
-
rope: RotaryPositionalEmbedding = None,
|
348
|
-
rope_only_for_query: bool = False,
|
349
|
-
use_relative_embeddings: bool = False,
|
350
|
-
max_seq_len: int = 1024,
|
351
|
-
use_flash_attention: bool = False,
|
352
|
-
is_causal: bool = False,
|
353
|
-
use_bias: bool = False,
|
354
|
-
num_experts: int = None,
|
355
|
-
*args,
|
356
|
-
**kwargs,
|
357
|
-
):
|
358
|
-
self.num_experts = num_experts or num_heads
|
359
|
-
super(GroupedMoeAttentionSimplified, self).__init__(
|
360
|
-
embed_dim,
|
361
|
-
num_heads,
|
362
|
-
num_groups=num_groups,
|
363
|
-
dropout=dropout,
|
364
|
-
rope=rope,
|
365
|
-
rope_only_for_query=rope_only_for_query,
|
366
|
-
use_relative_embeddings=use_relative_embeddings,
|
367
|
-
max_seq_len=max_seq_len,
|
368
|
-
use_flash_attention=use_flash_attention,
|
369
|
-
is_causal=is_causal,
|
370
|
-
use_bias=use_bias,
|
371
|
-
*args,
|
372
|
-
**kwargs,
|
373
|
-
)
|
374
|
-
|
375
|
-
def router_loss(self):
|
376
|
-
return self.router.aux_loss
|
377
|
-
|
378
|
-
def _init_kv(self, embed_dim: int):
|
379
|
-
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
380
|
-
|
381
|
-
hidden_dim = embed_dim // self.num_heads
|
382
|
-
self.wk = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
|
383
|
-
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
384
|
-
self.wv = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
|
385
|
-
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
386
|
-
self._init_experts()
|
387
|
-
|
388
|
-
def _init_experts(self):
|
389
|
-
nn.init.xavier_uniform_(self.wk)
|
390
|
-
nn.init.xavier_uniform_(self.wv)
|
391
|
-
if self.use_bias:
|
392
|
-
nn.init.zeros_(self.bk)
|
393
|
-
nn.init.zeros_(self.bv)
|
394
|
-
|
395
|
-
def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
|
396
|
-
B, S, G = indices.shape
|
397
|
-
x_flat = x.view(-1, x.size(-1)) # [B*S, D]
|
398
|
-
|
399
|
-
indices_flat = indices.view(-1, G) # [B*S, G]
|
400
|
-
weights_flat = weights.view(-1, G) # [B*S, G]
|
401
|
-
|
402
|
-
output = torch.zeros(B * S, G, w.size(1), device=x.device, dtype=x.dtype) # [B*S, G, hidden_dim]
|
403
|
-
|
404
|
-
for e in range(self.num_experts):
|
405
|
-
# 1. Find tokens where expert `e` is used in ANY group
|
406
|
-
expert_mask = (indices_flat == e).any(dim=1) # [B*S]
|
407
|
-
if not expert_mask.any():
|
408
|
-
continue
|
409
|
-
|
410
|
-
# 2. Project tokens using expert `e`
|
411
|
-
x_slice = x_flat[expert_mask] # [num_selected, D]
|
412
|
-
proj = F.linear(x_slice, w[e], b[e] if b is not None else None) # [num_selected, hidden_dim]
|
413
|
-
|
414
|
-
# 3. Scatter projections into correct groups
|
415
|
-
for g in range(G):
|
416
|
-
group_mask = indices_flat[expert_mask, g] == e # [num_selected]
|
417
|
-
if not group_mask.any():
|
418
|
-
continue
|
419
|
-
|
420
|
-
# Get tokens in this group using expert `e`
|
421
|
-
group_tokens = expert_mask.nonzero()[group_mask].squeeze(1)
|
422
|
-
# Weight and scatter
|
423
|
-
weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
|
424
|
-
output[group_tokens, g] += weighted_proj
|
425
|
-
|
426
|
-
return output.view(B, S, G, -1)
|
427
|
-
|
428
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
429
|
-
skip_query_processing: bool = False):
|
430
|
-
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
|
431
|
-
|
432
|
-
# Key/Value processing
|
433
|
-
B, S, D = key.shape
|
434
|
-
key_flat = key.view(-1, D)
|
435
|
-
weights_k_flat, indices_k_flat = self.router(key_flat)
|
436
|
-
# Reshape back to original dimensions
|
437
|
-
weights_k = weights_k_flat.view(B, S, -1)
|
438
|
-
indices_k = indices_k_flat.view(B, S, -1)
|
439
|
-
k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
|
440
|
-
v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
|
441
|
-
|
442
|
-
# Expand to GQA format
|
443
|
-
k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
444
|
-
v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
445
|
-
|
446
|
-
if not self.rel_embed:
|
447
|
-
group_heads = self.num_heads // self.num_groups
|
448
|
-
|
449
|
-
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
450
|
-
v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
451
|
-
|
452
|
-
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
453
|
-
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
454
|
-
|
455
|
-
return q, k, v
|
456
|
-
|
457
|
-
|
458
|
-
class DeepMoeAttentionSimplified(GroupedMoeAttentionSimplified):
|
459
|
-
"""
|
460
|
-
Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
461
|
-
|
462
|
-
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
463
|
-
query heads - with that approach, each token could attend to every other token, but only partially - only some part of
|
464
|
-
information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
|
465
|
-
sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
|
466
|
-
|
467
|
-
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
|
468
|
-
a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
|
469
|
-
|
470
|
-
© 2025 Adam Filipek
|
471
|
-
"""
|
472
|
-
|
473
|
-
def __init__(
|
474
|
-
self,
|
475
|
-
embed_dim: int,
|
476
|
-
num_heads: int,
|
477
|
-
num_groups: int,
|
478
|
-
dropout: float = 0.0,
|
479
|
-
rope: RotaryPositionalEmbedding = None,
|
480
|
-
rope_only_for_query: bool = False,
|
481
|
-
use_relative_embeddings: bool = False,
|
482
|
-
max_seq_len: int = 1024,
|
483
|
-
use_flash_attention: bool = False,
|
484
|
-
is_causal: bool = False,
|
485
|
-
use_bias: bool = False,
|
486
|
-
num_experts: int = None,
|
487
|
-
num_query_experts: int = None,
|
488
|
-
num_query_groups: int = None,
|
489
|
-
*args,
|
490
|
-
**kwargs,
|
491
|
-
):
|
492
|
-
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
493
|
-
self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
|
494
|
-
super(DeepMoeAttentionSimplified, self).__init__(
|
495
|
-
embed_dim,
|
496
|
-
num_heads,
|
497
|
-
num_groups=num_groups,
|
498
|
-
dropout=dropout,
|
499
|
-
rope=rope,
|
500
|
-
rope_only_for_query=rope_only_for_query,
|
501
|
-
use_relative_embeddings=use_relative_embeddings,
|
502
|
-
max_seq_len=max_seq_len,
|
503
|
-
use_flash_attention=use_flash_attention,
|
504
|
-
is_causal=is_causal,
|
505
|
-
use_bias=use_bias,
|
506
|
-
num_experts=num_experts,
|
507
|
-
*args,
|
508
|
-
**kwargs,
|
509
|
-
)
|
510
|
-
|
511
|
-
def router_loss(self):
|
512
|
-
return (self.router.aux_loss + self.query_router.aux_loss) / 2
|
513
|
-
|
514
|
-
def _init_q(self, embed_dim: int):
|
515
|
-
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
516
|
-
|
517
|
-
hidden_dim = embed_dim // self.num_heads
|
518
|
-
self.wq = nn.Parameter(torch.empty(self.num_query_experts, hidden_dim, embed_dim))
|
519
|
-
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
520
|
-
self._init_query_experts()
|
521
|
-
|
522
|
-
def _init_query_experts(self):
|
523
|
-
nn.init.xavier_uniform_(self.wq)
|
524
|
-
if self.use_bias:
|
525
|
-
nn.init.zeros_(self.bq)
|
526
|
-
|
527
|
-
def _init_out(self, embed_dim: int):
|
528
|
-
"""Initialize output projection"""
|
529
|
-
hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
|
530
|
-
self.out_proj = nn.Linear(hidden_dim, embed_dim)
|
531
|
-
|
532
|
-
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
533
|
-
"""Transpose attention output back to (B, T, D) shape"""
|
534
|
-
hidden_dim = d // self.num_heads * self.num_query_groups
|
535
|
-
return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
|
536
|
-
|
537
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
|
538
|
-
# Query processing
|
539
|
-
B, T, D = query.shape
|
540
|
-
# Flatten for query routing
|
541
|
-
query_flat = query.view(-1, D)
|
542
|
-
weights_q_flat, indices_q_flat = self.query_router(query_flat)
|
543
|
-
# Reshape back
|
544
|
-
weights_q = weights_q_flat.view(B, T, -1)
|
545
|
-
indices_q = indices_q_flat.view(B, T, -1)
|
546
|
-
q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
|
547
|
-
|
548
|
-
q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
|
549
|
-
# Key/Value processing
|
550
|
-
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
551
|
-
|
552
|
-
|
553
319
|
# Others
|
554
320
|
|
555
321
|
class FlexAttention(MultiHeadAttention):
|
@@ -683,7 +449,7 @@ class InfiniteAttention(MultiHeadAttention):
|
|
683
449
|
attn = torch.softmax(attn, dim=-1)
|
684
450
|
return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
685
451
|
|
686
|
-
def
|
452
|
+
def init_experimental_attention(
|
687
453
|
embed_dim: int,
|
688
454
|
num_heads: int,
|
689
455
|
attention_type: str,
|
@@ -700,7 +466,7 @@ def init_moe_attention(
|
|
700
466
|
num_query_experts: int = None,
|
701
467
|
num_query_groups: int = None,
|
702
468
|
) -> MultiHeadAttention:
|
703
|
-
assert attention_type in ['gma', 'dma', '
|
469
|
+
assert attention_type in ['gma', 'dma', 'sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa'"
|
704
470
|
|
705
471
|
if attention_type == "gma":
|
706
472
|
return GroupedMoeAttention(
|
@@ -734,7 +500,7 @@ def init_moe_attention(
|
|
734
500
|
num_query_experts=num_query_experts,
|
735
501
|
num_query_groups=num_query_groups,
|
736
502
|
)
|
737
|
-
|
503
|
+
else:
|
738
504
|
return SparseQueryAttention(
|
739
505
|
embed_dim,
|
740
506
|
num_heads,
|
@@ -749,35 +515,3 @@ def init_moe_attention(
|
|
749
515
|
is_causal=is_causal,
|
750
516
|
use_bias=use_bias,
|
751
517
|
)
|
752
|
-
elif attention_type == "gma_s":
|
753
|
-
return GroupedMoeAttentionSimplified(
|
754
|
-
embed_dim,
|
755
|
-
num_heads,
|
756
|
-
gqa_groups,
|
757
|
-
dropout=dropout,
|
758
|
-
rope=rope,
|
759
|
-
use_relative_embeddings=use_relative_embeddings,
|
760
|
-
max_seq_len=max_seq_len,
|
761
|
-
rope_only_for_query=rope_only_for_query,
|
762
|
-
use_flash_attention=use_flash_attention,
|
763
|
-
is_causal=is_causal,
|
764
|
-
use_bias=use_bias,
|
765
|
-
num_experts=num_experts,
|
766
|
-
)
|
767
|
-
else:
|
768
|
-
return DeepMoeAttentionSimplified(
|
769
|
-
embed_dim,
|
770
|
-
num_heads,
|
771
|
-
gqa_groups,
|
772
|
-
dropout=dropout,
|
773
|
-
rope=rope,
|
774
|
-
use_relative_embeddings=use_relative_embeddings,
|
775
|
-
max_seq_len=max_seq_len,
|
776
|
-
rope_only_for_query=rope_only_for_query,
|
777
|
-
use_flash_attention=use_flash_attention,
|
778
|
-
is_causal=is_causal,
|
779
|
-
use_bias=use_bias,
|
780
|
-
num_experts=num_experts,
|
781
|
-
num_query_experts=num_query_experts,
|
782
|
-
num_query_groups=num_query_groups,
|
783
|
-
)
|
rxnn/experimental/models.py
CHANGED
@@ -8,7 +8,7 @@ from ..transformers.layers import ClassicTransformerLayer
|
|
8
8
|
from ..transformers.models import ClassicTransformerDecoder
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..utils import get_model_size
|
11
|
-
from .attention import
|
11
|
+
from .attention import init_experimental_attention
|
12
12
|
|
13
13
|
|
14
14
|
class MoeAttentionTransformerConfig(TypedDict):
|
@@ -77,11 +77,11 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
77
77
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
78
78
|
max_seq_len=seq_len, is_causal=True)
|
79
79
|
else:
|
80
|
-
att_init = lambda:
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
80
|
+
att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
81
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
82
|
+
max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
|
83
|
+
num_query_experts=att_num_query_experts,
|
84
|
+
num_query_groups=att_num_query_groups)
|
85
85
|
|
86
86
|
use_moe_att = att_type in ['gma', 'dma', 'gma_s', 'dma_s']
|
87
87
|
|
rxnn/rxt/models.py
CHANGED
@@ -9,7 +9,7 @@ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEn
|
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..memory.stm import ShortTermMemory
|
11
11
|
from ..utils import get_model_size
|
12
|
-
|
12
|
+
from ..experimental.attention import init_experimental_attention
|
13
13
|
|
14
14
|
class RxTAlphaComponentConfig(TypedDict):
|
15
15
|
num_layers: int
|
@@ -31,6 +31,9 @@ class RxTAlphaComponentConfig(TypedDict):
|
|
31
31
|
moe_top_k: int
|
32
32
|
self_att_type: str
|
33
33
|
cross_att_type: str
|
34
|
+
att_num_experts: int
|
35
|
+
att_num_query_experts: int
|
36
|
+
att_num_query_groups: int
|
34
37
|
|
35
38
|
|
36
39
|
class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
@@ -58,14 +61,17 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
58
61
|
moe_top_k: int = 1,
|
59
62
|
self_att_type: str = 'gqa',
|
60
63
|
cross_att_type: str = 'mqa',
|
64
|
+
att_num_experts: int = None,
|
65
|
+
att_num_query_experts: int = None,
|
66
|
+
att_num_query_groups: int = None,
|
61
67
|
**kwargs
|
62
68
|
):
|
63
69
|
super(RxTAlphaComponentBase, self).__init__(**kwargs)
|
64
70
|
assert ff_activation in ['relu', 'gelu',
|
65
71
|
'swish', 'silu', 'linear',
|
66
72
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
67
|
-
assert self_att_type in ['mha', 'gqa', 'mqa'], 'Self-attention type could be "mha", "gqa", "mqa"'
|
68
|
-
assert cross_att_type in ['mha', 'gqa', 'mqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa"'
|
73
|
+
assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
74
|
+
assert cross_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
69
75
|
|
70
76
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
71
77
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -73,6 +79,28 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
73
79
|
|
74
80
|
ff_activation = get_activation_layer(ff_activation)
|
75
81
|
|
82
|
+
if self_att_type in ['mha', 'gqa', 'mqa']:
|
83
|
+
att_init = lambda: init_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
|
84
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
85
|
+
max_seq_len=seq_len, is_causal=is_causal)
|
86
|
+
else:
|
87
|
+
att_init = lambda: init_experimental_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
|
88
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
89
|
+
max_seq_len=seq_len, is_causal=is_causal, num_experts=att_num_experts,
|
90
|
+
num_query_experts=att_num_query_experts,
|
91
|
+
num_query_groups=att_num_query_groups)
|
92
|
+
|
93
|
+
if cross_att_type in ['mha', 'gqa', 'mqa']:
|
94
|
+
cross_att_init = lambda: init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
|
95
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
96
|
+
max_seq_len=seq_len, is_causal=is_causal)
|
97
|
+
else:
|
98
|
+
cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
|
99
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
100
|
+
max_seq_len=seq_len, is_causal=is_causal, num_experts=att_num_experts,
|
101
|
+
num_query_experts=att_num_query_experts,
|
102
|
+
num_query_groups=att_num_query_groups)
|
103
|
+
|
76
104
|
layers = nn.ModuleList([
|
77
105
|
ReactiveTransformerLayer(
|
78
106
|
embed_dim,
|
@@ -84,13 +112,8 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
84
112
|
ff_activation=ff_activation,
|
85
113
|
ff_dropout=ff_dropout,
|
86
114
|
use_rms_norm=use_rms_norm,
|
87
|
-
self_attention=
|
88
|
-
|
89
|
-
max_seq_len=seq_len, is_causal=is_causal),
|
90
|
-
memory_cross_attention=init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
|
91
|
-
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
92
|
-
max_seq_len=seq_len, rope_only_for_query=True,
|
93
|
-
is_causal=is_causal)
|
115
|
+
self_attention=att_init(),
|
116
|
+
memory_cross_attention=cross_att_init(),
|
94
117
|
) for _ in range(num_layers)
|
95
118
|
])
|
96
119
|
self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size)
|
@@ -1,13 +1,13 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
-
rxnn/experimental/models.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=bpZQiRXdQ8gJPwYRp3LBr2oELmrysB6-SWiD2F7UQrk,23127
|
4
|
+
rxnn/experimental/models.py,sha256=_i9kvQsAYPyMQo2VfMUTmtBs-mE2w75j1X-OHx03IJk,4743
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
8
8
|
rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
9
9
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
rxnn/rxt/models.py,sha256=
|
10
|
+
rxnn/rxt/models.py,sha256=87KBLbZB7V3NXW_uO2qAQyrPjf2gA2WJrNIFe-e4jdU,8565
|
11
11
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/training/base.py,sha256=gEWASLSuWR8UF8b2e-DYqkBZ1lBx0VsIm4kGf9eWSHM,11678
|
13
13
|
rxnn/training/bml.py,sha256=S1ZaXTybzeJH7uVFamCr4TPl2bLyZ5xmn_lSsjThTiM,19162
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.58.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.58.dist-info/METADATA,sha256=6aamtiDsToIFsNhpO73cacZMFmCPLMCMNluCTWcwWrE,16627
|
30
|
+
rxnn-0.1.58.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.58.dist-info/RECORD,,
|
File without changes
|
File without changes
|