rxnn 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/base.py CHANGED
@@ -8,6 +8,7 @@ import torch.distributed as dist
8
8
  from torch.nn.parallel import DistributedDataParallel
9
9
  from typing import Callable
10
10
  from .callbacks import TrainerCallback
11
+ from .ddp import get_os_ddp_config, distributed_value_mean
11
12
 
12
13
 
13
14
  class BaseTrainer(ABC):
@@ -91,8 +92,7 @@ class BaseTrainer(ABC):
91
92
  optimizer = self.optimizer
92
93
 
93
94
  if self.use_ddp:
94
- rank = int(os.environ['RANK'])
95
- world_size = int(os.environ['WORLD_SIZE'])
95
+ rank, world_size = get_os_ddp_config()
96
96
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
97
97
  self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
98
98
  train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
@@ -218,10 +218,9 @@ class BaseTrainer(ABC):
218
218
  if self.validation_dataset:
219
219
  self.validation_steps = 0
220
220
  val_loss, val_metrics = self.validate(batch_size)
221
- val_loss_tensor = torch.tensor(val_loss).to(self.device)
222
221
  if self.use_ddp:
223
- dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
224
- val_loss = val_loss_tensor.item() / dist.get_world_size()
222
+ val_loss = distributed_value_mean(val_loss, device=self.device)
223
+
225
224
  self.validation_metrics[epoch] = val_metrics
226
225
 
227
226
  if self.writer:
rxnn/training/bml.py CHANGED
@@ -7,6 +7,7 @@ import torch.distributed as dist
7
7
  from ..transformers.models import ReactiveTransformerDecoder
8
8
  from ..training.base import BaseTrainer
9
9
  from .models import MLMTrainingModel, JointTrainingModel
10
+ from .ddp import distributed_mean
10
11
 
11
12
  class MLMTrainer(BaseTrainer):
12
13
  def __init__(
@@ -96,8 +97,7 @@ class MLMTrainer(BaseTrainer):
96
97
  acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
97
98
  node_acc = acc.item()
98
99
  if self.use_ddp:
99
- dist.all_reduce(acc, op=dist.ReduceOp.SUM)
100
- acc = acc / dist.get_world_size()
100
+ acc = distributed_mean(acc)
101
101
 
102
102
  metrics = {
103
103
  'accuracy': acc.item(),
@@ -198,8 +198,7 @@ class AutoregressiveTrainer(BaseTrainer):
198
198
  acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
199
199
  node_acc = acc.item()
200
200
  if self.use_ddp:
201
- dist.all_reduce(acc, op=dist.ReduceOp.SUM)
202
- acc = acc / dist.get_world_size()
201
+ acc = distributed_mean(acc)
203
202
 
204
203
  metrics = {
205
204
  'accuracy': acc.item(),
@@ -347,14 +346,10 @@ class JointLMTrainer(BaseTrainer):
347
346
  node_mlm_acc = mlm_acc.item()
348
347
  node_alm_acc = alm_acc.item()
349
348
  if self.use_ddp:
350
- dist.all_reduce(avg_dec_loss, op=dist.ReduceOp.SUM)
351
- dist.all_reduce(avg_enc_loss, op=dist.ReduceOp.SUM)
352
- dist.all_reduce(mlm_acc, op=dist.ReduceOp.SUM)
353
- dist.all_reduce(alm_acc, op=dist.ReduceOp.SUM)
354
- avg_dec_loss = avg_dec_loss / dist.get_world_size()
355
- avg_enc_loss = avg_enc_loss / dist.get_world_size()
356
- mlm_acc = mlm_acc / dist.get_world_size()
357
- alm_acc = alm_acc / dist.get_world_size()
349
+ avg_dec_loss = distributed_mean(avg_dec_loss)
350
+ avg_enc_loss = distributed_mean(avg_enc_loss)
351
+ mlm_acc = distributed_mean(mlm_acc)
352
+ alm_acc = distributed_mean(alm_acc)
358
353
 
359
354
  metrics = {
360
355
  'accuracy': {
@@ -536,6 +536,9 @@ class MrlTrainerCallback:
536
536
  def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
537
537
  pass
538
538
 
539
+ def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
540
+ pass
541
+
539
542
  def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
540
543
  pass
541
544
 
@@ -543,6 +546,9 @@ class MrlTrainerCallback:
543
546
  critic_loss: float) -> None:
544
547
  pass
545
548
 
549
+ def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
550
+ pass
551
+
546
552
  def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
547
553
  pass
548
554
 
@@ -572,6 +578,9 @@ class MrlPrintCallback(MrlTrainerCallback):
572
578
  reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
573
579
  print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
574
580
 
581
+ def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
582
+ print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
583
+
575
584
  def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
576
585
  print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
577
586
 
@@ -579,6 +588,9 @@ class MrlPrintCallback(MrlTrainerCallback):
579
588
  critic_loss: float) -> None:
580
589
  print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
581
590
 
591
+ def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
592
+ print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
593
+
582
594
  def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
583
595
  print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
584
596
 
rxnn/training/ddp.py ADDED
@@ -0,0 +1,26 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+ import os
4
+ from ..utils import set_random_seed
5
+
6
+ def get_os_ddp_config():
7
+ rank = int(os.environ['RANK'])
8
+ world_size = int(os.environ['WORLD_SIZE'])
9
+ return rank, world_size
10
+
11
+ def distributed_mean(x: torch.Tensor) -> torch.Tensor:
12
+ """Average tensor across all devices"""
13
+ x = x.clone()
14
+ dist.all_reduce(x, op=dist.ReduceOp.SUM)
15
+ x /= dist.get_world_size()
16
+ return x
17
+
18
+ def distributed_value_mean(value: float, device: torch.device = None) -> float:
19
+ """Average float value across all devices"""
20
+ tensor = torch.tensor(value, device=device)
21
+ reduced = distributed_mean(tensor)
22
+ return reduced.item()
23
+
24
+ def set_distributed_random_seed(seed: int):
25
+ rank = dist.get_rank() if dist.is_initialized() else get_os_ddp_config()[0]
26
+ set_random_seed(seed + rank)
rxnn/training/mrl.py CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
3
3
  from torch.utils.tensorboard import SummaryWriter
4
4
  import torch.distributed as dist
5
5
  from torch.nn.parallel import DistributedDataParallel
6
- from typing import Optional, TypedDict, Union, TypeAlias, Literal
6
+ from typing import Optional, TypedDict, Union, TypeAlias, Literal, Callable
7
7
  from enum import Enum
8
8
  import random, os
9
9
  from ..transformers.sampler import BatchSampler
@@ -13,7 +13,7 @@ from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
13
13
  from .rl import RlAlgorithm
14
14
  from .reward import MrlRewardMode, MrlRewardModel
15
15
  from .models import MrlActorAction, MrlActorModel, MrlCriticModel
16
-
16
+ from .ddp import get_os_ddp_config, distributed_mean
17
17
 
18
18
  class MrlConfig(TypedDict):
19
19
  lr: float
@@ -24,6 +24,11 @@ class MrlConfig(TypedDict):
24
24
  critic_max_len: int
25
25
  weight_decay: float
26
26
  critic_weight_decay: float
27
+ update_epochs: int
28
+ pad_token_id: int
29
+ end_token_id: int
30
+ callbacks: Optional[list[MrlTrainerCallback]]
31
+ memory_aware_critic: bool
27
32
 
28
33
 
29
34
  class MrlStrategy(Enum):
@@ -31,8 +36,11 @@ class MrlStrategy(Enum):
31
36
  MULTI_STEP_STRATEGY = 2
32
37
  LONG_RANGE_STRATEGY = 3
33
38
 
39
+
40
+ UnfreezeStrategyFn = Callable[[int], None]
34
41
  UnfreezeItem = Union[int, tuple[int, float]]
35
- UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
42
+ UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int], UnfreezeStrategyFn]
43
+
36
44
 
37
45
  class CurriculumConfig(TypedDict):
38
46
  steps: int
@@ -52,6 +60,7 @@ class CurriculumConfig(TypedDict):
52
60
  critic_lr: Optional[float]
53
61
  weight_decay: Optional[float]
54
62
  critic_weight_decay: Optional[float]
63
+ update_epochs: Optional[int]
55
64
 
56
65
 
57
66
  class SamplerConfig(TypedDict):
@@ -66,6 +75,7 @@ class MrlTrajectoryStep(TypedDict):
66
75
  log_probs: torch.Tensor
67
76
  reward: list[float]
68
77
  reference: TokenizedDict
78
+ done: bool
69
79
 
70
80
 
71
81
  class MrlTrajectoryEpisode(TypedDict):
@@ -84,21 +94,25 @@ class MRLTrainer:
84
94
  rl_algorithm: RlAlgorithm,
85
95
  sampler_config: Optional[SamplerConfig] = None,
86
96
  log_dir: str = None,
87
- pad_token_id: int = 0,
88
- end_token_id: int = 3,
89
97
  use_ddp: bool = False,
90
98
  use_amp: bool = False,
91
99
  dtype: torch.dtype = torch.float32,
92
- callbacks: list[MrlTrainerCallback] = None,
93
-
94
100
  ):
