rxnn 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/base.py CHANGED
@@ -49,6 +49,7 @@ class BaseTrainer(ABC):
49
49
  self.validation_metrics = {}
50
50
  self.target_field_name = target_field_name
51
51
  self.total_tokens = 0
52
+ self.total_steps = 0
52
53
  self.gradient_accumulation_steps = gradient_accumulation_steps
53
54
  self.accumulated_loss = 0.0
54
55
  self.optimizer_step_count = 0
@@ -140,6 +141,7 @@ class BaseTrainer(ABC):
140
141
  for callback in self.callbacks:
141
142
  callback.on_batch_start(self.model, batch_idx, batch)
142
143
  if self.get_batch_size(batch) == batch_size:
144
+ self.total_steps += 1
143
145
  loss = self.train_step(batch, batch_idx)
144
146
  orig_loss = loss.item()
145
147
  self.accumulated_loss += orig_loss
@@ -226,10 +228,11 @@ class BaseTrainer(ABC):
226
228
  self.writer.close()
227
229
 
228
230
  def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
229
- self.writer.add_scalar('Loss/validation', val_loss, epoch)
230
- self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
231
+ self.writer.add_scalar('Loss/Valid', val_loss, epoch)
232
+ self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
231
233
  if val_metrics['accuracy']:
232
- self.writer.add_scalar('Accuracy/validation', val_metrics['accuracy'], epoch)
234
+ self.writer.add_scalar('Node Accuracy/Valid', val_metrics['node_accuracy'], epoch)
235
+ self.writer.add_scalar('Avg. Accuracy/Valid', val_metrics['accuracy'], epoch)
233
236
 
