rxnn 0.2.30__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/base.py +4 -5
- rxnn/training/bml.py +7 -12
- rxnn/training/ddp.py +26 -0
- rxnn/training/mrl.py +161 -104
- rxnn/training/rl.py +21 -8
- {rxnn-0.2.30.dist-info → rxnn-0.2.31.dist-info}/METADATA +1 -1
- {rxnn-0.2.30.dist-info → rxnn-0.2.31.dist-info}/RECORD +9 -8
- {rxnn-0.2.30.dist-info → rxnn-0.2.31.dist-info}/LICENSE +0 -0
- {rxnn-0.2.30.dist-info → rxnn-0.2.31.dist-info}/WHEEL +0 -0
rxnn/training/base.py
CHANGED
@@ -8,6 +8,7 @@ import torch.distributed as dist
|
|
8
8
|
from torch.nn.parallel import DistributedDataParallel
|
9
9
|
from typing import Callable
|
10
10
|
from .callbacks import TrainerCallback
|
11
|
+
from .ddp import get_os_ddp_config, distributed_value_mean
|
11
12
|
|
12
13
|
|
13
14
|
class BaseTrainer(ABC):
|
@@ -91,8 +92,7 @@ class BaseTrainer(ABC):
|
|
91
92
|
optimizer = self.optimizer
|
92
93
|
|
93
94
|
if self.use_ddp:
|
94
|
-
rank =
|
95
|
-
world_size = int(os.environ['WORLD_SIZE'])
|
95
|
+
rank, world_size = get_os_ddp_config()
|
96
96
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
97
97
|
self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
|
98
98
|
train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
@@ -218,10 +218,9 @@ class BaseTrainer(ABC):
|
|
218
218
|
if self.validation_dataset:
|
219
219
|
self.validation_steps = 0
|
220
220
|
val_loss, val_metrics = self.validate(batch_size)
|
221
|
-
val_loss_tensor = torch.tensor(val_loss).to(self.device)
|
222
221
|
if self.use_ddp:
|
223
|
-
|
224
|
-
|
222
|
+
val_loss = distributed_value_mean(val_loss, device=self.device)
|
223
|
+
|
225
224
|
self.validation_metrics[epoch] = val_metrics
|
226
225
|
|
227
226
|
if self.writer:
|
rxnn/training/bml.py
CHANGED
@@ -7,6 +7,7 @@ import torch.distributed as dist
|
|
7
7
|
from ..transformers.models import ReactiveTransformerDecoder
|
8
8
|
from ..training.base import BaseTrainer
|
9
9
|
from .models import MLMTrainingModel, JointTrainingModel
|
10
|
+
from .ddp import distributed_mean
|
10
11
|
|
11
12
|
class MLMTrainer(BaseTrainer):
|
12
13
|
def __init__(
|
@@ -96,8 +97,7 @@ class MLMTrainer(BaseTrainer):
|
|
96
97
|
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
97
98
|
node_acc = acc.item()
|
98
99
|
if self.use_ddp:
|
99
|
-
|
100
|
-
acc = acc / dist.get_world_size()
|
100
|
+
acc = distributed_mean(acc)
|
101
101
|
|
102
102
|
metrics = {
|
103
103
|
'accuracy': acc.item(),
|
@@ -198,8 +198,7 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
198
198
|
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
199
199
|
node_acc = acc.item()
|
200
200
|
if self.use_ddp:
|
201
|
-
|
202
|
-
acc = acc / dist.get_world_size()
|
201
|
+
acc = distributed_mean(acc)
|
203
202
|
|
204
203
|
metrics = {
|
205
204
|
'accuracy': acc.item(),
|
@@ -347,14 +346,10 @@ class JointLMTrainer(BaseTrainer):
|
|
347
346
|
node_mlm_acc = mlm_acc.item()
|
348
347
|
node_alm_acc = alm_acc.item()
|
349
348
|
if self.use_ddp:
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
avg_dec_loss = avg_dec_loss / dist.get_world_size()
|
355
|
-
avg_enc_loss = avg_enc_loss / dist.get_world_size()
|
356
|
-
mlm_acc = mlm_acc / dist.get_world_size()
|
357
|
-
alm_acc = alm_acc / dist.get_world_size()
|
349
|
+
avg_dec_loss = distributed_mean(avg_dec_loss)
|
350
|
+
avg_enc_loss = distributed_mean(avg_enc_loss)
|
351
|
+
mlm_acc = distributed_mean(mlm_acc)
|
352
|
+
alm_acc = distributed_mean(alm_acc)
|
358
353
|
|
359
354
|
metrics = {
|
360
355
|
'accuracy': {
|
rxnn/training/ddp.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.distributed as dist
|
3
|
+
import os
|
4
|
+
from ..utils import set_random_seed
|
5
|
+
|
6
|
+
def get_os_ddp_config():
|
7
|
+
rank = int(os.environ['RANK'])
|
8
|
+
world_size = int(os.environ['WORLD_SIZE'])
|
9
|
+
return rank, world_size
|
10
|
+
|
11
|
+
def distributed_mean(x: torch.Tensor) -> torch.Tensor:
|
12
|
+
"""Average tensor across all devices"""
|
13
|
+
x = x.clone()
|
14
|
+
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
15
|
+
x /= dist.get_world_size()
|
16
|
+
return x
|
17
|
+
|
18
|
+
def distributed_value_mean(value: float, device: torch.device = None) -> float:
|
19
|
+
"""Average float value across all devices"""
|
20
|
+
tensor = torch.tensor(value, device=device)
|
21
|
+
reduced = distributed_mean(tensor)
|
22
|
+
return reduced.item()
|
23
|
+
|
24
|
+
def set_distributed_random_seed(seed: int):
|
25
|
+
rank = dist.get_rank() if dist.is_initialized() else get_os_ddp_config()[0]
|
26
|
+
set_random_seed(seed + rank)
|
rxnn/training/mrl.py
CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|
3
3
|
from torch.utils.tensorboard import SummaryWriter
|
4
4
|
import torch.distributed as dist
|
5
5
|
from torch.nn.parallel import DistributedDataParallel
|
6
|
-
from typing import Optional, TypedDict, Union, TypeAlias, Literal
|
6
|
+
from typing import Optional, TypedDict, Union, TypeAlias, Literal, Callable
|
7
7
|
from enum import Enum
|
8
8
|
import random, os
|
9
9
|
from ..transformers.sampler import BatchSampler
|
@@ -13,7 +13,7 @@ from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
|
|
13
13
|
from .rl import RlAlgorithm
|
14
14
|
from .reward import MrlRewardMode, MrlRewardModel
|
15
15
|
from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
16
|
-
|
16
|
+
from .ddp import get_os_ddp_config, distributed_mean
|
17
17
|
|
18
18
|
class MrlConfig(TypedDict):
|
19
19
|
lr: float
|
@@ -25,6 +25,10 @@ class MrlConfig(TypedDict):
|
|
25
25
|
weight_decay: float
|
26
26
|
critic_weight_decay: float
|
27
27
|
update_epochs: int
|
28
|
+
pad_token_id: int
|
29
|
+
end_token_id: int
|
30
|
+
callbacks: Optional[list[MrlTrainerCallback]]
|
31
|
+
memory_aware_critic: bool
|
28
32
|
|
29
33
|
|
30
34
|
class MrlStrategy(Enum):
|
@@ -33,8 +37,9 @@ class MrlStrategy(Enum):
|
|
33
37
|
LONG_RANGE_STRATEGY = 3
|
34
38
|
|
35
39
|
|
40
|
+
UnfreezeStrategyFn = Callable[[int], None]
|
36
41
|
UnfreezeItem = Union[int, tuple[int, float]]
|
37
|
-
UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
|
42
|
+
UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int], UnfreezeStrategyFn]
|
38
43
|
|
39
44
|
|
40
45
|
class CurriculumConfig(TypedDict):
|
@@ -89,21 +94,25 @@ class MRLTrainer:
|
|
89
94
|
rl_algorithm: RlAlgorithm,
|
90
95
|
sampler_config: Optional[SamplerConfig] = None,
|
91
96
|
log_dir: str = None,
|
92
|
-
pad_token_id: int = 0,
|
93
|
-
end_token_id: int = 3,
|
94
97
|
use_ddp: bool = False,
|
95
98
|
use_amp: bool = False,
|
96
99
|
dtype: torch.dtype = torch.float32,
|
97
|
-
callbacks: list[MrlTrainerCallback] = None,
|
98
|
-
|
99
100
|
):
|
100
101
|
"""
|
101
|
-
Trainer for Memory Reinforcement Learning (MRL)
|
102
|
+
Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
|
102
103
|
|
103
104
|
Args:
|
104
|
-
actor: MRL Actor model with encoder, decoder and memory attention.
|
105
|
-
critic: Critic network for advantage estimation.
|
106
|
-
|
105
|
+
actor (MrlActorModel): MRL Actor model with encoder, decoder and memory attention.
|
106
|
+
critic (MrlCriticModel): MRL Critic network for advantage estimation.
|
107
|
+
reward (MrlRewardModel): MRL Reward model or extension.
|
108
|
+
device (torch.device): Device used for training.
|
109
|
+
config (MrlConfig): Configuration dictionary with hyperparameters.
|
110
|
+
rl_algorithm (RlAlgorithm): Reinforcement Learning algorithm (currently only PPO available).
|
111
|
+
sampler_config (SamplerConfig): Sampler configuration.
|
112
|
+
log_dir (str): Log directory for TensorBoard logs.
|
113
|
+
use_ddp (bool): Use Distributed Data Parallel mode.
|
114
|
+
use_amp (bool): Use AMP Autocast for training.
|
115
|
+
dtype (torch.dtype): Data type used in training - in AMP mode it's auto cast, otherwise data and model are transformed to this type
|
107
116
|
"""
|
108
117
|
self.actor = actor
|
109
118
|
self.critic = critic
|
@@ -112,6 +121,7 @@ class MRLTrainer:
|
|
112
121
|
self.device = device
|
113
122
|
self.max_seq_len = config.get('max_seq_len', 256)
|
114
123
|
self.critic_max_len = config.get('critic_max_len', 512)
|
124
|
+
self.memory_aware_critic = config.get('memory_aware_critic', False)
|
115
125
|
# Internal update epochs config
|
116
126
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
117
127
|
self.update_epochs = self.shared_update_epochs
|
@@ -125,14 +135,15 @@ class MRLTrainer:
|
|
125
135
|
self.critic.to(self.device, dtype=dtype)
|
126
136
|
|
127
137
|
# Batch Sampler for answer generation
|
128
|
-
self.generator =
|
138
|
+
self.generator = None
|
129
139
|
self.sampler_config = SamplerConfig(
|
130
140
|
temperature=1.0,
|
131
141
|
top_k=None,
|
132
142
|
top_p=None,
|
133
143
|
) if sampler_config is None else sampler_config
|
134
144
|
|
135
|
-
self.pad_token_id = pad_token_id
|
145
|
+
self.pad_token_id = config.get('pad_token_id', 0)
|
146
|
+
self.end_token_id = config.get('end_token_id', 3)
|
136
147
|
|
137
148
|
self.use_ddp = use_ddp
|
138
149
|
self.use_amp = use_amp
|
@@ -180,7 +191,7 @@ class MRLTrainer:
|
|
180
191
|
self.eval_dataset = None
|
181
192
|
self.random_resets_ratio = 0.0
|
182
193
|
self.strategy = None
|
183
|
-
self.shared_callbacks = callbacks
|
194
|
+
self.shared_callbacks = config.get('callbacks', [])
|
184
195
|
self.callbacks = []
|
185
196
|
self.global_epoch = 0
|
186
197
|
self.global_epochs_count = 0
|
@@ -218,7 +229,7 @@ class MRLTrainer:
|
|
218
229
|
def _init_steps(self):
|
219
230
|
return {
|
220
231
|
'collect': 0,
|
221
|
-
'
|
232
|
+
'train': 0,
|
222
233
|
'eval': 0,
|
223
234
|
}
|
224
235
|
|
@@ -227,9 +238,12 @@ class MRLTrainer:
|
|
227
238
|
self.epoch_step[step_type] += 1
|
228
239
|
self.stage_step[step_type] += 1
|
229
240
|
|
230
|
-
def reset_stm(self) -> bool:
|
241
|
+
def reset_stm(self, force: bool = False) -> bool:
|
231
242
|
"""Reset Short-Term Memory state with random reset ratio."""
|
232
|
-
if
|
243
|
+
if force:
|
244
|
+
self.actor.reset_memory()
|
245
|
+
return True
|
246
|
+
elif self.random_resets_ratio == 1.0:
|
233
247
|
self.actor.reset_memory()
|
234
248
|
return True
|
235
249
|
else:
|
@@ -439,33 +453,24 @@ class MRLTrainer:
|
|
439
453
|
|
440
454
|
return trajectories
|
441
455
|
|
442
|
-
def _critic_loss(self, inputs: TokenizedDict, ref_values: torch.Tensor) -> torch.Tensor:
|
443
|
-
# 1. Calculate values with critic encoder
|
444
|
-
values = self.critic(
|
445
|
-
inputs['input_ids'],
|
446
|
-
attention_mask=inputs['attention_mask'],
|
447
|
-
).squeeze()
|
448
|
-
# 2. Calculate critic loss
|
449
|
-
loss = self.rl_algorithm.critic_loss(values, ref_values)
|
450
|
-
return loss
|
451
|
-
|
452
456
|
def _critic_writer(self, critic_loss: float, epoch: int):
|
453
457
|
if self.writer is not None:
|
454
|
-
self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['
|
458
|
+
self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['train'])
|
455
459
|
self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps}, epoch: {epoch})', critic_loss,
|
456
|
-
self.epoch_step['
|
460
|
+
self.epoch_step['train'])
|
457
461
|
self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps})', critic_loss,
|
458
|
-
self.stage_step['
|
462
|
+
self.stage_step['train'])
|
459
463
|
|
460
464
|
def _rl_writer(self, policy_loss: float, epoch: int):
|
461
465
|
if self.writer is not None:
|
462
|
-
self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['
|
466
|
+
self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['train'])
|
463
467
|
self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps}, epoch: {epoch})', policy_loss,
|
464
|
-
self.epoch_step['
|
465
|
-
self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss,
|
468
|
+
self.epoch_step['train'])
|
469
|
+
self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss,
|
470
|
+
self.stage_step['train'])
|
466
471
|
|
467
|
-
def
|
468
|
-
|
472
|
+
def update_critic(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], ref_values: torch.Tensor,
|
473
|
+
epoch: int) -> float:
|
469
474
|
# 1. Reset critic gradients
|
470
475
|
self.critic_optimizer.zero_grad()
|
471
476
|
|
@@ -475,7 +480,8 @@ class MRLTrainer:
|
|
475
480
|
# 2.1 Concat states and calculate critic loss
|
476
481
|
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
477
482
|
pad_token_id=self.pad_token_id)
|
478
|
-
|
483
|
+
values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
|
484
|
+
critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
|
479
485
|
# 2.2 Run backpropagation with scaler
|
480
486
|
self.critic_scaler.scale(critic_loss).backward()
|
481
487
|
# 2.3 Unscale and clip gradients
|
@@ -488,7 +494,8 @@ class MRLTrainer:
|
|
488
494
|
# 2.1 Concat states and calculate critic loss
|
489
495
|
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
490
496
|
pad_token_id=self.pad_token_id)
|
491
|
-
|
497
|
+
values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
|
498
|
+
critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
|
492
499
|
# 2.2 Run backpropagation
|
493
500
|
critic_loss.backward()
|
494
501
|
# 2.3 Clip gradients
|
@@ -503,19 +510,20 @@ class MRLTrainer:
|
|
503
510
|
|
504
511
|
# 5. Run "on critic updated" callbacks
|
505
512
|
for cb in self.callbacks:
|
506
|
-
cb.on_critic_updated(self.actor, self.critic, epoch, self.epoch_step['
|
513
|
+
cb.on_critic_updated(self.actor, self.critic, epoch, self.epoch_step['train'], critic_loss_item)
|
507
514
|
# 6. Return loss item
|
508
515
|
return critic_loss_item
|
509
516
|
|
510
|
-
def
|
511
|
-
|
517
|
+
def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
518
|
+
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
512
519
|
# 1. Reset actor gradients
|
513
520
|
self.optimizer.zero_grad()
|
514
521
|
# 2. Unpack state dicts
|
515
522
|
query, answer, next_query = state
|
516
523
|
|
517
|
-
# 3. Encode and update STM on each step, to include encoder and memory attention gradients in loss
|
518
|
-
self.
|
524
|
+
# 3. Encode and update STM on each step, to include encoder and memory attention gradients in loss (skip if it was updated before with memory aware critic)
|
525
|
+
if not self.memory_aware_critic:
|
526
|
+
self.encode_and_update_stm(query, answer)
|
519
527
|
|
520
528
|
# 4. Update actor - with autocast on/off
|
521
529
|
if self.use_amp:
|
@@ -560,13 +568,13 @@ class MRLTrainer:
|
|
560
568
|
|
561
569
|
# 7. Run "on batch updated" callback
|
562
570
|
for cb in self.callbacks:
|
563
|
-
cb.on_batch_updated(self.actor, epoch, self.epoch_step['
|
571
|
+
cb.on_batch_updated(self.actor, epoch, self.epoch_step['train'], policy_loss_item)
|
564
572
|
|
565
573
|
# 8. Return loss item
|
566
574
|
return policy_loss_item
|
567
575
|
|
568
576
|
def rl_step(self, trajectories: list[MrlTrajectoryEpisode], advantages: torch.Tensor, ref_values: torch.Tensor,
|
569
|
-
epoch: int, batch_size: int) -> tuple[
|
577
|
+
epoch: int, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
570
578
|
"""Perform PPO update step using trajectories."""
|
571
579
|
# 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
|
572
580
|
# memory, based on collected episode data
|
@@ -577,18 +585,19 @@ class MRLTrainer:
|
|
577
585
|
should_reset_stm = episode['reset_stm']
|
578
586
|
|
579
587
|
# 2. Get advantages and reference values for current full episode (batch_size * episode_steps)
|
580
|
-
|
581
|
-
|
588
|
+
num_steps = len(episode_steps)
|
589
|
+
start = episode_idx * num_steps
|
590
|
+
end = start + num_steps
|
582
591
|
episode_critic_values = ref_values[start:end]
|
583
592
|
episode_advantages = advantages[start:end]
|
584
593
|
|
585
594
|
# 3. Reset memory for current batch episode
|
586
595
|
if should_reset_stm:
|
587
|
-
self.reset_stm()
|
596
|
+
self.reset_stm(force=True)
|
588
597
|
|
589
598
|
# 4. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
|
590
599
|
for step_idx, step in enumerate(episode_steps):
|
591
|
-
self._increment_steps('
|
600
|
+
self._increment_steps('train')
|
592
601
|
# 5. Get and move to device collected states, action and log probs
|
593
602
|
state, action, _, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
|
594
603
|
query, answer, next_query = self._move_multiple_batches(*state)
|
@@ -599,41 +608,74 @@ class MRLTrainer:
|
|
599
608
|
step_critic_values = episode_critic_values[step_idx]
|
600
609
|
step_advantages = episode_advantages[step_idx]
|
601
610
|
|
611
|
+
# 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
|
612
|
+
if self.memory_aware_critic:
|
613
|
+
self.encode_and_update_stm(query, answer)
|
614
|
+
|
602
615
|
# 7. Update critic
|
603
|
-
critic_loss_item = self.
|
616
|
+
critic_loss_item = self.update_critic((query, answer, next_query), step_critic_values, epoch)
|
604
617
|
|
605
618
|
# 8. Accumulate critic loss for epoch callbacks
|
606
619
|
critic_losses.append(critic_loss_item)
|
607
620
|
|
608
621
|
# 9. Update actor
|
609
|
-
policy_loss_item = self.
|
610
|
-
|
622
|
+
policy_loss_item = self.update_actor((query, answer, next_query), action, step_advantages, log_probs,
|
623
|
+
epoch)
|
611
624
|
all_losses.append(policy_loss_item)
|
612
625
|
# 10. Return mean losses for epoch callbacks
|
613
|
-
return torch.mean(torch.tensor(all_losses))
|
626
|
+
return torch.mean(torch.tensor(all_losses)), torch.mean(torch.tensor(critic_losses))
|
614
627
|
|
615
628
|
def _critic_values_rewards_and_dones(self, trajectories: list[MrlTrajectoryEpisode]):
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
)
|
624
|
-
)
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
+
if self.memory_aware_critic:
|
630
|
+
flat_trajectories = [
|
631
|
+
(t, i == 0 and episode['reset_stm'])
|
632
|
+
for episode in trajectories
|
633
|
+
for i, t in enumerate(episode['steps'])
|
634
|
+
]
|
635
|
+
values = torch.stack([
|
636
|
+
self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in flat_trajectories
|
637
|
+
]).to(self.device)
|
638
|
+
rewards = torch.stack([torch.tensor(t['reward']) for t, _ in flat_trajectories]).to(self.device)
|
639
|
+
dones = torch.stack([torch.tensor(t['done']) for t, _ in flat_trajectories]).to(self.device)
|
640
|
+
else:
|
641
|
+
flat_trajectories = [t for episode in trajectories for t in episode['steps']]
|
642
|
+
values = torch.stack([
|
643
|
+
self._critic_values(*self._move_multiple_batches(*t['state'])) for t in flat_trajectories
|
644
|
+
]).to(self.device)
|
645
|
+
rewards = torch.stack([torch.tensor(t['reward']) for t in flat_trajectories]).to(self.device)
|
646
|
+
dones = torch.stack([torch.tensor(t['done']) for t in flat_trajectories]).to(self.device)
|
629
647
|
return values, rewards, dones
|
630
648
|
|
631
|
-
def
|
649
|
+
def _critic_values_with_memory(self, reset_stm: bool, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
|
650
|
+
# 1. Calculate critic values in memory aware version - reset/update STM before calculating values
|
632
651
|
with torch.no_grad():
|
652
|
+
# 2. Reset STM if it was reset in trajectory collection
|
653
|
+
if reset_stm:
|
654
|
+
self.reset_stm(force=True)
|
655
|
+
# 3. Encode and update STM for critic
|
656
|
+
self.encode_and_update_stm(*moved_state)
|
657
|
+
# 4. Get concatenated critic states
|
658
|
+
inputs = smart_concat_critic_states(
|
659
|
+
*moved_state,
|
660
|
+
max_length=self.critic_max_len,
|
661
|
+
pad_token_id=self.pad_token_id,
|
662
|
+
)
|
663
|
+
# 5. Calculate values for current batch
|
633
664
|
return self.critic(inputs['input_ids'],
|
634
|
-
|
665
|
+
attention_mask=inputs['attention_mask']).squeeze()
|
635
666
|
|
636
|
-
|
667
|
+
def _critic_values(self, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
|
668
|
+
# 1. Calculate critic values
|
669
|
+
with torch.no_grad():
|
670
|
+
# 2. Get concatenated critic states
|
671
|
+
inputs = smart_concat_critic_states(
|
672
|
+
*moved_state,
|
673
|
+
max_length=self.critic_max_len,
|
674
|
+
pad_token_id=self.pad_token_id,
|
675
|
+
)
|
676
|
+
# 3. Calculate values for current batch
|
677
|
+
return self.critic(inputs['input_ids'],
|
678
|
+
attention_mask=inputs['attention_mask']).squeeze()
|
637
679
|
|
638
680
|
def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
|
639
681
|
"""Train for one epoch."""
|
@@ -655,16 +697,23 @@ class MRLTrainer:
|
|
655
697
|
# 4. Run 'on update epoch start' callbacks
|
656
698
|
for cb in self.callbacks:
|
657
699
|
cb.on_update_epoch_start(self.actor, self.critic, epoch, update_epoch)
|
700
|
+
|
658
701
|
# 5. Run RL algorithm step
|
659
702
|
policy_loss, critic_loss = self.rl_step(trajectories[:-1], advantages, ref_values, epoch, batch_size)
|
660
703
|
|
704
|
+
if self.use_ddp:
|
705
|
+
policy_loss = distributed_mean(policy_loss)
|
706
|
+
critic_loss = distributed_mean(critic_loss)
|
707
|
+
|
708
|
+
# 6. Run 'on update epoch end' callbacks
|
661
709
|
for cb in self.callbacks:
|
662
710
|
cb.on_update_epoch_end(self.actor, self.critic, epoch, update_epoch, policy_loss, critic_loss)
|
663
711
|
|
712
|
+
# 7. Accumulate losses for epoch callbacks
|
664
713
|
critic_loss_sum += critic_loss
|
665
714
|
policy_loss_sum += policy_loss
|
666
715
|
|
667
|
-
#
|
716
|
+
# 8. Return policy and critic mean losses for epoch callbacks
|
668
717
|
return policy_loss_sum / self.update_epochs, critic_loss_sum / self.update_epochs
|
669
718
|
|
670
719
|
def _eval_loader(self, batch_size: int):
|
@@ -770,12 +819,11 @@ class MRLTrainer:
|
|
770
819
|
cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
|
771
820
|
|
772
821
|
# 15. Calculate average reward
|
822
|
+
avg_reward = (total_reward / count) if count > 0 else torch.tensor(0.0).to(self.device)
|
773
823
|
if self.use_ddp:
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
else:
|
778
|
-
avg_reward = (total_reward / count).item() if count > 0 else 0
|
824
|
+
avg_reward = distributed_mean(avg_reward)
|
825
|
+
|
826
|
+
avg_reward = avg_reward.item()
|
779
827
|
|
780
828
|
should_stop_stage = False
|
781
829
|
# 16. Run "on eval end" callbacks
|
@@ -924,31 +972,36 @@ class MRLTrainer:
|
|
924
972
|
|
925
973
|
# 1. Init DDP for distributed training mode
|
926
974
|
if self.use_ddp:
|
927
|
-
rank =
|
928
|
-
world_size = int(os.environ['WORLD_SIZE'])
|
975
|
+
rank, world_size = get_os_ddp_config()
|
929
976
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
930
977
|
self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index])
|
931
978
|
self.critic = DistributedDataParallel(self.critic, device_ids=[self.device.index])
|
932
979
|
|
933
|
-
# 2.
|
980
|
+
# 2. Init BatchSampler with actor model (we have to run it after DDP init)
|
981
|
+
self.generator = BatchSampler(self.actor, self.device, end_token_id=self.end_token_id)
|
982
|
+
|
983
|
+
# 3. Run each curriculum step based on config
|
934
984
|
for current_curriculum_step in curriculum_config:
|
935
|
-
#
|
985
|
+
# 4. Setup training config for curriculum step
|
936
986
|
epochs_config, random_resets_config = self._setup_curriculum_step(current_curriculum_step)
|
937
987
|
epochs, unfreeze_epoch = epochs_config
|
938
988
|
random_resets, random_resets_from, random_resets_ratio = random_resets_config
|
939
989
|
assert self.train_dataset is not None
|
940
990
|
|
941
|
-
#
|
991
|
+
# 5. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
|
942
992
|
if unfreeze_epoch != 0:
|
943
|
-
|
944
|
-
|
945
|
-
print(
|
946
|
-
f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
|
993
|
+
if callable(unfreeze_epoch):
|
994
|
+
unfreeze_epoch(-1)
|
947
995
|
else:
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
996
|
+
self.actor.freeze_components('joint')
|
997
|
+
if isinstance(unfreeze_epoch, tuple):
|
998
|
+
print(
|
999
|
+
f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
|
1000
|
+
else:
|
1001
|
+
print(
|
1002
|
+
f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
|
1003
|
+
|
1004
|
+
# 6. Setup train DataLoader
|
952
1005
|
if self.use_ddp:
|
953
1006
|
train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
|
954
1007
|
dataloader = DataLoader(
|
@@ -957,6 +1010,7 @@ class MRLTrainer:
|
|
957
1010
|
sampler=train_sampler,
|
958
1011
|
pin_memory=True,
|
959
1012
|
collate_fn=MrlCurriculumDataset.collate_mrl_batch,
|
1013
|
+
drop_last=True,
|
960
1014
|
)
|
961
1015
|
else:
|
962
1016
|
train_sampler = None
|
@@ -968,65 +1022,68 @@ class MRLTrainer:
|
|
968
1022
|
collate_fn=MrlCurriculumDataset.collate_mrl_batch,
|
969
1023
|
)
|
970
1024
|
|
971
|
-
#
|
1025
|
+
# 7. Run selected number of epochs for given curriculum stage
|
972
1026
|
for epoch in range(epochs):
|
973
|
-
#
|
1027
|
+
# 8. Increment global epoch
|
974
1028
|
self.global_epoch += 1
|
975
|
-
#
|
1029
|
+
# 9. Run "on epoch start" callbacks (log info, etc.)
|
976
1030
|
for cb in self.callbacks:
|
977
1031
|
cb.on_epoch_start(self.actor, epoch, epochs, current_curriculum_step, self.global_epoch,
|
978
1032
|
self.global_epochs_count)
|
979
1033
|
|
980
|
-
#
|
1034
|
+
# 10. Reset steps counter for epoch
|
981
1035
|
self.epoch_step = self._init_steps()
|
982
1036
|
|
983
|
-
#
|
1037
|
+
# 11. Set random STM resets ratio from selected epoch
|
984
1038
|
if random_resets and random_resets_from <= epoch:
|
985
1039
|
self.random_resets_ratio = random_resets_ratio
|
986
1040
|
else:
|
987
1041
|
self.random_resets_ratio = 1.0
|
988
1042
|
|
989
|
-
#
|
990
|
-
|
1043
|
+
# 12. Apply the unfreeze strategy
|
1044
|
+
if callable(unfreeze_epoch):
|
1045
|
+
unfreeze_epoch(epoch)
|
1046
|
+
else:
|
1047
|
+
self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
|
991
1048
|
|
992
|
-
#
|
1049
|
+
# 13. Set epoch for distributed sampler
|
993
1050
|
if train_sampler is not None:
|
994
1051
|
train_sampler.set_epoch(epoch)
|
995
1052
|
|
996
|
-
#
|
1053
|
+
# 14. Run reinforcement learning algorithms for current epoch
|
997
1054
|
policy_loss, critic_loss = self.train_epoch(dataloader, epoch, batch_size)
|
998
1055
|
|
999
|
-
#
|
1056
|
+
# 15. If evaluation dataset is provided, run evaluation steps
|
1000
1057
|
if self.eval_dataset:
|
1001
1058
|
should_stop_stage = self.evaluate(batch_size, epoch)
|
1002
1059
|
else:
|
1003
1060
|
should_stop_stage = False
|
1004
1061
|
|
1005
|
-
#
|
1062
|
+
# 16. Finally, run "on epoch end" callbacks (save models, etc.)
|
1006
1063
|
for cb in self.callbacks:
|
1007
1064
|
cb.on_epoch_end(self.actor, epoch, epochs, policy_loss, critic_loss, self.global_epoch,
|
1008
1065
|
self.global_epochs_count)
|
1009
1066
|
|
1010
|
-
#
|
1067
|
+
# 17. Synchronize TensorBoard writer
|
1011
1068
|
if self.writer:
|
1012
1069
|
self.writer.flush()
|
1013
1070
|
|
1014
|
-
#
|
1071
|
+
# 18. Synchronize devices in DDP mode
|
1015
1072
|
if self.use_ddp:
|
1016
1073
|
dist.barrier()
|
1017
1074
|
|
1018
|
-
#
|
1075
|
+
# 19. Finish curriculum stage if rewards are not increased or reached threshold point
|
1019
1076
|
if should_stop_stage:
|
1020
1077
|
break
|
1021
1078
|
|
1022
|
-
#
|
1079
|
+
# 20. Run "on_training_end" callbacks after each curriculum stage (they have own callbacks)
|
1023
1080
|
for cb in self.callbacks:
|
1024
1081
|
cb.on_training_end(self.actor, self.critic, current_curriculum_step)
|
1025
1082
|
|
1026
|
-
#
|
1083
|
+
# 21. Training end - finish processes after all curriculum stages
|
1027
1084
|
if self.use_ddp:
|
1028
1085
|
dist.destroy_process_group()
|
1029
1086
|
|
1030
|
-
#
|
1087
|
+
# 22. Close writer
|
1031
1088
|
if self.writer:
|
1032
1089
|
self.writer.close()
|
rxnn/training/rl.py
CHANGED
@@ -2,8 +2,9 @@ import torch
|
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from typing import TypedDict
|
5
|
+
from typing import TypedDict, Optional
|
6
6
|
from .utils import TokenizedDict
|
7
|
+
from .ddp import distributed_mean
|
7
8
|
|
8
9
|
|
9
10
|
class RlAlgorithm(ABC):
|
@@ -23,21 +24,28 @@ class RlAlgorithm(ABC):
|
|
23
24
|
def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
24
25
|
return self.critic_loss(rewards, values)
|
25
26
|
|
27
|
+
|
26
28
|
class PPOConfig(TypedDict):
|
27
|
-
clip_eps: float
|
28
|
-
gae_lambda: float
|
29
|
-
gae_gamma: float
|
30
|
-
entropy_coef: float
|
29
|
+
clip_eps: Optional[float]
|
30
|
+
gae_lambda: Optional[float]
|
31
|
+
gae_gamma: Optional[float]
|
32
|
+
entropy_coef: Optional[float]
|
33
|
+
use_distributed_advantage_norm: Optional[bool]
|
34
|
+
|
31
35
|
|
32
36
|
class PPOAlgorithm(RlAlgorithm):
|
33
|
-
def __init__(self, config: PPOConfig):
|
37
|
+
def __init__(self, config: Optional[PPOConfig] = None):
|
34
38
|
super(PPOAlgorithm, self).__init__()
|
35
39
|
|
40
|
+
if config is None:
|
41
|
+
config = {}
|
42
|
+
|
36
43
|
# PPO Config
|
37
44
|
self.clip_eps = config.get('clip_eps', 0.2)
|
38
45
|
self.gae_lambda = config.get('gae_lambda', 0.95)
|
39
46
|
self.gae_gamma = config.get('gae_gamma', 0.99)
|
40
47
|
self.entropy_coef = config.get('entropy_coef', 0.01)
|
48
|
+
self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
|
41
49
|
|
42
50
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
43
51
|
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
@@ -101,7 +109,7 @@ class PPOAlgorithm(RlAlgorithm):
|
|
101
109
|
next_values = values[t + 1]
|
102
110
|
|
103
111
|
# Mask next values if episode ended
|
104
|
-
next_values = next_values *
|
112
|
+
next_values = next_values * ~dones[t]
|
105
113
|
delta = rewards[t] + self.gae_gamma * next_values - values[t]
|
106
114
|
advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
|
107
115
|
last_advantage = advantages[t]
|
@@ -111,5 +119,10 @@ class PPOAlgorithm(RlAlgorithm):
|
|
111
119
|
|
112
120
|
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
113
121
|
advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1])
|
114
|
-
|
122
|
+
if self.use_distributed_advantage_norm:
|
123
|
+
mean_advantage = distributed_mean(advantages.mean())
|
124
|
+
std_advantage = distributed_mean(advantages.std())
|
125
|
+
normalized_advantages = (advantages - mean_advantage) / (std_advantage + 1e-8)
|
126
|
+
else:
|
127
|
+
normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
115
128
|
return normalized_advantages, ref_values
|
@@ -11,14 +11,15 @@ rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
|
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=r8wZeeNTC2VAhiiNe4y7LrbnB4wjFu_cupKiGkpdgjI,13002
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
rxnn/training/base.py,sha256=
|
15
|
-
rxnn/training/bml.py,sha256=
|
14
|
+
rxnn/training/base.py,sha256=TGz_37RfI1qLI31GNRV5rLowW1kAHnJwqPm7DNfLfe4,11730
|
15
|
+
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
16
|
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
|
+
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
18
19
|
rxnn/training/models.py,sha256=2KhNT7yx0AgUke4nmsFqzQKx_YYp78QvsLWYZjWeUgQ,6812
|
19
|
-
rxnn/training/mrl.py,sha256=
|
20
|
+
rxnn/training/mrl.py,sha256=Aimiiqf_4p6dp5Ty9pY9VwetySBS_OFpCQlcVHVkO4Q,55124
|
20
21
|
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=eL3C0yryiNBgl_xb-D-5dyYUtK4V4-K4t3a60x5ir28,5142
|
22
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
25
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -32,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.31.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.31.dist-info/METADATA,sha256=zxD2qPHL_QrFH1bYZrMv4odbXE4B_YIVEpGDzV2MYEI,25960
|
38
|
+
rxnn-0.2.31.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|