95
101
  """
96
- Trainer for Memory Reinforcement Learning (MRL) in Reactive Transformer.
102
+ Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
97
103
 
98
104
  Args:
99
- actor: MRL Actor model with encoder, decoder and memory attention.
100
- critic: Critic network for advantage estimation.
101
- config: Configuration dictionary with hyperparameters.
105
+ actor (MrlActorModel): MRL Actor model with encoder, decoder and memory attention.
106
+ critic (MrlCriticModel): MRL Critic network for advantage estimation.
107
+ reward (MrlRewardModel): MRL Reward model or extension.
108
+ device (torch.device): Device used for training.
109
+ config (MrlConfig): Configuration dictionary with hyperparameters.
110
+ rl_algorithm (RlAlgorithm): Reinforcement Learning algorithm (currently only PPO available).
111
+ sampler_config (SamplerConfig): Sampler configuration.
112
+ log_dir (str): Log directory for TensorBoard logs.
113
+ use_ddp (bool): Use Distributed Data Parallel mode.
114
+ use_amp (bool): Use AMP Autocast for training.
115
+ dtype (torch.dtype): Data type used in training - in AMP mode it's auto cast, otherwise data and model are transformed to this type
102
116
  """
103
117
  self.actor = actor
104
118
  self.critic = critic
@@ -107,6 +121,10 @@ class MRLTrainer:
107
121
  self.device = device
108
122
  self.max_seq_len = config.get('max_seq_len', 256)
109
123
  self.critic_max_len = config.get('critic_max_len', 512)
124
+ self.memory_aware_critic = config.get('memory_aware_critic', False)
125
+ # Internal update epochs config
126
+ self.shared_update_epochs = config.get('update_epochs', 10)
127
+ self.update_epochs = self.shared_update_epochs
110
128
 
111
129
  # Move models to device
112
130
  if use_amp:
@@ -117,14 +135,15 @@ class MRLTrainer:
117
135
  self.critic.to(self.device, dtype=dtype)
118
136
 
119
137
  # Batch Sampler for answer generation
120
- self.generator = BatchSampler(self.actor, self.device, end_token_id=end_token_id)
138
+ self.generator = None
121
139
  self.sampler_config = SamplerConfig(
122
140
  temperature=1.0,
123
141
  top_k=None,
124
142
  top_p=None,
125
143
  ) if sampler_config is None else sampler_config
126
144
 
127
- self.pad_token_id = pad_token_id
145
+ self.pad_token_id = config.get('pad_token_id', 0)
146
+ self.end_token_id = config.get('end_token_id', 3)
128
147
 
129
148
  self.use_ddp = use_ddp
130
149
  self.use_amp = use_amp
@@ -172,7 +191,7 @@ class MRLTrainer:
172
191
  self.eval_dataset = None
173
192
  self.random_resets_ratio = 0.0
174
193
  self.strategy = None
175
- self.shared_callbacks = callbacks if callbacks else []
194
+ self.shared_callbacks = config.get('callbacks', [])
176
195
  self.callbacks = []
177
196
  self.global_epoch = 0
178
197
  self.global_epochs_count = 0
@@ -187,8 +206,8 @@ class MRLTrainer:
187
206
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
188
207
  if memory_lr is not None:
189
208
  optimizer = torch.optim.AdamW([
190
- { 'params': self.actor.not_memory_parameters(), 'lr': lr },
191
- { 'params': self.actor.memory_parameters(), 'lr': memory_lr },
209
+ {'params': self.actor.not_memory_parameters(), 'lr': lr},
210
+ {'params': self.actor.memory_parameters(), 'lr': memory_lr},
192
211
  ],
193
212
  weight_decay=weight_decay,
194
213
  )
