rxnn 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/base.py +6 -3
- rxnn/training/bml.py +120 -26
- rxnn/training/callbacks.py +89 -78
- rxnn/transformers/layers.py +8 -0
- rxnn/transformers/models.py +7 -0
- rxnn/transformers/moe.py +7 -5
- {rxnn-0.1.9.dist-info → rxnn-0.1.11.dist-info}/METADATA +1 -1
- {rxnn-0.1.9.dist-info → rxnn-0.1.11.dist-info}/RECORD +10 -10
- {rxnn-0.1.9.dist-info → rxnn-0.1.11.dist-info}/LICENSE +0 -0
- {rxnn-0.1.9.dist-info → rxnn-0.1.11.dist-info}/WHEEL +0 -0
rxnn/training/base.py
CHANGED
@@ -49,6 +49,7 @@ class BaseTrainer(ABC):
|
|
49
49
|
self.validation_metrics = {}
|
50
50
|
self.target_field_name = target_field_name
|
51
51
|
self.total_tokens = 0
|
52
|
+
self.total_steps = 0
|
52
53
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
53
54
|
self.accumulated_loss = 0.0
|
54
55
|
self.optimizer_step_count = 0
|
@@ -140,6 +141,7 @@ class BaseTrainer(ABC):
|
|
140
141
|
for callback in self.callbacks:
|
141
142
|
callback.on_batch_start(self.model, batch_idx, batch)
|
142
143
|
if self.get_batch_size(batch) == batch_size:
|
144
|
+
self.total_steps += 1
|
143
145
|
loss = self.train_step(batch, batch_idx)
|
144
146
|
orig_loss = loss.item()
|
145
147
|
self.accumulated_loss += orig_loss
|
@@ -226,10 +228,11 @@ class BaseTrainer(ABC):
|
|
226
228
|
self.writer.close()
|
227
229
|
|
228
230
|
def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
|
229
|
-
self.writer.add_scalar('Loss/
|
230
|
-
self.writer.add_scalar('Perplexity/
|
231
|
+
self.writer.add_scalar('Loss/Valid', val_loss, epoch)
|
232
|
+
self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
|
231
233
|
if val_metrics['accuracy']:
|
232
|
-
self.writer.add_scalar('Accuracy/
|
234
|
+
self.writer.add_scalar('Node Accuracy/Valid', val_metrics['node_accuracy'], epoch)
|
235
|
+
self.writer.add_scalar('Avg. Accuracy/Valid', val_metrics['accuracy'], epoch)
|
233
236
|
|
234
237
|
def valid_step(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
235
238
|
if self.use_amp:
|
rxnn/training/bml.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
+
from torch.nn.parallel import DistributedDataParallel
|
4
5
|
import math
|
5
6
|
from huggingface_hub import PyTorchModelHubMixin
|
6
7
|
from typing import Union
|
@@ -90,12 +91,15 @@ class MLMTrainer(BaseTrainer):
|
|
90
91
|
total += valid_indices.sum()
|
91
92
|
|
92
93
|
avg_loss = (val_loss / len(val_dataloader)).item()
|
94
|
+
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
95
|
+
node_acc = acc.item()
|
93
96
|
if self.use_ddp:
|
94
|
-
dist.all_reduce(
|
95
|
-
|
97
|
+
dist.all_reduce(acc, op=dist.ReduceOp.SUM)
|
98
|
+
acc = acc / dist.get_world_size()
|
96
99
|
|
97
100
|
metrics = {
|
98
|
-
'accuracy':
|
101
|
+
'accuracy': acc.item(),
|
102
|
+
'node_accuracy': node_acc,
|
99
103
|
}
|
100
104
|
self.model.train()
|
101
105
|
return avg_loss, metrics
|
@@ -154,13 +158,99 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
154
158
|
total += valid_indices.sum()
|
155
159
|
|
156
160
|
avg_loss = (val_loss / len(val_dataloader)).item()
|
161
|
+
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
162
|
+
node_acc = acc.item()
|
163
|
+
if self.use_ddp:
|
164
|
+
dist.all_reduce(acc, op=dist.ReduceOp.SUM)
|
165
|
+
acc = acc / dist.get_world_size()
|
166
|
+
|
167
|
+
metrics = {
|
168
|
+
'accuracy': acc.item(),
|
169
|
+
'node_accuracy': node_acc,
|
170
|
+
}
|
171
|
+
self.model.train()
|
172
|
+
return avg_loss, metrics
|
173
|
+
|
174
|
+
|
175
|
+
class AutoregressiveMoeTrainer(BaseTrainer):
|
176
|
+
def __init__(
|
177
|
+
self,
|
178
|
+
model: ReactiveTransformerDecoder,
|
179
|
+
device: torch.device,
|
180
|
+
vocab_size: int,
|
181
|
+
use_amp: bool = False,
|
182
|
+
dtype: torch.dtype = None,
|
183
|
+
router_loss_scale: float = 0.1,
|
184
|
+
**kwargs
|
185
|
+
):
|
186
|
+
super(AutoregressiveMoeTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
|
187
|
+
target_field_name='targets', **kwargs)
|
188
|
+
self.vocab_size = vocab_size
|
189
|
+
self.router_loss_scale = router_loss_scale
|
190
|
+
|
191
|
+
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
192
|
+
inputs = batch['input_ids']
|
193
|
+
attention_mask = batch['attention_mask']
|
194
|
+
targets = batch['targets']
|
195
|
+
|
196
|
+
outputs = self.model(
|
197
|
+
inputs,
|
198
|
+
attention_mask=attention_mask
|
199
|
+
)
|
200
|
+
|
201
|
+
shifted_logits = outputs[:, :-1].contiguous()
|
202
|
+
shifted_targets = targets[:, 1:].contiguous()
|
157
203
|
|
204
|
+
main_loss = F.cross_entropy(
|
205
|
+
shifted_logits.view(-1, self.vocab_size),
|
206
|
+
shifted_targets.view(-1)
|
207
|
+
)
|
208
|
+
|
209
|
+
model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
|
210
|
+
|
211
|
+
router_loss = model.moe_router_loss()
|
212
|
+
loss = main_loss + self.router_loss_scale * router_loss
|
213
|
+
|
214
|
+
if self.writer is not None:
|
215
|
+
if self.model.training:
|
216
|
+
self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
|
217
|
+
self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
|
218
|
+
else:
|
219
|
+
self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
|
220
|
+
self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
|
221
|
+
|
222
|
+
return loss, outputs
|
223
|
+
|
224
|
+
def validate(self, batch_size: int) -> tuple[float, dict]:
|
225
|
+
self.model.eval()
|
226
|
+
val_dataloader = self._valid_loader(batch_size)
|
227
|
+
val_loss = torch.tensor(0.0).to(self.device)
|
228
|
+
correct = torch.tensor(0).to(self.device)
|
229
|
+
total = torch.tensor(0).to(self.device)
|
230
|
+
|
231
|
+
with torch.no_grad():
|
232
|
+
for batch in val_dataloader:
|
233
|
+
if self.get_batch_size(batch) == batch_size:
|
234
|
+
loss, logits = self.valid_step(batch)
|
235
|
+
val_loss += loss
|
236
|
+
shifted_logits = logits[:, :-1].contiguous()
|
237
|
+
shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
|
238
|
+
valid_indices = shifted_targets != -100
|
239
|
+
if valid_indices.any():
|
240
|
+
preds = shifted_logits.argmax(-1)
|
241
|
+
correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
|
242
|
+
total += valid_indices.sum()
|
243
|
+
|
244
|
+
avg_loss = (val_loss / len(val_dataloader)).item()
|
245
|
+
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
246
|
+
node_acc = acc.item()
|
158
247
|
if self.use_ddp:
|
159
|
-
dist.all_reduce(
|
160
|
-
|
248
|
+
dist.all_reduce(acc, op=dist.ReduceOp.SUM)
|
249
|
+
acc = acc / dist.get_world_size()
|
161
250
|
|
162
251
|
metrics = {
|
163
|
-
'accuracy':
|
252
|
+
'accuracy': acc.item(),
|
253
|
+
'node_accuracy': node_acc,
|
164
254
|
}
|
165
255
|
self.model.train()
|
166
256
|
return avg_loss, metrics
|
@@ -257,16 +347,18 @@ class JointLMTrainer(BaseTrainer):
|
|
257
347
|
return (encoder_loss, decoder_loss), (encoder_logits, decoder_logits)
|
258
348
|
|
259
349
|
def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
|
260
|
-
self.writer.add_scalar('Loss/
|
261
|
-
self.writer.add_scalar('Perplexity/
|
350
|
+
self.writer.add_scalar('Loss/Valid', val_loss, epoch)
|
351
|
+
self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
|
262
352
|
if val_metrics['accuracy']:
|
263
|
-
self.writer.add_scalar('Encoder accuracy/
|
264
|
-
self.writer.add_scalar('Decoder accuracy/
|
353
|
+
self.writer.add_scalar('Encoder node accuracy/Valid', val_metrics['accuracy']['node_encoder'], epoch)
|
354
|
+
self.writer.add_scalar('Decoder node accuracy/Valid', val_metrics['accuracy']['node_decoder'], epoch)
|
355
|
+
self.writer.add_scalar('Encoder avg. accuracy/Valid', val_metrics['accuracy']['encoder'], epoch)
|
356
|
+
self.writer.add_scalar('Decoder avg. accuracy/Valid', val_metrics['accuracy']['decoder'], epoch)
|
265
357
|
if val_metrics['loss']:
|
266
|
-
self.writer.add_scalar('Encoder loss/
|
267
|
-
self.writer.add_scalar('Encoder perplexity/
|
268
|
-
self.writer.add_scalar('Decoder accuracy/
|
269
|
-
self.writer.add_scalar('Decoder perplexity/
|
358
|
+
self.writer.add_scalar('Encoder loss/Valid', val_metrics['loss']['encoder'], epoch)
|
359
|
+
self.writer.add_scalar('Encoder perplexity/Valid', math.exp(val_metrics['loss']['encoder']), epoch)
|
360
|
+
self.writer.add_scalar('Decoder accuracy/Valid', val_metrics['loss']['decoder'], epoch)
|
361
|
+
self.writer.add_scalar('Decoder perplexity/Valid', math.exp(val_metrics['loss']['decoder']), epoch)
|
270
362
|
|
271
363
|
def validate(self, batch_size: int) -> tuple[float, dict]:
|
272
364
|
self.model.eval()
|
@@ -317,28 +409,30 @@ class JointLMTrainer(BaseTrainer):
|
|
317
409
|
avg_loss = val_loss / loader_len
|
318
410
|
avg_dec_loss = dec_loss / loader_len
|
319
411
|
avg_enc_loss = enc_loss / loader_len
|
320
|
-
|
412
|
+
mlm_acc = (correct_mlm / total_mlm * 100) if total_mlm > 0 else torch.tensor(0.0).to(self.device)
|
413
|
+
alm_acc = (correct_alm / total_alm * 100) if total_alm > 0 else torch.tensor(0.0).to(self.device)
|
414
|
+
node_mlm_acc = mlm_acc.item()
|
415
|
+
node_alm_acc = alm_acc.item()
|
321
416
|
if self.use_ddp:
|
322
417
|
dist.all_reduce(avg_dec_loss, op=dist.ReduceOp.SUM)
|
323
418
|
dist.all_reduce(avg_enc_loss, op=dist.ReduceOp.SUM)
|
324
|
-
dist.all_reduce(
|
325
|
-
dist.all_reduce(
|
326
|
-
dist.all_reduce(correct_alm, op=dist.ReduceOp.SUM)
|
327
|
-
dist.all_reduce(total_alm, op=dist.ReduceOp.SUM)
|
419
|
+
dist.all_reduce(mlm_acc, op=dist.ReduceOp.SUM)
|
420
|
+
dist.all_reduce(alm_acc, op=dist.ReduceOp.SUM)
|
328
421
|
avg_dec_loss = avg_dec_loss / dist.get_world_size()
|
329
422
|
avg_enc_loss = avg_enc_loss / dist.get_world_size()
|
330
|
-
|
331
|
-
|
332
|
-
alm_acc = (correct_alm / total_alm * 100).item() if total_alm > 0 else 0.0
|
423
|
+
mlm_acc = mlm_acc / dist.get_world_size()
|
424
|
+
alm_acc = alm_acc / dist.get_world_size()
|
333
425
|
|
334
426
|
metrics = {
|
335
427
|
'accuracy': {
|
336
|
-
'encoder': mlm_acc,
|
337
|
-
'decoder': alm_acc,
|
428
|
+
'encoder': mlm_acc.item(),
|
429
|
+
'decoder': alm_acc.item(),
|
430
|
+
'node_encoder': node_mlm_acc,
|
431
|
+
'node_decoder': node_alm_acc,
|
338
432
|
},
|
339
433
|
'loss': {
|
340
|
-
'encoder': avg_enc_loss,
|
341
|
-
'decoder': avg_dec_loss,
|
434
|
+
'encoder': avg_enc_loss.item(),
|
435
|
+
'decoder': avg_dec_loss.item(),
|
342
436
|
}
|
343
437
|
}
|
344
438
|
self.model.train()
|
rxnn/training/callbacks.py
CHANGED
@@ -83,9 +83,12 @@ class PrintAccuracyCallback(TrainerCallback):
|
|
83
83
|
|
84
84
|
def on_validation_end(self, model: nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> None:
|
85
85
|
if self.joint_mode:
|
86
|
+
print(f"Epoch {epoch} - encoder node accuracy: {val_metrics['accuracy']['node_encoder']:.4f}")
|
87
|
+
print(f"Epoch {epoch} - decoder node accuracy: {val_metrics['accuracy']['node_decoder']:.4f}")
|
86
88
|
print(f"Epoch {epoch} - encoder accuracy: {val_metrics['accuracy']['encoder']:.4f}")
|
87
89
|
print(f"Epoch {epoch} - decoder accuracy: {val_metrics['accuracy']['decoder']:.4f}")
|
88
90
|
else:
|
91
|
+
print(f"Epoch {epoch} - node accuracy: {val_metrics['node_accuracy']:.4f}")
|
89
92
|
print(f"Epoch {epoch} - accuracy: {val_metrics['accuracy']:.4f}")
|
90
93
|
|
91
94
|
|
@@ -130,6 +133,7 @@ class ModelSaveCallback(TrainerCallback):
|
|
130
133
|
save_checkpoint_after_n_batches: int = None,
|
131
134
|
push_batch_checkpoint: bool = False,
|
132
135
|
display_exc_trace: bool = False,
|
136
|
+
use_ddp: bool = False,
|
133
137
|
):
|
134
138
|
self.save_dir = save_dir
|
135
139
|
self.save_best_only = save_best_only
|
@@ -146,10 +150,11 @@ class ModelSaveCallback(TrainerCallback):
|
|
146
150
|
self.push_batch_checkpoint = push_batch_checkpoint
|
147
151
|
self.finished_epochs = 0
|
148
152
|
self.display_exc_trace = display_exc_trace
|
153
|
+
self.rank = int(os.environ['RANK']) if use_ddp else 0
|
149
154
|
|
150
155
|
def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
|
151
156
|
bool, None]:
|
152
|
-
if self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
|
157
|
+
if self.rank == 0 and self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
|
153
158
|
if isinstance(model, DistributedDataParallel):
|
154
159
|
model = next(model.children())
|
155
160
|
try:
|
@@ -195,90 +200,92 @@ class ModelSaveCallback(TrainerCallback):
|
|
195
200
|
val_loss: float,
|
196
201
|
val_metrics: dict
|
197
202
|
):
|
198
|
-
self.
|
199
|
-
|
200
|
-
self.best_loss
|
203
|
+
if self.rank == 0:
|
204
|
+
self.finished_epochs += 1
|
205
|
+
if val_loss < self.best_loss:
|
206
|
+
self.best_loss = val_loss
|
207
|
+
if isinstance(model, DistributedDataParallel):
|
208
|
+
model = next(model.children())
|
209
|
+
try:
|
210
|
+
if model.save_pretrained is not None:
|
211
|
+
ckpt_path = os.path.join(
|
212
|
+
self.save_dir,
|
213
|
+
f'epoch_{epoch}_val_loss_{val_loss:.4f}'
|
214
|
+
)
|
215
|
+
path_exists = os.path.exists(ckpt_path)
|
216
|
+
if not path_exists:
|
217
|
+
os.makedirs(ckpt_path)
|
218
|
+
model.save_pretrained(save_directory=ckpt_path)
|
219
|
+
else:
|
220
|
+
path_exists = os.path.exists(self.save_dir)
|
221
|
+
if not path_exists:
|
222
|
+
os.makedirs(self.save_dir)
|
223
|
+
ckpt_path = os.path.join(
|
224
|
+
self.save_dir,
|
225
|
+
f'epoch_{epoch}_val_loss_{val_loss:.4f}.pt'
|
226
|
+
)
|
227
|
+
torch.save(model.state_dict(), ckpt_path)
|
228
|
+
self.ckpt_paths.append(ckpt_path)
|
229
|
+
|
230
|
+
# Keep only N best checkpoints
|
231
|
+
if len(self.ckpt_paths) > self.max_keep:
|
232
|
+
oldest_path = self.ckpt_paths.pop(0)
|
233
|
+
if model.save_pretrained is not None:
|
234
|
+
shutil.rmtree(oldest_path)
|
235
|
+
else:
|
236
|
+
os.remove(oldest_path)
|
237
|
+
except Exception as e:
|
238
|
+
print(f"Error saving epoch checkpoint: {str(e)}")
|
239
|
+
if self.display_exc_trace:
|
240
|
+
traceback.print_exc()
|
241
|
+
|
242
|
+
try:
|
243
|
+
if self.push_to_hub and self.push_checkpoint_weights and model.push_to_hub is not None and self.hub_model_id:
|
244
|
+
model.push_to_hub(
|
245
|
+
repo_id=self.hub_model_id,
|
246
|
+
commit_message=f'Epoch {epoch} - Val loss {val_loss:.4f}',
|
247
|
+
token=self.hf_token,
|
248
|
+
private=self.private_repo,
|
249
|
+
)
|
250
|
+
except Exception as e:
|
251
|
+
print(f"Error pushing epoch checkpoint: {str(e)}")
|
252
|
+
if self.display_exc_trace:
|
253
|
+
traceback.print_exc()
|
254
|
+
|
255
|
+
def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
|
256
|
+
if self.rank == 0:
|
201
257
|
if isinstance(model, DistributedDataParallel):
|
202
258
|
model = next(model.children())
|
203
259
|
try:
|
260
|
+
# Save final model
|
204
261
|
if model.save_pretrained is not None:
|
205
262
|
ckpt_path = os.path.join(
|
206
263
|
self.save_dir,
|
207
|
-
|
264
|
+
'final_model'
|
208
265
|
)
|
209
|
-
path_exists = os.path.exists(ckpt_path)
|
210
|
-
if not path_exists:
|
211
|
-
os.makedirs(ckpt_path)
|
212
266
|
model.save_pretrained(save_directory=ckpt_path)
|
213
267
|
else:
|
214
|
-
|
215
|
-
if not path_exists:
|
216
|
-
os.makedirs(self.save_dir)
|
217
|
-
ckpt_path = os.path.join(
|
218
|
-
self.save_dir,
|
219
|
-
f'epoch_{epoch}_val_loss_{val_loss:.4f}.pt'
|
220
|
-
)
|
268
|
+
ckpt_path = os.path.join(self.save_dir, 'final_model.pt')
|
221
269
|
torch.save(model.state_dict(), ckpt_path)
|
222
|
-
|
223
|
-
|
224
|
-
# Keep only N best checkpoints
|
225
|
-
if len(self.ckpt_paths) > self.max_keep:
|
226
|
-
oldest_path = self.ckpt_paths.pop(0)
|
227
|
-
if model.save_pretrained is not None:
|
228
|
-
shutil.rmtree(oldest_path)
|
229
|
-
else:
|
230
|
-
os.remove(oldest_path)
|
270
|
+
print(f"Final model saved to {ckpt_path}")
|
231
271
|
except Exception as e:
|
232
|
-
print(f"Error saving
|
272
|
+
print(f"Error saving final model: {str(e)}")
|
233
273
|
if self.display_exc_trace:
|
234
274
|
traceback.print_exc()
|
235
|
-
|
236
275
|
try:
|
237
|
-
if self.push_to_hub and
|
276
|
+
if self.push_to_hub and model.push_to_hub is not None:
|
238
277
|
model.push_to_hub(
|
239
278
|
repo_id=self.hub_model_id,
|
240
|
-
commit_message=f'
|
279
|
+
commit_message=self.final_commit_message or f'Final pre-trained model, after {self.finished_epochs} epochs',
|
241
280
|
token=self.hf_token,
|
242
281
|
private=self.private_repo,
|
243
282
|
)
|
283
|
+
print(f"Model uploaded to repo: {self.hub_model_id}")
|
244
284
|
except Exception as e:
|
245
|
-
print(f"Error pushing
|
285
|
+
print(f"Error pushing final model: {str(e)}")
|
246
286
|
if self.display_exc_trace:
|
247
287
|
traceback.print_exc()
|
248
288
|
|
249
|
-
def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
|
250
|
-
if isinstance(model, DistributedDataParallel):
|
251
|
-
model = next(model.children())
|
252
|
-
try:
|
253
|
-
# Save final model
|
254
|
-
if model.save_pretrained is not None:
|
255
|
-
ckpt_path = os.path.join(
|
256
|
-
self.save_dir,
|
257
|
-
'final_model'
|
258
|
-
)
|
259
|
-
model.save_pretrained(save_directory=ckpt_path)
|
260
|
-
else:
|
261
|
-
ckpt_path = os.path.join(self.save_dir, 'final_model.pt')
|
262
|
-
torch.save(model.state_dict(), ckpt_path)
|
263
|
-
print(f"Final model saved to {ckpt_path}")
|
264
|
-
except Exception as e:
|
265
|
-
print(f"Error saving final model: {str(e)}")
|
266
|
-
if self.display_exc_trace:
|
267
|
-
traceback.print_exc()
|
268
|
-
try:
|
269
|
-
if self.push_to_hub and model.push_to_hub is not None:
|
270
|
-
model.push_to_hub(
|
271
|
-
repo_id=self.hub_model_id,
|
272
|
-
commit_message=self.final_commit_message or f'Final pre-trained model, after {self.finished_epochs} epochs',
|
273
|
-
token=self.hf_token,
|
274
|
-
private=self.private_repo,
|
275
|
-
)
|
276
|
-
print(f"Model uploaded to repo: {self.hub_model_id}")
|
277
|
-
except Exception as e:
|
278
|
-
print(f"Error pushing final model: {str(e)}")
|
279
|
-
if self.display_exc_trace:
|
280
|
-
traceback.print_exc()
|
281
|
-
|
282
289
|
|
283
290
|
class JointModelSaveCallback(TrainerCallback):
|
284
291
|
def __init__(
|
@@ -298,6 +305,7 @@ class JointModelSaveCallback(TrainerCallback):
|
|
298
305
|
push_batch_checkpoint: bool = False,
|
299
306
|
mlm_mode: bool = False,
|
300
307
|
display_exc_trace: bool = False,
|
308
|
+
use_ddp: bool = False,
|
301
309
|
):
|
302
310
|
self.save_dir = save_dir
|
303
311
|
self.save_best_only = save_best_only
|
@@ -317,6 +325,7 @@ class JointModelSaveCallback(TrainerCallback):
|
|
317
325
|
self.finished_epochs = 0
|
318
326
|
self.mlm_mode = mlm_mode
|
319
327
|
self.display_exc_trace = display_exc_trace
|
328
|
+
self.rank = int(os.environ['RANK']) if use_ddp else 0
|
320
329
|
|
321
330
|
def _save_batch(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
|
322
331
|
try:
|
@@ -362,7 +371,7 @@ class JointModelSaveCallback(TrainerCallback):
|
|
362
371
|
|
363
372
|
def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
|
364
373
|
bool, None]:
|
365
|
-
if self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
|
374
|
+
if self.rank == 0 and self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
|
366
375
|
if isinstance(model, DistributedDataParallel):
|
367
376
|
model = next(model.children())
|
368
377
|
self._save_batch(model.encoder, 'encoder', hub_id=self.hub_model_encoder)
|
@@ -430,15 +439,16 @@ class JointModelSaveCallback(TrainerCallback):
|
|
430
439
|
val_loss: float,
|
431
440
|
val_metrics: dict
|
432
441
|
):
|
433
|
-
self.
|
434
|
-
|
435
|
-
self.best_loss
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
+
if self.rank == 0:
|
443
|
+
self.finished_epochs += 1
|
444
|
+
if val_loss < self.best_loss:
|
445
|
+
self.best_loss = val_loss
|
446
|
+
if isinstance(model, DistributedDataParallel):
|
447
|
+
model = next(model.children())
|
448
|
+
self._save_validation(model.encoder, 'encoder', epoch, val_loss, hub_id=self.hub_model_encoder)
|
449
|
+
if not self.mlm_mode:
|
450
|
+
self._save_validation(model.decoder, 'decoder', epoch, val_loss, hub_id=self.hub_model_decoder)
|
451
|
+
self._save_validation(model.mlm_head, 'head', epoch, val_loss, hub_id=self.hub_model_head)
|
442
452
|
|
443
453
|
def _save_final(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
|
444
454
|
try:
|
@@ -482,9 +492,10 @@ class JointModelSaveCallback(TrainerCallback):
|
|
482
492
|
traceback.print_exc()
|
483
493
|
|
484
494
|
def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
|
485
|
-
if
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
495
|
+
if self.rank == 0:
|
496
|
+
if isinstance(model, DistributedDataParallel):
|
497
|
+
model = next(model.children())
|
498
|
+
self._save_final(model.encoder, 'encoder', hub_id=self.hub_model_encoder)
|
499
|
+
if not self.mlm_mode:
|
500
|
+
self._save_final(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
|
501
|
+
self._save_final(model.mlm_head, 'head', hub_id=self.hub_model_head)
|
rxnn/transformers/layers.py
CHANGED
@@ -53,11 +53,15 @@ class ReactiveTransformerLayer(nn.Module):
|
|
53
53
|
self.norm2 = nn.LayerNorm(embed_dim)
|
54
54
|
self.norm3 = nn.LayerNorm(embed_dim)
|
55
55
|
self.use_post_norm = use_post_norm
|
56
|
+
self.use_moe = use_moe
|
56
57
|
|
57
58
|
def trainable_cross_attention_(self, is_trainable: bool):
|
58
59
|
for param in self.memory_cross_attention.parameters():
|
59
60
|
param.requires_grad_(is_trainable)
|
60
61
|
|
62
|
+
def moe_router_loss_(self):
|
63
|
+
return self.ff.router_loss() if self.use_moe else None
|
64
|
+
|
61
65
|
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
62
66
|
# First step, self-attention
|
63
67
|
residual = x
|
@@ -129,6 +133,10 @@ class ClassicTransformerLayer(nn.Module):
|
|
129
133
|
self.norm1 = nn.LayerNorm(embed_dim)
|
130
134
|
self.norm2 = nn.LayerNorm(embed_dim)
|
131
135
|
self.use_post_norm = use_post_norm
|
136
|
+
self.use_moe = use_moe
|
137
|
+
|
138
|
+
def moe_router_loss_(self):
|
139
|
+
return self.ff.router_loss() if self.use_moe else torch.tensor(0.0)
|
132
140
|
|
133
141
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
134
142
|
# First step, self-attention
|
rxnn/transformers/models.py
CHANGED
@@ -37,6 +37,10 @@ class ReactiveTransformerBase(nn.Module):
|
|
37
37
|
for i in range(self.num_own_layers):
|
38
38
|
self.layers[i].trainable_cross_attention_(is_trainable)
|
39
39
|
|
40
|
+
def moe_router_loss_(self):
|
41
|
+
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
|
42
|
+
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
|
43
|
+
|
40
44
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
41
45
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
42
46
|
x = self.embedding(x)
|
@@ -119,6 +123,9 @@ class ClassicTransformerBase(nn.Module):
|
|
119
123
|
self.layers = layers
|
120
124
|
self.num_layers = len(layers) if layers else 0
|
121
125
|
|
126
|
+
def moe_router_loss_(self):
|
127
|
+
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
|
128
|
+
|
122
129
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
123
130
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
124
131
|
x = self.embedding(x)
|
rxnn/transformers/moe.py
CHANGED
@@ -11,7 +11,8 @@ class MoeRouter(nn.Module):
|
|
11
11
|
self.top_k = top_k
|
12
12
|
self.num_experts = num_experts
|
13
13
|
self.gate = nn.Linear(embed_dim, num_experts, bias=False)
|
14
|
-
|
14
|
+
# For expert load balancing
|
15
|
+
self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
|
15
16
|
|
16
17
|
def forward(self, x: torch.Tensor):
|
17
18
|
# x shape: [batch_size*seq_len, embed_dim]
|
@@ -19,10 +20,8 @@ class MoeRouter(nn.Module):
|
|
19
20
|
probs = F.softmax(logits, dim=-1)
|
20
21
|
|
21
22
|
# Expert load balancing loss
|
22
|
-
|
23
|
-
|
24
|
-
self.aux_loss = (probs_for_bal.mean(dim=0) *
|
25
|
-
torch.log(probs_for_bal.mean(dim=0) + 1e-9)).sum()
|
23
|
+
mean_probs = probs.mean(dim=0) # Mean probability per expert across batch
|
24
|
+
self.aux_loss = (mean_probs * torch.log(mean_probs + 1e-9)).sum() # Entropy-based loss
|
26
25
|
|
27
26
|
top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
|
28
27
|
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
@@ -74,6 +73,9 @@ class MoeFeedForward(nn.Module):
|
|
74
73
|
def _activate(self, h: torch.Tensor):
|
75
74
|
return self.activation(h)
|
76
75
|
|
76
|
+
def router_loss(self):
|
77
|
+
return self.router.aux_loss
|
78
|
+
|
77
79
|
def forward(self, x: torch.Tensor):
|
78
80
|
orig_shape = x.shape
|
79
81
|
x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
@@ -7,23 +7,23 @@ rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
|
7
7
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
|
9
9
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
rxnn/training/base.py,sha256=
|
11
|
-
rxnn/training/bml.py,sha256=
|
12
|
-
rxnn/training/callbacks.py,sha256=
|
10
|
+
rxnn/training/base.py,sha256=QD8uS14jSyR5Y_8BgCaBQTKpsarerU3lyufsWsCq_6o,11227
|
11
|
+
rxnn/training/bml.py,sha256=2kk9q3Buxq4wBHUQhyIAuHoBCninYX2K8hykWAJnxB0,18654
|
12
|
+
rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
|
13
13
|
rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
14
14
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
15
15
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
16
16
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
rxnn/transformers/attention.py,sha256=FfEYE0THO73p_1eRupr2mcwfW4UbI_riIxkHfr8X_1c,14022
|
18
18
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
19
|
-
rxnn/transformers/layers.py,sha256=
|
19
|
+
rxnn/transformers/layers.py,sha256=xMocHzdSu7hcC_mPE_aG3-LQg2RXgunKSxcgNXYnOeo,5631
|
20
20
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
21
|
-
rxnn/transformers/models.py,sha256=
|
22
|
-
rxnn/transformers/moe.py,sha256=
|
21
|
+
rxnn/transformers/models.py,sha256=PVhiTTSQ7VTDVdOcyRf-xGNvj6oOa_2fUV2mfthcE0Y,7171
|
22
|
+
rxnn/transformers/moe.py,sha256=v21HDEhkDr10--If0P-XBjT5C7IlQJo0wGQlpDnVWEA,5020
|
23
23
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
24
24
|
rxnn/transformers/sampler.py,sha256=wSz_1wNloqtuiix5w2Mcsj5NhaO9QlY0j__TVG7wJnM,3938
|
25
25
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
26
|
-
rxnn-0.1.
|
27
|
-
rxnn-0.1.
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
26
|
+
rxnn-0.1.11.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
27
|
+
rxnn-0.1.11.dist-info/METADATA,sha256=WFoe6AqfJVI6wFZ23i3qGQ3babDlLtjIMU0htjOIikw,14629
|
28
|
+
rxnn-0.1.11.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
29
|
+
rxnn-0.1.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|