rxnn 0.2.30__tar.gz → 0.2.31__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.30 → rxnn-0.2.31}/PKG-INFO +1 -1
  2. {rxnn-0.2.30 → rxnn-0.2.31}/pyproject.toml +1 -1
  3. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/base.py +4 -5
  4. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/bml.py +7 -12
  5. rxnn-0.2.31/src/rxnn/training/ddp.py +26 -0
  6. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/mrl.py +161 -104
  7. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/rl.py +21 -8
  8. {rxnn-0.2.30 → rxnn-0.2.31}/LICENSE +0 -0
  9. {rxnn-0.2.30 → rxnn-0.2.31}/README.md +0 -0
  10. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/.DS_Store +0 -0
  11. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/__init__.py +0 -0
  12. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/experimental/__init__.py +0 -0
  13. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/experimental/attention.py +0 -0
  14. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/experimental/models.py +0 -0
  15. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/experimental/moe.py +0 -0
  16. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/memory/__init__.py +0 -0
  17. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/memory/attention.py +0 -0
  18. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/memory/norm.py +0 -0
  19. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/memory/stm.py +0 -0
  20. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/rxt/__init__.py +0 -0
  21. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/rxt/models.py +0 -0
  22. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/__init__.py +0 -0
  23. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/callbacks.py +0 -0
  24. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/dataset.py +0 -0
  25. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/models.py +0 -0
  26. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.30 → rxnn-0.2.31}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.30
3
+ Version: 0.2.31
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.30"
7
+ version = "0.2.31"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -8,6 +8,7 @@ import torch.distributed as dist
8
8
  from torch.nn.parallel import DistributedDataParallel
9
9
  from typing import Callable
10
10
  from .callbacks import TrainerCallback
11
+ from .ddp import get_os_ddp_config, distributed_value_mean
11
12
 
12
13
 
13
14
  class BaseTrainer(ABC):
@@ -91,8 +92,7 @@ class BaseTrainer(ABC):
91
92
  optimizer = self.optimizer
92
93
 
93
94
  if self.use_ddp:
94
- rank = int(os.environ['RANK'])
95
- world_size = int(os.environ['WORLD_SIZE'])
95
+ rank, world_size = get_os_ddp_config()
96
96
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
97
97
  self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
98
98
  train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
@@ -218,10 +218,9 @@ class BaseTrainer(ABC):
218
218
  if self.validation_dataset:
219
219
  self.validation_steps = 0
220
220
  val_loss, val_metrics = self.validate(batch_size)
221
- val_loss_tensor = torch.tensor(val_loss).to(self.device)
222
221
  if self.use_ddp:
223
- dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
224
- val_loss = val_loss_tensor.item() / dist.get_world_size()
222
+ val_loss = distributed_value_mean(val_loss, device=self.device)
223
+
225
224
  self.validation_metrics[epoch] = val_metrics
226
225
 
227
226
  if self.writer:
@@ -7,6 +7,7 @@ import torch.distributed as dist
7
7
  from ..transformers.models import ReactiveTransformerDecoder
8
8
  from ..training.base import BaseTrainer
9
9
  from .models import MLMTrainingModel, JointTrainingModel
10
+ from .ddp import distributed_mean
10
11
 
11
12
  class MLMTrainer(BaseTrainer):
