rxnn 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -536,6 +536,9 @@ class MrlTrainerCallback:
536
536
  def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
537
537
  pass
538
538
 
539
+ def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
540
+ pass
541
+
539
542
  def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
540
543
  pass
541
544
 
@@ -543,6 +546,9 @@ class MrlTrainerCallback:
543
546
  critic_loss: float) -> None:
544
547
  pass
545
548
 
549
+ def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
550
+ pass
551
+
546
552
  def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
547
553
  pass
548
554
 
@@ -572,6 +578,9 @@ class MrlPrintCallback(MrlTrainerCallback):
572
578
  reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
573
579
  print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
574
580
 
581
+ def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
582
+ print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
583
+
575
584
  def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
576
585
  print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
577
586
 
@@ -579,6 +588,9 @@ class MrlPrintCallback(MrlTrainerCallback):
579
588
  critic_loss: float) -> None:
580
589
  print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
581
590
 
591
+ def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
592
+ print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
593
+
582
594
  def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
583
595
  print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
584
596
 
rxnn/training/mrl.py CHANGED
@@ -24,6 +24,7 @@ class MrlConfig(TypedDict):
24
24
  critic_max_len: int
25
25
  weight_decay: float
26
26
  critic_weight_decay: float
27
+ update_epochs: int
27
28
 
28
29
 
29
30
  class MrlStrategy(Enum):
@@ -31,9 +32,11 @@ class MrlStrategy(Enum):
31
32
  MULTI_STEP_STRATEGY = 2
32
33
  LONG_RANGE_STRATEGY = 3
33
34
 
35
+
34
36
  UnfreezeItem = Union[int, tuple[int, float]]
35
37
  UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
36
38
 
39
+
37
40
  class CurriculumConfig(TypedDict):
38
41
  steps: int
39
42
  epochs: int
@@ -52,6 +55,7 @@ class CurriculumConfig(TypedDict):
52
55
  critic_lr: Optional[float]
53
56
  weight_decay: Optional[float]
54
57
  critic_weight_decay: Optional[float]
58
+ update_epochs: Optional[int]
55
59
 
56
60
 
57
61
  class SamplerConfig(TypedDict):
@@ -66,6 +70,7 @@ class MrlTrajectoryStep(TypedDict):
66
70
  log_probs: torch.Tensor
67
71
  reward: list[float]
68
72
  reference: TokenizedDict
73
+ done: bool
69
74
 
70
75
 
71
76
  class MrlTrajectoryEpisode(TypedDict):
@@ -107,6 +112,9 @@ class MRLTrainer:
107
112
  self.device = device
108
113
  self.max_seq_len = config.get('max_seq_len', 256)
109
114
  self.critic_max_len = config.get('critic_max_len', 512)
115
+ # Internal update epochs config
116
+ self.shared_update_epochs = config.get('update_epochs', 10)
117
+ self.update_epochs = self.shared_update_epochs
110
118
 
111
119
  # Move models to device
112
120
  if use_amp:
@@ -187,8 +195,8 @@ class MRLTrainer:
187
195
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
188
196
  if memory_lr is not None:
189
197
  optimizer = torch.optim.AdamW([
190
- { 'params': self.actor.not_memory_parameters(), 'lr': lr },
191
- { 'params': self.actor.memory_parameters(), 'lr': memory_lr },
198
+ {'params': self.actor.not_memory_parameters(), 'lr': lr},
199
+ {'params': self.actor.memory_parameters(), 'lr': memory_lr},
192
200
  ],
193
201
  weight_decay=weight_decay,
194
202
  )
@@ -207,11 +215,9 @@ class MRLTrainer:
207
215
 
208
216
  return optimizer, critic_optimizer
209
217
 
210
-
211
218
  def _init_steps(self):
212
219
  return {
213
220
  'collect': 0,
214
- 'critic': 0,
215
221
  'rl': 0,
216
222
  'eval': 0,
217
223
  }
@@ -351,7 +357,7 @@ class MRLTrainer:
351
357
  # state from existing one, instead of new random one)
