rxnn 0.2.29__tar.gz → 0.2.31__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.29 → rxnn-0.2.31}/PKG-INFO +1 -1
- {rxnn-0.2.29 → rxnn-0.2.31}/pyproject.toml +1 -1
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/base.py +4 -5
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/bml.py +7 -12
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/callbacks.py +12 -0
- rxnn-0.2.31/src/rxnn/training/ddp.py +26 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/mrl.py +374 -272
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/rl.py +49 -9
- {rxnn-0.2.29 → rxnn-0.2.31}/LICENSE +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/README.md +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.29 → rxnn-0.2.31}/src/rxnn/utils.py +0 -0
@@ -8,6 +8,7 @@ import torch.distributed as dist
|
|
8
8
|
from torch.nn.parallel import DistributedDataParallel
|
9
9
|
from typing import Callable
|
10
10
|
from .callbacks import TrainerCallback
|
11
|
+
from .ddp import get_os_ddp_config, distributed_value_mean
|
11
12
|
|
12
13
|
|
13
14
|
class BaseTrainer(ABC):
|
@@ -91,8 +92,7 @@ class BaseTrainer(ABC):
|
|
91
92
|
optimizer = self.optimizer
|
92
93
|
|
93
94
|
if self.use_ddp:
|
94
|
-
rank =
|
95
|
-
world_size = int(os.environ['WORLD_SIZE'])
|
95
|
+
rank, world_size = get_os_ddp_config()
|
96
96
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
97
97
|
self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
|
98
98
|
train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
@@ -218,10 +218,9 @@ class BaseTrainer(ABC):
|
|
218
218
|
if self.validation_dataset:
|
219
219
|
self.validation_steps = 0
|
220
220
|
val_loss, val_metrics = self.validate(batch_size)
|
221
|
-
val_loss_tensor = torch.tensor(val_loss).to(self.device)
|
222
221
|
if self.use_ddp:
|
223
|
-
|
224
|
-
|
222
|
+
val_loss = distributed_value_mean(val_loss, device=self.device)
|
223
|
+
|
225
224
|
self.validation_metrics[epoch] = val_metrics
|
226
225
|
|
227
226
|
if self.writer:
|
@@ -7,6 +7,7 @@ import torch.distributed as dist
|
|
7
7
|
from ..transformers.models import ReactiveTransformerDecoder
|
8
8
|
from ..training.base import BaseTrainer
|
9
9
|
from .models import MLMTrainingModel, JointTrainingModel
|
10
|
+
from .ddp import distributed_mean
|
10
11
|
|
11
12
|
class MLMTrainer(BaseTrainer):
|
12
13
|
def __init__(
|
@@ -96,8 +97,7 @@ class MLMTrainer(BaseTrainer):
|
|
96
97
|
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
97
98
|
node_acc = acc.item()
|
98
99
|
if self.use_ddp:
|
99
|
-
|
100
|
-
acc = acc / dist.get_world_size()
|
100
|
+
acc = distributed_mean(acc)
|
101
101
|
|
102
102
|
metrics = {
|
103
103
|
'accuracy': acc.item(),
|
@@ -198,8 +198,7 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
198
198
|
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
199
199
|
node_acc = acc.item()
|
200
200
|
if self.use_ddp:
|
201
|
-
|
202
|
-
acc = acc / dist.get_world_size()
|
201
|
+
acc = distributed_mean(acc)
|
203
202
|
|
204
203
|
metrics = {
|
205
204
|
'accuracy': acc.item(),
|
@@ -347,14 +346,10 @@ class JointLMTrainer(BaseTrainer):
|
|
347
346
|
node_mlm_acc = mlm_acc.item()
|
348
347
|
node_alm_acc = alm_acc.item()
|
349
348
|
if self.use_ddp:
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
avg_dec_loss = avg_dec_loss / dist.get_world_size()
|
355
|
-
avg_enc_loss = avg_enc_loss / dist.get_world_size()
|
356
|
-
mlm_acc = mlm_acc / dist.get_world_size()
|
357
|
-
alm_acc = alm_acc / dist.get_world_size()
|
349
|
+
avg_dec_loss = distributed_mean(avg_dec_loss)
|
350
|
+
avg_enc_loss = distributed_mean(avg_enc_loss)
|
351
|
+
mlm_acc = distributed_mean(mlm_acc)
|
352
|
+
alm_acc = distributed_mean(alm_acc)
|
358
353
|
|
359
354
|
metrics = {
|
360
355
|
'accuracy': {
|
@@ -536,6 +536,9 @@ class MrlTrainerCallback:
|
|
536
536
|
def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
|
537
537
|
pass
|
538
538
|
|
539
|
+
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
540
|
+
pass
|
541
|
+
|
539
542
|
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
540
543
|
pass
|
541
544
|
|
@@ -543,6 +546,9 @@ class MrlTrainerCallback:
|
|
543
546
|
critic_loss: float) -> None:
|
544
547
|
pass
|
545
548
|
|
549
|
+
def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
|
550
|
+
pass
|
551
|
+
|
546
552
|
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
547
553
|
pass
|
548
554
|
|
@@ -572,6 +578,9 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
572
578
|
reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
|
573
579
|
print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
|
574
580
|
|
581
|
+
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
582
|
+
print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
|
583
|
+
|
575
584
|
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
576
585
|
print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
|
577
586
|
|
@@ -579,6 +588,9 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
579
588
|
critic_loss: float) -> None:
|
580
589
|
print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
|
581
590
|
|
591
|
+
def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
|
592
|
+
print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
|
593
|
+
|
582
594
|
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
583
595
|
print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
584
596
|
|
@@ -0,0 +1,26 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.distributed as dist
|
3
|
+
import os
|
4
|
+
from ..utils import set_random_seed
|
5
|
+
|
6
|
+
def get_os_ddp_config():
|
7
|
+
rank = int(os.environ['RANK'])
|
8
|
+
world_size = int(os.environ['WORLD_SIZE'])
|
9
|
+
return rank, world_size
|
10
|
+
|
11
|
+
def distributed_mean(x: torch.Tensor) -> torch.Tensor:
|
12
|
+
"""Average tensor across all devices"""
|
13
|
+
x = x.clone()
|
14
|
+
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
15
|
+
x /= dist.get_world_size()
|
16
|
+
return x
|
17
|
+
|
18
|
+
def distributed_value_mean(value: float, device: torch.device = None) -> float:
|
19
|
+
"""Average float value across all devices"""
|
20
|
+
tensor = torch.tensor(value, device=device)
|
21
|
+
reduced = distributed_mean(tensor)
|
22
|
+
return reduced.item()
|
23
|
+
|
24
|
+
def set_distributed_random_seed(seed: int):
|
25
|
+
rank = dist.get_rank() if dist.is_initialized() else get_os_ddp_config()[0]
|
26
|
+
set_random_seed(seed + rank)
|