rxnn 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/base.py +6 -4
- rxnn/training/bml.py +95 -10
- rxnn/transformers/layers.py +8 -0
- rxnn/transformers/models.py +7 -0
- rxnn/transformers/moe.py +7 -5
- {rxnn-0.1.10.dist-info → rxnn-0.1.12.dist-info}/METADATA +1 -1
- {rxnn-0.1.10.dist-info → rxnn-0.1.12.dist-info}/RECORD +9 -9
- {rxnn-0.1.10.dist-info → rxnn-0.1.12.dist-info}/LICENSE +0 -0
- {rxnn-0.1.10.dist-info → rxnn-0.1.12.dist-info}/WHEEL +0 -0
rxnn/training/base.py
CHANGED
@@ -49,6 +49,7 @@ class BaseTrainer(ABC):
|
|
49
49
|
self.validation_metrics = {}
|
50
50
|
self.target_field_name = target_field_name
|
51
51
|
self.total_tokens = 0
|
52
|
+
self.total_steps = 0
|
52
53
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
53
54
|
self.accumulated_loss = 0.0
|
54
55
|
self.optimizer_step_count = 0
|
@@ -140,6 +141,7 @@ class BaseTrainer(ABC):
|
|
140
141
|
for callback in self.callbacks:
|
141
142
|
callback.on_batch_start(self.model, batch_idx, batch)
|
142
143
|
if self.get_batch_size(batch) == batch_size:
|
144
|
+
self.total_steps += 1
|
143
145
|
loss = self.train_step(batch, batch_idx)
|
144
146
|
orig_loss = loss.item()
|
145
147
|
self.accumulated_loss += orig_loss
|
@@ -226,11 +228,11 @@ class BaseTrainer(ABC):
|
|
226
228
|
self.writer.close()
|
227
229
|
|
228
230
|
def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
|
229
|
-
self.writer.add_scalar('Loss/
|
230
|
-
self.writer.add_scalar('Perplexity/
|
231
|
+
self.writer.add_scalar('Loss/Valid', val_loss, epoch)
|
232
|
+
self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
|
231
233
|
if val_metrics['accuracy']:
|
232
|
-
self.writer.add_scalar('Node Accuracy/
|
233
|
-
self.writer.add_scalar('Avg. Accuracy/
|
234
|
+
self.writer.add_scalar('Node Accuracy/Valid', val_metrics['node_accuracy'], epoch)
|
235
|
+
self.writer.add_scalar('Avg. Accuracy/Valid', val_metrics['accuracy'], epoch)
|
234
236
|
|
235
237
|
def valid_step(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
236
238
|
if self.use_amp:
|
rxnn/training/bml.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
+
from torch.nn.parallel import DistributedDataParallel
|
4
5
|
import math
|
5
6
|
from huggingface_hub import PyTorchModelHubMixin
|
6
7
|
from typing import Union
|
@@ -171,6 +172,90 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
171
172
|
return avg_loss, metrics
|
172
173
|
|
173
174
|
|
175
|
+
class AutoregressiveMoeTrainer(BaseTrainer):
|
176
|
+
def __init__(
|
177
|
+
self,
|
178
|
+
model: ReactiveTransformerDecoder,
|
179
|
+
device: torch.device,
|
180
|
+
vocab_size: int,
|
181
|
+
use_amp: bool = False,
|
182
|
+
dtype: torch.dtype = None,
|
183
|
+
router_loss_scale: float = 0.1,
|
184
|
+
**kwargs
|
185
|
+
):
|
186
|
+
super(AutoregressiveMoeTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
|
187
|
+
target_field_name='targets', **kwargs)
|
188
|
+
self.vocab_size = vocab_size
|
189
|
+
self.router_loss_scale = router_loss_scale
|
190
|
+
|
191
|
+
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
192
|
+
inputs = batch['input_ids']
|
193
|
+
attention_mask = batch['attention_mask']
|
194
|
+
targets = batch['targets']
|
195
|
+
|
196
|
+
outputs = self.model(
|
197
|
+
inputs,
|
198
|
+
attention_mask=attention_mask
|
199
|
+
)
|
200
|
+
|
201
|
+
shifted_logits = outputs[:, :-1].contiguous()
|
202
|
+
shifted_targets = targets[:, 1:].contiguous()
|
203
|
+
|
204
|
+
main_loss = F.cross_entropy(
|
205
|
+
shifted_logits.view(-1, self.vocab_size),
|
206
|
+
shifted_targets.view(-1)
|
207
|
+
)
|
208
|
+
|
209
|
+
model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
|
210
|
+
|
211
|
+
router_loss = model.model.moe_router_loss()
|
212
|
+
loss = main_loss + self.router_loss_scale * router_loss
|
213
|
+
|
214
|
+
if self.writer is not None:
|
215
|
+
if self.model.training:
|
216
|
+
self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
|
217
|
+
self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
|
218
|
+
else:
|
219
|
+
self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
|
220
|
+
self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
|
221
|
+
|
222
|
+
return loss, outputs
|
223
|
+
|
224
|
+
def validate(self, batch_size: int) -> tuple[float, dict]:
|
225
|
+
self.model.eval()
|
226
|
+
val_dataloader = self._valid_loader(batch_size)
|
227
|
+
val_loss = torch.tensor(0.0).to(self.device)
|
228
|
+
correct = torch.tensor(0).to(self.device)
|
229
|
+
total = torch.tensor(0).to(self.device)
|
230
|
+
|
231
|
+
with torch.no_grad():
|
232
|
+
for batch in val_dataloader:
|
233
|
+
if self.get_batch_size(batch) == batch_size:
|
234
|
+
loss, logits = self.valid_step(batch)
|
235
|
+
val_loss += loss
|
236
|
+
shifted_logits = logits[:, :-1].contiguous()
|
237
|
+
shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
|
238
|
+
valid_indices = shifted_targets != -100
|
239
|
+
if valid_indices.any():
|
240
|
+
preds = shifted_logits.argmax(-1)
|
241
|
+
correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
|
242
|
+
total += valid_indices.sum()
|
243
|
+
|
244
|
+
avg_loss = (val_loss / len(val_dataloader)).item()
|
245
|
+
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
246
|
+
node_acc = acc.item()
|
247
|
+
if self.use_ddp:
|
248
|
+
dist.all_reduce(acc, op=dist.ReduceOp.SUM)
|
249
|
+
acc = acc / dist.get_world_size()
|
250
|
+
|
251
|
+
metrics = {
|
252
|
+
'accuracy': acc.item(),
|
253
|
+
'node_accuracy': node_acc,
|
254
|
+
}
|
255
|
+
self.model.train()
|
256
|
+
return avg_loss, metrics
|
257
|
+
|
258
|
+
|
174
259
|
class JointTrainingModel(nn.Module):
|
175
260
|
def __init__(
|
176
261
|
self,
|
@@ -262,18 +347,18 @@ class JointLMTrainer(BaseTrainer):
|
|
262
347
|
return (encoder_loss, decoder_loss), (encoder_logits, decoder_logits)
|
263
348
|
|
264
349
|
def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
|
265
|
-
self.writer.add_scalar('Loss/
|
266
|
-
self.writer.add_scalar('Perplexity/
|
350
|
+
self.writer.add_scalar('Loss/Valid', val_loss, epoch)
|
351
|
+
self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
|
267
352
|
if val_metrics['accuracy']:
|
268
|
-
self.writer.add_scalar('Encoder node accuracy/
|
269
|
-
self.writer.add_scalar('Decoder node accuracy/
|
270
|
-
self.writer.add_scalar('Encoder avg. accuracy/
|
271
|
-
self.writer.add_scalar('Decoder avg. accuracy/
|
353
|
+
self.writer.add_scalar('Encoder node accuracy/Valid', val_metrics['accuracy']['node_encoder'], epoch)
|
354
|
+
self.writer.add_scalar('Decoder node accuracy/Valid', val_metrics['accuracy']['node_decoder'], epoch)
|
355
|
+
self.writer.add_scalar('Encoder avg. accuracy/Valid', val_metrics['accuracy']['encoder'], epoch)
|
356
|
+
self.writer.add_scalar('Decoder avg. accuracy/Valid', val_metrics['accuracy']['decoder'], epoch)
|
272
357
|
if val_metrics['loss']:
|
273
|
-
self.writer.add_scalar('Encoder loss/
|
274
|
-
self.writer.add_scalar('Encoder perplexity/
|
275
|
-
self.writer.add_scalar('Decoder accuracy/
|
276
|
-
self.writer.add_scalar('Decoder perplexity/
|
358
|
+
self.writer.add_scalar('Encoder loss/Valid', val_metrics['loss']['encoder'], epoch)
|
359
|
+
self.writer.add_scalar('Encoder perplexity/Valid', math.exp(val_metrics['loss']['encoder']), epoch)
|
360
|
+
self.writer.add_scalar('Decoder accuracy/Valid', val_metrics['loss']['decoder'], epoch)
|
361
|
+
self.writer.add_scalar('Decoder perplexity/Valid', math.exp(val_metrics['loss']['decoder']), epoch)
|
277
362
|
|
278
363
|
def validate(self, batch_size: int) -> tuple[float, dict]:
|
279
364
|
self.model.eval()
|
rxnn/transformers/layers.py
CHANGED
@@ -53,11 +53,15 @@ class ReactiveTransformerLayer(nn.Module):
|
|
53
53
|
self.norm2 = nn.LayerNorm(embed_dim)
|
54
54
|
self.norm3 = nn.LayerNorm(embed_dim)
|
55
55
|
self.use_post_norm = use_post_norm
|
56
|
+
self.use_moe = use_moe
|
56
57
|
|
57
58
|
def trainable_cross_attention_(self, is_trainable: bool):
|
58
59
|
for param in self.memory_cross_attention.parameters():
|
59
60
|
param.requires_grad_(is_trainable)
|
60
61
|
|
62
|
+
def moe_router_loss_(self):
|
63
|
+
return self.ff.router_loss() if self.use_moe else None
|
64
|
+
|
61
65
|
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
62
66
|
# First step, self-attention
|
63
67
|
residual = x
|
@@ -129,6 +133,10 @@ class ClassicTransformerLayer(nn.Module):
|
|
129
133
|
self.norm1 = nn.LayerNorm(embed_dim)
|
130
134
|
self.norm2 = nn.LayerNorm(embed_dim)
|
131
135
|
self.use_post_norm = use_post_norm
|
136
|
+
self.use_moe = use_moe
|
137
|
+
|
138
|
+
def moe_router_loss_(self):
|
139
|
+
return self.ff.router_loss() if self.use_moe else torch.tensor(0.0)
|
132
140
|
|
133
141
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
134
142
|
# First step, self-attention
|
rxnn/transformers/models.py
CHANGED
@@ -37,6 +37,10 @@ class ReactiveTransformerBase(nn.Module):
|
|
37
37
|
for i in range(self.num_own_layers):
|
38
38
|
self.layers[i].trainable_cross_attention_(is_trainable)
|
39
39
|
|
40
|
+
def moe_router_loss_(self):
|
41
|
+
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
|
42
|
+
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
|
43
|
+
|
40
44
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
41
45
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
42
46
|
x = self.embedding(x)
|
@@ -119,6 +123,9 @@ class ClassicTransformerBase(nn.Module):
|
|
119
123
|
self.layers = layers
|
120
124
|
self.num_layers = len(layers) if layers else 0
|
121
125
|
|
126
|
+
def moe_router_loss_(self):
|
127
|
+
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
|
128
|
+
|
122
129
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
123
130
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
124
131
|
x = self.embedding(x)
|
rxnn/transformers/moe.py
CHANGED
@@ -11,7 +11,8 @@ class MoeRouter(nn.Module):
|
|
11
11
|
self.top_k = top_k
|
12
12
|
self.num_experts = num_experts
|
13
13
|
self.gate = nn.Linear(embed_dim, num_experts, bias=False)
|
14
|
-
|
14
|
+
# For expert load balancing
|
15
|
+
self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
|
15
16
|
|
16
17
|
def forward(self, x: torch.Tensor):
|
17
18
|
# x shape: [batch_size*seq_len, embed_dim]
|
@@ -19,10 +20,8 @@ class MoeRouter(nn.Module):
|
|
19
20
|
probs = F.softmax(logits, dim=-1)
|
20
21
|
|
21
22
|
# Expert load balancing loss
|
22
|
-
|
23
|
-
|
24
|
-
self.aux_loss = (probs_for_bal.mean(dim=0) *
|
25
|
-
torch.log(probs_for_bal.mean(dim=0) + 1e-9)).sum()
|
23
|
+
mean_probs = probs.mean(dim=0) # Mean probability per expert across batch
|
24
|
+
self.aux_loss = (mean_probs * torch.log(mean_probs + 1e-9)).sum() # Entropy-based loss
|
26
25
|
|
27
26
|
top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
|
28
27
|
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
@@ -74,6 +73,9 @@ class MoeFeedForward(nn.Module):
|
|
74
73
|
def _activate(self, h: torch.Tensor):
|
75
74
|
return self.activation(h)
|
76
75
|
|
76
|
+
def router_loss(self):
|
77
|
+
return self.router.aux_loss
|
78
|
+
|
77
79
|
def forward(self, x: torch.Tensor):
|
78
80
|
orig_shape = x.shape
|
79
81
|
x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
@@ -7,8 +7,8 @@ rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
|
7
7
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
|
9
9
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
rxnn/training/base.py,sha256=
|
11
|
-
rxnn/training/bml.py,sha256=
|
10
|
+
rxnn/training/base.py,sha256=QD8uS14jSyR5Y_8BgCaBQTKpsarerU3lyufsWsCq_6o,11227
|
11
|
+
rxnn/training/bml.py,sha256=o_88ZL1YWd5gWXaBqYPK2UzSTbJaiTiw96E6z73LeOQ,18660
|
12
12
|
rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
|
13
13
|
rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
14
14
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
@@ -16,14 +16,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
|
|
16
16
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
rxnn/transformers/attention.py,sha256=FfEYE0THO73p_1eRupr2mcwfW4UbI_riIxkHfr8X_1c,14022
|
18
18
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
19
|
-
rxnn/transformers/layers.py,sha256=
|
19
|
+
rxnn/transformers/layers.py,sha256=xMocHzdSu7hcC_mPE_aG3-LQg2RXgunKSxcgNXYnOeo,5631
|
20
20
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
21
|
-
rxnn/transformers/models.py,sha256=
|
22
|
-
rxnn/transformers/moe.py,sha256=
|
21
|
+
rxnn/transformers/models.py,sha256=PVhiTTSQ7VTDVdOcyRf-xGNvj6oOa_2fUV2mfthcE0Y,7171
|
22
|
+
rxnn/transformers/moe.py,sha256=v21HDEhkDr10--If0P-XBjT5C7IlQJo0wGQlpDnVWEA,5020
|
23
23
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
24
24
|
rxnn/transformers/sampler.py,sha256=wSz_1wNloqtuiix5w2Mcsj5NhaO9QlY0j__TVG7wJnM,3938
|
25
25
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
26
|
-
rxnn-0.1.
|
27
|
-
rxnn-0.1.
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
26
|
+
rxnn-0.1.12.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
27
|
+
rxnn-0.1.12.dist-info/METADATA,sha256=mdoZLApjlSpC6GnprzoPuVpVhHpmVDejSjJABq_HKbk,14629
|
28
|
+
rxnn-0.1.12.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
29
|
+
rxnn-0.1.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|