352
358
  reset_done = self.reset_stm()
353
359
 
354
- # 4. Reset reward prev data running mean - it's calculated for multi-step retention, we have to reset it before episode
360
+ # 4. Reset reward prev data running mean - it's calculated for multistep retention, we have to reset it before episode
355
361
  self.reward.reset_running_mean()
356
362
 
357
363
  # 5. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
@@ -406,6 +412,7 @@ class MRLTrainer:
406
412
  'log_probs': log_probs.detach().cpu(),
407
413
  'reward': reward,
408
414
  'reference': interaction['answer'],
415
+ 'done': is_last_interaction,
409
416
  }
410
417
  episode_steps.append(trajectory)
411
418
  episode_rewards.append(reward)
@@ -432,92 +439,23 @@ class MRLTrainer:
432
439
 
433
440
  return trajectories
434
441
 
435
- def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
442
+ def _critic_loss(self, inputs: TokenizedDict, ref_values: torch.Tensor) -> torch.Tensor:
436
443
  # 1. Calculate values with critic encoder
437
444
  values = self.critic(
438
445
  inputs['input_ids'],
439
446
  attention_mask=inputs['attention_mask'],
440
447
  ).squeeze()
441
448
  # 2. Calculate critic loss
442
- loss = self.rl_algorithm.critic_loss(values, rewards)
449
+ loss = self.rl_algorithm.critic_loss(values, ref_values)
443
450
  return loss
444
451
 
445
452
  def _critic_writer(self, critic_loss: float, epoch: int):
446
453
  if self.writer is not None:
447
- self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['critic'])
454
+ self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['rl'])
448
455
  self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps}, epoch: {epoch})', critic_loss,
449
- self.epoch_step['critic'])
456
+ self.epoch_step['rl'])
450
457
  self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps})', critic_loss,
451
- self.stage_step['critic'])
452
-
453
- def update_critic(self, states: list[tuple[TokenizedDict, TokenizedDict, TokenizedDict]],
454
- rewards: list[torch.Tensor], epoch: int):
455
- """Update critic network using MSE loss."""
456
- # 1. Run critic updates for all collected batches
457
- critic_losses = []
458
- for step_idx, (state, reward) in enumerate(zip(states, rewards)):
459
- self._increment_steps('critic')
460
- # 2. Move state batches to training device (GPU)
461
- prev_query, prev_answer, next_query = self._move_multiple_batches(*state)
462
-
463
- # 3. Reset critic gradients
464
- self.critic_optimizer.zero_grad()
465
-
466
- # 4. Run critic and calculate loss - in autocast on/off mode
467
- if self.use_amp:
468
- # Move tensors to training device and calculate loss in autocast mode
469
- batch_rewards = reward.to(self.device)
470
- with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
471
- # Concatenate state into single critic input sequence
472
- inputs = smart_concat_critic_states(
473
- prev_query, prev_answer, next_query,
474
- max_length=self.critic_max_len,
475
- pad_token_id=self.pad_token_id,
476
- )
477
- loss = self._critic_loss(inputs, batch_rewards)
478
- # Run backpropagation with scaler
479
- self.critic_scaler.scale(loss).backward()
480
- # Unscale and clip gradients
481
- self.critic_scaler.unscale_(self.critic_optimizer)
482
- torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
483
- # Run scaled optimization step
484
- self.critic_scaler.step(self.critic_optimizer)
485
- self.critic_scaler.update()
486
- else:
487
- # Concatenate state into single critic input sequence
488
- inputs = smart_concat_critic_states(
489
- prev_query, prev_answer, next_query,
490
- max_length=self.critic_max_len,
491
- pad_token_id=self.pad_token_id,
492
- )
493
- # Calculate loss
494
- loss = self._critic_loss(inputs, reward.to(self.device, dtype=self.dtype))
495
- # Run backpropagation
496
- loss.backward()
497
- # Clip gradients
498
- torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
499
- # Run optimizer step
500
- self.critic_optimizer.step()
501
- critic_loss = loss.item()
502
- self._critic_writer(critic_loss, epoch)
503
-
504
- # 5. Run "on critic updated" callbacks
505
- for cb in self.callbacks:
506
- cb.on_critic_updated(self.actor, self.critic, epoch, step_idx, critic_loss)
507
-
508
- # 6. Accumulate loss for epoch callbacks
509
- critic_losses.append(critic_loss)
510
-
511
- # 7. Calculate mean loss for epoch callbacks
512
- critic_mean_loss = torch.tensor(critic_losses).mean().item()
513
-
514
- return critic_mean_loss
515
-
516
- def _critic_advantages(self, critic_state: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
517
- with torch.no_grad():
518
- values = self.critic(critic_state['input_ids'],
519
- attention_mask=critic_state['attention_mask']).squeeze()
520
- return self.rl_algorithm.calculate_advantages(rewards, values)
458
+ self.stage_step['rl'])
521
459
 