12
13
  def __init__(
@@ -96,8 +97,7 @@ class MLMTrainer(BaseTrainer):
96
97
  acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
97
98
  node_acc = acc.item()
98
99
  if self.use_ddp:
99
- dist.all_reduce(acc, op=dist.ReduceOp.SUM)
100
- acc = acc / dist.get_world_size()
100
+ acc = distributed_mean(acc)
101
101
 
102
102
  metrics = {
103
103
  'accuracy': acc.item(),
@@ -198,8 +198,7 @@ class AutoregressiveTrainer(BaseTrainer):
198
198
  acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
199
199
  node_acc = acc.item()
200
200
  if self.use_ddp:
201
- dist.all_reduce(acc, op=dist.ReduceOp.SUM)
202
- acc = acc / dist.get_world_size()
201
+ acc = distributed_mean(acc)
203
202
 
204
203
  metrics = {
205
204
  'accuracy': acc.item(),
@@ -347,14 +346,10 @@ class JointLMTrainer(BaseTrainer):
347
346
  node_mlm_acc = mlm_acc.item()
348
347
  node_alm_acc = alm_acc.item()
349
348
  if self.use_ddp:
350
- dist.all_reduce(avg_dec_loss, op=dist.ReduceOp.SUM)
351
- dist.all_reduce(avg_enc_loss, op=dist.ReduceOp.SUM)
352
- dist.all_reduce(mlm_acc, op=dist.ReduceOp.SUM)
353
- dist.all_reduce(alm_acc, op=dist.ReduceOp.SUM)
354
- avg_dec_loss = avg_dec_loss / dist.get_world_size()
355
- avg_enc_loss = avg_enc_loss / dist.get_world_size()
356
- mlm_acc = mlm_acc / dist.get_world_size()
357
- alm_acc = alm_acc / dist.get_world_size()
349
+ avg_dec_loss = distributed_mean(avg_dec_loss)
350
+ avg_enc_loss = distributed_mean(avg_enc_loss)
351
+ mlm_acc = distributed_mean(mlm_acc)
352
+ alm_acc = distributed_mean(alm_acc)
358
353
 
359
354
  metrics = {
360
355
  'accuracy': {
@@ -0,0 +1,26 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+ import os
4
+ from ..utils import set_random_seed
5
+
6
+ def get_os_ddp_config():
7
+ rank = int(os.environ['RANK'])
8
+ world_size = int(os.environ['WORLD_SIZE'])
9
+ return rank, world_size
10
+
11
+ def distributed_mean(x: torch.Tensor) -> torch.Tensor:
12
+ """Average tensor across all devices"""
13
+ x = x.clone()
14
+ dist.all_reduce(x, op=dist.ReduceOp.SUM)
15
+ x /= dist.get_world_size()
16
+ return x
17
+
18
+ def distributed_value_mean(value: float, device: torch.device = None) -> float:
19
+ """Average float value across all devices"""
20
+ tensor = torch.tensor(value, device=device)
21
+ reduced = distributed_mean(tensor)
22
+ return reduced.item()
23
+
24
+ def set_distributed_random_seed(seed: int):
25
+ rank = dist.get_rank() if dist.is_initialized() else get_os_ddp_config()[0]
26
+ set_random_seed(seed + rank)
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
3
3
  from torch.utils.tensorboard import SummaryWriter
4
4
  import torch.distributed as dist
5
5
  from torch.nn.parallel import DistributedDataParallel
6
- from typing import Optional, TypedDict, Union, TypeAlias, Literal
6
+ from typing import Optional, TypedDict, Union, TypeAlias, Literal, Callable
7
7
  from enum import Enum
8
8
  import random, os
9
9
  from ..transformers.sampler import BatchSampler
@@ -13,7 +13,7 @@ from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
13
13
  from .rl import RlAlgorithm
14
14
  from .reward import MrlRewardMode, MrlRewardModel
15
15
  from .models import MrlActorAction, MrlActorModel, MrlCriticModel
16
-
16
+ from .ddp import get_os_ddp_config, distributed_mean
17
17
 
18
18
  class MrlConfig(TypedDict):
19
19
  lr: float
@@ -25,6 +25,10 @@ class MrlConfig(TypedDict):
25
25
  weight_decay: float
26
26
  critic_weight_decay: float
27
27
  update_epochs: int
28
+ pad_token_id: int
29
+ end_token_id: int
30
+ callbacks: Optional[list[MrlTrainerCallback]]
31
+ memory_aware_critic: bool
28
32
 
29
33
 
30
34
  class MrlStrategy(Enum):
@@ -33,8 +37,9 @@ class MrlStrategy(Enum):
33
37
  LONG_RANGE_STRATEGY = 3
34
38
 
35
39
 
40
+ UnfreezeStrategyFn = Callable[[int], None]
36
41
  UnfreezeItem = Union[int, tuple[int, float]]
37
- UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
42
+ UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int], UnfreezeStrategyFn]
38
43
 
39
44
 
40
45
  class CurriculumConfig(TypedDict):
@@ -89,21 +94,25 @@ class MRLTrainer:
89
94
  rl_algorithm: RlAlgorithm,
90
95
  sampler_config: Optional[SamplerConfig] = None,
91
96
  log_dir: str = None,
92
- pad_token_id: int = 0,
93
- end_token_id: int = 3,
94
97
  use_ddp: bool = False,
95
98
  use_amp: bool = False,
96
99
  dtype: torch.dtype = torch.float32,
97
- callbacks: list[MrlTrainerCallback] = None,
98
-
99
100
  ):
100
101
  """
101
- Trainer for Memory Reinforcement Learning (MRL) in Reactive Transformer.
102
+ Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
102
103
 
103
104
  Args:
104
- actor: MRL Actor model with encoder, decoder and memory attention.
105
- critic: Critic network for advantage estimation.
106
- config: Configuration dictionary with hyperparameters.
105
+ actor (MrlActorModel): MRL Actor model with encoder, decoder and memory attention.
106
+ critic (MrlCriticModel): MRL Critic network for advantage estimation.
107
+ reward (MrlRewardModel): MRL Reward model or extension.
108
+ device (torch.device): Device used for training.
109
+ config (MrlConfig): Configuration dictionary with hyperparameters.
110
+ rl_algorithm (RlAlgorithm): Reinforcement Learning algorithm (currently only PPO available).
111
+ sampler_config (SamplerConfig): Sampler configuration.
112
+ log_dir (str): Log directory for TensorBoard logs.
113
+ use_ddp (bool): Use Distributed Data Parallel mode.
114
+ use_amp (bool): Use AMP Autocast for training.
115
+ dtype (torch.dtype): Data type used in training - in AMP mode it's auto cast, otherwise data and model are transformed to this type
107
116
  """
108
117
  self.actor = actor
109
118
  self.critic = critic
@@ -112,6 +121,7 @@ class MRLTrainer:
112
121
  self.device = device
113
122
  self.max_seq_len = config.get('max_seq_len', 256)
114
123
  self.critic_max_len = config.get('critic_max_len', 512)
124
+ self.memory_aware_critic = config.get('memory_aware_critic', False)
115
125
  # Internal update epochs config
116
126
  self.shared_update_epochs = config.get('update_epochs', 10)
117
127
  self.update_epochs = self.shared_update_epochs
@@ -125,14 +135,15 @@ class MRLTrainer:
125
135
  self.critic.to(self.device, dtype=dtype)
126
136
 
127
137
  # Batch Sampler for answer generation
128
- self.generator = BatchSampler(self.actor, self.device, end_token_id=end_token_id)
138
+ self.generator = None
129
139
  self.sampler_config = SamplerConfig(
130
140
  temperature=1.0,
131
141
  top_k=None,
132
142
  top_p=None,
133
143
  ) if sampler_config is None else sampler_config
134
144
 
135
- self.pad_token_id = pad_token_id
145
+ self.pad_token_id = config.get('pad_token_id', 0)
146
+ self.end_token_id = config.get('end_token_id', 3)
136
147
 
137
148
  self.use_ddp = use_ddp
138
149
  self.use_amp = use_amp
@@ -180,7 +191,7 @@ class MRLTrainer:
180
191
  self.eval_dataset = None
181
192
  self.random_resets_ratio = 0.0
182
193
  self.strategy = None
183
- self.shared_callbacks = callbacks if callbacks else []
194
+ self.shared_callbacks = config.get('callbacks', [])
184
195
  self.callbacks = []
185
196
  self.global_epoch = 0
186
197
  self.global_epochs_count = 0
@@ -218,7 +229,7 @@ class MRLTrainer:
218
229
  def _init_steps(self):
219
230
  return {
220
231
  'collect': 0,
221
- 'rl': 0,
232
+ 'train': 0,
222
233
  'eval': 0,
223
234
  }
224
235
 
@@ -227,9 +238,12 @@ class MRLTrainer:
227
238
  self.epoch_step[step_type] += 1
228
239
  self.stage_step[step_type] += 1
229
240
 
230
- def reset_stm(self) -> bool:
241
+ def reset_stm(self, force: bool = False) -> bool:
231
242
  """Reset Short-Term Memory state with random reset ratio."""
232
- if self.random_resets_ratio == 1.0:
243
+ if force:
244
+ self.actor.reset_memory()
245
+ return True
246
+ elif self.random_resets_ratio == 1.0:
233
247
  self.actor.reset_memory()
234
248
  return True
235
249
  else:
@@ -439,33 +453,24 @@ class MRLTrainer:
439
453
 
440
454
  return trajectories
441
455
 
442
- def _critic_loss(self, inputs: TokenizedDict, ref_values: torch.Tensor) -> torch.Tensor:
443
- # 1. Calculate values with critic encoder
444
- values = self.critic(
445
- inputs['input_ids'],
446
- attention_mask=inputs['attention_mask'],
447
- ).squeeze()
448
- # 2. Calculate critic loss
449
- loss = self.rl_algorithm.critic_loss(values, ref_values)
450
- return loss
451
-
452
456
  def _critic_writer(self, critic_loss: float, epoch: int):
453
457
  if self.writer is not None:
454
- self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['rl'])
458
+ self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['train'])
455
459
  self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps}, epoch: {epoch})', critic_loss,