@@ -207,12 +226,10 @@ class MRLTrainer:
207
226
 
208
227
  return optimizer, critic_optimizer
209
228
 
210
-
211
229
  def _init_steps(self):
212
230
  return {
213
231
  'collect': 0,
214
- 'critic': 0,
215
- 'rl': 0,
232
+ 'train': 0,
216
233
  'eval': 0,
217
234
  }
218
235
 
@@ -221,9 +238,12 @@ class MRLTrainer:
221
238
  self.epoch_step[step_type] += 1
222
239
  self.stage_step[step_type] += 1
223
240
 
224
- def reset_stm(self) -> bool:
241
+ def reset_stm(self, force: bool = False) -> bool:
225
242
  """Reset Short-Term Memory state with random reset ratio."""
226
- if self.random_resets_ratio == 1.0:
243
+ if force:
244
+ self.actor.reset_memory()
245
+ return True
246
+ elif self.random_resets_ratio == 1.0:
227
247
  self.actor.reset_memory()
228
248
  return True
229
249
  else:
@@ -351,7 +371,7 @@ class MRLTrainer:
351
371
  # state from existing one, instead of new random one)
352
372
  reset_done = self.reset_stm()
353
373
 
354
- # 4. Reset reward prev data running mean - it's calculated for multi-step retention, we have to reset it before episode
374
+ # 4. Reset reward prev data running mean - it's calculated for multistep retention, we have to reset it before episode
355
375
  self.reward.reset_running_mean()
356
376
 
357
377
  # 5. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
@@ -406,6 +426,7 @@ class MRLTrainer:
406
426
  'log_probs': log_probs.detach().cpu(),
407
427
  'reward': reward,
408
428
  'reference': interaction['answer'],
429
+ 'done': is_last_interaction,
409
430
  }
410
431
  episode_steps.append(trajectory)
411
432
  episode_rewards.append(reward)
@@ -432,201 +453,268 @@ class MRLTrainer:
432
453
 
433
454
  return trajectories
434
455
 
435
- def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
436
- # 1. Calculate values with critic encoder
437
- values = self.critic(
438
- inputs['input_ids'],
439
- attention_mask=inputs['attention_mask'],
440
- ).squeeze()
441
- # 2. Calculate critic loss
442
- loss = self.rl_algorithm.critic_loss(values, rewards)
443
- return loss
444
-
445
456
  def _critic_writer(self, critic_loss: float, epoch: int):
446
457
  if self.writer is not None:
447
- self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['critic'])
458
+ self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['train'])
448
459
  self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps}, epoch: {epoch})', critic_loss,
449
- self.epoch_step['critic'])
460
+ self.epoch_step['train'])
450
461
  self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps})', critic_loss,
451
- self.stage_step['critic'])
452
-
453
- def update_critic(self, states: list[tuple[TokenizedDict, TokenizedDict, TokenizedDict]],
454
- rewards: list[torch.Tensor], epoch: int):
455
- """Update critic network using MSE loss."""
456
- # 1. Run critic updates for all collected batches
457
- critic_losses = []
458
- for step_idx, (state, reward) in enumerate(zip(states, rewards)):
459
- self._increment_steps('critic')
460
- # 2. Move state batches to training device (GPU)
461
- prev_query, prev_answer, next_query = self._move_multiple_batches(*state)
462
-
463
- # 3. Reset critic gradients
464
- self.critic_optimizer.zero_grad()
465
-
466
- # 4. Run critic and calculate loss - in autocast on/off mode
467
- if self.use_amp:
468
- # Move tensors to training device and calculate loss in autocast mode
469
- batch_rewards = reward.to(self.device)
470
- with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
471
- # Concatenate state into single critic input sequence
472
- inputs = smart_concat_critic_states(
473
- prev_query, prev_answer, next_query,
474
- max_length=self.critic_max_len,
475
- pad_token_id=self.pad_token_id,
476
- )
477
- loss = self._critic_loss(inputs, batch_rewards)
478
- # Run backpropagation with scaler
479
- self.critic_scaler.scale(loss).backward()
480
- # Unscale and clip gradients
481
- self.critic_scaler.unscale_(self.critic_optimizer)
482
- torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
483
- # Run scaled optimization step
484
- self.critic_scaler.step(self.critic_optimizer)
485
- self.critic_scaler.update()
486
- else:
487
- # Concatenate state into single critic input sequence
488
- inputs = smart_concat_critic_states(
489
- prev_query, prev_answer, next_query,
490
- max_length=self.critic_max_len,
491
- pad_token_id=self.pad_token_id,
492
- )
493
- # Calculate loss
494
- loss = self._critic_loss(inputs, reward.to(self.device, dtype=self.dtype))
495
- # Run backpropagation
496
- loss.backward()
497
- # Clip gradients
498
- torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
499
- # Run optimizer step
500
- self.critic_optimizer.step()
501
- critic_loss = loss.item()
502
- self._critic_writer(critic_loss, epoch)
503
-
504
- # 5. Run "on critic updated" callbacks
505
- for cb in self.callbacks:
506
- cb.on_critic_updated(self.actor, self.critic, epoch, step_idx, critic_loss)
507
-
508
- # 6. Accumulate loss for epoch callbacks
509
- critic_losses.append(critic_loss)
510
-
511
- # 7. Calculate mean loss for epoch callbacks
512
- critic_mean_loss = torch.tensor(critic_losses).mean().item()
513
-
514
- return critic_mean_loss
515
-
516
- def _critic_advantages(self, critic_state: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
517
- with torch.no_grad():
518
- values = self.critic(critic_state['input_ids'],
519
- attention_mask=critic_state['attention_mask']).squeeze()
520
- return self.rl_algorithm.calculate_advantages(rewards, values)
462
+ self.stage_step['train'])
521
463
 
522
464
  def _rl_writer(self, policy_loss: float, epoch: int):
523
465
  if self.writer is not None:
524
- self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['rl'])
466
+ self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['train'])
525
467
  self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps}, epoch: {epoch})', policy_loss,
526
- self.epoch_step['rl'])
527
- self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss, self.stage_step['rl'])
468
+ self.epoch_step['train'])
469
+ self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss,
470
+ self.stage_step['train'])
528
471
 