522
460
  def _rl_writer(self, policy_loss: float, epoch: int):
523
461
  if self.writer is not None:
@@ -526,107 +464,208 @@ class MRLTrainer:
526
464
  self.epoch_step['rl'])
527
465
  self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss, self.stage_step['rl'])
528
466
 
529
- def rl_step(self, trajectories: list[MrlTrajectoryEpisode], epoch: int):
467
+ def _update_critic(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], ref_values: torch.Tensor,
468
+ epoch: int) -> float:
469
+ # 1. Reset critic gradients
470
+ self.critic_optimizer.zero_grad()
471
+
472
+ # 2. Update critic - with autocast on/off
473
+ if self.use_amp:
474
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
475
+ # 2.1 Concat states and calculate critic loss
476
+ critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
477
+ pad_token_id=self.pad_token_id)
478
+ critic_loss = self._critic_loss(critic_state, ref_values)
479
+ # 2.2 Run backpropagation with scaler
480
+ self.critic_scaler.scale(critic_loss).backward()
481
+ # 2.3 Unscale and clip gradients
482
+ self.critic_scaler.unscale_(self.critic_optimizer)
483
+ torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
484
+ # 2.4 Run scaled optimization step
485
+ self.critic_scaler.step(self.critic_optimizer)
486
+ self.critic_scaler.update()
487
+ else:
488
+ # 2.1 Concat states and calculate critic loss
489
+ critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
490
+ pad_token_id=self.pad_token_id)
491
+ critic_loss = self._critic_loss(critic_state, ref_values)
492
+ # 2.2 Run backpropagation
493
+ critic_loss.backward()
494
+ # 2.3 Clip gradients
495
+ torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
496
+ # 2.4 Run optimizer step
497
+ self.critic_optimizer.step()
498
+ # 3. Get float loss value for callbacks/writer
499
+ critic_loss_item = critic_loss.item()
500
+
501
+ # 4. Write to TensorBoard
502
+ self._critic_writer(critic_loss_item, epoch)
503
+
504
+ # 5. Run "on critic updated" callbacks
505
+ for cb in self.callbacks:
506
+ cb.on_critic_updated(self.actor, self.critic, epoch, self.epoch_step['rl'], critic_loss_item)
507
+ # 6. Return loss item
508
+ return critic_loss_item
509
+
510
+ def _update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
511
+ advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
512
+ # 1. Reset actor gradients
513
+ self.optimizer.zero_grad()
514
+ # 2. Unpack state dicts
515
+ query, answer, next_query = state
516
+
517
+ # 3. Encode and update STM on each step, to include encoder and memory attention gradients in loss
518
+ self.encode_and_update_stm(query, answer)
519
+
520
+ # 4. Update actor - with autocast on/off
521
+ if self.use_amp:
522
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
523
+ # 4.1 Concatenate next query and action and get action logits from decoder
524
+ inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
525
+ pad_token_id=self.pad_token_id)
526
+ logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
527
+ action=MrlActorAction.DECODE)
528
+ # 4.2 Calculate policy loss with selected algorithm
529
+ policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
530
+ advantages)
531
+ # 4.3 Run backpropagation with scaler
532
+ self.scaler.scale(policy_loss).backward(retain_graph=True)
533
+ # 4.4 Unscale and clip gradient norms
534
+ self.scaler.unscale_(self.optimizer)
535
+ torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
536
+ error_if_nonfinite=False)
537
+ # 4.5 Run scaled optimization step
538
+ self.scaler.step(self.optimizer)
539
+ self.scaler.update()
540
+ else:
541
+ # 4.1 Concatenate next query and action and get action logits from decoder
542
+ inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
543
+ pad_token_id=self.pad_token_id)
544
+ logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
545
+ action=MrlActorAction.DECODE)
546
+ # 4.2 Calculate policy loss with selected algorithm
547
+ policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
548
+ # 4.3 Run backpropagation
549
+ policy_loss.backward(retain_graph=True)
550
+ # 4.4 Clip gradient norms
551
+ torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
552
+ error_if_nonfinite=False)
553
+ # 4.5 Run scaled optimization step
554
+ self.optimizer.step()
555
+ # 5. Get float loss value for callbacks/writer
556
+ policy_loss_item = policy_loss.item()
557
+
558
+ # 6. Write to TensorBoard
559
+ self._rl_writer(policy_loss_item, epoch)
560
+
561
+ # 7. Run "on batch updated" callback
562
+ for cb in self.callbacks:
563
+ cb.on_batch_updated(self.actor, epoch, self.epoch_step['rl'], policy_loss_item)
564
+
565
+ # 8. Return loss item
566
+ return policy_loss_item
567
+
568
+ def rl_step(self, trajectories: list[MrlTrajectoryEpisode], advantages: torch.Tensor, ref_values: torch.Tensor,
569
+ epoch: int, batch_size: int) -> tuple[float, float]:
530
570
  """Perform PPO update step using trajectories."""
531
571
  # 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
532
572
  # memory, based on collected episode data
533
573
  all_losses = []
534
- trajectories_len = len(trajectories)
574
+ critic_losses = []
535
575
  for episode_idx, episode in enumerate(trajectories):
536
576
  episode_steps = episode['steps']
537
577
  should_reset_stm = episode['reset_stm']
538
578
 
539
- # 2. Reset memory for current batch episode
579
+ # 2. Get advantages and reference values for current full episode (batch_size * episode_steps)
580
+ start = episode_idx * episode_steps
581
+ end = start + episode_steps
582
+ episode_critic_values = ref_values[start:end]
583
+ episode_advantages = advantages[start:end]
584
+
585
+ # 3. Reset memory for current batch episode
540
586
  if should_reset_stm:
541
587
  self.reset_stm()
542
588
 
543
- # 3. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
544
- for step in episode_steps:
589
+ # 4. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
590
+ for step_idx, step in enumerate(episode_steps):
545
591
  self._increment_steps('rl')
546
- state, action, reward, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
592
+ # 5. Get and move to device collected states, action and log probs
593
+ state, action, _, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
547
594
  query, answer, next_query = self._move_multiple_batches(*state)
548
595
  action = self._move_batch(action)
549
596
  log_probs = log_probs.to(self.device)
550
- rewards = torch.tensor(reward).to(self.device)
551
-
552
- # 4. Compute advantages using critic
553
- if self.use_amp:
554
- with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
555
- critic_state = smart_concat_critic_states(query, answer, next_query,
556
- max_length=self.critic_max_len,
557
- pad_token_id=self.pad_token_id)
558
- advantages = self._critic_advantages(critic_state, rewards)
559
- else:
560
- critic_state = smart_concat_critic_states(query, answer, next_query, max_length=self.critic_max_len,
561
- pad_token_id=self.pad_token_id)
562
- advantages = self._critic_advantages(critic_state, rewards)
563
-
564
- # 5. Encode and update STM on each step, to include encoder and memory attention gradients in loss
565
- self.encode_and_update_stm(query, answer)
566
- # 6. Concatenate next query and action and get action logits from decoder
567
- if self.use_amp:
568
- with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
569
- inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
570
- pad_token_id=self.pad_token_id)
571
- logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
572
- action=MrlActorAction.DECODE)
573
- else:
574
- inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
575
- pad_token_id=self.pad_token_id)
576
- logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
577
- action=MrlActorAction.DECODE)
578
-
579
- # 7. Calculate RL Algorithm (PPO etc.) loss
580
- policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, log_probs, advantages)
581
-
582
- # 8. Reset gradients
583
- self.optimizer.zero_grad()
584
-
585
- # 9. Update the model in AMP or regular mode
586
- if self.use_amp:
587
- self.scaler.scale(policy_loss).backward(retain_graph=True)
588
- self.scaler.unscale_(self.optimizer)
589
- torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
590
- error_if_nonfinite=False)
591
- self.scaler.step(self.optimizer)
592
- self.scaler.update()
593
- else:
594
- policy_loss.backward(retain_graph=True)
595
- torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
596
- error_if_nonfinite=False)
597
- self.optimizer.step()
598
597
 
599
- policy_loss_item = policy_loss.item()
600
- self._rl_writer(policy_loss_item, epoch)
601
- all_losses.append(policy_loss_item)
598
+ # 6. Select advantages and reference values for current step (batch_size)
599
+ step_critic_values = episode_critic_values[step_idx]
600
+ step_advantages = episode_advantages[step_idx]
602
601
 
603
- # 10. Run "on batch updated" callback
604
- for cb in self.callbacks:
605
- cb.on_batch_updated(self.actor, epoch, self.epoch_step['rl'], policy_loss_item)
602
+ # 7. Update critic
603
+ critic_loss_item = self._update_critic((query, answer, next_query), step_critic_values, epoch)
606
604
 
607
- return torch.mean(torch.tensor(all_losses)).item()
605
+ # 8. Accumulate critic loss for epoch callbacks
606
+ critic_losses.append(critic_loss_item)
608
607
 
609
- def _critic_states_and_rewards(self, trajectories: list[MrlTrajectoryEpisode]):
608
+ # 9. Update actor
609
+ policy_loss_item = self._update_actor((query, answer, next_query), action, step_advantages, log_probs,
610
+ epoch)
611
+ all_losses.append(policy_loss_item)
612
+ # 10. Return mean losses for epoch callbacks
613
+ return torch.mean(torch.tensor(all_losses)).item(), torch.mean(torch.tensor(critic_losses)).item()
614
+
615
+ def _critic_values_rewards_and_dones(self, trajectories: list[MrlTrajectoryEpisode]):
610
616
  flat_trajectories = [t for episode in trajectories for t in episode['steps']]
611
- states = [t['state'] for t in flat_trajectories]
612
- rewards = [torch.tensor(t['reward']) for t in flat_trajectories]
613
- return states, rewards
617
+ values = [
618
+ self._critic_values(
619
+ smart_concat_critic_states(
620
+ *self._move_multiple_batches(*t['state']),
621
+ max_length=self.critic_max_len,
622
+ pad_token_id=self.pad_token_id,
623
+ )
624
+ ) for t in flat_trajectories
625
+ ]
626
+ values = torch.stack(values).to(self.device)
627
+ rewards = torch.stack([torch.tensor(t['reward']) for t in flat_trajectories]).to(self.device)
628
+ dones = torch.stack([torch.BoolTensor(t['done']) for t in flat_trajectories]).to(self.device)
629
+ return values, rewards, dones
630
+
631
+ def _critic_values(self, inputs: TokenizedDict) -> torch.Tensor:
632
+ with torch.no_grad():
633
+ return self.critic(inputs['input_ids'],
634
+ attention_mask=inputs['attention_mask']).squeeze()
635
+
636
+ # return self.rl_algorithm.calculate_advantages(rewards, values)
614
637
 
615
638
  def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
616
639
  """Train for one epoch."""
617
640
  # 1. Collect trajectories for current epoch
618
641
  trajectories = self.collect_trajectories(dataloader, epoch, batch_size)
619
642
 
620
- # 2. Flatten trajectories and collect state and rewards for critic update
621
- states, rewards = self._critic_states_and_rewards(trajectories)
622
- # 3. Update critic model, based on states and rewards
623
- critic_loss = self.update_critic(states, rewards, epoch)
643
+ # 2. Flatten trajectories, call critic and collect values, dones and rewards, and calculate advantages
644
+ if self.use_amp:
645
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
646
+ values, rewards, dones = self._critic_values_rewards_and_dones(trajectories)
647
+ advantages, ref_values = self.rl_algorithm.calculate_advantages(rewards, values, dones)
648
+ else:
649
+ values, rewards, dones = self._critic_values_rewards_and_dones(trajectories)
650
+ advantages, ref_values = self.rl_algorithm.calculate_advantages(rewards, values, dones)
624
651
 
625
- # 4. Run PPO algorithm step
626
- policy_loss = self.rl_step(trajectories, epoch)
652
+ # 3. Run internal update epochs
653
+ critic_loss_sum, policy_loss_sum = 0.0, 0.0
654
+ for update_epoch in range(self.update_epochs):
655
+ # 4. Run 'on update epoch start' callbacks
656
+ for cb in self.callbacks:
657
+ cb.on_update_epoch_start(self.actor, self.critic, epoch, update_epoch)
658
+ # 5. Run RL algorithm step
659
+ policy_loss, critic_loss = self.rl_step(trajectories[:-1], advantages, ref_values, epoch, batch_size)
627
660
 
628
- # 5. Return policy and critic mean losses for epoch callbacks
629
- return policy_loss, critic_loss
661
+ for cb in self.callbacks:
662
+ cb.on_update_epoch_end(self.actor, self.critic, epoch, update_epoch, policy_loss, critic_loss)
663
+
664
+ critic_loss_sum += critic_loss
665
+ policy_loss_sum += policy_loss
666
+
667
+ # 6. Return policy and critic mean losses for epoch callbacks
668
+ return policy_loss_sum / self.update_epochs, critic_loss_sum / self.update_epochs
630
669
 
631
670
  def _eval_loader(self, batch_size: int):
632
671
  if self.use_ddp:
@@ -747,55 +786,6 @@ class MRLTrainer:
747
786
 
748
787
  return should_stop_stage
749
788
 
750
- def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
751
- # 1. Set common fields based on config
752
- self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
753
- self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
754
- self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
755
- self.callbacks = config.get('callbacks',
756
- self.shared_callbacks) # trainer callbacks for current curriculum stage
757
- self.strategy = config.get('strategy',
758
- MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
759
- self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
760
- if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
761
- if config.get('separate_memory_lr', False):
762
- self.optim_config = {
763
- 'lr': config.get('lr', self.base_optim_config['lr']),
764
- 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
765
- 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
766
- 'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
767
- 'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
768
- }
769
- else:
770
- self.optim_config = {
771
- 'lr': config.get('lr', self.base_optim_config['lr']),
772
- 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
773
- 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
774
- 'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
775
- }
776
- self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
777
- elif self.optim_config != self.base_optim_config:
778
- self.optim_config = self.base_optim_config
779
- self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
780
-
781
-
782
-
783
-
784
- # 2. Get epochs and random resets configs
785
- epochs = config.get('epochs', 5) # number of epochs for current stage
786
- unfreeze_epoch = config.get('unfreeze_epoch',
787
- 0) # epoch when components (other than memory) are unfrozen (before epoch starts)
788
- random_resets = config.get('random_resets',
789
- False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
790
- random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
791
- random_resets_ratio = config.get('random_resets_ratio',
792
- None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
793
-
794
- # 3. Reset stage step counter
795
- self.stage_step = self._init_steps()
796
-
797
- return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
798
-
799
789
  def _apply_unfreeze_strategy(self, epoch: int, unfreeze_epoch: UnfreezeEpochsStrategy):
800
790
  is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
801
791
  if is_staged_unfreeze:
@@ -808,28 +798,31 @@ class MRLTrainer:
808
798
  self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
809
799
  print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
810
800
  elif epoch == update_epoch:
811
- self.actor.freeze_components('update')
812
- print(f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
801
+ self.actor.freeze_components('update')
802
+ print(
803
+ f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
813
804
 
814
805
  if isinstance(fetch_epoch, tuple):
815
806
  switch_epoch, mem_att_lr = fetch_epoch
816
- if epoch == fetch_epoch:
807
+ if epoch == switch_epoch:
817
808
  self.actor.freeze_components('joint')
818
809
  self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
819
810
  print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
820
811
  elif epoch == fetch_epoch:
821
812
  self.actor.freeze_components('fetch')
822
- print(f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
813
+ print(
814
+ f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
823
815
 
824
816
  if isinstance(joint_epoch, tuple):
825
817
  switch_epoch, model_lr = joint_epoch
826
- if epoch == joint_epoch:
818
+ if epoch == switch_epoch:
827
819
  self.actor.unfreeze_components()
828
820
  self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
829
821
  print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
830
822
  elif epoch == joint_epoch:
831
- self.actor.freeze_components('joint')
832
- print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
823
+ self.actor.freeze_components('joint')
824
+ print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
825
+
833
826
  if epoch == all_epoch:
834
827
  self.actor.unfreeze_components()
835
828
  self.optimizer = self._init_unfreeze_optimizer('all', 0.)
@@ -871,6 +864,56 @@ class MRLTrainer:
871
864
 
872
865
  return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
873
866
 
867
+ def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[
868
+ tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
869
+ # 1. Set common fields based on config
870
+ self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
871
+ self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
872
+ self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
873
+ self.callbacks = config.get('callbacks',
874
+ self.shared_callbacks) # trainer callbacks for current curriculum stage
875
+ self.strategy = config.get('strategy',
876
+ MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
877
+ self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
878
+ self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
879
+ if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
880
+ 'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
881
+ if config.get('separate_memory_lr', False):
882
+ self.optim_config = {
883
+ 'lr': config.get('lr', self.base_optim_config['lr']),
884
+ 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
885
+ 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
886
+ 'critic_weight_decay': config.get('critic_weight_decay',
887
+ self.base_optim_config['critic_weight_decay']),
888
+ 'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
889
+ }
890
+ else:
891
+ self.optim_config = {
892
+ 'lr': config.get('lr', self.base_optim_config['lr']),
893
+ 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
894
+ 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
895
+ 'critic_weight_decay': config.get('critic_weight_decay',
896
+ self.base_optim_config['critic_weight_decay']),
897
+ }
898
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
899
+ elif self.optim_config != self.base_optim_config:
900
+ self.optim_config = self.base_optim_config
901
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
902
+
903
+ # 2. Get epochs and random resets configs
904
+ epochs = config.get('epochs', 5) # number of epochs for current stage
905
+ unfreeze_epoch = config.get('unfreeze_epoch',
906
+ 0) # epoch when components (other than memory) are unfrozen (before epoch starts)
907
+ random_resets = config.get('random_resets',
908
+ False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
909
+ random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
910
+ random_resets_ratio = config.get('random_resets_ratio',
911
+ None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
912
+
913
+ # 3. Reset stage step counter
914
+ self.stage_step = self._init_steps()
915
+
916
+ return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
874
917
 
875
918
  def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
876
919
  """Start Memory Reinforcement Learning Curriculum."""
@@ -899,9 +942,11 @@ class MRLTrainer:
899
942
  if unfreeze_epoch != 0:
900
943
  self.actor.freeze_components('joint')
901
944
  if isinstance(unfreeze_epoch, tuple):
902
- print(f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
945
+ print(
946
+ f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
903
947
  else:
904
- print(f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
948
+ print(
949
+ f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
905
950
 
906
951
  # 5. Setup train DataLoader
907
952
  if self.use_ddp:
rxnn/training/rl.py CHANGED
@@ -17,7 +17,7 @@ class RlAlgorithm(ABC):
17
17
  pass
18
18
 
19
19
  @abstractmethod
20
- def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
20
+ def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
21
21
  pass
22
22
 
23
23
  def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
@@ -25,6 +25,9 @@ class RlAlgorithm(ABC):
25
25
 
26
26
  class PPOConfig(TypedDict):
27
27
  clip_eps: float
28
+ gae_lambda: float
29
+ gae_gamma: float
30
+ entropy_coef: float
28
31
 
29
32
  class PPOAlgorithm(RlAlgorithm):
30
33
  def __init__(self, config: PPOConfig):
@@ -32,6 +35,9 @@ class PPOAlgorithm(RlAlgorithm):
32
35
 
33
36
  # PPO Config
34
37
  self.clip_eps = config.get('clip_eps', 0.2)
38
+ self.gae_lambda = config.get('gae_lambda', 0.95)
39
+ self.gae_gamma = config.get('gae_gamma', 0.99)
40
+ self.entropy_coef = config.get('entropy_coef', 0.01)
35
41
 
36
42
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
37
43
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -78,11 +84,32 @@ class PPOAlgorithm(RlAlgorithm):
78
84
 
79
85
  # d) Entropy bonus
80
86
  entropy = -torch.sum(new_probs * new_probs.exp(), dim=-1).mean()
81
- policy_loss -= 0.01 * entropy
87
+ policy_loss -= self.entropy_coef * entropy
82
88
 
83
89
  return policy_loss
84
90
 
85
- def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
86
- advantages = rewards - values
91
+ def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
92
+ T, B = rewards.shape
93
+ advantages = torch.zeros_like(rewards, device=values.device)
94
+ last_advantage = 0
95
+ last_value = next_value.detach()
96
+
97
+ for t in reversed(range(T)):
98
+ if t == T - 1:
99
+ next_values = last_value
100
+ else:
101
+ next_values = values[t + 1]
102
+
103
+ # Mask next values if episode ended
104
+ next_values = next_values * (1 - dones[t])
105
+ delta = rewards[t] + self.gae_gamma * next_values - values[t]
106
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
107
+ last_advantage = advantages[t]
108
+
109
+ returns = advantages + values
110
+ return advantages, returns
111
+
112
+ def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
113
+ advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1])
87
114
  normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
88
- return normalized_advantages
115
+ return normalized_advantages, ref_values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.29
3
+ Version: 0.2.30
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -13,12 +13,12 @@ rxnn/rxt/models.py,sha256=r8wZeeNTC2VAhiiNe4y7LrbnB4wjFu_cupKiGkpdgjI,13002
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
- rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
16
+ rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/models.py,sha256=2KhNT7yx0AgUke4nmsFqzQKx_YYp78QvsLWYZjWeUgQ,6812
19
- rxnn/training/mrl.py,sha256=MnLaYWxblc5cF261R5PNjIvddVQVNxyjAkEYtchBn9E,49299
19
+ rxnn/training/mrl.py,sha256=DGevQoimkB9qBEkqIw1kkh5DfLqBM-XGvFkraqh-uYk,51545
20
20
  rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
21
- rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
21
+ rxnn/training/rl.py,sha256=U-mlTK2hF0wZQslzjlvF4S_sMkeTuSqKsCB3IWEsd2A,4558
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
24
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.29.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.29.dist-info/METADATA,sha256=WVEyKmyYbMOb5sm7vjjnCN9j8ABz0QfGJCYkQbWvwT8,25960
37
- rxnn-0.2.29.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.29.dist-info/RECORD,,
35
+ rxnn-0.2.30.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.30.dist-info/METADATA,sha256=zRJ_oHLqUD0QDKJoGRJ6FH5MC-y0k8nOn_inZ_iEP8c,25960
37
+ rxnn-0.2.30.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.30.dist-info/RECORD,,
File without changes
File without changes