456
- self.epoch_step['rl'])
460
+ self.epoch_step['train'])
457
461
  self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps})', critic_loss,
458
- self.stage_step['rl'])
462
+ self.stage_step['train'])
459
463
 
460
464
  def _rl_writer(self, policy_loss: float, epoch: int):
461
465
  if self.writer is not None:
462
- self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['rl'])
466
+ self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['train'])
463
467
  self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps}, epoch: {epoch})', policy_loss,
464
- self.epoch_step['rl'])
465
- self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss, self.stage_step['rl'])
468
+ self.epoch_step['train'])
469
+ self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss,
470
+ self.stage_step['train'])
466
471
 
467
- def _update_critic(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], ref_values: torch.Tensor,
468
- epoch: int) -> float:
472
+ def update_critic(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], ref_values: torch.Tensor,
473
+ epoch: int) -> float:
469
474
  # 1. Reset critic gradients
470
475
  self.critic_optimizer.zero_grad()
471
476
 
@@ -475,7 +480,8 @@ class MRLTrainer:
475
480
  # 2.1 Concat states and calculate critic loss
476
481
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
477
482
  pad_token_id=self.pad_token_id)
478
- critic_loss = self._critic_loss(critic_state, ref_values)
483
+ values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
484
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
479
485
  # 2.2 Run backpropagation with scaler
