rxnn 0.1.49__py3-none-any.whl → 0.1.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/models.py +3 -0
- rxnn/transformers/layers.py +12 -7
- rxnn/transformers/models.py +3 -3
- {rxnn-0.1.49.dist-info → rxnn-0.1.50.dist-info}/METADATA +1 -1
- {rxnn-0.1.49.dist-info → rxnn-0.1.50.dist-info}/RECORD +7 -7
- {rxnn-0.1.49.dist-info → rxnn-0.1.50.dist-info}/LICENSE +0 -0
- {rxnn-0.1.49.dist-info → rxnn-0.1.50.dist-info}/WHEEL +0 -0
rxnn/experimental/models.py
CHANGED
@@ -83,6 +83,8 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
83
83
|
num_query_experts=att_num_query_experts,
|
84
84
|
num_query_groups=att_num_query_groups)
|
85
85
|
|
86
|
+
use_moe_att = att_type in ['gma', 'dma', 'gma_s', 'dma_s']
|
87
|
+
|
86
88
|
self.model = ClassicTransformerDecoder(
|
87
89
|
embed_dim,
|
88
90
|
vocab_size,
|
@@ -99,6 +101,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
99
101
|
ff_dropout=ff_dropout,
|
100
102
|
use_rms_norm=use_rms_norm,
|
101
103
|
self_attention=att_init(),
|
104
|
+
use_moe_att=use_moe_att,
|
102
105
|
) for _ in range(num_layers)
|
103
106
|
]),
|
104
107
|
use_flash_attention=use_flash_attention,
|
rxnn/transformers/layers.py
CHANGED
@@ -22,6 +22,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
22
22
|
use_moe: bool = False,
|
23
23
|
num_experts: int = 1,
|
24
24
|
moe_top_k: int = 1,
|
25
|
+
use_moe_att: bool = False,
|
25
26
|
*args,
|
26
27
|
**kwargs,
|
27
28
|
):
|
@@ -54,6 +55,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
54
55
|
self.norm3 = nn.LayerNorm(embed_dim)
|
55
56
|
self.use_post_norm = use_post_norm
|
56
57
|
self.use_moe = use_moe
|
58
|
+
self.use_moe_att = use_moe_att
|
57
59
|
|
58
60
|
def trainable_cross_attention_(self, is_trainable: bool):
|
59
61
|
for param in self.memory_cross_attention.parameters():
|
@@ -62,12 +64,13 @@ class ReactiveTransformerLayer(nn.Module):
|
|
62
64
|
def moe_router_loss(self):
|
63
65
|
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
64
66
|
att_router_loss = None
|
65
|
-
if self.
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
67
|
+
if self.use_moe_att:
|
68
|
+
if self.attention.router_loss is not None and self.memory_cross_attention.router_loss is not None:
|
69
|
+
att_router_loss = (self.attention.router_loss() + self.memory_cross_attention.router_loss()) / 2
|
70
|
+
elif self.attention.router_loss is not None:
|
71
|
+
att_router_loss = self.attention.router_loss()
|
72
|
+
elif self.memory_cross_attention.router_loss is not None:
|
73
|
+
att_router_loss = self.memory_cross_attention.router_loss()
|
71
74
|
|
72
75
|
if ff_router_loss is not None and att_router_loss is not None:
|
73
76
|
return (ff_router_loss + att_router_loss) / 2
|
@@ -123,6 +126,7 @@ class ClassicTransformerLayer(nn.Module):
|
|
123
126
|
use_moe: bool = False,
|
124
127
|
num_experts: int = 1,
|
125
128
|
moe_top_k: int = 1,
|
129
|
+
use_moe_att: bool = False,
|
126
130
|
*args,
|
127
131
|
**kwargs,
|
128
132
|
):
|
@@ -151,10 +155,11 @@ class ClassicTransformerLayer(nn.Module):
|
|
151
155
|
self.norm2 = nn.LayerNorm(embed_dim)
|
152
156
|
self.use_post_norm = use_post_norm
|
153
157
|
self.use_moe = use_moe
|
158
|
+
self.use_moe_att = use_moe_att
|
154
159
|
|
155
160
|
def moe_router_loss(self):
|
156
161
|
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
157
|
-
att_router_loss = self.attention.router_loss() if self.attention.router_loss is not None else None
|
162
|
+
att_router_loss = self.attention.router_loss() if self.use_moe_att and self.attention.router_loss is not None else None
|
158
163
|
|
159
164
|
if ff_router_loss is not None and att_router_loss is not None:
|
160
165
|
return (ff_router_loss + att_router_loss) / 2
|
rxnn/transformers/models.py
CHANGED
@@ -38,8 +38,8 @@ class ReactiveTransformerBase(nn.Module):
|
|
38
38
|
self.layers[i].trainable_cross_attention_(is_trainable)
|
39
39
|
|
40
40
|
def moe_router_loss(self):
|
41
|
-
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
|
42
|
-
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
|
41
|
+
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
|
42
|
+
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
|
43
43
|
|
44
44
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
45
45
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
@@ -124,7 +124,7 @@ class ClassicTransformerBase(nn.Module):
|
|
124
124
|
self.num_layers = len(layers) if layers else 0
|
125
125
|
|
126
126
|
def moe_router_loss(self):
|
127
|
-
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
|
127
|
+
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att]).mean()
|
128
128
|
|
129
129
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
130
130
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
rxnn/experimental/attention.py,sha256=22Qb4jYN6QaqibTU8bwD8x2FaOKCxvWglM2eK9EuOlo,29468
|
4
|
-
rxnn/experimental/models.py,sha256
|
4
|
+
rxnn/experimental/models.py,sha256=-BQn7gWlSHLpkAQdthPW5L9ZNzIBqSJS9tkm2N88jgw,4711
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
@@ -18,14 +18,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
|
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
|
-
rxnn/transformers/layers.py,sha256=
|
21
|
+
rxnn/transformers/layers.py,sha256=OX8CsFY9A7uqH1SLwyexR_5BNlwheYrJHCGXjF8Q7HU,7186
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
|
-
rxnn/transformers/models.py,sha256=
|
23
|
+
rxnn/transformers/models.py,sha256=_w5C7xvjT4-BFeMfzi57BQ51_fgaYZ4UK0SqUDE5Ooo,7266
|
24
24
|
rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.50.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.50.dist-info/METADATA,sha256=bIeDbrlcclSfD9oHf26i_sYepOTvTkpcwQMWpOm2jWc,16627
|
30
|
+
rxnn-0.1.50.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.50.dist-info/RECORD,,
|
File without changes
|
File without changes
|