rxnn 0.2.29__tar.gz → 0.2.31__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.29 → rxnn-0.2.31}/PKG-INFO +1 -1
  2. {rxnn-0.2.29 → rxnn-0.2.31}/pyproject.toml +1 -1
  3. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/base.py +4 -5
  4. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/bml.py +7 -12
  5. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/callbacks.py +12 -0
  6. rxnn-0.2.31/src/rxnn/training/ddp.py +26 -0
  7. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/mrl.py +374 -272
  8. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/rl.py +49 -9
  9. {rxnn-0.2.29 → rxnn-0.2.31}/LICENSE +0 -0
  10. {rxnn-0.2.29 → rxnn-0.2.31}/README.md +0 -0
  11. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/.DS_Store +0 -0
  12. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/__init__.py +0 -0
  13. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/experimental/__init__.py +0 -0
  14. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/experimental/attention.py +0 -0
  15. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/experimental/models.py +0 -0
  16. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/experimental/moe.py +0 -0
  17. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/memory/__init__.py +0 -0
  18. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/memory/attention.py +0 -0
  19. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/memory/norm.py +0 -0
  20. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/memory/stm.py +0 -0
  21. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/rxt/__init__.py +0 -0
  22. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/rxt/models.py +0 -0
  23. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/__init__.py +0 -0
  24. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/dataset.py +0 -0
  25. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/models.py +0 -0
  26. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.29
3
+ Version: 0.2.31
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.29"
7
+ version = "0.2.31"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -8,6 +8,7 @@ import torch.distributed as dist
8
8
  from torch.nn.parallel import DistributedDataParallel
9
9
  from typing import Callable
10
10
  from .callbacks import TrainerCallback
11
+ from .ddp import get_os_ddp_config, distributed_value_mean
11
12
 
12
13
 
13
14
  class BaseTrainer(ABC):
@@ -91,8 +92,7 @@ class BaseTrainer(ABC):
91
92
  optimizer = self.optimizer
92
93
 
93
94
  if self.use_ddp:
94
- rank = int(os.environ['RANK'])
95
- world_size = int(os.environ['WORLD_SIZE'])
95
+ rank, world_size = get_os_ddp_config()
96
96
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
97
97
  self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
98
98
  train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
@@ -218,10 +218,9 @@ class BaseTrainer(ABC):
218
218
  if self.validation_dataset:
219
219
  self.validation_steps = 0
220
220
  val_loss, val_metrics = self.validate(batch_size)
221
- val_loss_tensor = torch.tensor(val_loss).to(self.device)
222
221
  if self.use_ddp:
223
- dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
224
- val_loss = val_loss_tensor.item() / dist.get_world_size()
222
+ val_loss = distributed_value_mean(val_loss, device=self.device)
223
+
225
224
  self.validation_metrics[epoch] = val_metrics
226
225
 
227
226
  if self.writer:
@@ -7,6 +7,7 @@ import torch.distributed as dist
7
7
  from ..transformers.models import ReactiveTransformerDecoder
8
8
  from ..training.base import BaseTrainer
9
9
  from .models import MLMTrainingModel, JointTrainingModel
10
+ from .ddp import distributed_mean
10
11
 
11
12
  class MLMTrainer(BaseTrainer):
12
13
  def __init__(
@@ -96,8 +97,7 @@ class MLMTrainer(BaseTrainer):
96
97
  acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
97
98
  node_acc = acc.item()
98
99
  if self.use_ddp:
99
- dist.all_reduce(acc, op=dist.ReduceOp.SUM)
100
- acc = acc / dist.get_world_size()
100
+ acc = distributed_mean(acc)
101
101
 
102
102
  metrics = {
103
103
  'accuracy': acc.item(),
@@ -198,8 +198,7 @@ class AutoregressiveTrainer(BaseTrainer):
198
198
  acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
199
199
  node_acc = acc.item()
200
200
  if self.use_ddp:
201
- dist.all_reduce(acc, op=dist.ReduceOp.SUM)
202
- acc = acc / dist.get_world_size()
201
+ acc = distributed_mean(acc)
203
202
 
204
203
  metrics = {
205
204
  'accuracy': acc.item(),
@@ -347,14 +346,10 @@ class JointLMTrainer(BaseTrainer):
347
346
  node_mlm_acc = mlm_acc.item()
348
347
  node_alm_acc = alm_acc.item()
349
348
  if self.use_ddp:
350
- dist.all_reduce(avg_dec_loss, op=dist.ReduceOp.SUM)
351
- dist.all_reduce(avg_enc_loss, op=dist.ReduceOp.SUM)
352
- dist.all_reduce(mlm_acc, op=dist.ReduceOp.SUM)
353
- dist.all_reduce(alm_acc, op=dist.ReduceOp.SUM)
354
- avg_dec_loss = avg_dec_loss / dist.get_world_size()
355
- avg_enc_loss = avg_enc_loss / dist.get_world_size()
356
- mlm_acc = mlm_acc / dist.get_world_size()
357
- alm_acc = alm_acc / dist.get_world_size()
349
+ avg_dec_loss = distributed_mean(avg_dec_loss)
350
+ avg_enc_loss = distributed_mean(avg_enc_loss)
351
+ mlm_acc = distributed_mean(mlm_acc)
352
+ alm_acc = distributed_mean(alm_acc)
358
353
 
359
354
  metrics = {
360
355
  'accuracy': {
@@ -536,6 +536,9 @@ class MrlTrainerCallback:
536
536
  def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
537
537
  pass
538
538
 
539
+ def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
540
+ pass
541
+
539
542
  def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
540
543
  pass
541
544
 
@@ -543,6 +546,9 @@ class MrlTrainerCallback:
543
546
  critic_loss: float) -> None:
544
547
  pass
545
548
 
549
+ def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
550
+ pass
551
+
546
552
  def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
547
553
  pass
548
554
 
@@ -572,6 +578,9 @@ class MrlPrintCallback(MrlTrainerCallback):
572
578
  reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
573
579
  print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
574
580
 
581
+ def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
582
+ print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
583
+
575
584
  def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
576
585
  print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
577
586
 
@@ -579,6 +588,9 @@ class MrlPrintCallback(MrlTrainerCallback):
579
588
  critic_loss: float) -> None:
580
589
  print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
581
590
 
591
+ def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
592
+ print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
593
+
582
594
  def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
583
595
  print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
584
596
 
@@ -0,0 +1,26 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+ import os
4
+ from ..utils import set_random_seed
5
+
6
+ def get_os_ddp_config():
7
+ rank = int(os.environ['RANK'])
8
+ world_size = int(os.environ['WORLD_SIZE'])
9
+ return rank, world_size
10
+
11
+ def distributed_mean(x: torch.Tensor) -> torch.Tensor:
12
+ """Average tensor across all devices"""
13
+ x = x.clone()
14
+ dist.all_reduce(x, op=dist.ReduceOp.SUM)
15
+ x /= dist.get_world_size()
16
+ return x
17
+
18
+ def distributed_value_mean(value: float, device: torch.device = None) -> float:
19
+ """Average float value across all devices"""
20
+ tensor = torch.tensor(value, device=device)
21
+ reduced = distributed_mean(tensor)
22
+ return reduced.item()
23
+
24
+ def set_distributed_random_seed(seed: int):
25
+ rank = dist.get_rank() if dist.is_initialized() else get_os_ddp_config()[0]
26
+ set_random_seed(seed + rank)