480
486
  self.critic_scaler.scale(critic_loss).backward()
481
487
  # 2.3 Unscale and clip gradients
@@ -488,7 +494,8 @@ class MRLTrainer:
488
494
  # 2.1 Concat states and calculate critic loss
489
495
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
490
496
  pad_token_id=self.pad_token_id)
491
- critic_loss = self._critic_loss(critic_state, ref_values)
497
+ values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
498
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
492
499
  # 2.2 Run backpropagation
493
500
  critic_loss.backward()
494
501
  # 2.3 Clip gradients
@@ -503,19 +510,20 @@ class MRLTrainer:
503
510
 
504
511
  # 5. Run "on critic updated" callbacks
505
512
  for cb in self.callbacks:
506
- cb.on_critic_updated(self.actor, self.critic, epoch, self.epoch_step['rl'], critic_loss_item)
513
+ cb.on_critic_updated(self.actor, self.critic, epoch, self.epoch_step['train'], critic_loss_item)
507
514
  # 6. Return loss item
508
515
  return critic_loss_item
509
516
 
510
- def _update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
511
- advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
517
+ def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
518
+ advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
512
519
  # 1. Reset actor gradients
513
520
  self.optimizer.zero_grad()
514
521
  # 2. Unpack state dicts
515
522
  query, answer, next_query = state
516
523
 
517
- # 3. Encode and update STM on each step, to include encoder and memory attention gradients in loss
518
- self.encode_and_update_stm(query, answer)
524
+ # 3. Encode and update STM on each step, to include encoder and memory attention gradients in loss (skip if it was updated before with memory aware critic)
525
+ if not self.memory_aware_critic:
526
+ self.encode_and_update_stm(query, answer)
519
527
 
520
528
  # 4. Update actor - with autocast on/off
521
529
  if self.use_amp:
@@ -560,13 +568,13 @@ class MRLTrainer:
560
568
 
561
569
  # 7. Run "on batch updated" callback
562
570
  for cb in self.callbacks:
563
- cb.on_batch_updated(self.actor, epoch, self.epoch_step['rl'], policy_loss_item)
571
+ cb.on_batch_updated(self.actor, epoch, self.epoch_step['train'], policy_loss_item)
564
572
 
565
573
  # 8. Return loss item
566
574
  return policy_loss_item
567
575
 
568
576
  def rl_step(self, trajectories: list[MrlTrajectoryEpisode], advantages: torch.Tensor, ref_values: torch.Tensor,
569
- epoch: int, batch_size: int) -> tuple[float, float]:
577
+ epoch: int, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]:
570
578
  """Perform PPO update step using trajectories."""
571
579
  # 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
572
580
  # memory, based on collected episode data