234
237
  def valid_step(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
235
238
  if self.use_amp:
rxnn/training/bml.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
+ from torch.nn.parallel import DistributedDataParallel
4
5
  import math
5
6
  from huggingface_hub import PyTorchModelHubMixin
6
7
  from typing import Union
@@ -90,12 +91,15 @@ class MLMTrainer(BaseTrainer):
90
91
  total += valid_indices.sum()
91
92
 
92
93
  avg_loss = (val_loss / len(val_dataloader)).item()
94
+ acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
95
+ node_acc = acc.item()
93
96
  if self.use_ddp:
94
- dist.all_reduce(correct, op=dist.ReduceOp.SUM)
95
- dist.all_reduce(total, op=dist.ReduceOp.SUM)
97
+ dist.all_reduce(acc, op=dist.ReduceOp.SUM)
98
+ acc = acc / dist.get_world_size()
96
99
 
97
100
  metrics = {
98
- 'accuracy': (correct / total * 100).item() if total > 0 else 0.0
101
+ 'accuracy': acc.item(),
102
+ 'node_accuracy': node_acc,
99
103
  }
100
104
  self.model.train()
101
105
  return avg_loss, metrics
@@ -154,13 +158,99 @@ class AutoregressiveTrainer(BaseTrainer):
154
158
  total += valid_indices.sum()
155
159
 
156
160
  avg_loss = (val_loss / len(val_dataloader)).item()
161
+ acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
162
+ node_acc = acc.item()
163
+ if self.use_ddp:
164
+ dist.all_reduce(acc, op=dist.ReduceOp.SUM)
165
+ acc = acc / dist.get_world_size()
166
+
167
+ metrics = {
168
+ 'accuracy': acc.item(),
169
+ 'node_accuracy': node_acc,
170
+ }
171
+ self.model.train()
172
+ return avg_loss, metrics
173
+
174
+
175
+ class AutoregressiveMoeTrainer(BaseTrainer):
176
+ def __init__(
177
+ self,
178
+ model: ReactiveTransformerDecoder,
179
+ device: torch.device,
180
+ vocab_size: int,
181
+ use_amp: bool = False,
182
+ dtype: torch.dtype = None,
183
+ router_loss_scale: float = 0.1,
184
+ **kwargs
185
+ ):
186
+ super(AutoregressiveMoeTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
187
+ target_field_name='targets', **kwargs)
188
+ self.vocab_size = vocab_size
189
+ self.router_loss_scale = router_loss_scale
190
+
191
+ def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
192
+ inputs = batch['input_ids']
193
+ attention_mask = batch['attention_mask']
194
+ targets = batch['targets']
195
+
196
+ outputs = self.model(
197
+ inputs,
198
+ attention_mask=attention_mask
199
+ )
200
+
201
+ shifted_logits = outputs[:, :-1].contiguous()
202
+ shifted_targets = targets[:, 1:].contiguous()
157
203
 
204
+ main_loss = F.cross_entropy(
205
+ shifted_logits.view(-1, self.vocab_size),
206
+ shifted_targets.view(-1)
207
+ )
208
+
209
+ model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
210
+
211
+ router_loss = model.moe_router_loss()
212
+ loss = main_loss + self.router_loss_scale * router_loss
213
+
214
+ if self.writer is not None:
215
+ if self.model.training:
216
+ self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
217
+ self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
218
+ else:
219
+ self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
220
+ self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
221
+
222
+ return loss, outputs
223
+
224
+ def validate(self, batch_size: int) -> tuple[float, dict]:
225
+ self.model.eval()
226
+ val_dataloader = self._valid_loader(batch_size)
227
+ val_loss = torch.tensor(0.0).to(self.device)
228
+ correct = torch.tensor(0).to(self.device)
229
+ total = torch.tensor(0).to(self.device)
230
+
231
+ with torch.no_grad():
232
+ for batch in val_dataloader:
233
+ if self.get_batch_size(batch) == batch_size:
234
+ loss, logits = self.valid_step(batch)
235
+ val_loss += loss
236
+ shifted_logits = logits[:, :-1].contiguous()
237
+ shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
238
+ valid_indices = shifted_targets != -100
239
+ if valid_indices.any():
240
+ preds = shifted_logits.argmax(-1)
241
+ correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
242
+ total += valid_indices.sum()
243
+
244
+ avg_loss = (val_loss / len(val_dataloader)).item()
245
+ acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
246
+ node_acc = acc.item()
158
247
  if self.use_ddp:
159
- dist.all_reduce(correct, op=dist.ReduceOp.SUM)
160
- dist.all_reduce(total, op=dist.ReduceOp.SUM)
248
+ dist.all_reduce(acc, op=dist.ReduceOp.SUM)
249
+ acc = acc / dist.get_world_size()
161
250
 
162
251
  metrics = {
163
- 'accuracy': (correct / total * 100).item() if total > 0 else 0.0
252
+ 'accuracy': acc.item(),
253
+ 'node_accuracy': node_acc,
164
254
  }
165
255
  self.model.train()
166
256
  return avg_loss, metrics
@@ -257,16 +347,18 @@ class JointLMTrainer(BaseTrainer):
257
347
  return (encoder_loss, decoder_loss), (encoder_logits, decoder_logits)
258
348
 
259
349
  def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
260
- self.writer.add_scalar('Loss/validation', val_loss, epoch)
261
- self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
350
+ self.writer.add_scalar('Loss/Valid', val_loss, epoch)
351
+ self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
262
352
  if val_metrics['accuracy']:
263
- self.writer.add_scalar('Encoder accuracy/validation', val_metrics['accuracy']['encoder'], epoch)
264
- self.writer.add_scalar('Decoder accuracy/validation', val_metrics['accuracy']['decoder'], epoch)
353
+ self.writer.add_scalar('Encoder node accuracy/Valid', val_metrics['accuracy']['node_encoder'], epoch)
354
+ self.writer.add_scalar('Decoder node accuracy/Valid', val_metrics['accuracy']['node_decoder'], epoch)
355
+ self.writer.add_scalar('Encoder avg. accuracy/Valid', val_metrics['accuracy']['encoder'], epoch)
356
+ self.writer.add_scalar('Decoder avg. accuracy/Valid', val_metrics['accuracy']['decoder'], epoch)
265
357
  if val_metrics['loss']:
266
- self.writer.add_scalar('Encoder loss/validation', val_metrics['loss']['encoder'], epoch)
267
- self.writer.add_scalar('Encoder perplexity/validation', math.exp(val_metrics['loss']['encoder']), epoch)
268
- self.writer.add_scalar('Decoder accuracy/validation', val_metrics['loss']['decoder'], epoch)
269
- self.writer.add_scalar('Decoder perplexity/validation', math.exp(val_metrics['loss']['decoder']), epoch)
358
+ self.writer.add_scalar('Encoder loss/Valid', val_metrics['loss']['encoder'], epoch)
359
+ self.writer.add_scalar('Encoder perplexity/Valid', math.exp(val_metrics['loss']['encoder']), epoch)
360
+ self.writer.add_scalar('Decoder accuracy/Valid', val_metrics['loss']['decoder'], epoch)
361
+ self.writer.add_scalar('Decoder perplexity/Valid', math.exp(val_metrics['loss']['decoder']), epoch)
270
362
 
271
363
  def validate(self, batch_size: int) -> tuple[float, dict]:
272
364
  self.model.eval()
@@ -317,28 +409,30 @@ class JointLMTrainer(BaseTrainer):
317
409
  avg_loss = val_loss / loader_len
318
410
  avg_dec_loss = dec_loss / loader_len
319
411
  avg_enc_loss = enc_loss / loader_len
320
-
412
+ mlm_acc = (correct_mlm / total_mlm * 100) if total_mlm > 0 else torch.tensor(0.0).to(self.device)
413
+ alm_acc = (correct_alm / total_alm * 100) if total_alm > 0 else torch.tensor(0.0).to(self.device)
414
+ node_mlm_acc = mlm_acc.item()
415
+ node_alm_acc = alm_acc.item()
321
416
  if self.use_ddp:
322
417
  dist.all_reduce(avg_dec_loss, op=dist.ReduceOp.SUM)
323
418
  dist.all_reduce(avg_enc_loss, op=dist.ReduceOp.SUM)
324
- dist.all_reduce(correct_mlm, op=dist.ReduceOp.SUM)
325
- dist.all_reduce(total_mlm, op=dist.ReduceOp.SUM)
326
- dist.all_reduce(correct_alm, op=dist.ReduceOp.SUM)
327
- dist.all_reduce(total_alm, op=dist.ReduceOp.SUM)
419
+ dist.all_reduce(mlm_acc, op=dist.ReduceOp.SUM)
420
+ dist.all_reduce(alm_acc, op=dist.ReduceOp.SUM)
328
421
  avg_dec_loss = avg_dec_loss / dist.get_world_size()
329
422
  avg_enc_loss = avg_enc_loss / dist.get_world_size()
330
-
331
- mlm_acc = (correct_mlm / total_mlm * 100).item() if total_mlm > 0 else 0.0
332
- alm_acc = (correct_alm / total_alm * 100).item() if total_alm > 0 else 0.0
423
+ mlm_acc = mlm_acc / dist.get_world_size()
424
+ alm_acc = alm_acc / dist.get_world_size()
333
425
 
334
426
  metrics = {
335
427
  'accuracy': {
336
- 'encoder': mlm_acc,
337
- 'decoder': alm_acc,
428
+ 'encoder': mlm_acc.item(),
429
+ 'decoder': alm_acc.item(),
430
+ 'node_encoder': node_mlm_acc,
431
+ 'node_decoder': node_alm_acc,
338
432
  },
339
433
  'loss': {
340
- 'encoder': avg_enc_loss,
341
- 'decoder': avg_dec_loss,
434
+ 'encoder': avg_enc_loss.item(),
435
+ 'decoder': avg_dec_loss.item(),
342
436
  }
343
437
  }
344
438
  self.model.train()
@@ -83,9 +83,12 @@ class PrintAccuracyCallback(TrainerCallback):
83
83
 
84
84
  def on_validation_end(self, model: nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> None:
85
85
  if self.joint_mode:
86
+ print(f"Epoch {epoch} - encoder node accuracy: {val_metrics['accuracy']['node_encoder']:.4f}")
87
+ print(f"Epoch {epoch} - decoder node accuracy: {val_metrics['accuracy']['node_decoder']:.4f}")
86
88
  print(f"Epoch {epoch} - encoder accuracy: {val_metrics['accuracy']['encoder']:.4f}")
87
89
  print(f"Epoch {epoch} - decoder accuracy: {val_metrics['accuracy']['decoder']:.4f}")
88
90
  else:
91
+ print(f"Epoch {epoch} - node accuracy: {val_metrics['node_accuracy']:.4f}")
89
92
  print(f"Epoch {epoch} - accuracy: {val_metrics['accuracy']:.4f}")
90
93
 
91
94
 
@@ -130,6 +133,7 @@ class ModelSaveCallback(TrainerCallback):
130
133
  save_checkpoint_after_n_batches: int = None,
131
134
  push_batch_checkpoint: bool = False,
132
135
  display_exc_trace: bool = False,
136
+ use_ddp: bool = False,
133
137
  ):
134
138
  self.save_dir = save_dir
135
139
  self.save_best_only = save_best_only
@@ -146,10 +150,11 @@ class ModelSaveCallback(TrainerCallback):
146
150
  self.push_batch_checkpoint = push_batch_checkpoint
147
151
  self.finished_epochs = 0
148
152
  self.display_exc_trace = display_exc_trace
153
+ self.rank = int(os.environ['RANK']) if use_ddp else 0
149
154
 
150
155
  def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
151
156
  bool, None]:
152
- if self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
157
+ if self.rank == 0 and self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
153
158
  if isinstance(model, DistributedDataParallel):
154
159
  model = next(model.children())
155
160
  try:
@@ -195,90 +200,92 @@ class ModelSaveCallback(TrainerCallback):
195
200
  val_loss: float,
196
201
  val_metrics: dict
197
202
  ):
198
- self.finished_epochs += 1
199
- if val_loss < self.best_loss:
200
- self.best_loss = val_loss
203
+ if self.rank == 0:
204
+ self.finished_epochs += 1
205
+ if val_loss < self.best_loss:
206
+ self.best_loss = val_loss
207
+ if isinstance(model, DistributedDataParallel):
208
+ model = next(model.children())
209
+ try:
210
+ if model.save_pretrained is not None:
211
+ ckpt_path = os.path.join(
212
+ self.save_dir,
213
+ f'epoch_{epoch}_val_loss_{val_loss:.4f}'
214
+ )
215
+ path_exists = os.path.exists(ckpt_path)
216
+ if not path_exists:
217
+ os.makedirs(ckpt_path)
218
+ model.save_pretrained(save_directory=ckpt_path)
219
+ else:
220
+ path_exists = os.path.exists(self.save_dir)
221
+ if not path_exists:
222
+ os.makedirs(self.save_dir)
223
+ ckpt_path = os.path.join(
224
+ self.save_dir,
225
+ f'epoch_{epoch}_val_loss_{val_loss:.4f}.pt'
226
+ )
227
+ torch.save(model.state_dict(), ckpt_path)
228
+ self.ckpt_paths.append(ckpt_path)
229
+
230
+ # Keep only N best checkpoints
231
+ if len(self.ckpt_paths) > self.max_keep:
232
+ oldest_path = self.ckpt_paths.pop(0)
233
+ if model.save_pretrained is not None:
234
+ shutil.rmtree(oldest_path)
235
+ else:
236
+ os.remove(oldest_path)
237
+ except Exception as e:
238
+ print(f"Error saving epoch checkpoint: {str(e)}")
239
+ if self.display_exc_trace:
240
+ traceback.print_exc()
241
+
242
+ try:
243
+ if self.push_to_hub and self.push_checkpoint_weights and model.push_to_hub is not None and self.hub_model_id:
244
+ model.push_to_hub(
245
+ repo_id=self.hub_model_id,
246
+ commit_message=f'Epoch {epoch} - Val loss {val_loss:.4f}',
247
+ token=self.hf_token,
248
+ private=self.private_repo,
249
+ )
250
+ except Exception as e:
251
+ print(f"Error pushing epoch checkpoint: {str(e)}")
252
+ if self.display_exc_trace:
253
+ traceback.print_exc()
254
+
255
+ def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
256
+ if self.rank == 0:
201
257
  if isinstance(model, DistributedDataParallel):