529
- def rl_step(self, trajectories: list[MrlTrajectoryEpisode], epoch: int):
472
+ def update_critic(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], ref_values: torch.Tensor,
473
+ epoch: int) -> float:
474
+ # 1. Reset critic gradients
475
+ self.critic_optimizer.zero_grad()
476
+
477
+ # 2. Update critic - with autocast on/off
478
+ if self.use_amp:
479
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
480
+ # 2.1 Concat states and calculate critic loss
481
+ critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
482
+ pad_token_id=self.pad_token_id)
483
+ values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
484
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
485
+ # 2.2 Run backpropagation with scaler
486
+ self.critic_scaler.scale(critic_loss).backward()
487
+ # 2.3 Unscale and clip gradients
488
+ self.critic_scaler.unscale_(self.critic_optimizer)
489
+ torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
490
+ # 2.4 Run scaled optimization step
491
+ self.critic_scaler.step(self.critic_optimizer)
492
+ self.critic_scaler.update()
493
+ else:
494
+ # 2.1 Concat states and calculate critic loss
495
+ critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
496
+ pad_token_id=self.pad_token_id)
497
+ values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
498
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
499
+ # 2.2 Run backpropagation
500
+ critic_loss.backward()
501
+ # 2.3 Clip gradients
502
+ torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
503
+ # 2.4 Run optimizer step
504
+ self.critic_optimizer.step()
505
+ # 3. Get float loss value for callbacks/writer
506
+ critic_loss_item = critic_loss.item()
507
+
508
+ # 4. Write to TensorBoard
509
+ self._critic_writer(critic_loss_item, epoch)
510
+
511
+ # 5. Run "on critic updated" callbacks
512
+ for cb in self.callbacks:
513
+ cb.on_critic_updated(self.actor, self.critic, epoch, self.epoch_step['train'], critic_loss_item)
514
+ # 6. Return loss item
515
+ return critic_loss_item
516
+
517
+ def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
518
+ advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
519
+ # 1. Reset actor gradients
520
+ self.optimizer.zero_grad()
521
+ # 2. Unpack state dicts
522
+ query, answer, next_query = state
523
+
524
+ # 3. Encode and update STM on each step, to include encoder and memory attention gradients in loss (skip if it was updated before with memory aware critic)
525
+ if not self.memory_aware_critic:
526
+ self.encode_and_update_stm(query, answer)
527
+
528
+ # 4. Update actor - with autocast on/off
529
+ if self.use_amp:
530
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
531
+ # 4.1 Concatenate next query and action and get action logits from decoder
532
+ inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
533
+ pad_token_id=self.pad_token_id)
534
+ logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
535
+ action=MrlActorAction.DECODE)
536
+ # 4.2 Calculate policy loss with selected algorithm
537
+ policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
538
+ advantages)
539
+ # 4.3 Run backpropagation with scaler
540
+ self.scaler.scale(policy_loss).backward(retain_graph=True)
541
+ # 4.4 Unscale and clip gradient norms
542
+ self.scaler.unscale_(self.optimizer)
543
+ torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
544
+ error_if_nonfinite=False)
545
+ # 4.5 Run scaled optimization step
546
+ self.scaler.step(self.optimizer)
547
+ self.scaler.update()
548
+ else:
549
+ # 4.1 Concatenate next query and action and get action logits from decoder
550
+ inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
551
+ pad_token_id=self.pad_token_id)
552
+ logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
553
+ action=MrlActorAction.DECODE)
554
+ # 4.2 Calculate policy loss with selected algorithm
555
+ policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
556
+ # 4.3 Run backpropagation
557
+ policy_loss.backward(retain_graph=True)
558
+ # 4.4 Clip gradient norms
559
+ torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
560
+ error_if_nonfinite=False)
561
+ # 4.5 Run scaled optimization step
562
+ self.optimizer.step()
563
+ # 5. Get float loss value for callbacks/writer
564
+ policy_loss_item = policy_loss.item()
565
+
566
+ # 6. Write to TensorBoard
567
+ self._rl_writer(policy_loss_item, epoch)
568
+
569
+ # 7. Run "on batch updated" callback
570
+ for cb in self.callbacks:
571
+ cb.on_batch_updated(self.actor, epoch, self.epoch_step['train'], policy_loss_item)
572
+
573
+ # 8. Return loss item
574
+ return policy_loss_item
575
+
576
+ def rl_step(self, trajectories: list[MrlTrajectoryEpisode], advantages: torch.Tensor, ref_values: torch.Tensor,
577
+ epoch: int, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]:
530
578
  """Perform PPO update step using trajectories."""
531
579
  # 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
532
580
  # memory, based on collected episode data
533
581
  all_losses = []
534
- trajectories_len = len(trajectories)
582
+ critic_losses = []
535
583
  for episode_idx, episode in enumerate(trajectories):
536
584
  episode_steps = episode['steps']
537
585
  should_reset_stm = episode['reset_stm']
538
586
 
539
- # 2. Reset memory for current batch episode
587
+ # 2. Get advantages and reference values for current full episode (batch_size * episode_steps)
588
+ num_steps = len(episode_steps)
589
+ start = episode_idx * num_steps
590
+ end = start + num_steps
591
+ episode_critic_values = ref_values[start:end]
592
+ episode_advantages = advantages[start:end]
593
+
594
+ # 3. Reset memory for current batch episode
540
595
  if should_reset_stm:
541
- self.reset_stm()
596
+ self.reset_stm(force=True)
542
597
 
543
- # 3. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
544
- for step in episode_steps:
545
- self._increment_steps('rl')
546
- state, action, reward, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
598
+ # 4. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
599
+ for step_idx, step in enumerate(episode_steps):
600
+ self._increment_steps('train')
601
+ # 5. Get and move to device collected states, action and log probs
602
+ state, action, _, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
547
603
  query, answer, next_query = self._move_multiple_batches(*state)
548
604
  action = self._move_batch(action)
549
605
  log_probs = log_probs.to(self.device)