@@ -577,18 +585,19 @@ class MRLTrainer:
577
585
  should_reset_stm = episode['reset_stm']
578
586
 
579
587
  # 2. Get advantages and reference values for current full episode (batch_size * episode_steps)
580
- start = episode_idx * episode_steps
581
- end = start + episode_steps
588
+ num_steps = len(episode_steps)
589
+ start = episode_idx * num_steps
590
+ end = start + num_steps
582
591
  episode_critic_values = ref_values[start:end]
583
592
  episode_advantages = advantages[start:end]
584
593
 
585
594
  # 3. Reset memory for current batch episode
586
595
  if should_reset_stm:
587
- self.reset_stm()
596
+ self.reset_stm(force=True)
588
597
 
589
598
  # 4. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
590
599
  for step_idx, step in enumerate(episode_steps):
591
- self._increment_steps('rl')
600
+ self._increment_steps('train')
592
601
  # 5. Get and move to device collected states, action and log probs
593
602
  state, action, _, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
594
603
  query, answer, next_query = self._move_multiple_batches(*state)
@@ -599,41 +608,74 @@ class MRLTrainer:
599
608
  step_critic_values = episode_critic_values[step_idx]
600
609
  step_advantages = episode_advantages[step_idx]
601
610
 
611
+ # 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
612
+ if self.memory_aware_critic:
613
+ self.encode_and_update_stm(query, answer)
614
+
602
615
  # 7. Update critic
603
- critic_loss_item = self._update_critic((query, answer, next_query), step_critic_values, epoch)
616
+ critic_loss_item = self.update_critic((query, answer, next_query), step_critic_values, epoch)
604
617
 
605
618
  # 8. Accumulate critic loss for epoch callbacks
606
619
  critic_losses.append(critic_loss_item)
607
620
 
608
621
  # 9. Update actor
609
- policy_loss_item = self._update_actor((query, answer, next_query), action, step_advantages, log_probs,
610
- epoch)
622
+ policy_loss_item = self.update_actor((query, answer, next_query), action, step_advantages, log_probs,
623
+ epoch)
611
624
  all_losses.append(policy_loss_item)
612
625
  # 10. Return mean losses for epoch callbacks
613
- return torch.mean(torch.tensor(all_losses)).item(), torch.mean(torch.tensor(critic_losses)).item()
626
+ return torch.mean(torch.tensor(all_losses)), torch.mean(torch.tensor(critic_losses))
614
627
 
615
628
  def _critic_values_rewards_and_dones(self, trajectories: list[MrlTrajectoryEpisode]):
616
- flat_trajectories = [t for episode in trajectories for t in episode['steps']]
617
- values = [
618
- self._critic_values(
619
- smart_concat_critic_states(
620
- *self._move_multiple_batches(*t['state']),
621
- max_length=self.critic_max_len,
622
- pad_token_id=self.pad_token_id,
623
- )
624
- ) for t in flat_trajectories
625
- ]
626
- values = torch.stack(values).to(self.device)
627
- rewards = torch.stack([torch.tensor(t['reward']) for t in flat_trajectories]).to(self.device)
628
- dones = torch.stack([torch.BoolTensor(t['done']) for t in flat_trajectories]).to(self.device)
629
+ if self.memory_aware_critic:
630
+ flat_trajectories = [
631
+ (t, i == 0 and episode['reset_stm'])
632
+ for episode in trajectories
633
+ for i, t in enumerate(episode['steps'])
634
+ ]
635
+ values = torch.stack([
636
+ self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in flat_trajectories
637
+ ]).to(self.device)
638
+ rewards = torch.stack([torch.tensor(t['reward']) for t, _ in flat_trajectories]).to(self.device)
639
+ dones = torch.stack([torch.tensor(t['done']) for t, _ in flat_trajectories]).to(self.device)
640
+ else:
641
+ flat_trajectories = [t for episode in trajectories for t in episode['steps']]
642
+ values = torch.stack([
643
+ self._critic_values(*self._move_multiple_batches(*t['state'])) for t in flat_trajectories
644
+ ]).to(self.device)
645
+ rewards = torch.stack([torch.tensor(t['reward']) for t in flat_trajectories]).to(self.device)
646
+ dones = torch.stack([torch.tensor(t['done']) for t in flat_trajectories]).to(self.device)
629
647
  return values, rewards, dones
630
648
 