202
258
  model = next(model.children())
203
259
  try:
260
+ # Save final model
204
261
  if model.save_pretrained is not None:
205
262
  ckpt_path = os.path.join(
206
263
  self.save_dir,
207
- f'epoch_{epoch}_val_loss_{val_loss:.4f}'
264
+ 'final_model'
208
265
  )
209
- path_exists = os.path.exists(ckpt_path)
210
- if not path_exists:
211
- os.makedirs(ckpt_path)
212
266
  model.save_pretrained(save_directory=ckpt_path)
213
267
  else:
214
- path_exists = os.path.exists(self.save_dir)
215
- if not path_exists:
216
- os.makedirs(self.save_dir)
217
- ckpt_path = os.path.join(
218
- self.save_dir,
219
- f'epoch_{epoch}_val_loss_{val_loss:.4f}.pt'
220
- )
268
+ ckpt_path = os.path.join(self.save_dir, 'final_model.pt')
221
269
  torch.save(model.state_dict(), ckpt_path)
222
- self.ckpt_paths.append(ckpt_path)
223
-
224
- # Keep only N best checkpoints
225
- if len(self.ckpt_paths) > self.max_keep:
226
- oldest_path = self.ckpt_paths.pop(0)
227
- if model.save_pretrained is not None:
228
- shutil.rmtree(oldest_path)
229
- else:
230
- os.remove(oldest_path)
270
+ print(f"Final model saved to {ckpt_path}")
231
271
  except Exception as e:
232
- print(f"Error saving epoch checkpoint: {str(e)}")
272
+ print(f"Error saving final model: {str(e)}")
233
273
  if self.display_exc_trace:
234
274
  traceback.print_exc()
235
-
236
275
  try:
237
- if self.push_to_hub and self.push_checkpoint_weights and model.push_to_hub is not None and self.hub_model_id:
276
+ if self.push_to_hub and model.push_to_hub is not None:
238
277
  model.push_to_hub(
239
278
  repo_id=self.hub_model_id,
240
- commit_message=f'Epoch {epoch} - Val loss {val_loss:.4f}',
279
+ commit_message=self.final_commit_message or f'Final pre-trained model, after {self.finished_epochs} epochs',
241
280
  token=self.hf_token,
242
281
  private=self.private_repo,
243
282
  )
283
+ print(f"Model uploaded to repo: {self.hub_model_id}")
244
284
  except Exception as e:
245
- print(f"Error pushing epoch checkpoint: {str(e)}")
285
+ print(f"Error pushing final model: {str(e)}")
246
286
  if self.display_exc_trace:
247
287
  traceback.print_exc()
248
288
 
249
- def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
250
- if isinstance(model, DistributedDataParallel):
251
- model = next(model.children())
252
- try:
253
- # Save final model
254
- if model.save_pretrained is not None:
255
- ckpt_path = os.path.join(
256
- self.save_dir,
257
- 'final_model'
258
- )
259
- model.save_pretrained(save_directory=ckpt_path)
260
- else:
261
- ckpt_path = os.path.join(self.save_dir, 'final_model.pt')
262
- torch.save(model.state_dict(), ckpt_path)
263
- print(f"Final model saved to {ckpt_path}")
264
- except Exception as e:
265
- print(f"Error saving final model: {str(e)}")
266
- if self.display_exc_trace:
267
- traceback.print_exc()
268
- try:
269
- if self.push_to_hub and model.push_to_hub is not None:
270
- model.push_to_hub(
271
- repo_id=self.hub_model_id,
272
- commit_message=self.final_commit_message or f'Final pre-trained model, after {self.finished_epochs} epochs',
273
- token=self.hf_token,
274
- private=self.private_repo,
275
- )
276
- print(f"Model uploaded to repo: {self.hub_model_id}")
277
- except Exception as e:
278
- print(f"Error pushing final model: {str(e)}")
279
- if self.display_exc_trace:
280
- traceback.print_exc()
281
-
282
289
 
283
290
  class JointModelSaveCallback(TrainerCallback):
284
291
  def __init__(
@@ -298,6 +305,7 @@ class JointModelSaveCallback(TrainerCallback):
298
305
  push_batch_checkpoint: bool = False,
299
306
  mlm_mode: bool = False,
300
307
  display_exc_trace: bool = False,
308
+ use_ddp: bool = False,
301
309
  ):
302
310
  self.save_dir = save_dir
303
311
  self.save_best_only = save_best_only
@@ -317,6 +325,7 @@ class JointModelSaveCallback(TrainerCallback):
317
325
  self.finished_epochs = 0
318
326
  self.mlm_mode = mlm_mode
319
327
  self.display_exc_trace = display_exc_trace
328
+ self.rank = int(os.environ['RANK']) if use_ddp else 0
320
329
 
321
330
  def _save_batch(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
322
331
  try:
@@ -362,7 +371,7 @@ class JointModelSaveCallback(TrainerCallback):
362
371
 
363
372
  def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
364
373
  bool, None]:
365
- if self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
374
+ if self.rank == 0 and self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
366
375
  if isinstance(model, DistributedDataParallel):
367
376
  model = next(model.children())
368
377
  self._save_batch(model.encoder, 'encoder', hub_id=self.hub_model_encoder)
@@ -430,15 +439,16 @@ class JointModelSaveCallback(TrainerCallback):
430
439
  val_loss: float,
431
440
  val_metrics: dict
432
441
  ):
433
- self.finished_epochs += 1
434
- if val_loss < self.best_loss:
435
- self.best_loss = val_loss
436
- if isinstance(model, DistributedDataParallel):
437
- model = next(model.children())
438
- self._save_validation(model.encoder, 'encoder', epoch, val_loss, hub_id=self.hub_model_encoder)
439
- if not self.mlm_mode:
440
- self._save_validation(model.decoder, 'decoder', epoch, val_loss, hub_id=self.hub_model_decoder)
441
- self._save_validation(model.mlm_head, 'head', epoch, val_loss, hub_id=self.hub_model_head)
442
+ if self.rank == 0:
443
+ self.finished_epochs += 1
444
+ if val_loss < self.best_loss:
445
+ self.best_loss = val_loss
446
+ if isinstance(model, DistributedDataParallel):
447
+ model = next(model.children())
448
+ self._save_validation(model.encoder, 'encoder', epoch, val_loss, hub_id=self.hub_model_encoder)
449
+ if not self.mlm_mode:
450
+ self._save_validation(model.decoder, 'decoder', epoch, val_loss, hub_id=self.hub_model_decoder)
451
+ self._save_validation(model.mlm_head, 'head', epoch, val_loss, hub_id=self.hub_model_head)
442
452
 
443
453
  def _save_final(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
444
454
  try:
@@ -482,9 +492,10 @@ class JointModelSaveCallback(TrainerCallback):
482
492
  traceback.print_exc()
483
493
 
484
494
  def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
485
- if isinstance(model, DistributedDataParallel):
486
- model = next(model.children())
487
- self._save_final(model.encoder, 'encoder', hub_id=self.hub_model_encoder)
488
- if not self.mlm_mode:
489
- self._save_final(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
490
- self._save_final(model.mlm_head, 'head', hub_id=self.hub_model_head)
495
+ if self.rank == 0:
496
+ if isinstance(model, DistributedDataParallel):
497
+ model = next(model.children())
498
+ self._save_final(model.encoder, 'encoder', hub_id=self.hub_model_encoder)
499
+ if not self.mlm_mode:
500
+ self._save_final(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
501
+ self._save_final(model.mlm_head, 'head', hub_id=self.hub_model_head)
@@ -53,11 +53,15 @@ class ReactiveTransformerLayer(nn.Module):
53
53
  self.norm2 = nn.LayerNorm(embed_dim)
54
54
  self.norm3 = nn.LayerNorm(embed_dim)
55
55
  self.use_post_norm = use_post_norm
56
+ self.use_moe = use_moe
56
57
 
57
58
  def trainable_cross_attention_(self, is_trainable: bool):
58
59
  for param in self.memory_cross_attention.parameters():
59
60
  param.requires_grad_(is_trainable)
60
61
 
62
+ def moe_router_loss_(self):
63
+ return self.ff.router_loss() if self.use_moe else None
64
+
61
65
  def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
62
66
  # First step, self-attention
63
67
  residual = x
@@ -129,6 +133,10 @@ class ClassicTransformerLayer(nn.Module):
129
133
  self.norm1 = nn.LayerNorm(embed_dim)
130
134
  self.norm2 = nn.LayerNorm(embed_dim)
131
135
  self.use_post_norm = use_post_norm
136
+ self.use_moe = use_moe
137
+
138
+ def moe_router_loss_(self):
139
+ return self.ff.router_loss() if self.use_moe else torch.tensor(0.0)
132
140
 
133
141
  def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
134
142
  # First step, self-attention
@@ -37,6 +37,10 @@ class ReactiveTransformerBase(nn.Module):
37
37
  for i in range(self.num_own_layers):
38
38
  self.layers[i].trainable_cross_attention_(is_trainable)
39
39
 
40
+ def moe_router_loss_(self):
41
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
42
+ self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
43
+
40
44
  def forward(self, x: torch.Tensor) -> torch.Tensor:
41
45
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
42
46
  x = self.embedding(x)
@@ -119,6 +123,9 @@ class ClassicTransformerBase(nn.Module):
119
123
  self.layers = layers
120
124
  self.num_layers = len(layers) if layers else 0
121
125
 
126
+ def moe_router_loss_(self):
127
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
128
+
122
129
  def forward(self, x: torch.Tensor) -> torch.Tensor:
123
130
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
124
131
  x = self.embedding(x)
rxnn/transformers/moe.py CHANGED
@@ -11,7 +11,8 @@ class MoeRouter(nn.Module):
11
11
  self.top_k = top_k
12
12
  self.num_experts = num_experts
13
13
  self.gate = nn.Linear(embed_dim, num_experts, bias=False)
14
- self.aux_loss = 0.0 # For expert load balancing
14
+ # For expert load balancing
15
+ self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
15
16
 
16
17
  def forward(self, x: torch.Tensor):
17
18
  # x shape: [batch_size*seq_len, embed_dim]
@@ -19,10 +20,8 @@ class MoeRouter(nn.Module):
19
20
  probs = F.softmax(logits, dim=-1)
20
21
 
21
22
  # Expert load balancing loss
22
- if self.training:
23
- probs_for_bal = F.softmax(logits, dim=0)
24
- self.aux_loss = (probs_for_bal.mean(dim=0) *
25
- torch.log(probs_for_bal.mean(dim=0) + 1e-9)).sum()
23
+ mean_probs = probs.mean(dim=0) # Mean probability per expert across batch
24
+ self.aux_loss = (mean_probs * torch.log(mean_probs + 1e-9)).sum() # Entropy-based loss
26
25
 
27
26
  top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
28
27
  top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
@@ -74,6 +73,9 @@ class MoeFeedForward(nn.Module):
74
73
  def _activate(self, h: torch.Tensor):
75
74
  return self.activation(h)
76
75
 
76
+ def router_loss(self):
77
+ return self.router.aux_loss
78
+
77
79
  def forward(self, x: torch.Tensor):
78
80
  orig_shape = x.shape
79
81
  x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -7,23 +7,23 @@ rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
7
7
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
9
9
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- rxnn/training/base.py,sha256=UOFFA1Ai6g8l2iOwdYuWrEJPkioej8DOU2YsYN4K9QI,11071
11
- rxnn/training/bml.py,sha256=pyK6aRLpXlPuLge6CQ9PD64Un57yUgbOpu8lUfTdV9k,14575
12
- rxnn/training/callbacks.py,sha256=IyVJAJ0ggJmfIWBZnpzV9U08URYCeWIStK_wbx7m3pg,21090
10
+ rxnn/training/base.py,sha256=QD8uS14jSyR5Y_8BgCaBQTKpsarerU3lyufsWsCq_6o,11227
11
+ rxnn/training/bml.py,sha256=2kk9q3Buxq4wBHUQhyIAuHoBCninYX2K8hykWAJnxB0,18654
12
+ rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
13
13
  rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
14
14
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
15
15
  rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
16
16
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  rxnn/transformers/attention.py,sha256=FfEYE0THO73p_1eRupr2mcwfW4UbI_riIxkHfr8X_1c,14022
18
18
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
19
- rxnn/transformers/layers.py,sha256=jdM7L0uOMO68aZiu9p6jba1Hx3aLGOChF1Zz-j4vJ5U,5364
19
+ rxnn/transformers/layers.py,sha256=xMocHzdSu7hcC_mPE_aG3-LQg2RXgunKSxcgNXYnOeo,5631
20
20
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
21
- rxnn/transformers/models.py,sha256=sLYMkVOWQ1NcM1evpCTUMucXvklySpeNT0IqpIGKmyc,6716
22
- rxnn/transformers/moe.py,sha256=JQ5QSX4FS7S-fqB7-s1ZmJbPpOeD_Injn8o4vo7wGQE,4936
21
+ rxnn/transformers/models.py,sha256=PVhiTTSQ7VTDVdOcyRf-xGNvj6oOa_2fUV2mfthcE0Y,7171
22
+ rxnn/transformers/moe.py,sha256=v21HDEhkDr10--If0P-XBjT5C7IlQJo0wGQlpDnVWEA,5020
23
23
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
24
24
  rxnn/transformers/sampler.py,sha256=wSz_1wNloqtuiix5w2Mcsj5NhaO9QlY0j__TVG7wJnM,3938
25
25
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
26
- rxnn-0.1.9.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
- rxnn-0.1.9.dist-info/METADATA,sha256=AraTWJtxAkj6Zx2UUB2YwfFWSk-WwZ5tgcYhWkLZEEM,14628
28
- rxnn-0.1.9.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
- rxnn-0.1.9.dist-info/RECORD,,
26
+ rxnn-0.1.11.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
+ rxnn-0.1.11.dist-info/METADATA,sha256=WFoe6AqfJVI6wFZ23i3qGQ3babDlLtjIMU0htjOIikw,14629
28
+ rxnn-0.1.11.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
+ rxnn-0.1.11.dist-info/RECORD,,
File without changes
File without changes