550
- rewards = torch.tensor(reward).to(self.device)
551
-
552
- # 4. Compute advantages using critic
553
- if self.use_amp:
554
- with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
555
- critic_state = smart_concat_critic_states(query, answer, next_query,
556
- max_length=self.critic_max_len,
557
- pad_token_id=self.pad_token_id)
558
- advantages = self._critic_advantages(critic_state, rewards)
559
- else:
560
- critic_state = smart_concat_critic_states(query, answer, next_query, max_length=self.critic_max_len,
561
- pad_token_id=self.pad_token_id)
562
- advantages = self._critic_advantages(critic_state, rewards)
563
-
564
- # 5. Encode and update STM on each step, to include encoder and memory attention gradients in loss
565
- self.encode_and_update_stm(query, answer)
566
- # 6. Concatenate next query and action and get action logits from decoder
567
- if self.use_amp:
568
- with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
569
- inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
570
- pad_token_id=self.pad_token_id)
571
- logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
572
- action=MrlActorAction.DECODE)
573
- else:
574
- inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
575
- pad_token_id=self.pad_token_id)
576
- logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
577
- action=MrlActorAction.DECODE)
578
-
579
- # 7. Calculate RL Algorithm (PPO etc.) loss
580
- policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, log_probs, advantages)
581
-
582
- # 8. Reset gradients
583
- self.optimizer.zero_grad()
584
-
585
- # 9. Update the model in AMP or regular mode
586
- if self.use_amp:
587
- self.scaler.scale(policy_loss).backward(retain_graph=True)
588
- self.scaler.unscale_(self.optimizer)
589
- torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
590
- error_if_nonfinite=False)
591
- self.scaler.step(self.optimizer)
592
- self.scaler.update()
593
- else:
594
- policy_loss.backward(retain_graph=True)
595
- torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
596
- error_if_nonfinite=False)
597
- self.optimizer.step()
598
606
 
599
- policy_loss_item = policy_loss.item()
600
- self._rl_writer(policy_loss_item, epoch)
601
- all_losses.append(policy_loss_item)
607
+ # 6. Select advantages and reference values for current step (batch_size)
608
+ step_critic_values = episode_critic_values[step_idx]
609
+ step_advantages = episode_advantages[step_idx]
602
610
 
603
- # 10. Run "on batch updated" callback
604
- for cb in self.callbacks:
605
- cb.on_batch_updated(self.actor, epoch, self.epoch_step['rl'], policy_loss_item)
611
+ # 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
612
+ if self.memory_aware_critic:
613
+ self.encode_and_update_stm(query, answer)
614
+
615
+ # 7. Update critic
616
+ critic_loss_item = self.update_critic((query, answer, next_query), step_critic_values, epoch)
606
617
 
607
- return torch.mean(torch.tensor(all_losses)).item()
618
+ # 8. Accumulate critic loss for epoch callbacks
619
+ critic_losses.append(critic_loss_item)
608
620
 
609
- def _critic_states_and_rewards(self, trajectories: list[MrlTrajectoryEpisode]):
610
- flat_trajectories = [t for episode in trajectories for t in episode['steps']]
611
- states = [t['state'] for t in flat_trajectories]
612
- rewards = [torch.tensor(t['reward']) for t in flat_trajectories]
613
- return states, rewards
621
+ # 9. Update actor
622
+ policy_loss_item = self.update_actor((query, answer, next_query), action, step_advantages, log_probs,
623
+ epoch)
624
+ all_losses.append(policy_loss_item)
625
+ # 10. Return mean losses for epoch callbacks
626
+ return torch.mean(torch.tensor(all_losses)), torch.mean(torch.tensor(critic_losses))
627
+
628
+ def _critic_values_rewards_and_dones(self, trajectories: list[MrlTrajectoryEpisode]):
629
+ if self.memory_aware_critic:
630
+ flat_trajectories = [
631
+ (t, i == 0 and episode['reset_stm'])
632
+ for episode in trajectories
633
+ for i, t in enumerate(episode['steps'])
634
+ ]
635
+ values = torch.stack([
636
+ self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in flat_trajectories
637
+ ]).to(self.device)
638
+ rewards = torch.stack([torch.tensor(t['reward']) for t, _ in flat_trajectories]).to(self.device)
639
+ dones = torch.stack([torch.tensor(t['done']) for t, _ in flat_trajectories]).to(self.device)
640
+ else:
641
+ flat_trajectories = [t for episode in trajectories for t in episode['steps']]
642
+ values = torch.stack([
643
+ self._critic_values(*self._move_multiple_batches(*t['state'])) for t in flat_trajectories
644
+ ]).to(self.device)
645
+ rewards = torch.stack([torch.tensor(t['reward']) for t in flat_trajectories]).to(self.device)
646
+ dones = torch.stack([torch.tensor(t['done']) for t in flat_trajectories]).to(self.device)
647
+ return values, rewards, dones
648
+
649
+ def _critic_values_with_memory(self, reset_stm: bool, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
650
+ # 1. Calculate critic values in memory aware version - reset/update STM before calculating values
651
+ with torch.no_grad():
652
+ # 2. Reset STM if it was reset in trajectory collection
653
+ if reset_stm:
654
+ self.reset_stm(force=True)
655
+ # 3. Encode and update STM for critic
656
+ self.encode_and_update_stm(*moved_state)
657
+ # 4. Get concatenated critic states
658
+ inputs = smart_concat_critic_states(
659
+ *moved_state,
660
+ max_length=self.critic_max_len,
661
+ pad_token_id=self.pad_token_id,
662
+ )
663
+ # 5. Calculate values for current batch
664
+ return self.critic(inputs['input_ids'],
665
+ attention_mask=inputs['attention_mask']).squeeze()
666
+
667
+ def _critic_values(self, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
668
+ # 1. Calculate critic values
669
+ with torch.no_grad():
670
+ # 2. Get concatenated critic states
671
+ inputs = smart_concat_critic_states(
672
+ *moved_state,
673
+ max_length=self.critic_max_len,
674
+ pad_token_id=self.pad_token_id,
675
+ )
676
+ # 3. Calculate values for current batch
677
+ return self.critic(inputs['input_ids'],
678
+ attention_mask=inputs['attention_mask']).squeeze()
614
679
 
615
680
  def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
616
681
  """Train for one epoch."""
617
682
  # 1. Collect trajectories for current epoch
618
683
  trajectories = self.collect_trajectories(dataloader, epoch, batch_size)
619
684
 
620
- # 2. Flatten trajectories and collect state and rewards for critic update
621
- states, rewards = self._critic_states_and_rewards(trajectories)
622
- # 3. Update critic model, based on states and rewards
623
- critic_loss = self.update_critic(states, rewards, epoch)
685
+ # 2. Flatten trajectories, call critic and collect values, dones and rewards, and calculate advantages
686
+ if self.use_amp:
687
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
688
+ values, rewards, dones = self._critic_values_rewards_and_dones(trajectories)
689
+ advantages, ref_values = self.rl_algorithm.calculate_advantages(rewards, values, dones)
690
+ else:
691
+ values, rewards, dones = self._critic_values_rewards_and_dones(trajectories)
692
+ advantages, ref_values = self.rl_algorithm.calculate_advantages(rewards, values, dones)
693
+
694
+ # 3. Run internal update epochs
695
+ critic_loss_sum, policy_loss_sum = 0.0, 0.0
696
+ for update_epoch in range(self.update_epochs):
697
+ # 4. Run 'on update epoch start' callbacks
698
+ for cb in self.callbacks:
699
+ cb.on_update_epoch_start(self.actor, self.critic, epoch, update_epoch)
700
+
701
+ # 5. Run RL algorithm step
702
+ policy_loss, critic_loss = self.rl_step(trajectories[:-1], advantages, ref_values, epoch, batch_size)
703
+
704
+ if self.use_ddp:
705
+ policy_loss = distributed_mean(policy_loss)
706
+ critic_loss = distributed_mean(critic_loss)
707
+
708
+ # 6. Run 'on update epoch end' callbacks
709
+ for cb in self.callbacks:
710
+ cb.on_update_epoch_end(self.actor, self.critic, epoch, update_epoch, policy_loss, critic_loss)
624
711
 
625
- # 4. Run PPO algorithm step
626
- policy_loss = self.rl_step(trajectories, epoch)
712
+ # 7. Accumulate losses for epoch callbacks
713
+ critic_loss_sum += critic_loss
714
+ policy_loss_sum += policy_loss
627
715
 
628
- # 5. Return policy and critic mean losses for epoch callbacks
629
- return policy_loss, critic_loss
716
+ # 8. Return policy and critic mean losses for epoch callbacks
717
+ return policy_loss_sum / self.update_epochs, critic_loss_sum / self.update_epochs
630
718
 
631
719
  def _eval_loader(self, batch_size: int):
632
720
  if self.use_ddp:
@@ -731,12 +819,11 @@ class MRLTrainer:
731
819
  cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
732
820
 
733
821
  # 15. Calculate average reward
822
+ avg_reward = (total_reward / count) if count > 0 else torch.tensor(0.0).to(self.device)
734
823
  if self.use_ddp:
735
- total_sum = dist.all_reduce(total_reward, dist.ReduceOp.SUM)
736
- count_sum = dist.all_reduce(count, dist.ReduceOp.SUM)
737
- avg_reward = (total_sum / count_sum).item() if count_sum > 0 else 0
738
- else:
739
- avg_reward = (total_reward / count).item() if count > 0 else 0
824
+ avg_reward = distributed_mean(avg_reward)
825
+
826
+ avg_reward = avg_reward.item()
740
827
 
741
828
  should_stop_stage = False
742
829
  # 16. Run "on eval end" callbacks
@@ -747,55 +834,6 @@ class MRLTrainer:
747
834
 
748
835
  return should_stop_stage
749
836
 
750
- def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
751
- # 1. Set common fields based on config
752
- self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
753
- self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
754
- self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
755
- self.callbacks = config.get('callbacks',
756
- self.shared_callbacks) # trainer callbacks for current curriculum stage
757
- self.strategy = config.get('strategy',
758
- MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
759
- self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
760
- if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
761
- if config.get('separate_memory_lr', False):
762
- self.optim_config = {
763
- 'lr': config.get('lr', self.base_optim_config['lr']),
764
- 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
765
- 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
766
- 'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
767
- 'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
768
- }
769
- else:
770
- self.optim_config = {
771
- 'lr': config.get('lr', self.base_optim_config['lr']),
772
- 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
773
- 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
774
- 'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
775
- }
776
- self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
777
- elif self.optim_config != self.base_optim_config:
778
- self.optim_config = self.base_optim_config
779
- self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
780
-
781
-
782
-
783
-
784
- # 2. Get epochs and random resets configs
785
- epochs = config.get('epochs', 5) # number of epochs for current stage
786
- unfreeze_epoch = config.get('unfreeze_epoch',
787
- 0) # epoch when components (other than memory) are unfrozen (before epoch starts)
788
- random_resets = config.get('random_resets',
789
- False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
790
- random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
791
- random_resets_ratio = config.get('random_resets_ratio',
792
- None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
793
-
794
- # 3. Reset stage step counter
795
- self.stage_step = self._init_steps()
796
-
797
- return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
798
-
799
837
  def _apply_unfreeze_strategy(self, epoch: int, unfreeze_epoch: UnfreezeEpochsStrategy):
800
838
  is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
801
839
  if is_staged_unfreeze:
@@ -808,28 +846,31 @@ class MRLTrainer:
808
846
  self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
809
847
  print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
810
848
  elif epoch == update_epoch:
811
- self.actor.freeze_components('update')
812
- print(f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
849
+ self.actor.freeze_components('update')
850
+ print(
851
+ f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
813
852
 
814
853
  if isinstance(fetch_epoch, tuple):
815
854
  switch_epoch, mem_att_lr = fetch_epoch
816
- if epoch == fetch_epoch:
855
+ if epoch == switch_epoch:
817
856
  self.actor.freeze_components('joint')
818
857
  self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
819
858
  print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
820
859
  elif epoch == fetch_epoch:
821
860
  self.actor.freeze_components('fetch')
822
- print(f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
861
+ print(
862
+ f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
823
863
 
824
864
  if isinstance(joint_epoch, tuple):
825
865
  switch_epoch, model_lr = joint_epoch
826
- if epoch == joint_epoch:
866
+ if epoch == switch_epoch:
827
867
  self.actor.unfreeze_components()
828
868
  self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
829
869
  print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
830
870
  elif epoch == joint_epoch:
831
- self.actor.freeze_components('joint')
832
- print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
871
+ self.actor.freeze_components('joint')
872
+ print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
873
+
833
874
  if epoch == all_epoch:
834
875
  self.actor.unfreeze_components()
835
876
  self.optimizer = self._init_unfreeze_optimizer('all', 0.)
@@ -871,6 +912,56 @@ class MRLTrainer:
871
912
 
872
913
  return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
873
914
 
915
+ def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[
916
+ tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
917
+ # 1. Set common fields based on config
918
+ self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
919
+ self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
920
+ self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
921
+ self.callbacks = config.get('callbacks',
922
+ self.shared_callbacks) # trainer callbacks for current curriculum stage
923
+ self.strategy = config.get('strategy',
924
+ MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
925
+ self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
926
+ self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
927
+ if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
928
+ 'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
929
+ if config.get('separate_memory_lr', False):
930
+ self.optim_config = {
931
+ 'lr': config.get('lr', self.base_optim_config['lr']),
932
+ 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
933
+ 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
934
+ 'critic_weight_decay': config.get('critic_weight_decay',
935
+ self.base_optim_config['critic_weight_decay']),
936
+ 'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
937
+ }
938
+ else:
939
+ self.optim_config = {
940
+ 'lr': config.get('lr', self.base_optim_config['lr']),
941
+ 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
942
+ 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
943
+ 'critic_weight_decay': config.get('critic_weight_decay',
944
+ self.base_optim_config['critic_weight_decay']),
945
+ }
946
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
947
+ elif self.optim_config != self.base_optim_config:
948
+ self.optim_config = self.base_optim_config
949
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
950
+
951
+ # 2. Get epochs and random resets configs
952
+ epochs = config.get('epochs', 5) # number of epochs for current stage
953
+ unfreeze_epoch = config.get('unfreeze_epoch',
954
+ 0) # epoch when components (other than memory) are unfrozen (before epoch starts)
955
+ random_resets = config.get('random_resets',
956
+ False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
957
+ random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
958
+ random_resets_ratio = config.get('random_resets_ratio',
959
+ None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
960
+
961
+ # 3. Reset stage step counter
962
+ self.stage_step = self._init_steps()
963
+
964
+ return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
874
965
 
875
966
  def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
876
967
  """Start Memory Reinforcement Learning Curriculum."""
@@ -881,29 +972,36 @@ class MRLTrainer:
881
972
 
882
973
  # 1. Init DDP for distributed training mode
883
974
  if self.use_ddp:
884
- rank = int(os.environ['RANK'])
885
- world_size = int(os.environ['WORLD_SIZE'])
975
+ rank, world_size = get_os_ddp_config()
886
976
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
887
977
  self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index])
888
978
  self.critic = DistributedDataParallel(self.critic, device_ids=[self.device.index])
889
979
 
890
- # 2. Run each curriculum step based on config
980
+ # 2. Init BatchSampler with actor model (we have to run it after DDP init)
981
+ self.generator = BatchSampler(self.actor, self.device, end_token_id=self.end_token_id)
982
+
983
+ # 3. Run each curriculum step based on config
891
984
  for current_curriculum_step in curriculum_config:
892
- # 3. Setup training config for curriculum step
985
+ # 4. Setup training config for curriculum step
893
986
  epochs_config, random_resets_config = self._setup_curriculum_step(current_curriculum_step)
894
987
  epochs, unfreeze_epoch = epochs_config
895
988
  random_resets, random_resets_from, random_resets_ratio = random_resets_config
896
989
  assert self.train_dataset is not None
897
990
 
898
- # 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
991
+ # 5. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
899
992
  if unfreeze_epoch != 0:
900
- self.actor.freeze_components('joint')
901
- if isinstance(unfreeze_epoch, tuple):
902
- print(f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
993
+ if callable(unfreeze_epoch):
994
+ unfreeze_epoch(-1)
903
995
  else:
904
- print(f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
905
-
906
- # 5. Setup train DataLoader
996
+ self.actor.freeze_components('joint')
997
+ if isinstance(unfreeze_epoch, tuple):
998
+ print(
999
+ f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
1000
+ else:
1001
+ print(
1002
+ f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
1003
+
1004
+ # 6. Setup train DataLoader
907
1005
  if self.use_ddp:
908
1006
  train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
909
1007
  dataloader = DataLoader(
@@ -912,6 +1010,7 @@ class MRLTrainer:
912
1010
  sampler=train_sampler,
913
1011
  pin_memory=True,
914
1012
  collate_fn=MrlCurriculumDataset.collate_mrl_batch,
1013
+ drop_last=True,
915
1014
  )
916
1015
  else:
917
1016
  train_sampler = None
@@ -923,65 +1022,68 @@ class MRLTrainer:
923
1022
  collate_fn=MrlCurriculumDataset.collate_mrl_batch,
924
1023
  )
925
1024
 
926
- # 6. Run selected number of epochs for given curriculum stage
1025
+ # 7. Run selected number of epochs for given curriculum stage
927
1026
  for epoch in range(epochs):
928
- # 7. Increment global epoch
1027
+ # 8. Increment global epoch
929
1028
  self.global_epoch += 1
930
- # 8. Run "on epoch start" callbacks (log info, etc.)
1029
+ # 9. Run "on epoch start" callbacks (log info, etc.)
931
1030
  for cb in self.callbacks:
932
1031
  cb.on_epoch_start(self.actor, epoch, epochs, current_curriculum_step, self.global_epoch,
933
1032
  self.global_epochs_count)
934
1033
 
935
- # 9. Reset steps counter for epoch
1034
+ # 10. Reset steps counter for epoch
936
1035
  self.epoch_step = self._init_steps()
937
1036
 
938
- # 10. Set random STM resets ratio from selected epoch
1037
+ # 11. Set random STM resets ratio from selected epoch
939
1038
  if random_resets and random_resets_from <= epoch:
940
1039
  self.random_resets_ratio = random_resets_ratio
941
1040
  else:
942
1041
  self.random_resets_ratio = 1.0
943
1042
 
944
- # 11. Apply the unfreeze strategy
945
- self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
1043
+ # 12. Apply the unfreeze strategy
1044
+ if callable(unfreeze_epoch):
1045
+ unfreeze_epoch(epoch)
1046
+ else:
1047
+ self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
946
1048
 
947
- # 12. Set epoch for distributed sampler
1049
+ # 13. Set epoch for distributed sampler
948
1050
  if train_sampler is not None:
949
1051
  train_sampler.set_epoch(epoch)
950
1052
 
951
- # 13. Run reinforcement learning algorithms for current epoch
1053
+ # 14. Run reinforcement learning algorithms for current epoch
952
1054
  policy_loss, critic_loss = self.train_epoch(dataloader, epoch, batch_size)
953
1055
 
954
- # 14. If evaluation dataset is provided, run evaluation steps
1056
+ # 15. If evaluation dataset is provided, run evaluation steps
955
1057
  if self.eval_dataset:
956
1058
  should_stop_stage = self.evaluate(batch_size, epoch)
957
1059
  else:
958
1060
  should_stop_stage = False
959
1061
 
960
- # 15. Finally, run "on epoch end" callbacks (save models, etc.)
1062
+ # 16. Finally, run "on epoch end" callbacks (save models, etc.)
961
1063
  for cb in self.callbacks:
962
1064
  cb.on_epoch_end(self.actor, epoch, epochs, policy_loss, critic_loss, self.global_epoch,
963
1065
  self.global_epochs_count)
964
1066
 
965
- # 16. Synchronize TensorBoard writer
1067
+ # 17. Synchronize TensorBoard writer
966
1068
  if self.writer:
967
1069
  self.writer.flush()
968
1070
 
969
- # 17. Synchronize devices in DDP mode
1071
+ # 18. Synchronize devices in DDP mode
970
1072
  if self.use_ddp:
971
1073
  dist.barrier()
972
1074
 
973
- # 18. Finish curriculum stage if rewards are not increased or reached threshold point
1075
+ # 19. Finish curriculum stage if rewards are not increased or reached threshold point
974
1076
  if should_stop_stage:
975
1077
  break
976
1078
 
977
- # 19. Run "on_training_end" callbacks after each curriculum stage (they have own callbacks)
1079
+ # 20. Run "on_training_end" callbacks after each curriculum stage (they have own callbacks)
978
1080
  for cb in self.callbacks:
979
1081
  cb.on_training_end(self.actor, self.critic, current_curriculum_step)
980
1082
 
981
- # 20. Training end - finish processes after all curriculum stages
1083
+ # 21. Training end - finish processes after all curriculum stages
982
1084
  if self.use_ddp:
983
1085
  dist.destroy_process_group()
984
1086
 
985
- # 21. Close writer
1087
+ # 22. Close writer
986
1088
  if self.writer:
987
1089
  self.writer.close()
rxnn/training/rl.py CHANGED
@@ -2,8 +2,9 @@ import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
4
  from abc import ABC, abstractmethod
5
- from typing import TypedDict
5
+ from typing import TypedDict, Optional
6
6
  from .utils import TokenizedDict
7
+ from .ddp import distributed_mean
7
8
 
8
9
 
9
10
  class RlAlgorithm(ABC):
@@ -17,21 +18,34 @@ class RlAlgorithm(ABC):
17
18
  pass
18
19
 
19
20
  @abstractmethod
20
- def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
21
+ def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
21
22
  pass
22
23
 
23
24
  def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
24
25
  return self.critic_loss(rewards, values)
25
26
 
27
+
26
28
  class PPOConfig(TypedDict):
27
- clip_eps: float
29
+ clip_eps: Optional[float]
30
+ gae_lambda: Optional[float]
31
+ gae_gamma: Optional[float]
32
+ entropy_coef: Optional[float]
33
+ use_distributed_advantage_norm: Optional[bool]
34
+
28
35
 
29
36
  class PPOAlgorithm(RlAlgorithm):
30
- def __init__(self, config: PPOConfig):
37
+ def __init__(self, config: Optional[PPOConfig] = None):
31
38
  super(PPOAlgorithm, self).__init__()
32
39
 
40
+ if config is None:
41
+ config = {}
42
+
33
43
  # PPO Config
34
44
  self.clip_eps = config.get('clip_eps', 0.2)
45
+ self.gae_lambda = config.get('gae_lambda', 0.95)
46
+ self.gae_gamma = config.get('gae_gamma', 0.99)
47
+ self.entropy_coef = config.get('entropy_coef', 0.01)
48
+ self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
35
49
 
36
50
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
37
51
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -78,11 +92,37 @@ class PPOAlgorithm(RlAlgorithm):
78
92
 
79
93
  # d) Entropy bonus
80
94
  entropy = -torch.sum(new_probs * new_probs.exp(), dim=-1).mean()
81
- policy_loss -= 0.01 * entropy
95
+ policy_loss -= self.entropy_coef * entropy
82
96
 
83
97
  return policy_loss
84
98
 
85
- def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
86
- advantages = rewards - values
87
- normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
88
- return normalized_advantages
99
+ def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
100
+ T, B = rewards.shape
101
+ advantages = torch.zeros_like(rewards, device=values.device)
102
+ last_advantage = 0
103
+ last_value = next_value.detach()
104
+
105
+ for t in reversed(range(T)):
106
+ if t == T - 1:
107
+ next_values = last_value
108
+ else:
109
+ next_values = values[t + 1]
110
+
111
+ # Mask next values if episode ended
112
+ next_values = next_values * ~dones[t]
113
+ delta = rewards[t] + self.gae_gamma * next_values - values[t]
114
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
115
+ last_advantage = advantages[t]
116
+
117
+ returns = advantages + values
118
+ return advantages, returns
119
+
120
+ def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
121
+ advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1])
122
+ if self.use_distributed_advantage_norm:
123
+ mean_advantage = distributed_mean(advantages.mean())
124
+ std_advantage = distributed_mean(advantages.std())
125
+ normalized_advantages = (advantages - mean_advantage) / (std_advantage + 1e-8)
126
+ else:
127
+ normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
128
+ return normalized_advantages, ref_values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.29
3
+ Version: 0.2.31
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -11,14 +11,15 @@ rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=r8wZeeNTC2VAhiiNe4y7LrbnB4wjFu_cupKiGkpdgjI,13002
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
- rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
- rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
14
+ rxnn/training/base.py,sha256=TGz_37RfI1qLI31GNRV5rLowW1kAHnJwqPm7DNfLfe4,11730
15
+ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
+ rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
+ rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
18
19
  rxnn/training/models.py,sha256=2KhNT7yx0AgUke4nmsFqzQKx_YYp78QvsLWYZjWeUgQ,6812
19
- rxnn/training/mrl.py,sha256=MnLaYWxblc5cF261R5PNjIvddVQVNxyjAkEYtchBn9E,49299
20
+ rxnn/training/mrl.py,sha256=Aimiiqf_4p6dp5Ty9pY9VwetySBS_OFpCQlcVHVkO4Q,55124
20
21
  rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
21
- rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
22
+ rxnn/training/rl.py,sha256=eL3C0yryiNBgl_xb-D-5dyYUtK4V4-K4t3a60x5ir28,5142
22
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
25
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -32,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.29.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.29.dist-info/METADATA,sha256=WVEyKmyYbMOb5sm7vjjnCN9j8ABz0QfGJCYkQbWvwT8,25960
37
- rxnn-0.2.29.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.29.dist-info/RECORD,,
36
+ rxnn-0.2.31.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.31.dist-info/METADATA,sha256=zxD2qPHL_QrFH1bYZrMv4odbXE4B_YIVEpGDzV2MYEI,25960
38
+ rxnn-0.2.31.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.31.dist-info/RECORD,,
File without changes
File without changes