631
- def _critic_values(self, inputs: TokenizedDict) -> torch.Tensor:
649
+ def _critic_values_with_memory(self, reset_stm: bool, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
650
+ # 1. Calculate critic values in memory aware version - reset/update STM before calculating values
632
651
  with torch.no_grad():
652
+ # 2. Reset STM if it was reset in trajectory collection
653
+ if reset_stm:
654
+ self.reset_stm(force=True)
655
+ # 3. Encode and update STM for critic
656
+ self.encode_and_update_stm(*moved_state)
657
+ # 4. Get concatenated critic states
658
+ inputs = smart_concat_critic_states(
659
+ *moved_state,
660
+ max_length=self.critic_max_len,
661
+ pad_token_id=self.pad_token_id,
662
+ )
663
+ # 5. Calculate values for current batch
633
664
  return self.critic(inputs['input_ids'],
634
- attention_mask=inputs['attention_mask']).squeeze()
665
+ attention_mask=inputs['attention_mask']).squeeze()
635
666
 
636
- # return self.rl_algorithm.calculate_advantages(rewards, values)
667
+ def _critic_values(self, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
668
+ # 1. Calculate critic values
669
+ with torch.no_grad():
670
+ # 2. Get concatenated critic states
671
+ inputs = smart_concat_critic_states(
672
+ *moved_state,
673
+ max_length=self.critic_max_len,
674
+ pad_token_id=self.pad_token_id,
675
+ )
676
+ # 3. Calculate values for current batch
677
+ return self.critic(inputs['input_ids'],
678
+ attention_mask=inputs['attention_mask']).squeeze()
637
679
 
638
680
  def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
639
681
  """Train for one epoch."""
@@ -655,16 +697,23 @@ class MRLTrainer:
655
697
  # 4. Run 'on update epoch start' callbacks
656
698
  for cb in self.callbacks:
657
699
  cb.on_update_epoch_start(self.actor, self.critic, epoch, update_epoch)
700
+
658
701
  # 5. Run RL algorithm step
659
702
  policy_loss, critic_loss = self.rl_step(trajectories[:-1], advantages, ref_values, epoch, batch_size)
660
703
 
704
+ if self.use_ddp:
705
+ policy_loss = distributed_mean(policy_loss)
706
+ critic_loss = distributed_mean(critic_loss)
707
+
708
+ # 6. Run 'on update epoch end' callbacks
661
709
  for cb in self.callbacks:
662
710
  cb.on_update_epoch_end(self.actor, self.critic, epoch, update_epoch, policy_loss, critic_loss)
663
711
 
712
+ # 7. Accumulate losses for epoch callbacks
664
713
  critic_loss_sum += critic_loss
665
714
  policy_loss_sum += policy_loss
666
715
 
667
- # 6. Return policy and critic mean losses for epoch callbacks
716
+ # 8. Return policy and critic mean losses for epoch callbacks
668
717
  return policy_loss_sum / self.update_epochs, critic_loss_sum / self.update_epochs
669
718
 
670
719
  def _eval_loader(self, batch_size: int):
@@ -770,12 +819,11 @@ class MRLTrainer:
770
819
  cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
771
820
 
772
821
  # 15. Calculate average reward
822
+ avg_reward = (total_reward / count) if count > 0 else torch.tensor(0.0).to(self.device)
773
823
  if self.use_ddp:
774
- total_sum = dist.all_reduce(total_reward, dist.ReduceOp.SUM)
775
- count_sum = dist.all_reduce(count, dist.ReduceOp.SUM)
776
- avg_reward = (total_sum / count_sum).item() if count_sum > 0 else 0
777
- else:
778
- avg_reward = (total_reward / count).item() if count > 0 else 0
824
+ avg_reward = distributed_mean(avg_reward)
825
+
826
+ avg_reward = avg_reward.item()
779
827
 
780
828
  should_stop_stage = False
781
829
  # 16. Run "on eval end" callbacks
@@ -924,31 +972,36 @@ class MRLTrainer:
924
972
 
925
973
  # 1. Init DDP for distributed training mode
926
974
  if self.use_ddp:
927
- rank = int(os.environ['RANK'])
928
- world_size = int(os.environ['WORLD_SIZE'])
975
+ rank, world_size = get_os_ddp_config()
929
976
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
930
977
  self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index])
931
978
  self.critic = DistributedDataParallel(self.critic, device_ids=[self.device.index])
932
979
 
