rxnn 0.1.83__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/mrl.py ADDED
@@ -0,0 +1,808 @@
1
+ import torch
2
+ from torch.utils.data import DataLoader, DistributedSampler
3
+ from torch.utils.tensorboard import SummaryWriter
4
+ import torch.distributed as dist
5
+ from torch.nn.parallel import DistributedDataParallel
6
+ from typing import Optional, TypedDict
7
+ from enum import Enum
8
+ import random, os
9
+ from ..transformers.sampler import BatchSampler
10
+ from .callbacks import MrlTrainerCallback
11
+ from .dataset import MrlCurriculumDataset
12
+ from .utils import smart_concat, smart_concat_critic_states, SpecialTokenIds, TokenizedDict
13
+ from .rl import RlAlgorithm
14
+ from .reward import MrlRewardMode, MrlRewardModel
15
+ from .models import MrlActorAction, MrlActorModel, MrlCriticModel
16
+
17
+
18
+ class MrlConfig(TypedDict):
19
+ lr: float
20
+ critic_lr: float
21
+ max_seq_len: int
22
+ critic_max_len: int
23
+ weight_decay: float
24
+ critic_weight_decay: float
25
+
26
+
27
+ class MrlStrategy(Enum):
28
+ SINGLE_STEP_STRATEGY = 1
29
+ MULTI_STEP_STRATEGY = 2
30
+ LONG_RANGE_STRATEGY = 3
31
+
32
+
33
+ class CurriculumConfig(TypedDict):
34
+ steps: int
35
+ epochs: int
36
+ dataset: MrlCurriculumDataset
37
+ eval_dataset: Optional[MrlCurriculumDataset]
38
+ callbacks: Optional[list[MrlTrainerCallback]]
39
+ strategy: MrlStrategy
40
+ unfreeze_epoch: Optional[int]
41
+ random_resets: Optional[bool]
42
+ random_resets_from: Optional[int]
43
+ random_resets_ratio: Optional[float]
44
+
45
+
46
+ class SamplerConfig(TypedDict):
47
+ temperature: float
48
+ top_k: Optional[int]
49
+ top_p: Optional[float]
50
+
51
+
52
+ class MrlTrajectoryStep(TypedDict):
53
+ state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]
54
+ action: TokenizedDict
55
+ log_probs: torch.Tensor
56
+ reward: list[float]
57
+ reference: TokenizedDict
58
+
59
+
60
+ class MrlTrajectoryEpisode(TypedDict):
61
+ reset_stm: bool
62
+ steps: list[MrlTrajectoryStep]
63
+
64
+
65
+ class MRLTrainer:
66
+ def __init__(
67
+ self,
68
+ actor: MrlActorModel,
69
+ critic: MrlCriticModel,
70
+ reward: MrlRewardModel,
71
+ device: torch.device,
72
+ config: MrlConfig,
73
+ rl_algorithm: RlAlgorithm,
74
+ sampler_config: Optional[SamplerConfig] = None,
75
+ log_dir: str = None,
76
+ pad_token_id: int = 0,
77
+ start_token_id: int = 2,
78
+ end_token_id: int = 3,
79
+ use_ddp: bool = False,
80
+ use_amp: bool = False,
81
+ dtype: torch.dtype = torch.float32,
82
+ callbacks: list[MrlTrainerCallback] = None,
83
+ ):
84
+ """
85
+ Trainer for Memory Reinforcement Learning (MRL) in Reactive Transformer.
86
+
87
+ Args:
88
+ actor: MRL Actor model with encoder, decoder and memory attention.
89
+ critic: Critic network for advantage estimation.
90
+ config: Configuration dictionary with hyperparameters.
91
+ """
92
+ self.actor = actor
93
+ self.critic = critic
94
+ self.reward = reward
95
+ self.device = device
96
+ self.max_seq_len = config.get('max_seq_len', 1024)
97
+ self.critic_max_len = config.get('critic_max_len', 2048)
98
+
99
+ # Move models to device
100
+ if use_amp:
101
+ self.actor.to(self.device)
102
+ self.critic.to(self.device)
103
+ else:
104
+ self.actor.to(self.device, dtype=dtype)
105
+ self.critic.to(self.device, dtype=dtype)
106
+
107
+ # Batch Sampler for answer generation
108
+ self.generator = BatchSampler(self.actor, self.device, end_token_id=end_token_id)
109
+ self.sampler_config = SamplerConfig(
110
+ temperature=1.0,
111
+ top_k=None,
112
+ top_p=None,
113
+ ) if sampler_config is None else sampler_config
114
+
115
+ self.special_token_ids: SpecialTokenIds = {
116
+ 'pad': pad_token_id,
117
+ 'bos': start_token_id,
118
+ 'eos': end_token_id,
119
+ }
120
+
121
+ self.use_ddp = use_ddp
122
+ self.use_amp = use_amp
123
+ self.dtype = dtype
124
+
125
+ # Optimizers
126
+ self.optimizer = torch.optim.AdamW(
127
+ self.actor.unique_parameters(),
128
+ lr=config.get("lr", 3e-4),
129
+ weight_decay=config.get("weight_decay", 0.01),
130
+ )
131
+ self.critic_optimizer = torch.optim.AdamW(
132
+ self.critic.parameters(),
133
+ lr=config.get("critic_lr", 1e-4),
134
+ weight_decay=config.get("critic_weight_decay", 0.01),
135
+ )
136
+
137
+ self.scaler = torch.amp.GradScaler() if self.use_amp else None
138
+ self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
139
+
140
+ # TensorBoard Writer
141
+ if log_dir and not os.path.exists(log_dir):
142
+ os.makedirs(log_dir)
143
+ self.writer = SummaryWriter(log_dir) if log_dir else None
144
+
145
+ self.global_step = self._init_steps()
146
+ self.epoch_step = self._init_steps()
147
+ self.stage_step = self._init_steps()
148
+
149
+ self.rl_algorithm = rl_algorithm
150
+
151
+ # Dynamic fields, updated for each curriculum step
152
+ self.curriculum_steps = 0
153
+ self.train_dataset = None
154
+ self.eval_dataset = None
155
+ self.random_resets_ratio = 0.0
156
+ self.strategy = None
157
+ self.shared_callbacks = callbacks if callbacks else []
158
+ self.callbacks = []
159
+ self.global_epoch = 0
160
+ self.global_epochs_count = 0
161
+
162
+ def _init_steps(self):
163
+ return {
164
+ 'collect': 0,
165
+ 'critic': 0,
166
+ 'rl': 0,
167
+ 'eval': 0,
168
+ }
169
+
170
+ def _increment_steps(self, step_type: str):
171
+ self.global_step[step_type] += 1
172
+ self.epoch_step[step_type] += 1
173
+ self.stage_step[step_type] += 1
174
+
175
+ def reset_stm(self) -> bool:
176
+ """Reset Short-Term Memory state with random reset ratio."""
177
+ if self.random_resets_ratio == 1.0:
178
+ self.actor.reset_memory()
179
+ return True
180
+ else:
181
+ rng = random.random()
182
+ if rng <= self.random_resets_ratio:
183
+ self.actor.reset_memory()
184
+ return True
185
+ else:
186
+ return False
187
+
188
+ def encode_and_update_stm(self, query: TokenizedDict, answer: TokenizedDict):
189
+ """Encode interaction and update STM."""
190
+ # 1. Encode data and update memory - with autocast on/off
191
+ if self.use_amp:
192
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
193
+ # 2. Concatenate batch of queries and answers (they are already on training device)
194
+ inputs = smart_concat(query, answer, self.max_seq_len, self.special_token_ids['pad'])
195
+ # 3. Encode data and update STM
196
+ self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'], action=MrlActorAction.UPDATE)
197
+ else:
198
+ # 2. Concatenate batch of queries and answers (they are already on training device)
199
+ inputs = smart_concat(query, answer, self.max_seq_len, self.special_token_ids['pad'])
200
+ # 3. Encode data and update STM
201
+ self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'], action=MrlActorAction.UPDATE)
202
+
203
+ def generate_answer(self, query: TokenizedDict) -> tuple[TokenizedDict, torch.Tensor]:
204
+ """Generate response using batch sampler with decoder."""
205
+ # 1. Generate answer with BatchSampler - with autocast on/off
206
+ if self.use_amp:
207
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
208
+ input_ids, attention_mask, log_probs = self.generator(
209
+ query['input_ids'],
210
+ query['attention_mask'],
211
+ max_gen_len=self.max_seq_len,
212
+ **self.sampler_config,
213
+ )
214
+ else:
215
+ input_ids, attention_mask, log_probs = self.generator(
216
+ query['input_ids'],
217
+ query['attention_mask'],
218
+ max_gen_len=self.max_seq_len,
219
+ **self.sampler_config,
220
+ )
221
+ # 2. Convert generated answer to TokenizedDict
222
+ generated_answer: TokenizedDict = {
223
+ 'input_ids': input_ids,
224
+ 'attention_mask': attention_mask,
225
+ }
226
+
227
+ return generated_answer, log_probs
228
+
229
+ def compute_reward(self, generated: TokenizedDict, reference: TokenizedDict,
230
+ saved_data: tuple[TokenizedDict, TokenizedDict], mode: MrlRewardMode = MrlRewardMode.STANDARD,
231
+ eval_mode: bool = False) -> list[float]:
232
+ """Compute reward based on memory retention (e.g., BLEU-4)."""
233
+ saved_query, saved_answer = saved_data
234
+ # 1. Concat saved (previous) interaction and calculate reward using generated sequence, reference and saved data - with autocast on/off
235
+ if self.use_amp:
236
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
237
+ saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
238
+ pad_token_id=self.special_token_ids['pad'])
239
+ reward = self.reward(generated, reference, saved_interaction, mode=mode)
240
+ else:
241
+ saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
242
+ pad_token_id=self.special_token_ids['pad'])
243
+ reward = self.reward(generated, reference, saved_interaction, mode=mode)
244
+
245
+ # 2. Run 'on reward' callbacks
246
+ for cb in self.callbacks:
247
+ cb.on_reward(self.actor, reward, generated, reference, saved_interaction, eval_mode)
248
+ # 3. Return rewards for batch
249
+ return reward
250
+
251
+ def _move_batch(self, batch: TokenizedDict) -> TokenizedDict:
252
+ if self.use_amp:
253
+ return {
254
+ 'input_ids': batch['input_ids'].to(self.device),
255
+ 'attention_mask': batch['attention_mask'].to(self.device),
256
+ }
257
+ else:
258
+ return {
259
+ 'input_ids': batch['input_ids'].to(self.device, dtype=self.dtype),
260
+ 'attention_mask': batch['attention_mask'].to(self.device, dtype=self.dtype),
261
+ }
262
+
263
+ def _move_multiple_batches(self, *batches: TokenizedDict) -> list[TokenizedDict]:
264
+ return [self._move_batch(batch) for batch in batches]
265
+
266
+ def _cpu_detach(self, batch: TokenizedDict) -> TokenizedDict:
267
+ return {
268
+ 'input_ids': batch['input_ids'].detach().cpu(),
269
+ 'attention_mask': batch['attention_mask'].detach().cpu(),
270
+ }
271
+
272
+ def _cpu_detach_multiple(self, *batches: TokenizedDict) -> list[TokenizedDict]:
273
+ return [self._cpu_detach(batch) for batch in batches]
274
+
275
+ def _collect_writer(self, avg_reward: float, epoch: int):
276
+ if self.writer is not None:
277
+ self.writer.add_scalar('Collect/episode reward (global)', avg_reward, self.global_step['collect'])
278
+ self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps}, epoch: {epoch})',
279
+ avg_reward, self.epoch_step['collect'])
280
+ self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
281
+ self.stage_step['collect'])
282
+
283
+ def collect_trajectories(self, dataloader: DataLoader, epoch: int) -> list[MrlTrajectoryEpisode]:
284
+ """Collect trajectories for PPO for current curriculum step."""
285
+ # 1. Init trajectories list
286
+ trajectories = []
287
+
288
+ with torch.no_grad():
289
+ # 2. Collect episode trajectories for all batches in dataset
290
+ for batch_idx, batch in enumerate(dataloader):
291
+ self._increment_steps('collect')
292
+ # 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
293
+ # state from existing one, instead of new random one)
294
+ reset_done = self.reset_stm()
295
+
296
+ # 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
297
+ first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
298
+ interactions = interactions[:self.curriculum_steps]
299
+ interactions_len = len(interactions)
300
+ # 5. Encode and update STM with data to save from first interaction
301
+ self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
302
+
303
+ # 6. Save first interaction as data to save (for trajectory state)
304
+ query, answer = first_query, first_answer
305
+
306
+ # 7. Run training strategy for follow-up interactions
307
+ episode_steps = []
308
+ episode_rewards = []
309
+ for i, interaction in enumerate(interactions):
310
+ # 8. Generate batch of answers based on batch of follow-up queries
311
+ next_query = self._move_batch(interaction['query'])
312
+ generated_answer, log_probs = self.generate_answer(next_query)
313
+
314
+ is_last_interaction = (i + 1) == interactions_len
315
+
316
+ detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
317
+
318
+ # 9. Depending on strategy compute reward
319
+ if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
320
+ # a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
321
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
322
+ mode=MrlRewardMode.NEGATIVE)
323
+ elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
324
+ # b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
325
+ reward = self.compute_reward(detached_answer, interaction['answer'],
326
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
327
+ else:
328
+ # c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
329
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
330
+ mode=MrlRewardMode.STANDARD)
331
+
332
+ # 10. Update STM with generated response (except last interaction, it's not needed)
333
+ if not is_last_interaction:
334
+ self.encode_and_update_stm(next_query, generated_answer) # update with generated_answer on GPU
335
+
336
+ # 11. Store trajectory step
337
+ trajectory: MrlTrajectoryStep = {
338
+ 'state': (query, answer, interaction['query']),
339
+ 'action': detached_answer,
340
+ 'log_probs': log_probs.detach().cpu(),
341
+ 'reward': reward,
342
+ 'reference': interaction['answer'],
343
+ }
344
+ episode_steps.append(trajectory)
345
+ episode_rewards.append(reward)
346
+
347
+ # 12. Set current interaction query and generated answer (batches), as saved data for next interaction
348
+ query, answer = interaction['query'], detached_answer
349
+
350
+ # 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
351
+ episode_trajectory: MrlTrajectoryEpisode = {
352
+ 'reset_stm': reset_done,
353
+ 'steps': episode_steps,
354
+ }
355
+ trajectories.append(episode_trajectory)
356
+
357
+ mean_episode_reward = torch.tensor(episode_rewards).mean().item()
358
+
359
+ self._collect_writer(mean_episode_reward, epoch)
360
+
361
+ # 14. Run "on episode collected" callbacks
362
+ for cb in self.callbacks:
363
+ cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
364
+
365
+ return trajectories
366
+
367
+ def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
368
+ # 1. Calculate values with critic encoder
369
+ values = self.critic(
370
+ input_ids=inputs['input_ids'],
371
+ attention_mask=inputs['attention_mask'],
372
+ ).squeeze()
373
+ # 2. Calculate critic loss
374
+ loss = self.rl_algorithm.critic_loss(values, rewards)
375
+ return loss
376
+
377
+ def _critic_writer(self, critic_loss: float, epoch: int):
378
+ if self.writer is not None:
379
+ self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['critic'])
380
+ self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps}, epoch: {epoch})', critic_loss,
381
+ self.epoch_step['critic'])
382
+ self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps})', critic_loss,
383
+ self.stage_step['critic'])
384
+
385
+ def update_critic(self, states: list[tuple[TokenizedDict, TokenizedDict, TokenizedDict]],
386
+ rewards: list[torch.Tensor], epoch: int):
387
+ """Update critic network using MSE loss."""
388
+ # 1. Run critic updates for all collected batches
389
+ critic_losses = []
390
+ for step_idx, (state, reward) in enumerate(zip(states, rewards)):
391
+ self._increment_steps('critic')
392
+ # 2. Move state batches to training device (GPU)
393
+ prev_query, prev_answer, next_query = self._move_multiple_batches(*state)
394
+
395
+ # 3. Reset critic gradients
396
+ self.critic_optimizer.zero_grad()
397
+
398
+ # 4. Run critic and calculate loss - in autocast on/off mode
399
+ if self.use_amp:
400
+ # Move tensors to training device and calculate loss in autocast mode
401
+ batch_rewards = reward.to(self.device)
402
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
403
+ # Concatenate state into single critic input sequence
404
+ inputs = smart_concat_critic_states(
405
+ prev_query, prev_answer, next_query,
406
+ max_length=self.critic_max_len,
407
+ pad_token_id=self.special_token_ids['pad'],
408
+ )
409
+ loss = self._critic_loss(inputs, batch_rewards)
410
+ # Run backpropagation with scaler
411
+ self.critic_scaler.scale(loss).backward()
412
+ # Unscale and clip gradients
413
+ self.critic_scaler.unscale_(self.critic_optimizer)
414
+ torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
415
+ # Run scaled optimization step
416
+ self.critic_scaler.step(self.critic_optimizer)
417
+ self.critic_scaler.update()
418
+ else:
419
+ # Concatenate state into single critic input sequence
420
+ inputs = smart_concat_critic_states(
421
+ prev_query, prev_answer, next_query,
422
+ max_length=self.critic_max_len,
423
+ pad_token_id=self.special_token_ids['pad'],
424
+ )
425
+ # Calculate loss
426
+ loss = self._critic_loss(inputs, reward.to(self.device, dtype=self.dtype))
427
+ # Run backpropagation
428
+ loss.backward()
429
+ # Clip gradients
430
+ torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
431
+ # Run optimizer step
432
+ self.critic_optimizer.step()
433
+ critic_loss = loss.item()
434
+ self._critic_writer(critic_loss, epoch)
435
+
436
+ # 5. Run "on critic updated" callbacks
437
+ for cb in self.callbacks:
438
+ cb.on_critic_updated(self.actor, self.critic, epoch, step_idx, critic_loss)
439
+
440
+ # 6. Accumulate loss for epoch callbacks
441
+ critic_losses.append(critic_loss)
442
+
443
+ # 7. Calculate mean loss for epoch callbacks
444
+ critic_mean_loss = torch.stack(critic_losses).mean().item()
445
+
446
+ return critic_mean_loss
447
+
448
+ def _critic_advantages(self, critic_state: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
449
+ with torch.no_grad():
450
+ values = self.critic(critic_state['input_ids'],
451
+ attention_mask=critic_state['attention_mask']).squeeze()
452
+ return self.rl_algorithm.calculate_advantages(rewards, values)
453
+
454
+ def _rl_writer(self, policy_loss: float, epoch: int):
455
+ if self.writer is not None:
456
+ self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['rl'])
457
+ self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps}, epoch: {epoch})', policy_loss,
458
+ self.epoch_step['rl'])
459
+ self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss, self.stage_step['rl'])
460
+
461
+ def rl_step(self, trajectories: list[MrlTrajectoryEpisode], epoch: int):
462
+ """Perform PPO update step using trajectories."""
463
+ # 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
464
+ # memory, based on collected episode data
465
+ all_losses = []
466
+ trajectories_len = len(trajectories)
467
+ for episode_idx, episode in enumerate(trajectories):
468
+ episode_steps = episode['steps']
469
+ should_reset_stm = episode['reset_stm']
470
+
471
+ # 2. Reset memory for current batch episode
472
+ if should_reset_stm:
473
+ self.reset_stm()
474
+
475
+ # 3. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
476
+ for step in episode_steps:
477
+ self._increment_steps('rl')
478
+ state, action, reward, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
479
+ query, answer, next_query = self._move_multiple_batches(*state)
480
+ action = self._move_batch(action)
481
+ log_probs = log_probs.to(self.device)
482
+ rewards = torch.tensor(reward).to(self.device)
483
+
484
+ # 4. Compute advantages using critic
485
+ if self.use_amp:
486
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
487
+ critic_state = smart_concat_critic_states(query, answer, next_query,
488
+ max_length=self.critic_max_len,
489
+ pad_token_id=self.special_token_ids['pad'])
490
+ advantages = self._critic_advantages(critic_state, rewards)
491
+ else:
492
+ critic_state = smart_concat_critic_states(query, answer, next_query, max_length=self.critic_max_len,
493
+ pad_token_id=self.special_token_ids['pad'])
494
+ advantages = self._critic_advantages(critic_state, rewards)
495
+
496
+ # 5. Encode and update STM on each step, to include encoder and memory attention gradients in loss
497
+ self.encode_and_update_stm(query, answer)
498
+ # 6. Concatenate next query and action and get action logits from decoder
499
+ if self.use_amp:
500
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
501
+ inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
502
+ pad_token_id=self.special_token_ids['pad'])
503
+ logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
504
+ action=MrlActorAction.DECODE)
505
+ else:
506
+ inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
507
+ pad_token_id=self.special_token_ids['pad'])
508
+ logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
509
+ action=MrlActorAction.DECODE)
510
+
511
+ # 7. Calculate RL Algorithm (PPO etc.) loss
512
+ policy_loss = self.rl_algorithm.policy_loss(action['input_ids'], logits, log_probs, advantages)
513
+
514
+ # 8. Reset gradients
515
+ self.optimizer.zero_grad()
516
+
517
+ # 9. Update the model in AMP or regular mode
518
+ if self.use_amp:
519
+ self.scaler.scale(policy_loss).backward()
520
+ self.scaler.unscale_(self.optimizer)
521
+ torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
522
+ error_if_nonfinite=False)
523
+ self.scaler.step(self.optimizer)
524
+ self.scaler.update()
525
+ else:
526
+ policy_loss.backward()
527
+ torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
528
+ error_if_nonfinite=False)
529
+ self.optimizer.step()
530
+
531
+ policy_loss_item = policy_loss.item()
532
+ self._rl_writer(policy_loss_item, epoch)
533
+ all_losses.append(policy_loss_item)
534
+
535
+ # 10. Run "on batch updated" callback
536
+ for cb in self.callbacks:
537
+ cb.on_batch_updated(self.actor, epoch, self.epoch_step['rl'], policy_loss_item)
538
+
539
+ return torch.mean(torch.tensor(all_losses)).item()
540
+
541
+ def _critic_states_and_rewards(self, trajectories: list[MrlTrajectoryEpisode]):
542
+ flat_trajectories = [t for episode in trajectories for t in episode['steps']]
543
+ states = [t['state'] for t in flat_trajectories]
544
+ rewards = [torch.tensor(t['reward']) for t in flat_trajectories]
545
+ return states, rewards
546
+
547
+ def train_epoch(self, dataloader: DataLoader, epoch: int):
548
+ """Train for one epoch."""
549
+ # 1. Collect trajectories for current epoch
550
+ trajectories = self.collect_trajectories(dataloader, epoch)
551
+
552
+ # 2. Flatten trajectories and collect state and rewards for critic update
553
+ states, rewards = self._critic_states_and_rewards(trajectories)
554
+ # 3. Update critic model, based on states and rewards
555
+ critic_loss = self.update_critic(states, rewards, epoch)
556
+
557
+ # 4. Run PPO algorithm step
558
+ policy_loss = self.rl_step(trajectories, epoch)
559
+
560
+ # 5. Return policy and critic mean losses for epoch callbacks
561
+ return policy_loss, critic_loss
562
+
563
+ def _eval_loader(self, batch_size: int):
564
+ if self.use_ddp:
565
+ return DataLoader(
566
+ self.eval_dataset,
567
+ batch_size=batch_size,
568
+ pin_memory=True,
569
+ sampler=DistributedSampler(self.eval_dataset, shuffle=False),
570
+ collate_fn=MrlCurriculumDataset.collate_mrl_batch,
571
+ )
572
+ else:
573
+ return DataLoader(
574
+ self.eval_dataset,
575
+ batch_size=batch_size,
576
+ shuffle=False,
577
+ pin_memory=True,
578
+ collate_fn=MrlCurriculumDataset.collate_mrl_batch,
579
+ )
580
+
581
+ def _eval_writer(self, avg_reward: float, epoch: int):
582
+ if self.writer is not None:
583
+ self.writer.add_scalar('Eval/episode reward (global)', avg_reward, self.global_step['eval'])
584
+ self.writer.add_scalar(f'Eval/episode reward (steps: {self.curriculum_steps}, epoch: {epoch})', avg_reward,
585
+ self.epoch_step['eval'])
586
+ self.writer.add_scalar(f'Eval/episode reward (steps: {self.curriculum_steps})', avg_reward,
587
+ self.stage_step['eval'])
588
+
589
+ def evaluate(self, batch_size: int, epoch: int):
590
+ """Evaluate model on validation dataset."""
591
+ # 1. Init evaluation DataLoader
592
+ dataloader = self._eval_loader(batch_size)
593
+ total_reward = torch.tensor(0.0).to(self.device)
594
+ count = torch.tensor(0).to(self.device)
595
+
596
+ # 2. Run evaluation on all batch episodes
597
+ for batch in dataloader:
598
+ with torch.no_grad():
599
+ self._increment_steps('eval')
600
+ # 3. Reset STM with random resets ratio
601
+ self.reset_stm()
602
+
603
+ # 4. Get batches for first queries, answers and all follow-up interactions
604
+ first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
605
+ # 5. Encode and update STM with initial interactions (batch)
606
+ self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
607
+
608
+ # 6. Save follow-up interactions len and first query and answer as previous one for iteration
609
+ interactions_len = len(interactions)
610
+ query, answer = first_query, first_answer
611
+ episode_reward = torch.tensor(0.0).to(self.device)
612
+ episode_interactions = torch.tensor(0).to(self.device)
613
+ # 7. Run all follow-up interactions
614
+ for i, interaction in enumerate(interactions):
615
+ # 8. Generate batch of answers
616
+ next_query = self._move_batch(interaction['query'])
617
+ generated_answer, _ = self.generate_answer(next_query)
618
+
619
+ is_last_interaction = (i + 1) == interactions_len
620
+
621
+ detached_answer = self._cpu_detach(generated_answer)
622
+
623
+ # 9. Depending on current strategy and step, compute reward
624
+ if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
625
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
626
+ mode=MrlRewardMode.NEGATIVE, eval_mode=True)
627
+ elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
628
+ reward = self.compute_reward(detached_answer, interaction['answer'],
629
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
630
+ eval_mode=True)
631
+ else:
632
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
633
+ mode=MrlRewardMode.STANDARD, eval_mode=True)
634
+
635
+ # 10. Encode and update memory for the next interaction
636
+ if not is_last_interaction:
637
+ self.encode_and_update_stm(next_query, generated_answer)
638
+
639
+ # 11. Accumulate rewards
640
+ step_reward = torch.tensor(reward).mean().to(self.device)
641
+ # total
642
+ total_reward += step_reward
643
+ count += 1
644
+ # episode
645
+ episode_reward += step_reward
646
+ episode_interactions += 1
647
+ # 12. Save previous interaction
648
+ query, answer = interaction['query'], detached_answer
649
+ avg_episode_reward = (episode_reward / episode_interactions).item()
650
+ # 13. Run eval TensorBoard writer with average episode reward
651
+ self._eval_writer(avg_episode_reward, epoch)
652
+
653
+ # 14. Run "on eval episode end" callbacks
654
+ for cb in self.callbacks:
655
+ cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
656
+
657
+ # 15. Calculate average reward
658
+ if self.use_ddp:
659
+ total_sum = dist.all_reduce(total_reward, dist.ReduceOp.SUM)
660
+ count_sum = dist.all_reduce(count, dist.ReduceOp.SUM)
661
+ avg_reward = (total_sum / count_sum).item() if count_sum > 0 else 0
662
+ else:
663
+ avg_reward = (total_reward / count).item() if count > 0 else 0
664
+
665
+ should_stop_stage = False
666
+ # 16. Run "on eval end" callbacks
667
+ for cb in self.callbacks:
668
+ should_stop = cb.on_eval_end(self.actor, self.critic, epoch, avg_reward)
669
+ if should_stop:
670
+ should_stop_stage = True
671
+
672
+ return should_stop_stage
673
+
674
+ def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, int], tuple[bool, int, float]]:
675
+ # 1. Set common fields based on config
676
+ self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
677
+ self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
678
+ self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
679
+ self.callbacks = config.get('callbacks',
680
+ self.shared_callbacks) # trainer callbacks for current curriculum stage
681
+ self.strategy = config.get('strategy',
682
+ MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
683
+
684
+ # 2. Get epochs and random resets configs
685
+ epochs = config.get('epochs', 5) # number of epochs for current stage
686
+ unfreeze_epoch = config.get('unfreeze_epoch',
687
+ 0) # epoch when components (other than memory) are unfrozen (before epoch starts)
688
+ random_resets = config.get('random_resets',
689
+ False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
690
+ random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
691
+ random_resets_ratio = config.get('random_resets_ratio',
692
+ None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
693
+
694
+ # 3. Reset stage step counter
695
+ self.stage_step = self._init_steps()
696
+
697
+ return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
698
+
699
+ def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
700
+ """Start Memory Reinforcement Learning Curriculum."""
701
+
702
+ # 0. Set global epoch count for all stages
703
+ self.global_epochs_count = sum(stage['epochs'] for stage in curriculum_config)
704
+
705
+ # 1. Init DDP for distributed training mode
706
+ if self.use_ddp:
707
+ rank = int(os.environ['RANK'])
708
+ world_size = int(os.environ['WORLD_SIZE'])
709
+ dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
710
+ self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index])
711
+ self.critic = DistributedDataParallel(self.critic, device_ids=[self.device.index])
712
+
713
+ # 2. Run each curriculum step based on config
714
+ for current_curriculum_step in curriculum_config:
715
+ # 3. Setup training config for curriculum step
716
+ epochs_config, random_resets_config = self._setup_curriculum_step(current_curriculum_step)
717
+ epochs, unfreeze_epoch = epochs_config
718
+ random_resets, random_resets_from, random_resets_ratio = random_resets_config
719
+ assert self.train_dataset is not None
720
+ print(f'Curriculum Steps Increased to {self.curriculum_steps}')
721
+
722
+ # 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
723
+ if unfreeze_epoch != 0:
724
+ self.actor.freeze_components()
725
+
726
+ # 5. Setup train DataLoader
727
+ if self.use_ddp:
728
+ train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
729
+ dataloader = DataLoader(
730
+ self.train_dataset,
731
+ batch_size=batch_size,
732
+ sampler=train_sampler,
733
+ pin_memory=True,
734
+ collate_fn=MrlCurriculumDataset.collate_mrl_batch,
735
+ )
736
+ else:
737
+ train_sampler = None
738
+ dataloader = DataLoader(
739
+ self.train_dataset,
740
+ batch_size=batch_size,
741
+ shuffle=True,
742
+ pin_memory=True,
743
+ collate_fn=MrlCurriculumDataset.collate_mrl_batch,
744
+ )
745
+
746
+ # 6. Run selected number of epochs for given curriculum stage
747
+ for epoch in range(epochs):
748
+ # 7. Increment global epoch
749
+ self.global_epoch += 1
750
+ # 8. Run "on epoch start" callbacks (log info, etc.)
751
+ for cb in self.callbacks:
752
+ cb.on_epoch_start(self.actor, epoch, epochs, current_curriculum_step, self.global_epoch,
753
+ self.global_epochs_count)
754
+
755
+ # 9. Reset steps counter for epoch
756
+ self.epoch_step = self._init_steps()
757
+
758
+ # 10. Set random STM resets ratio from selected epoch
759
+ if random_resets and random_resets_from <= epoch:
760
+ self.random_resets_ratio = random_resets_ratio
761
+ else:
762
+ self.random_resets_ratio = 1.0
763
+
764
+ # 11. Unfreeze all components before selected epoch
765
+ if epoch == unfreeze_epoch:
766
+ self.actor.unfreeze_components()
767
+
768
+ # 12. Set epoch for distributed sampler
769
+ if train_sampler is not None:
770
+ train_sampler.set_epoch(epoch)
771
+
772
+ # 13. Run reinforcement learning algorithms for current epoch
773
+ policy_loss, critic_loss = self.train_epoch(dataloader, epoch)
774
+
775
+ # 14. If evaluation dataset is provided, run evaluation steps
776
+ if self.eval_dataset:
777
+ should_stop_stage = self.evaluate(batch_size, epoch)
778
+ else:
779
+ should_stop_stage = False
780
+
781
+ # 15. Finally, run "on epoch end" callbacks (save models, etc.)
782
+ for cb in self.callbacks:
783
+ cb.on_epoch_end(self.actor, epoch, epochs, policy_loss, critic_loss, self.global_epoch,
784
+ self.global_epochs_count)
785
+
786
+ # 16. Synchronize TensorBoard writer
787
+ if self.writer:
788
+ self.writer.flush()
789
+
790
+ # 17. Synchronize devices in DDP mode
791
+ if self.use_ddp:
792
+ dist.barrier()
793
+
794
+ # 18. Finish curriculum stage if rewards are not increased or reached threshold point
795
+ if should_stop_stage:
796
+ break
797
+
798
+ # 19. Run "on_training_end" callbacks after each curriculum stage (they have own callbacks)
799
+ for cb in self.callbacks:
800
+ cb.on_training_end(self.actor, self.critic, current_curriculum_step)
801
+
802
+ # 20. Training end - finish processes after all curriculum stages
803
+ if self.use_ddp:
804
+ dist.destroy_process_group()
805
+
806
+ # 21. Close writer
807
+ if self.writer:
808
+ self.writer.close()