rxnn 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/transformers/layers.py +2 -2
- rxnn/transformers/models.py +2 -2
- {rxnn-0.1.12.dist-info → rxnn-0.1.13.dist-info}/METADATA +1 -1
- {rxnn-0.1.12.dist-info → rxnn-0.1.13.dist-info}/RECORD +6 -6
- {rxnn-0.1.12.dist-info → rxnn-0.1.13.dist-info}/LICENSE +0 -0
- {rxnn-0.1.12.dist-info → rxnn-0.1.13.dist-info}/WHEEL +0 -0
rxnn/transformers/layers.py
CHANGED
@@ -59,7 +59,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
59
59
|
for param in self.memory_cross_attention.parameters():
|
60
60
|
param.requires_grad_(is_trainable)
|
61
61
|
|
62
|
-
def
|
62
|
+
def moe_router_loss(self):
|
63
63
|
return self.ff.router_loss() if self.use_moe else None
|
64
64
|
|
65
65
|
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
@@ -135,7 +135,7 @@ class ClassicTransformerLayer(nn.Module):
|
|
135
135
|
self.use_post_norm = use_post_norm
|
136
136
|
self.use_moe = use_moe
|
137
137
|
|
138
|
-
def
|
138
|
+
def moe_router_loss(self):
|
139
139
|
return self.ff.router_loss() if self.use_moe else torch.tensor(0.0)
|
140
140
|
|
141
141
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
rxnn/transformers/models.py
CHANGED
@@ -37,7 +37,7 @@ class ReactiveTransformerBase(nn.Module):
|
|
37
37
|
for i in range(self.num_own_layers):
|
38
38
|
self.layers[i].trainable_cross_attention_(is_trainable)
|
39
39
|
|
40
|
-
def
|
40
|
+
def moe_router_loss(self):
|
41
41
|
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
|
42
42
|
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
|
43
43
|
|
@@ -123,7 +123,7 @@ class ClassicTransformerBase(nn.Module):
|
|
123
123
|
self.layers = layers
|
124
124
|
self.num_layers = len(layers) if layers else 0
|
125
125
|
|
126
|
-
def
|
126
|
+
def moe_router_loss(self):
|
127
127
|
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
|
128
128
|
|
129
129
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -16,14 +16,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
|
|
16
16
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
rxnn/transformers/attention.py,sha256=FfEYE0THO73p_1eRupr2mcwfW4UbI_riIxkHfr8X_1c,14022
|
18
18
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
19
|
-
rxnn/transformers/layers.py,sha256=
|
19
|
+
rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
|
20
20
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
21
|
-
rxnn/transformers/models.py,sha256=
|
21
|
+
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
22
22
|
rxnn/transformers/moe.py,sha256=v21HDEhkDr10--If0P-XBjT5C7IlQJo0wGQlpDnVWEA,5020
|
23
23
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
24
24
|
rxnn/transformers/sampler.py,sha256=wSz_1wNloqtuiix5w2Mcsj5NhaO9QlY0j__TVG7wJnM,3938
|
25
25
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
26
|
-
rxnn-0.1.
|
27
|
-
rxnn-0.1.
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
26
|
+
rxnn-0.1.13.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
27
|
+
rxnn-0.1.13.dist-info/METADATA,sha256=i32JDhkCLYc2-Chhy_LMSWbuwN7gQK2LjKiNDIJCQ0U,14629
|
28
|
+
rxnn-0.1.13.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
29
|
+
rxnn-0.1.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|