933
- # 2. Run each curriculum step based on config
980
+ # 2. Init BatchSampler with actor model (we have to run it after DDP init)
981
+ self.generator = BatchSampler(self.actor, self.device, end_token_id=self.end_token_id)
982
+
983
+ # 3. Run each curriculum step based on config
934
984
  for current_curriculum_step in curriculum_config:
935
- # 3. Setup training config for curriculum step
985
+ # 4. Setup training config for curriculum step
936
986
  epochs_config, random_resets_config = self._setup_curriculum_step(current_curriculum_step)
937
987
  epochs, unfreeze_epoch = epochs_config
938
988
  random_resets, random_resets_from, random_resets_ratio = random_resets_config
939
989
  assert self.train_dataset is not None
940
990
 
941
- # 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
991
+ # 5. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
942
992
  if unfreeze_epoch != 0:
943
- self.actor.freeze_components('joint')
944
- if isinstance(unfreeze_epoch, tuple):
945
- print(
946
- f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
993
+ if callable(unfreeze_epoch):
994
+ unfreeze_epoch(-1)
947
995
  else:
948
- print(
949
- f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
950
-
951
- # 5. Setup train DataLoader
996
+ self.actor.freeze_components('joint')
997
+ if isinstance(unfreeze_epoch, tuple):
998
+ print(
999
+ f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
1000
+ else:
1001
+ print(
1002
+ f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
1003
+
1004
+ # 6. Setup train DataLoader
952
1005
  if self.use_ddp:
953
1006
  train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
954
1007
  dataloader = DataLoader(
@@ -957,6 +1010,7 @@ class MRLTrainer:
957
1010
  sampler=train_sampler,
958
1011
  pin_memory=True,
959
1012
  collate_fn=MrlCurriculumDataset.collate_mrl_batch,
1013
+ drop_last=True,
960
1014
  )
961
1015
  else:
962
1016
  train_sampler = None
@@ -968,65 +1022,68 @@ class MRLTrainer:
968
1022
  collate_fn=MrlCurriculumDataset.collate_mrl_batch,
969
1023
  )
970
1024
 
971
- # 6. Run selected number of epochs for given curriculum stage
1025
+ # 7. Run selected number of epochs for given curriculum stage
972
1026
  for epoch in range(epochs):
973
- # 7. Increment global epoch
1027
+ # 8. Increment global epoch
974
1028
  self.global_epoch += 1
975
- # 8. Run "on epoch start" callbacks (log info, etc.)
1029
+ # 9. Run "on epoch start" callbacks (log info, etc.)
976
1030
  for cb in self.callbacks:
977
1031
  cb.on_epoch_start(self.actor, epoch, epochs, current_curriculum_step, self.global_epoch,
978
1032
  self.global_epochs_count)
979
1033
 
980
- # 9. Reset steps counter for epoch
1034
+ # 10. Reset steps counter for epoch
981
1035
  self.epoch_step = self._init_steps()
982
1036
 
983
- # 10. Set random STM resets ratio from selected epoch
1037
+ # 11. Set random STM resets ratio from selected epoch
984
1038
  if random_resets and random_resets_from <= epoch:
985
1039
  self.random_resets_ratio = random_resets_ratio
986
1040
  else:
987
1041
  self.random_resets_ratio = 1.0
988
1042
 
989
- # 11. Apply the unfreeze strategy
990
- self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
1043
+ # 12. Apply the unfreeze strategy
1044
+ if callable(unfreeze_epoch):
1045
+ unfreeze_epoch(epoch)
1046
+ else:
1047
+ self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
991
1048
 
992
- # 12. Set epoch for distributed sampler
1049
+ # 13. Set epoch for distributed sampler
993
1050
  if train_sampler is not None:
994
1051
  train_sampler.set_epoch(epoch)
995
1052
 
996
- # 13. Run reinforcement learning algorithms for current epoch
1053
+ # 14. Run reinforcement learning algorithms for current epoch
997
1054
  policy_loss, critic_loss = self.train_epoch(dataloader, epoch, batch_size)
998
1055
 
999
- # 14. If evaluation dataset is provided, run evaluation steps
1056
+ # 15. If evaluation dataset is provided, run evaluation steps
1000
1057
  if self.eval_dataset:
1001
1058
  should_stop_stage = self.evaluate(batch_size, epoch)
1002
1059
  else:
1003
1060
  should_stop_stage = False
1004
1061
 
1005
- # 15. Finally, run "on epoch end" callbacks (save models, etc.)
1062
+ # 16. Finally, run "on epoch end" callbacks (save models, etc.)
1006
1063
  for cb in self.callbacks:
1007
1064
  cb.on_epoch_end(self.actor, epoch, epochs, policy_loss, critic_loss, self.global_epoch,
1008
1065
  self.global_epochs_count)
1009
1066
 
1010
- # 16. Synchronize TensorBoard writer
1067
+ # 17. Synchronize TensorBoard writer
1011
1068
  if self.writer:
1012
1069
  self.writer.flush()
1013
1070
 
1014
- # 17. Synchronize devices in DDP mode
1071
+ # 18. Synchronize devices in DDP mode
1015
1072
  if self.use_ddp:
1016
1073
  dist.barrier()
1017
1074
 
1018
- # 18. Finish curriculum stage if rewards are not increased or reached threshold point
1075
+ # 19. Finish curriculum stage if rewards are not increased or reached threshold point
1019
1076
  if should_stop_stage:
1020
1077
  break
1021
1078
 
1022
- # 19. Run "on_training_end" callbacks after each curriculum stage (they have own callbacks)
1079
+ # 20. Run "on_training_end" callbacks after each curriculum stage (they have own callbacks)
1023
1080
  for cb in self.callbacks:
1024
1081
  cb.on_training_end(self.actor, self.critic, current_curriculum_step)
1025
1082
 
1026
- # 20. Training end - finish processes after all curriculum stages
1083
+ # 21. Training end - finish processes after all curriculum stages
1027
1084
  if self.use_ddp:
1028
1085
  dist.destroy_process_group()
1029
1086
 
1030
- # 21. Close writer
1087
+ # 22. Close writer
1031
1088
  if self.writer:
1032
1089
  self.writer.close()
@@ -2,8 +2,9 @@ import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
4
  from abc import ABC, abstractmethod
5
- from typing import TypedDict
5
+ from typing import TypedDict, Optional
6
6
  from .utils import TokenizedDict
7
+ from .ddp import distributed_mean
7
8
 
8
9
 
9
10
  class RlAlgorithm(ABC):
@@ -23,21 +24,28 @@ class RlAlgorithm(ABC):
23
24
  def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
24
25
  return self.critic_loss(rewards, values)
25
26
 
27
+
26
28
  class PPOConfig(TypedDict):
27
- clip_eps: float
28
- gae_lambda: float
29
- gae_gamma: float
30
- entropy_coef: float
29
+ clip_eps: Optional[float]
30
+ gae_lambda: Optional[float]
31
+ gae_gamma: Optional[float]
32
+ entropy_coef: Optional[float]
33
+ use_distributed_advantage_norm: Optional[bool]
34
+
31
35
 
32
36
  class PPOAlgorithm(RlAlgorithm):
33
- def __init__(self, config: PPOConfig):
37
+ def __init__(self, config: Optional[PPOConfig] = None):
34
38
  super(PPOAlgorithm, self).__init__()
35
39
 
40
+ if config is None:
41
+ config = {}
42
+
36
43
  # PPO Config
37
44
  self.clip_eps = config.get('clip_eps', 0.2)
38
45
  self.gae_lambda = config.get('gae_lambda', 0.95)
39
46
  self.gae_gamma = config.get('gae_gamma', 0.99)
40
47
  self.entropy_coef = config.get('entropy_coef', 0.01)
48
+ self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
41
49
 
42
50
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
43
51
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -101,7 +109,7 @@ class PPOAlgorithm(RlAlgorithm):
101
109
  next_values = values[t + 1]
102
110
 
103
111
  # Mask next values if episode ended
104
- next_values = next_values * (1 - dones[t])
112
+ next_values = next_values * ~dones[t]
105
113
  delta = rewards[t] + self.gae_gamma * next_values - values[t]
106
114
  advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
107
115
  last_advantage = advantages[t]
@@ -111,5 +119,10 @@ class PPOAlgorithm(RlAlgorithm):
111
119
 
112
120
  def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
113
121
  advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1])
114
- normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
122
+ if self.use_distributed_advantage_norm:
123
+ mean_advantage = distributed_mean(advantages.mean())
124
+ std_advantage = distributed_mean(advantages.std())
125
+ normalized_advantages = (advantages - mean_advantage) / (std_advantage + 1e-8)
126
+ else:
127
+ normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
115
128
  return normalized_advantages, ref_values
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes