rxnn 0.1.83__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/.DS_Store +0 -0
- rxnn/experimental/attention.py +5 -0
- rxnn/memory/attention.py +42 -0
- rxnn/memory/stm.py +53 -12
- rxnn/rxt/models.py +71 -0
- rxnn/training/bml.py +2 -59
- rxnn/training/callbacks.py +302 -39
- rxnn/training/dataset.py +344 -1
- rxnn/training/models.py +142 -0
- rxnn/training/mrl.py +808 -0
- rxnn/training/reward.py +111 -0
- rxnn/training/rl.py +69 -0
- rxnn/training/utils.py +148 -0
- rxnn/transformers/attention.py +10 -0
- rxnn/transformers/layers.py +6 -0
- rxnn/transformers/models.py +16 -4
- rxnn/transformers/positional.py +7 -0
- rxnn/transformers/sampler.py +283 -9
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/METADATA +11 -9
- rxnn-0.2.0.dist-info/RECORD +38 -0
- rxnn-0.1.83.dist-info/RECORD +0 -31
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/LICENSE +0 -0
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/WHEEL +0 -0
rxnn/training/mrl.py
ADDED
@@ -0,0 +1,808 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.utils.data import DataLoader, DistributedSampler
|
3
|
+
from torch.utils.tensorboard import SummaryWriter
|
4
|
+
import torch.distributed as dist
|
5
|
+
from torch.nn.parallel import DistributedDataParallel
|
6
|
+
from typing import Optional, TypedDict
|
7
|
+
from enum import Enum
|
8
|
+
import random, os
|
9
|
+
from ..transformers.sampler import BatchSampler
|
10
|
+
from .callbacks import MrlTrainerCallback
|
11
|
+
from .dataset import MrlCurriculumDataset
|
12
|
+
from .utils import smart_concat, smart_concat_critic_states, SpecialTokenIds, TokenizedDict
|
13
|
+
from .rl import RlAlgorithm
|
14
|
+
from .reward import MrlRewardMode, MrlRewardModel
|
15
|
+
from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
16
|
+
|
17
|
+
|
18
|
+
class MrlConfig(TypedDict):
|
19
|
+
lr: float
|
20
|
+
critic_lr: float
|
21
|
+
max_seq_len: int
|
22
|
+
critic_max_len: int
|
23
|
+
weight_decay: float
|
24
|
+
critic_weight_decay: float
|
25
|
+
|
26
|
+
|
27
|
+
class MrlStrategy(Enum):
|
28
|
+
SINGLE_STEP_STRATEGY = 1
|
29
|
+
MULTI_STEP_STRATEGY = 2
|
30
|
+
LONG_RANGE_STRATEGY = 3
|
31
|
+
|
32
|
+
|
33
|
+
class CurriculumConfig(TypedDict):
|
34
|
+
steps: int
|
35
|
+
epochs: int
|
36
|
+
dataset: MrlCurriculumDataset
|
37
|
+
eval_dataset: Optional[MrlCurriculumDataset]
|
38
|
+
callbacks: Optional[list[MrlTrainerCallback]]
|
39
|
+
strategy: MrlStrategy
|
40
|
+
unfreeze_epoch: Optional[int]
|
41
|
+
random_resets: Optional[bool]
|
42
|
+
random_resets_from: Optional[int]
|
43
|
+
random_resets_ratio: Optional[float]
|
44
|
+
|
45
|
+
|
46
|
+
class SamplerConfig(TypedDict):
|
47
|
+
temperature: float
|
48
|
+
top_k: Optional[int]
|
49
|
+
top_p: Optional[float]
|
50
|
+
|
51
|
+
|
52
|
+
class MrlTrajectoryStep(TypedDict):
|
53
|
+
state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]
|
54
|
+
action: TokenizedDict
|
55
|
+
log_probs: torch.Tensor
|
56
|
+
reward: list[float]
|
57
|
+
reference: TokenizedDict
|
58
|
+
|
59
|
+
|
60
|
+
class MrlTrajectoryEpisode(TypedDict):
|
61
|
+
reset_stm: bool
|
62
|
+
steps: list[MrlTrajectoryStep]
|
63
|
+
|
64
|
+
|
65
|
+
class MRLTrainer:
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
actor: MrlActorModel,
|
69
|
+
critic: MrlCriticModel,
|
70
|
+
reward: MrlRewardModel,
|
71
|
+
device: torch.device,
|
72
|
+
config: MrlConfig,
|
73
|
+
rl_algorithm: RlAlgorithm,
|
74
|
+
sampler_config: Optional[SamplerConfig] = None,
|
75
|
+
log_dir: str = None,
|
76
|
+
pad_token_id: int = 0,
|
77
|
+
start_token_id: int = 2,
|
78
|
+
end_token_id: int = 3,
|
79
|
+
use_ddp: bool = False,
|
80
|
+
use_amp: bool = False,
|
81
|
+
dtype: torch.dtype = torch.float32,
|
82
|
+
callbacks: list[MrlTrainerCallback] = None,
|
83
|
+
):
|
84
|
+
"""
|
85
|
+
Trainer for Memory Reinforcement Learning (MRL) in Reactive Transformer.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
actor: MRL Actor model with encoder, decoder and memory attention.
|
89
|
+
critic: Critic network for advantage estimation.
|
90
|
+
config: Configuration dictionary with hyperparameters.
|
91
|
+
"""
|
92
|
+
self.actor = actor
|
93
|
+
self.critic = critic
|
94
|
+
self.reward = reward
|
95
|
+
self.device = device
|
96
|
+
self.max_seq_len = config.get('max_seq_len', 1024)
|
97
|
+
self.critic_max_len = config.get('critic_max_len', 2048)
|
98
|
+
|
99
|
+
# Move models to device
|
100
|
+
if use_amp:
|
101
|
+
self.actor.to(self.device)
|
102
|
+
self.critic.to(self.device)
|
103
|
+
else:
|
104
|
+
self.actor.to(self.device, dtype=dtype)
|
105
|
+
self.critic.to(self.device, dtype=dtype)
|
106
|
+
|
107
|
+
# Batch Sampler for answer generation
|
108
|
+
self.generator = BatchSampler(self.actor, self.device, end_token_id=end_token_id)
|
109
|
+
self.sampler_config = SamplerConfig(
|
110
|
+
temperature=1.0,
|
111
|
+
top_k=None,
|
112
|
+
top_p=None,
|
113
|
+
) if sampler_config is None else sampler_config
|
114
|
+
|
115
|
+
self.special_token_ids: SpecialTokenIds = {
|
116
|
+
'pad': pad_token_id,
|
117
|
+
'bos': start_token_id,
|
118
|
+
'eos': end_token_id,
|
119
|
+
}
|
120
|
+
|
121
|
+
self.use_ddp = use_ddp
|
122
|
+
self.use_amp = use_amp
|
123
|
+
self.dtype = dtype
|
124
|
+
|
125
|
+
# Optimizers
|
126
|
+
self.optimizer = torch.optim.AdamW(
|
127
|
+
self.actor.unique_parameters(),
|
128
|
+
lr=config.get("lr", 3e-4),
|
129
|
+
weight_decay=config.get("weight_decay", 0.01),
|
130
|
+
)
|
131
|
+
self.critic_optimizer = torch.optim.AdamW(
|
132
|
+
self.critic.parameters(),
|
133
|
+
lr=config.get("critic_lr", 1e-4),
|
134
|
+
weight_decay=config.get("critic_weight_decay", 0.01),
|
135
|
+
)
|
136
|
+
|
137
|
+
self.scaler = torch.amp.GradScaler() if self.use_amp else None
|
138
|
+
self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
|
139
|
+
|
140
|
+
# TensorBoard Writer
|
141
|
+
if log_dir and not os.path.exists(log_dir):
|
142
|
+
os.makedirs(log_dir)
|
143
|
+
self.writer = SummaryWriter(log_dir) if log_dir else None
|
144
|
+
|
145
|
+
self.global_step = self._init_steps()
|
146
|
+
self.epoch_step = self._init_steps()
|
147
|
+
self.stage_step = self._init_steps()
|
148
|
+
|
149
|
+
self.rl_algorithm = rl_algorithm
|
150
|
+
|
151
|
+
# Dynamic fields, updated for each curriculum step
|
152
|
+
self.curriculum_steps = 0
|
153
|
+
self.train_dataset = None
|
154
|
+
self.eval_dataset = None
|
155
|
+
self.random_resets_ratio = 0.0
|
156
|
+
self.strategy = None
|
157
|
+
self.shared_callbacks = callbacks if callbacks else []
|
158
|
+
self.callbacks = []
|
159
|
+
self.global_epoch = 0
|
160
|
+
self.global_epochs_count = 0
|
161
|
+
|
162
|
+
def _init_steps(self):
|
163
|
+
return {
|
164
|
+
'collect': 0,
|
165
|
+
'critic': 0,
|
166
|
+
'rl': 0,
|
167
|
+
'eval': 0,
|
168
|
+
}
|
169
|
+
|
170
|
+
def _increment_steps(self, step_type: str):
|
171
|
+
self.global_step[step_type] += 1
|
172
|
+
self.epoch_step[step_type] += 1
|
173
|
+
self.stage_step[step_type] += 1
|
174
|
+
|
175
|
+
def reset_stm(self) -> bool:
|
176
|
+
"""Reset Short-Term Memory state with random reset ratio."""
|
177
|
+
if self.random_resets_ratio == 1.0:
|
178
|
+
self.actor.reset_memory()
|
179
|
+
return True
|
180
|
+
else:
|
181
|
+
rng = random.random()
|
182
|
+
if rng <= self.random_resets_ratio:
|
183
|
+
self.actor.reset_memory()
|
184
|
+
return True
|
185
|
+
else:
|
186
|
+
return False
|
187
|
+
|
188
|
+
def encode_and_update_stm(self, query: TokenizedDict, answer: TokenizedDict):
|
189
|
+
"""Encode interaction and update STM."""
|
190
|
+
# 1. Encode data and update memory - with autocast on/off
|
191
|
+
if self.use_amp:
|
192
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
193
|
+
# 2. Concatenate batch of queries and answers (they are already on training device)
|
194
|
+
inputs = smart_concat(query, answer, self.max_seq_len, self.special_token_ids['pad'])
|
195
|
+
# 3. Encode data and update STM
|
196
|
+
self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'], action=MrlActorAction.UPDATE)
|
197
|
+
else:
|
198
|
+
# 2. Concatenate batch of queries and answers (they are already on training device)
|
199
|
+
inputs = smart_concat(query, answer, self.max_seq_len, self.special_token_ids['pad'])
|
200
|
+
# 3. Encode data and update STM
|
201
|
+
self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'], action=MrlActorAction.UPDATE)
|
202
|
+
|
203
|
+
def generate_answer(self, query: TokenizedDict) -> tuple[TokenizedDict, torch.Tensor]:
|
204
|
+
"""Generate response using batch sampler with decoder."""
|
205
|
+
# 1. Generate answer with BatchSampler - with autocast on/off
|
206
|
+
if self.use_amp:
|
207
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
208
|
+
input_ids, attention_mask, log_probs = self.generator(
|
209
|
+
query['input_ids'],
|
210
|
+
query['attention_mask'],
|
211
|
+
max_gen_len=self.max_seq_len,
|
212
|
+
**self.sampler_config,
|
213
|
+
)
|
214
|
+
else:
|
215
|
+
input_ids, attention_mask, log_probs = self.generator(
|
216
|
+
query['input_ids'],
|
217
|
+
query['attention_mask'],
|
218
|
+
max_gen_len=self.max_seq_len,
|
219
|
+
**self.sampler_config,
|
220
|
+
)
|
221
|
+
# 2. Convert generated answer to TokenizedDict
|
222
|
+
generated_answer: TokenizedDict = {
|
223
|
+
'input_ids': input_ids,
|
224
|
+
'attention_mask': attention_mask,
|
225
|
+
}
|
226
|
+
|
227
|
+
return generated_answer, log_probs
|
228
|
+
|
229
|
+
def compute_reward(self, generated: TokenizedDict, reference: TokenizedDict,
|
230
|
+
saved_data: tuple[TokenizedDict, TokenizedDict], mode: MrlRewardMode = MrlRewardMode.STANDARD,
|
231
|
+
eval_mode: bool = False) -> list[float]:
|
232
|
+
"""Compute reward based on memory retention (e.g., BLEU-4)."""
|
233
|
+
saved_query, saved_answer = saved_data
|
234
|
+
# 1. Concat saved (previous) interaction and calculate reward using generated sequence, reference and saved data - with autocast on/off
|
235
|
+
if self.use_amp:
|
236
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
237
|
+
saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
|
238
|
+
pad_token_id=self.special_token_ids['pad'])
|
239
|
+
reward = self.reward(generated, reference, saved_interaction, mode=mode)
|
240
|
+
else:
|
241
|
+
saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
|
242
|
+
pad_token_id=self.special_token_ids['pad'])
|
243
|
+
reward = self.reward(generated, reference, saved_interaction, mode=mode)
|
244
|
+
|
245
|
+
# 2. Run 'on reward' callbacks
|
246
|
+
for cb in self.callbacks:
|
247
|
+
cb.on_reward(self.actor, reward, generated, reference, saved_interaction, eval_mode)
|
248
|
+
# 3. Return rewards for batch
|
249
|
+
return reward
|
250
|
+
|
251
|
+
def _move_batch(self, batch: TokenizedDict) -> TokenizedDict:
|
252
|
+
if self.use_amp:
|
253
|
+
return {
|
254
|
+
'input_ids': batch['input_ids'].to(self.device),
|
255
|
+
'attention_mask': batch['attention_mask'].to(self.device),
|
256
|
+
}
|
257
|
+
else:
|
258
|
+
return {
|
259
|
+
'input_ids': batch['input_ids'].to(self.device, dtype=self.dtype),
|
260
|
+
'attention_mask': batch['attention_mask'].to(self.device, dtype=self.dtype),
|
261
|
+
}
|
262
|
+
|
263
|
+
def _move_multiple_batches(self, *batches: TokenizedDict) -> list[TokenizedDict]:
|
264
|
+
return [self._move_batch(batch) for batch in batches]
|
265
|
+
|
266
|
+
def _cpu_detach(self, batch: TokenizedDict) -> TokenizedDict:
|
267
|
+
return {
|
268
|
+
'input_ids': batch['input_ids'].detach().cpu(),
|
269
|
+
'attention_mask': batch['attention_mask'].detach().cpu(),
|
270
|
+
}
|
271
|
+
|
272
|
+
def _cpu_detach_multiple(self, *batches: TokenizedDict) -> list[TokenizedDict]:
|
273
|
+
return [self._cpu_detach(batch) for batch in batches]
|
274
|
+
|
275
|
+
def _collect_writer(self, avg_reward: float, epoch: int):
|
276
|
+
if self.writer is not None:
|
277
|
+
self.writer.add_scalar('Collect/episode reward (global)', avg_reward, self.global_step['collect'])
|
278
|
+
self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps}, epoch: {epoch})',
|
279
|
+
avg_reward, self.epoch_step['collect'])
|
280
|
+
self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
|
281
|
+
self.stage_step['collect'])
|
282
|
+
|
283
|
+
def collect_trajectories(self, dataloader: DataLoader, epoch: int) -> list[MrlTrajectoryEpisode]:
|
284
|
+
"""Collect trajectories for PPO for current curriculum step."""
|
285
|
+
# 1. Init trajectories list
|
286
|
+
trajectories = []
|
287
|
+
|
288
|
+
with torch.no_grad():
|
289
|
+
# 2. Collect episode trajectories for all batches in dataset
|
290
|
+
for batch_idx, batch in enumerate(dataloader):
|
291
|
+
self._increment_steps('collect')
|
292
|
+
# 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
|
293
|
+
# state from existing one, instead of new random one)
|
294
|
+
reset_done = self.reset_stm()
|
295
|
+
|
296
|
+
# 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
|
297
|
+
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
298
|
+
interactions = interactions[:self.curriculum_steps]
|
299
|
+
interactions_len = len(interactions)
|
300
|
+
# 5. Encode and update STM with data to save from first interaction
|
301
|
+
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
302
|
+
|
303
|
+
# 6. Save first interaction as data to save (for trajectory state)
|
304
|
+
query, answer = first_query, first_answer
|
305
|
+
|
306
|
+
# 7. Run training strategy for follow-up interactions
|
307
|
+
episode_steps = []
|
308
|
+
episode_rewards = []
|
309
|
+
for i, interaction in enumerate(interactions):
|
310
|
+
# 8. Generate batch of answers based on batch of follow-up queries
|
311
|
+
next_query = self._move_batch(interaction['query'])
|
312
|
+
generated_answer, log_probs = self.generate_answer(next_query)
|
313
|
+
|
314
|
+
is_last_interaction = (i + 1) == interactions_len
|
315
|
+
|
316
|
+
detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
|
317
|
+
|
318
|
+
# 9. Depending on strategy compute reward
|
319
|
+
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
320
|
+
# a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
|
321
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
322
|
+
mode=MrlRewardMode.NEGATIVE)
|
323
|
+
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
324
|
+
# b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
|
325
|
+
reward = self.compute_reward(detached_answer, interaction['answer'],
|
326
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
|
327
|
+
else:
|
328
|
+
# c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
|
329
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
330
|
+
mode=MrlRewardMode.STANDARD)
|
331
|
+
|
332
|
+
# 10. Update STM with generated response (except last interaction, it's not needed)
|
333
|
+
if not is_last_interaction:
|
334
|
+
self.encode_and_update_stm(next_query, generated_answer) # update with generated_answer on GPU
|
335
|
+
|
336
|
+
# 11. Store trajectory step
|
337
|
+
trajectory: MrlTrajectoryStep = {
|
338
|
+
'state': (query, answer, interaction['query']),
|
339
|
+
'action': detached_answer,
|
340
|
+
'log_probs': log_probs.detach().cpu(),
|
341
|
+
'reward': reward,
|
342
|
+
'reference': interaction['answer'],
|
343
|
+
}
|
344
|
+
episode_steps.append(trajectory)
|
345
|
+
episode_rewards.append(reward)
|
346
|
+
|
347
|
+
# 12. Set current interaction query and generated answer (batches), as saved data for next interaction
|
348
|
+
query, answer = interaction['query'], detached_answer
|
349
|
+
|
350
|
+
# 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
|
351
|
+
episode_trajectory: MrlTrajectoryEpisode = {
|
352
|
+
'reset_stm': reset_done,
|
353
|
+
'steps': episode_steps,
|
354
|
+
}
|
355
|
+
trajectories.append(episode_trajectory)
|
356
|
+
|
357
|
+
mean_episode_reward = torch.tensor(episode_rewards).mean().item()
|
358
|
+
|
359
|
+
self._collect_writer(mean_episode_reward, epoch)
|
360
|
+
|
361
|
+
# 14. Run "on episode collected" callbacks
|
362
|
+
for cb in self.callbacks:
|
363
|
+
cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
|
364
|
+
|
365
|
+
return trajectories
|
366
|
+
|
367
|
+
def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
|
368
|
+
# 1. Calculate values with critic encoder
|
369
|
+
values = self.critic(
|
370
|
+
input_ids=inputs['input_ids'],
|
371
|
+
attention_mask=inputs['attention_mask'],
|
372
|
+
).squeeze()
|
373
|
+
# 2. Calculate critic loss
|
374
|
+
loss = self.rl_algorithm.critic_loss(values, rewards)
|
375
|
+
return loss
|
376
|
+
|
377
|
+
def _critic_writer(self, critic_loss: float, epoch: int):
|
378
|
+
if self.writer is not None:
|
379
|
+
self.writer.add_scalar('Loss/critic (global)', critic_loss, self.global_step['critic'])
|
380
|
+
self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps}, epoch: {epoch})', critic_loss,
|
381
|
+
self.epoch_step['critic'])
|
382
|
+
self.writer.add_scalar(f'Loss/critic (steps: {self.curriculum_steps})', critic_loss,
|
383
|
+
self.stage_step['critic'])
|
384
|
+
|
385
|
+
def update_critic(self, states: list[tuple[TokenizedDict, TokenizedDict, TokenizedDict]],
|
386
|
+
rewards: list[torch.Tensor], epoch: int):
|
387
|
+
"""Update critic network using MSE loss."""
|
388
|
+
# 1. Run critic updates for all collected batches
|
389
|
+
critic_losses = []
|
390
|
+
for step_idx, (state, reward) in enumerate(zip(states, rewards)):
|
391
|
+
self._increment_steps('critic')
|
392
|
+
# 2. Move state batches to training device (GPU)
|
393
|
+
prev_query, prev_answer, next_query = self._move_multiple_batches(*state)
|
394
|
+
|
395
|
+
# 3. Reset critic gradients
|
396
|
+
self.critic_optimizer.zero_grad()
|
397
|
+
|
398
|
+
# 4. Run critic and calculate loss - in autocast on/off mode
|
399
|
+
if self.use_amp:
|
400
|
+
# Move tensors to training device and calculate loss in autocast mode
|
401
|
+
batch_rewards = reward.to(self.device)
|
402
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
403
|
+
# Concatenate state into single critic input sequence
|
404
|
+
inputs = smart_concat_critic_states(
|
405
|
+
prev_query, prev_answer, next_query,
|
406
|
+
max_length=self.critic_max_len,
|
407
|
+
pad_token_id=self.special_token_ids['pad'],
|
408
|
+
)
|
409
|
+
loss = self._critic_loss(inputs, batch_rewards)
|
410
|
+
# Run backpropagation with scaler
|
411
|
+
self.critic_scaler.scale(loss).backward()
|
412
|
+
# Unscale and clip gradients
|
413
|
+
self.critic_scaler.unscale_(self.critic_optimizer)
|
414
|
+
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
415
|
+
# Run scaled optimization step
|
416
|
+
self.critic_scaler.step(self.critic_optimizer)
|
417
|
+
self.critic_scaler.update()
|
418
|
+
else:
|
419
|
+
# Concatenate state into single critic input sequence
|
420
|
+
inputs = smart_concat_critic_states(
|
421
|
+
prev_query, prev_answer, next_query,
|
422
|
+
max_length=self.critic_max_len,
|
423
|
+
pad_token_id=self.special_token_ids['pad'],
|
424
|
+
)
|
425
|
+
# Calculate loss
|
426
|
+
loss = self._critic_loss(inputs, reward.to(self.device, dtype=self.dtype))
|
427
|
+
# Run backpropagation
|
428
|
+
loss.backward()
|
429
|
+
# Clip gradients
|
430
|
+
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
431
|
+
# Run optimizer step
|
432
|
+
self.critic_optimizer.step()
|
433
|
+
critic_loss = loss.item()
|
434
|
+
self._critic_writer(critic_loss, epoch)
|
435
|
+
|
436
|
+
# 5. Run "on critic updated" callbacks
|
437
|
+
for cb in self.callbacks:
|
438
|
+
cb.on_critic_updated(self.actor, self.critic, epoch, step_idx, critic_loss)
|
439
|
+
|
440
|
+
# 6. Accumulate loss for epoch callbacks
|
441
|
+
critic_losses.append(critic_loss)
|
442
|
+
|
443
|
+
# 7. Calculate mean loss for epoch callbacks
|
444
|
+
critic_mean_loss = torch.stack(critic_losses).mean().item()
|
445
|
+
|
446
|
+
return critic_mean_loss
|
447
|
+
|
448
|
+
def _critic_advantages(self, critic_state: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
|
449
|
+
with torch.no_grad():
|
450
|
+
values = self.critic(critic_state['input_ids'],
|
451
|
+
attention_mask=critic_state['attention_mask']).squeeze()
|
452
|
+
return self.rl_algorithm.calculate_advantages(rewards, values)
|
453
|
+
|
454
|
+
def _rl_writer(self, policy_loss: float, epoch: int):
|
455
|
+
if self.writer is not None:
|
456
|
+
self.writer.add_scalar('Loss/policy (global)', policy_loss, self.global_step['rl'])
|
457
|
+
self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps}, epoch: {epoch})', policy_loss,
|
458
|
+
self.epoch_step['rl'])
|
459
|
+
self.writer.add_scalar(f'Loss/policy (steps: {self.curriculum_steps})', policy_loss, self.stage_step['rl'])
|
460
|
+
|
461
|
+
def rl_step(self, trajectories: list[MrlTrajectoryEpisode], epoch: int):
|
462
|
+
"""Perform PPO update step using trajectories."""
|
463
|
+
# 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
|
464
|
+
# memory, based on collected episode data
|
465
|
+
all_losses = []
|
466
|
+
trajectories_len = len(trajectories)
|
467
|
+
for episode_idx, episode in enumerate(trajectories):
|
468
|
+
episode_steps = episode['steps']
|
469
|
+
should_reset_stm = episode['reset_stm']
|
470
|
+
|
471
|
+
# 2. Reset memory for current batch episode
|
472
|
+
if should_reset_stm:
|
473
|
+
self.reset_stm()
|
474
|
+
|
475
|
+
# 3. Run episode steps - each episode has number of steps depending on curriculum stage. Each step is run for all batch
|
476
|
+
for step in episode_steps:
|
477
|
+
self._increment_steps('rl')
|
478
|
+
state, action, reward, log_probs = step['state'], step['action'], step['reward'], step['log_probs']
|
479
|
+
query, answer, next_query = self._move_multiple_batches(*state)
|
480
|
+
action = self._move_batch(action)
|
481
|
+
log_probs = log_probs.to(self.device)
|
482
|
+
rewards = torch.tensor(reward).to(self.device)
|
483
|
+
|
484
|
+
# 4. Compute advantages using critic
|
485
|
+
if self.use_amp:
|
486
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
487
|
+
critic_state = smart_concat_critic_states(query, answer, next_query,
|
488
|
+
max_length=self.critic_max_len,
|
489
|
+
pad_token_id=self.special_token_ids['pad'])
|
490
|
+
advantages = self._critic_advantages(critic_state, rewards)
|
491
|
+
else:
|
492
|
+
critic_state = smart_concat_critic_states(query, answer, next_query, max_length=self.critic_max_len,
|
493
|
+
pad_token_id=self.special_token_ids['pad'])
|
494
|
+
advantages = self._critic_advantages(critic_state, rewards)
|
495
|
+
|
496
|
+
# 5. Encode and update STM on each step, to include encoder and memory attention gradients in loss
|
497
|
+
self.encode_and_update_stm(query, answer)
|
498
|
+
# 6. Concatenate next query and action and get action logits from decoder
|
499
|
+
if self.use_amp:
|
500
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
501
|
+
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
502
|
+
pad_token_id=self.special_token_ids['pad'])
|
503
|
+
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
504
|
+
action=MrlActorAction.DECODE)
|
505
|
+
else:
|
506
|
+
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
507
|
+
pad_token_id=self.special_token_ids['pad'])
|
508
|
+
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
509
|
+
action=MrlActorAction.DECODE)
|
510
|
+
|
511
|
+
# 7. Calculate RL Algorithm (PPO etc.) loss
|
512
|
+
policy_loss = self.rl_algorithm.policy_loss(action['input_ids'], logits, log_probs, advantages)
|
513
|
+
|
514
|
+
# 8. Reset gradients
|
515
|
+
self.optimizer.zero_grad()
|
516
|
+
|
517
|
+
# 9. Update the model in AMP or regular mode
|
518
|
+
if self.use_amp:
|
519
|
+
self.scaler.scale(policy_loss).backward()
|
520
|
+
self.scaler.unscale_(self.optimizer)
|
521
|
+
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
522
|
+
error_if_nonfinite=False)
|
523
|
+
self.scaler.step(self.optimizer)
|
524
|
+
self.scaler.update()
|
525
|
+
else:
|
526
|
+
policy_loss.backward()
|
527
|
+
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
528
|
+
error_if_nonfinite=False)
|
529
|
+
self.optimizer.step()
|
530
|
+
|
531
|
+
policy_loss_item = policy_loss.item()
|
532
|
+
self._rl_writer(policy_loss_item, epoch)
|
533
|
+
all_losses.append(policy_loss_item)
|
534
|
+
|
535
|
+
# 10. Run "on batch updated" callback
|
536
|
+
for cb in self.callbacks:
|
537
|
+
cb.on_batch_updated(self.actor, epoch, self.epoch_step['rl'], policy_loss_item)
|
538
|
+
|
539
|
+
return torch.mean(torch.tensor(all_losses)).item()
|
540
|
+
|
541
|
+
def _critic_states_and_rewards(self, trajectories: list[MrlTrajectoryEpisode]):
|
542
|
+
flat_trajectories = [t for episode in trajectories for t in episode['steps']]
|
543
|
+
states = [t['state'] for t in flat_trajectories]
|
544
|
+
rewards = [torch.tensor(t['reward']) for t in flat_trajectories]
|
545
|
+
return states, rewards
|
546
|
+
|
547
|
+
def train_epoch(self, dataloader: DataLoader, epoch: int):
|
548
|
+
"""Train for one epoch."""
|
549
|
+
# 1. Collect trajectories for current epoch
|
550
|
+
trajectories = self.collect_trajectories(dataloader, epoch)
|
551
|
+
|
552
|
+
# 2. Flatten trajectories and collect state and rewards for critic update
|
553
|
+
states, rewards = self._critic_states_and_rewards(trajectories)
|
554
|
+
# 3. Update critic model, based on states and rewards
|
555
|
+
critic_loss = self.update_critic(states, rewards, epoch)
|
556
|
+
|
557
|
+
# 4. Run PPO algorithm step
|
558
|
+
policy_loss = self.rl_step(trajectories, epoch)
|
559
|
+
|
560
|
+
# 5. Return policy and critic mean losses for epoch callbacks
|
561
|
+
return policy_loss, critic_loss
|
562
|
+
|
563
|
+
def _eval_loader(self, batch_size: int):
|
564
|
+
if self.use_ddp:
|
565
|
+
return DataLoader(
|
566
|
+
self.eval_dataset,
|
567
|
+
batch_size=batch_size,
|
568
|
+
pin_memory=True,
|
569
|
+
sampler=DistributedSampler(self.eval_dataset, shuffle=False),
|
570
|
+
collate_fn=MrlCurriculumDataset.collate_mrl_batch,
|
571
|
+
)
|
572
|
+
else:
|
573
|
+
return DataLoader(
|
574
|
+
self.eval_dataset,
|
575
|
+
batch_size=batch_size,
|
576
|
+
shuffle=False,
|
577
|
+
pin_memory=True,
|
578
|
+
collate_fn=MrlCurriculumDataset.collate_mrl_batch,
|
579
|
+
)
|
580
|
+
|
581
|
+
def _eval_writer(self, avg_reward: float, epoch: int):
|
582
|
+
if self.writer is not None:
|
583
|
+
self.writer.add_scalar('Eval/episode reward (global)', avg_reward, self.global_step['eval'])
|
584
|
+
self.writer.add_scalar(f'Eval/episode reward (steps: {self.curriculum_steps}, epoch: {epoch})', avg_reward,
|
585
|
+
self.epoch_step['eval'])
|
586
|
+
self.writer.add_scalar(f'Eval/episode reward (steps: {self.curriculum_steps})', avg_reward,
|
587
|
+
self.stage_step['eval'])
|
588
|
+
|
589
|
+
def evaluate(self, batch_size: int, epoch: int):
|
590
|
+
"""Evaluate model on validation dataset."""
|
591
|
+
# 1. Init evaluation DataLoader
|
592
|
+
dataloader = self._eval_loader(batch_size)
|
593
|
+
total_reward = torch.tensor(0.0).to(self.device)
|
594
|
+
count = torch.tensor(0).to(self.device)
|
595
|
+
|
596
|
+
# 2. Run evaluation on all batch episodes
|
597
|
+
for batch in dataloader:
|
598
|
+
with torch.no_grad():
|
599
|
+
self._increment_steps('eval')
|
600
|
+
# 3. Reset STM with random resets ratio
|
601
|
+
self.reset_stm()
|
602
|
+
|
603
|
+
# 4. Get batches for first queries, answers and all follow-up interactions
|
604
|
+
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
605
|
+
# 5. Encode and update STM with initial interactions (batch)
|
606
|
+
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
607
|
+
|
608
|
+
# 6. Save follow-up interactions len and first query and answer as previous one for iteration
|
609
|
+
interactions_len = len(interactions)
|
610
|
+
query, answer = first_query, first_answer
|
611
|
+
episode_reward = torch.tensor(0.0).to(self.device)
|
612
|
+
episode_interactions = torch.tensor(0).to(self.device)
|
613
|
+
# 7. Run all follow-up interactions
|
614
|
+
for i, interaction in enumerate(interactions):
|
615
|
+
# 8. Generate batch of answers
|
616
|
+
next_query = self._move_batch(interaction['query'])
|
617
|
+
generated_answer, _ = self.generate_answer(next_query)
|
618
|
+
|
619
|
+
is_last_interaction = (i + 1) == interactions_len
|
620
|
+
|
621
|
+
detached_answer = self._cpu_detach(generated_answer)
|
622
|
+
|
623
|
+
# 9. Depending on current strategy and step, compute reward
|
624
|
+
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
625
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
626
|
+
mode=MrlRewardMode.NEGATIVE, eval_mode=True)
|
627
|
+
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
628
|
+
reward = self.compute_reward(detached_answer, interaction['answer'],
|
629
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
|
630
|
+
eval_mode=True)
|
631
|
+
else:
|
632
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
633
|
+
mode=MrlRewardMode.STANDARD, eval_mode=True)
|
634
|
+
|
635
|
+
# 10. Encode and update memory for the next interaction
|
636
|
+
if not is_last_interaction:
|
637
|
+
self.encode_and_update_stm(next_query, generated_answer)
|
638
|
+
|
639
|
+
# 11. Accumulate rewards
|
640
|
+
step_reward = torch.tensor(reward).mean().to(self.device)
|
641
|
+
# total
|
642
|
+
total_reward += step_reward
|
643
|
+
count += 1
|
644
|
+
# episode
|
645
|
+
episode_reward += step_reward
|
646
|
+
episode_interactions += 1
|
647
|
+
# 12. Save previous interaction
|
648
|
+
query, answer = interaction['query'], detached_answer
|
649
|
+
avg_episode_reward = (episode_reward / episode_interactions).item()
|
650
|
+
# 13. Run eval TensorBoard writer with average episode reward
|
651
|
+
self._eval_writer(avg_episode_reward, epoch)
|
652
|
+
|
653
|
+
# 14. Run "on eval episode end" callbacks
|
654
|
+
for cb in self.callbacks:
|
655
|
+
cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
|
656
|
+
|
657
|
+
# 15. Calculate average reward
|
658
|
+
if self.use_ddp:
|
659
|
+
total_sum = dist.all_reduce(total_reward, dist.ReduceOp.SUM)
|
660
|
+
count_sum = dist.all_reduce(count, dist.ReduceOp.SUM)
|
661
|
+
avg_reward = (total_sum / count_sum).item() if count_sum > 0 else 0
|
662
|
+
else:
|
663
|
+
avg_reward = (total_reward / count).item() if count > 0 else 0
|
664
|
+
|
665
|
+
should_stop_stage = False
|
666
|
+
# 16. Run "on eval end" callbacks
|
667
|
+
for cb in self.callbacks:
|
668
|
+
should_stop = cb.on_eval_end(self.actor, self.critic, epoch, avg_reward)
|
669
|
+
if should_stop:
|
670
|
+
should_stop_stage = True
|
671
|
+
|
672
|
+
return should_stop_stage
|
673
|
+
|
674
|
+
def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, int], tuple[bool, int, float]]:
|
675
|
+
# 1. Set common fields based on config
|
676
|
+
self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
|
677
|
+
self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
|
678
|
+
self.eval_dataset = config.get('eval_dataset', None) # evaluation dataset for current curriculum stage
|
679
|
+
self.callbacks = config.get('callbacks',
|
680
|
+
self.shared_callbacks) # trainer callbacks for current curriculum stage
|
681
|
+
self.strategy = config.get('strategy',
|
682
|
+
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
683
|
+
|
684
|
+
# 2. Get epochs and random resets configs
|
685
|
+
epochs = config.get('epochs', 5) # number of epochs for current stage
|
686
|
+
unfreeze_epoch = config.get('unfreeze_epoch',
|
687
|
+
0) # epoch when components (other than memory) are unfrozen (before epoch starts)
|
688
|
+
random_resets = config.get('random_resets',
|
689
|
+
False) # flag for using random STM resets (recommended, as model should learn transitions between different states)
|
690
|
+
random_resets_from = config.get('random_resets_from', None) # epoch from which random STM resets are started
|
691
|
+
random_resets_ratio = config.get('random_resets_ratio',
|
692
|
+
None) # ratio of random STM resets - 1.0 is "always reset", 0.0 is "no resets"
|
693
|
+
|
694
|
+
# 3. Reset stage step counter
|
695
|
+
self.stage_step = self._init_steps()
|
696
|
+
|
697
|
+
return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
|
698
|
+
|
699
|
+
def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
|
700
|
+
"""Start Memory Reinforcement Learning Curriculum."""
|
701
|
+
|
702
|
+
# 0. Set global epoch count for all stages
|
703
|
+
self.global_epochs_count = sum(stage['epochs'] for stage in curriculum_config)
|
704
|
+
|
705
|
+
# 1. Init DDP for distributed training mode
|
706
|
+
if self.use_ddp:
|
707
|
+
rank = int(os.environ['RANK'])
|
708
|
+
world_size = int(os.environ['WORLD_SIZE'])
|
709
|
+
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
710
|
+
self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index])
|
711
|
+
self.critic = DistributedDataParallel(self.critic, device_ids=[self.device.index])
|
712
|
+
|
713
|
+
# 2. Run each curriculum step based on config
|
714
|
+
for current_curriculum_step in curriculum_config:
|
715
|
+
# 3. Setup training config for curriculum step
|
716
|
+
epochs_config, random_resets_config = self._setup_curriculum_step(current_curriculum_step)
|
717
|
+
epochs, unfreeze_epoch = epochs_config
|
718
|
+
random_resets, random_resets_from, random_resets_ratio = random_resets_config
|
719
|
+
assert self.train_dataset is not None
|
720
|
+
print(f'Curriculum Steps Increased to {self.curriculum_steps}')
|
721
|
+
|
722
|
+
# 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
|
723
|
+
if unfreeze_epoch != 0:
|
724
|
+
self.actor.freeze_components()
|
725
|
+
|
726
|
+
# 5. Setup train DataLoader
|
727
|
+
if self.use_ddp:
|
728
|
+
train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
|
729
|
+
dataloader = DataLoader(
|
730
|
+
self.train_dataset,
|
731
|
+
batch_size=batch_size,
|
732
|
+
sampler=train_sampler,
|
733
|
+
pin_memory=True,
|
734
|
+
collate_fn=MrlCurriculumDataset.collate_mrl_batch,
|
735
|
+
)
|
736
|
+
else:
|
737
|
+
train_sampler = None
|
738
|
+
dataloader = DataLoader(
|
739
|
+
self.train_dataset,
|
740
|
+
batch_size=batch_size,
|
741
|
+
shuffle=True,
|
742
|
+
pin_memory=True,
|
743
|
+
collate_fn=MrlCurriculumDataset.collate_mrl_batch,
|
744
|
+
)
|
745
|
+
|
746
|
+
# 6. Run selected number of epochs for given curriculum stage
|
747
|
+
for epoch in range(epochs):
|
748
|
+
# 7. Increment global epoch
|
749
|
+
self.global_epoch += 1
|
750
|
+
# 8. Run "on epoch start" callbacks (log info, etc.)
|
751
|
+
for cb in self.callbacks:
|
752
|
+
cb.on_epoch_start(self.actor, epoch, epochs, current_curriculum_step, self.global_epoch,
|
753
|
+
self.global_epochs_count)
|
754
|
+
|
755
|
+
# 9. Reset steps counter for epoch
|
756
|
+
self.epoch_step = self._init_steps()
|
757
|
+
|
758
|
+
# 10. Set random STM resets ratio from selected epoch
|
759
|
+
if random_resets and random_resets_from <= epoch:
|
760
|
+
self.random_resets_ratio = random_resets_ratio
|
761
|
+
else:
|
762
|
+
self.random_resets_ratio = 1.0
|
763
|
+
|
764
|
+
# 11. Unfreeze all components before selected epoch
|
765
|
+
if epoch == unfreeze_epoch:
|
766
|
+
self.actor.unfreeze_components()
|
767
|
+
|
768
|
+
# 12. Set epoch for distributed sampler
|
769
|
+
if train_sampler is not None:
|
770
|
+
train_sampler.set_epoch(epoch)
|
771
|
+
|
772
|
+
# 13. Run reinforcement learning algorithms for current epoch
|
773
|
+
policy_loss, critic_loss = self.train_epoch(dataloader, epoch)
|
774
|
+
|
775
|
+
# 14. If evaluation dataset is provided, run evaluation steps
|
776
|
+
if self.eval_dataset:
|
777
|
+
should_stop_stage = self.evaluate(batch_size, epoch)
|
778
|
+
else:
|
779
|
+
should_stop_stage = False
|
780
|
+
|
781
|
+
# 15. Finally, run "on epoch end" callbacks (save models, etc.)
|
782
|
+
for cb in self.callbacks:
|
783
|
+
cb.on_epoch_end(self.actor, epoch, epochs, policy_loss, critic_loss, self.global_epoch,
|
784
|
+
self.global_epochs_count)
|
785
|
+
|
786
|
+
# 16. Synchronize TensorBoard writer
|
787
|
+
if self.writer:
|
788
|
+
self.writer.flush()
|
789
|
+
|
790
|
+
# 17. Synchronize devices in DDP mode
|
791
|
+
if self.use_ddp:
|
792
|
+
dist.barrier()
|
793
|
+
|
794
|
+
# 18. Finish curriculum stage if rewards are not increased or reached threshold point
|
795
|
+
if should_stop_stage:
|
796
|
+
break
|
797
|
+
|
798
|
+
# 19. Run "on_training_end" callbacks after each curriculum stage (they have own callbacks)
|
799
|
+
for cb in self.callbacks:
|
800
|
+
cb.on_training_end(self.actor, self.critic, current_curriculum_step)
|
801
|
+
|
802
|
+
# 20. Training end - finish processes after all curriculum stages
|
803
|
+
if self.use_ddp:
|
804
|
+
dist.destroy_process_group()
|
805
|
+
|
806
|
+
# 21. Close writer
|
807
|
+
if self.writer:
|
808
|
+
self.writer.close()
|