project-llm-trainer 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,686 @@
1
+ from typing import Tuple, List, Union, Callable, Optional
2
+ import gc
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.utils.data import Dataset
6
+ import torch.nn as nn
7
+ from itertools import islice
8
+
9
+ from llm_model import LlmModel, VlmModel
10
+
11
+ from .base_trainer import BaseTrainer
12
+ from .train_configs import TrainConfig
13
+ from .dataset import RLDataset
14
+ from .loss import PPOLoss
15
+ from .tools import TrainerTools
16
+ from .generate_utils import batch_generate
17
+ from .utils import (
18
+ autocast,
19
+ left_pad_sequence,
20
+ log_softmax,
21
+ masked_whiten,
22
+ disable_dropout_in_model,
23
+ calc_position_ids,
24
+ RunningMeanStd
25
+ )
26
+ from .partition_utils import unwrap_model_for_generation
27
+ from .log import Logger
28
+ from .checkpoint import (
29
+ save_checkpoint,
30
+ save_steps,
31
+ load_checkpoint
32
+ )
33
+ from .scheduler import (
34
+ LRScheduler,
35
+ WarmupCosineAnnealingLRScheduler,
36
+ CompositeLRScheduler,
37
+ NoneLRScheduler
38
+ )
39
+
40
+
41
+ class ValueModel(nn.Module):
42
+ def __init__(self, base_model: Union[LlmModel, VlmModel]):
43
+ super().__init__()
44
+ self.base_model = base_model
45
+ self.value_head = nn.Linear(base_model.config.hidden_size, 1, bias=True)
46
+ self.value_head.weight.data.normal_(mean=0.0, std=0.01)
47
+ self.value_head.bias.data.zero_()
48
+
49
+ def forward(self, *args, **kwargs) -> torch.Tensor:
50
+ outputs = self.base_model(*args, **kwargs)
51
+ # [batch_size, seq_len, hidden_size]
52
+ last_hidden_state = outputs['hidden_states']
53
+ # [batch_size, seq_len, 1]
54
+ values = self.value_head(last_hidden_state)
55
+ # [batch_size, seq_len]
56
+ return values.squeeze(-1)
57
+
58
+
59
+ class PolicyAndValueModelWrapper(nn.Module):
60
+ def __init__(self, policy_model: nn.Module, value_model: nn.Module):
61
+ super().__init__()
62
+ self.policy_model = policy_model
63
+ self.value_model = value_model
64
+
65
+ def forward(self, *args, **kwargs):
66
+ return self.policy_model(*args, **kwargs), self.value_model(*args, **kwargs)
67
+
68
+
69
+ class PPOTrainer(BaseTrainer):
70
+ """
71
+ reward_func(prompt_ids, complete_ids, answer_ids) -> scores
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ *,
77
+ train_config: TrainConfig,
78
+ reward_func: Callable[[List[torch.Tensor], torch.Tensor, List[Optional[torch.Tensor]]], List[float]],
79
+ eval_prompts: List[str]
80
+ ):
81
+ self.ppo_config = train_config.ppo_config
82
+
83
+ if self.ppo_config.normalize_rewards:
84
+ if self.ppo_config.normalize_method == 'RunningMeanStd':
85
+ self.reward_normalizer = RunningMeanStd(shape=()).to(TrainerTools().parallel.device)
86
+ else:
87
+ self.reward_normalizer = None
88
+
89
+ if self.ppo_config.whiten_rewards:
90
+ self.ppo_config.whiten_rewards = False
91
+ if TrainerTools().parallel.is_main_process:
92
+ Logger.std_log('WARN: ppo_config.normalize_rewards is enabled, ppo_config.whiten_rewards must be disabled.')
93
+
94
+ super().__init__(
95
+ train_config=train_config,
96
+ eval_prompts=eval_prompts,
97
+ gradient_accumulation_steps=self.ppo_config.gradient_accumulation_steps
98
+ )
99
+
100
+ self.reward_func = reward_func
101
+ self.ref_model = self._init_ref_model()
102
+
103
+ def _init_train_model_and_optim(self, initial_lr: float):
104
+ policy_model = self._new_model(self.train_config)
105
+ value_model = ValueModel(self._new_model(self.train_config))
106
+
107
+ if self.train_config.ds_config and self.train_config.ds_config.activation_checkpointing:
108
+ policy_model.gradient_checkpointing_enable()
109
+ value_model.base_model.gradient_checkpointing_enable()
110
+
111
+ train_model = PolicyAndValueModelWrapper(policy_model, value_model)
112
+
113
+ if self.train_config.init_state_dict:
114
+ policy_model.load_state_dict(self.train_config.init_state_dict, strict=False)
115
+ value_model.base_model.load_state_dict(self.train_config.init_state_dict, strict=False)
116
+ self.train_config.init_state_dict = None
117
+
118
+ if self.train_config.ppo_config.value_model_checkpoint:
119
+ value_model.load_state_dict(self.train_config.ppo_config.value_model_checkpoint)
120
+ self.train_config.ppo_config.value_model_checkpoint = {}
121
+
122
+ if TrainerTools().parallel.is_main_process:
123
+ for name, model in zip(['policy', 'value'], [policy_model, value_model]):
124
+ total_params = sum(p.numel() for p in model.parameters())
125
+ Logger.std_log(f"Total number of {name} model parameters: {total_params:,}")
126
+
127
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
128
+ Logger.std_log(f"Trainable number of {name} model parameters: {trainable_params:,}")
129
+
130
+ total_size_bytes = total_params * 4
131
+ total_size_mb = total_size_bytes / (1024 * 1024)
132
+ Logger.std_log(f"Total size of {name} model model: {total_size_mb:.2f} MB")
133
+
134
+ model, optim = TrainerTools().parallel.process(
135
+ model=train_model,
136
+ optimizer=self._config_optim(train_model, initial_lr),
137
+ kwargs=self.parallel_kwargs
138
+ )
139
+
140
+ return model, optim
141
+
142
+ def _config_optim(self, model, initial_lr):
143
+ optimizer_cls, use_lion_optim = self._get_optim_cls()
144
+
145
+ policy_config = self.train_config.optim_config
146
+ value_config = self.ppo_config.value_optim_config if self.ppo_config.value_optim_config else policy_config
147
+
148
+ no_decay_name_list = ["bias", "norm.weight"]
149
+
150
+ def get_param_groups(module, config, name_prefix):
151
+ current_betas = config.betas
152
+ current_weight_decay = config.weight_decay
153
+
154
+ if current_betas is None:
155
+ current_betas = (0.95, 0.98) if use_lion_optim else (0.9, 0.999)
156
+
157
+ if current_weight_decay is None:
158
+ current_weight_decay = 0.015 if use_lion_optim else 0.01
159
+
160
+ decay_params = []
161
+ no_decay_params = []
162
+
163
+ for name, param in module.named_parameters():
164
+ if not param.requires_grad:
165
+ continue
166
+ if any(nd in name for nd in no_decay_name_list):
167
+ no_decay_params.append(param)
168
+ else:
169
+ decay_params.append(param)
170
+
171
+ return [
172
+ {
173
+ "params": decay_params,
174
+ "weight_decay": current_weight_decay,
175
+ "lr": config.initial_lr,
176
+ "betas": current_betas,
177
+ "name": f"{name_prefix}_decay"
178
+ },
179
+ {
180
+ "params": no_decay_params,
181
+ "weight_decay": 0.0,
182
+ "lr": config.initial_lr,
183
+ "betas": current_betas,
184
+ "name": f"{name_prefix}_no_decay"
185
+ }
186
+ ]
187
+
188
+ optimizer_grouped_parameters = []
189
+ optimizer_grouped_parameters.extend(get_param_groups(model.policy_model, policy_config, "policy"))
190
+ optimizer_grouped_parameters.extend(get_param_groups(model.value_model, value_config, "value"))
191
+
192
+ default_betas = policy_config.betas if policy_config.betas else ((0.95, 0.98) if use_lion_optim else (0.9, 0.999))
193
+ default_weight_decay = policy_config.weight_decay if policy_config.weight_decay else (0.015 if use_lion_optim else 0.01)
194
+
195
+ return optimizer_cls(
196
+ optimizer_grouped_parameters,
197
+ lr=policy_config.initial_lr,
198
+ betas=default_betas,
199
+ weight_decay=default_weight_decay
200
+ )
201
+
202
+ def _init_lr_scheduler(self, initial_lr: float, optimizer) -> LRScheduler:
203
+ policy_config = self.train_config.optim_config
204
+ value_config = self.ppo_config.value_optim_config
205
+
206
+ if value_config is None:
207
+ return super()._init_lr_scheduler(initial_lr, optimizer)
208
+
209
+ schedulers = []
210
+
211
+ def create_scheduler(config, group_indices, need_log):
212
+ initial_lr = config.initial_lr
213
+ if config.enable_lr_scheduler:
214
+ warmup_iters = config.warmup_iters
215
+ min_lr = config.min_lr
216
+ max_lr = config.max_lr
217
+ cosine_annealing_period = config.cosine_annealing_period
218
+ cosine_annealing_period_mul = config.cosine_annealing_period_mul
219
+
220
+ return WarmupCosineAnnealingLRScheduler(
221
+ optimizer=optimizer,
222
+ warmup_iters=warmup_iters,
223
+ initial_lr=initial_lr,
224
+ min_lr=min_lr,
225
+ max_lr=max_lr,
226
+ cosine_annealing_period=cosine_annealing_period,
227
+ cosine_annealing_period_mul=cosine_annealing_period_mul,
228
+ param_group_indices=group_indices,
229
+ need_log=TrainerTools().parallel.is_main_process and need_log
230
+ )
231
+ else:
232
+ return NoneLRScheduler(initial_lr)
233
+
234
+ schedulers.append(create_scheduler(policy_config, [0, 1], True))
235
+ schedulers.append(create_scheduler(value_config, [2, 3], False))
236
+
237
+ return CompositeLRScheduler(schedulers)
238
+
239
+ def _init_ref_model(self):
240
+ ref_model = self._new_model(self.train_config)
241
+
242
+ if self.train_config.ppo_config.ref_model_checkpoint:
243
+ ref_model.load_state_dict(self.train_config.ppo_config.ref_model_checkpoint)
244
+ self.train_config.ppo_config.ref_model_checkpoint = {}
245
+
246
+ ref_model.eval()
247
+ for param in ref_model.parameters():
248
+ param.requires_grad = False
249
+
250
+ ref_model, _ = TrainerTools().parallel.process(
251
+ model=ref_model,
252
+ optimizer=None,
253
+ kwargs=self._init_ref_model_args(),
254
+ save_instance=False
255
+ )
256
+
257
+ return ref_model
258
+
259
+ def _new_model(self, train_config: TrainConfig):
260
+ model = super()._new_model(train_config)
261
+ disable_dropout_in_model(model)
262
+ return model
263
+
264
+ def _init_loss(self):
265
+ ppo_config = self.train_config.ppo_config
266
+ criterion = PPOLoss(
267
+ clip_eps=ppo_config.clip_eps,
268
+ vf_coef=ppo_config.vf_coef
269
+ )
270
+ return criterion, None
271
+
272
+ def _load_train_model_checkpoint(self):
273
+ load_checkpoint(
274
+ self.train_model,
275
+ optimizer=self.optimizer,
276
+ device=TrainerTools().parallel.device,
277
+ extra_module=self.reward_normalizer
278
+ )
279
+
280
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
281
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
282
+ data_loader_kwargs.update({"collate_fn": lambda x: x})
283
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
284
+
285
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
286
+ file_path = self.train_config.file_dataset[file_idx]
287
+ return RLDataset(file_path), file_path
288
+
289
+ def _calc_loss(self, inputs, attention_mask, logits, labels): ...
290
+
291
+ def _check_eval_model(self, eval_model):
292
+ return eval_model.policy_model
293
+
294
+ def _compute_advantages_and_returns(
295
+ self,
296
+ rewards: torch.Tensor,
297
+ values: torch.Tensor,
298
+ last_values: torch.Tensor,
299
+ completion_mask: torch.Tensor,
300
+ dones: torch.Tensor,
301
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
302
+ gamma, lam = self.train_config.ppo_config.gamma, self.train_config.ppo_config.lam
303
+ advantages_reversed = []
304
+ last_gae_lam = 0
305
+ seq_len = rewards.size(1)
306
+
307
+ values = values * completion_mask
308
+ for t in reversed(range(seq_len)):
309
+ if t == seq_len - 1:
310
+ next_values = torch.where(dones, 0.0, last_values)
311
+ else:
312
+ next_values = values[:, t + 1]
313
+
314
+ delta = rewards[:, t] + gamma * next_values - values[:, t]
315
+ last_gae_lam = delta + gamma * lam * last_gae_lam * completion_mask[:, t]
316
+ advantages_reversed.append(last_gae_lam)
317
+
318
+ advantages = torch.stack(advantages_reversed[::-1], dim=1)
319
+ returns = advantages + values
320
+
321
+ return advantages * completion_mask, returns * completion_mask
322
+
323
+ def _generate_rollout_data(self, batch_data: List[dict]) -> dict:
324
+ ppo_config = self.train_config.ppo_config
325
+ device = TrainerTools().parallel.device
326
+ pad_token_id = TrainerTools().tokenizer.pad
327
+ eos_token_id = TrainerTools().tokenizer.end
328
+
329
+ prompts = [item["prompt"] for item in batch_data]
330
+ answers = [item["answer"] for item in batch_data]
331
+
332
+ prompt_ids = left_pad_sequence(prompts, padding_value=pad_token_id)
333
+ prompt_ids = prompt_ids.to(device)
334
+ prompt_masks = (prompt_ids != pad_token_id)
335
+ prompt_len = prompt_ids.shape[1]
336
+
337
+ max_new_tokens = ppo_config.gen_max_seq_len - prompt_len
338
+ if max_new_tokens <= 0:
339
+ raise ValueError(
340
+ f"Prompt length ({prompt_len}) >= gen_max_seq_len ({ppo_config.gen_max_seq_len}). "
341
+ f"Cannot generate any tokens. Please increase gen_max_seq_len or reduce dataset_block_size."
342
+ )
343
+
344
+ with torch.no_grad():
345
+ with unwrap_model_for_generation(self.train_model) as unwrapped_model:
346
+ full_ids, logitss = batch_generate(
347
+ model=unwrapped_model.policy_model,
348
+ tokens=prompt_ids,
349
+ attention_mask=prompt_masks,
350
+ max_new_tokens=max_new_tokens,
351
+ temperature=ppo_config.gen_temperature,
352
+ k=ppo_config.gen_k,
353
+ p=ppo_config.gen_p,
354
+ suppress_tokens=ppo_config.gen_suppress_tokens,
355
+ device=device
356
+ )
357
+
358
+ completion_ids = full_ids[:, prompt_len:]
359
+ full_attention_mask = (full_ids != pad_token_id)
360
+ full_position_ids = calc_position_ids(full_attention_mask)
361
+
362
+ old_log_probs = log_softmax(logitss.float(), completion_ids)
363
+ del logitss
364
+
365
+ with autocast(TrainerTools().parallel.device_type):
366
+ value_output = unwrapped_model.value_model(
367
+ full_ids,
368
+ attention_mask=full_attention_mask,
369
+ position_ids=full_position_ids
370
+ )
371
+
372
+ with unwrap_model_for_generation(self.ref_model) as unwrapped_ref_model:
373
+ ref_outputs = unwrapped_ref_model(
374
+ full_ids,
375
+ attention_mask=full_attention_mask,
376
+ position_ids=full_position_ids
377
+ )
378
+ ref_logits_full = ref_outputs['logits']
379
+
380
+ ref_logits_completion = ref_logits_full[:, prompt_len - 1: -1]
381
+ ref_log_probs_completion = log_softmax(ref_logits_completion.float(), completion_ids)
382
+ del ref_outputs, ref_logits_full, ref_logits_completion
383
+
384
+ dones = torch.any(completion_ids == eos_token_id, dim=1)
385
+ rewards = torch.zeros_like(completion_ids, dtype=torch.float32, device=device)
386
+ completion_mask = (completion_ids != pad_token_id)
387
+
388
+ if ppo_config.kl_beta > 0.0:
389
+ logr = ref_log_probs_completion - old_log_probs
390
+ kl = -logr if ppo_config.kl_estimator == "k1" else (logr.exp() - 1) - logr
391
+ kl_rewards = -ppo_config.kl_beta * kl
392
+ rewards += kl_rewards * completion_mask
393
+
394
+ env_rewards_tensor = torch.tensor(
395
+ self.reward_func(prompts, completion_ids, answers),
396
+ dtype=torch.float32,
397
+ device=device
398
+ )
399
+
400
+ if ppo_config.missing_eos_penalty is not None:
401
+ env_rewards_tensor[~dones] -= ppo_config.missing_eos_penalty
402
+
403
+ raw_reward_mean = env_rewards_tensor.mean()
404
+
405
+ if self.train_config.ppo_config.normalize_rewards:
406
+ if self.reward_normalizer:
407
+ self.reward_normalizer.update(env_rewards_tensor)
408
+ env_rewards_tensor = self.reward_normalizer(env_rewards_tensor)
409
+ else:
410
+ batch_std = env_rewards_tensor.std()
411
+ if torch.isnan(batch_std) or batch_std < 1e-8:
412
+ batch_std = 1.0
413
+
414
+ env_rewards_tensor = (env_rewards_tensor - raw_reward_mean) / batch_std
415
+
416
+ last_token_indices = completion_mask.sum(dim=1) - 1
417
+ valid_indices_mask = last_token_indices >= 0
418
+
419
+ if valid_indices_mask.any():
420
+ valid_batch_indices = torch.arange(prompt_ids.size(0), device=device)[valid_indices_mask]
421
+ valid_last_token_indices = last_token_indices[valid_indices_mask]
422
+ valid_env_rewards = env_rewards_tensor[valid_indices_mask]
423
+ rewards[valid_batch_indices, valid_last_token_indices] += valid_env_rewards
424
+
425
+ return {
426
+ 'prompt_ids': prompt_ids.detach(),
427
+ 'completion_ids': completion_ids.detach(),
428
+ 'old_log_probs': old_log_probs.detach(),
429
+ 'values': value_output.detach(),
430
+ 'rewards': rewards.detach(),
431
+ 'env_rewards': raw_reward_mean.detach(),
432
+ 'dones': dones.detach(),
433
+ }
434
+
435
+ def _ppo_learning_phase(self, rollout_data: dict):
436
+ ppo_config = self.train_config.ppo_config
437
+
438
+ prompt_ids: torch.Tensor = rollout_data['prompt_ids']
439
+ completion_ids: torch.Tensor = rollout_data['completion_ids']
440
+ old_log_probs: torch.Tensor = rollout_data['old_log_probs']
441
+ old_values: torch.Tensor = rollout_data['values']
442
+ rewards: torch.Tensor = rollout_data['rewards']
443
+ dones: torch.Tensor = rollout_data['dones']
444
+
445
+ prompt_len = prompt_ids.shape[1]
446
+ batch_size = prompt_ids.shape[0]
447
+
448
+ values_for_gae = old_values[:, prompt_len - 1: -1]
449
+ last_values = old_values[:, -1]
450
+ assert values_for_gae.shape[1] == completion_ids.shape[1]
451
+
452
+ completion_mask: torch.Tensor = (completion_ids != TrainerTools().tokenizer.pad)
453
+
454
+ if ppo_config.whiten_rewards:
455
+ rewards = masked_whiten(rewards, completion_mask, shift_mean=False)
456
+ rewards = torch.masked_fill(rewards, ~completion_mask, 0.0)
457
+
458
+ advantages, returns = self._compute_advantages_and_returns(
459
+ rewards, values_for_gae, last_values, completion_mask, dones
460
+ )
461
+
462
+ advantages_whitened = masked_whiten(advantages, completion_mask, shift_mean=True)
463
+ advantages_whitened = torch.masked_fill(advantages_whitened, ~completion_mask, 0.0)
464
+
465
+ input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
466
+ attention_mask = (input_ids != TrainerTools().tokenizer.pad)
467
+
468
+ ppo_stats = {
469
+ "loss": 0.0, "moe_aux_loss": 0.0, "actor_loss": 0.0,
470
+ "value_loss": 0.0, "approx_kl": 0.0, "clip_frac": 0.0
471
+ }
472
+
473
+ grad_acc_steps = max(1, self.gradient_accumulation_steps)
474
+ ppo_batch_size = ppo_config.ppo_batch_size
475
+ num_micro_batches = (batch_size + ppo_batch_size - 1) // ppo_batch_size
476
+ total_micro_batches_processed = 0
477
+
478
+ for ppo_epoch in range(ppo_config.ppo_epochs):
479
+ indices = torch.randperm(batch_size, device=TrainerTools().parallel.device)
480
+
481
+ for i in range(0, batch_size, ppo_batch_size):
482
+ mini_batch_indices = indices[i:i + ppo_batch_size]
483
+ micro_batch_idx = i // ppo_batch_size
484
+ is_last_micro_batch = (micro_batch_idx == num_micro_batches - 1)
485
+ need_update_grad = ((micro_batch_idx + 1) % grad_acc_steps == 0) or is_last_micro_batch
486
+
487
+ if is_last_micro_batch:
488
+ remainder = (micro_batch_idx + 1) % grad_acc_steps
489
+ actual_acc_steps = remainder if remainder > 0 else grad_acc_steps
490
+ else:
491
+ actual_acc_steps = grad_acc_steps
492
+
493
+ if TrainerTools().parallel.parallel_train:
494
+ self.train_model.require_backward_grad_sync = need_update_grad
495
+
496
+ mb_input_ids = input_ids[mini_batch_indices]
497
+ mb_attention_mask = attention_mask[mini_batch_indices]
498
+ mb_completion_ids = completion_ids[mini_batch_indices]
499
+ mb_completion_mask = completion_mask[mini_batch_indices]
500
+ mb_old_log_probs = old_log_probs[mini_batch_indices]
501
+ mb_values = values_for_gae[mini_batch_indices]
502
+ mb_returns = returns[mini_batch_indices]
503
+ mb_advantages = advantages_whitened[mini_batch_indices]
504
+ mb_position_ids = calc_position_ids(mb_attention_mask)
505
+
506
+ with autocast(TrainerTools().parallel.device_type):
507
+ policy_output, value_output = self.train_model(
508
+ mb_input_ids,
509
+ attention_mask=mb_attention_mask,
510
+ position_ids=mb_position_ids
511
+ )
512
+
513
+ target_dtype = policy_output['logits'].dtype
514
+ mb_old_log_probs = mb_old_log_probs.to(target_dtype)
515
+ mb_values = mb_values.to(target_dtype)
516
+ mb_returns = mb_returns.to(target_dtype)
517
+ mb_advantages = mb_advantages.to(target_dtype)
518
+
519
+ logits_completion = policy_output['logits'][:, prompt_len - 1: -1]
520
+ current_log_probs = log_softmax(logits_completion, mb_completion_ids)
521
+ current_values = value_output[:, prompt_len - 1: -1]
522
+
523
+ loss, actor_loss, value_loss, approx_kl, clip_frac = self.criterion(
524
+ log_probs=current_log_probs,
525
+ old_log_probs=mb_old_log_probs,
526
+ values=current_values,
527
+ old_values=mb_values,
528
+ returns=mb_returns,
529
+ advantages=mb_advantages,
530
+ mask=mb_completion_mask
531
+ )
532
+
533
+ aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
534
+ if policy_output.get('aux_loss') and self.train_config.loss_config.aux_loss_coef:
535
+ aux_loss = self.train_config.loss_config.aux_loss_coef * policy_output['aux_loss']
536
+
537
+ total_loss = loss + aux_loss
538
+ scaled_total_loss = total_loss / actual_acc_steps
539
+ self._backward_loss(scaled_total_loss)
540
+
541
+ ppo_stats["loss"] += total_loss.detach().item()
542
+ ppo_stats["moe_aux_loss"] += aux_loss.detach().item()
543
+ ppo_stats["actor_loss"] += actor_loss.detach().item()
544
+ ppo_stats["value_loss"] += value_loss.detach().item()
545
+ ppo_stats["approx_kl"] += approx_kl.detach().item()
546
+ ppo_stats["clip_frac"] += clip_frac.detach().item()
547
+ total_micro_batches_processed += 1
548
+
549
+ if need_update_grad:
550
+ self._apply_grad_clipping()
551
+ self._apply_step()
552
+
553
+ if total_micro_batches_processed > 0:
554
+ for key in ppo_stats:
555
+ ppo_stats[key] /= total_micro_batches_processed
556
+
557
+ return ppo_stats
558
+
559
+ def train(self):
560
+ for epoch in range(self.resume_epoch, self.train_config.n_epochs):
561
+ file_count = len(self.train_config.file_dataset)
562
+ start_file_idx = self.resume_file_idx if epoch == self.resume_epoch else 0
563
+
564
+ for file_idx in range(start_file_idx, file_count):
565
+ dataset, file_path = self._create_dataset(file_idx)
566
+ train_data_loader = TrainerTools().parallel.process_dataloader(
567
+ dataset=dataset,
568
+ data_loader_kwargs=self.data_loader_kwargs,
569
+ sampler_kwargs=self.sampler_kwargs
570
+ )
571
+
572
+ last_ckpt_batch = 0
573
+ batch_count_per_file = len(train_data_loader)
574
+
575
+ TrainerTools().parallel.on_epoch_start(epoch)
576
+ self._on_file_start(epoch, file_path)
577
+
578
+ skip_batches = 0
579
+ if epoch == self.resume_epoch and file_idx == self.resume_file_idx:
580
+ skip_batches = self.resume_batch_idx
581
+ if skip_batches > 0 and TrainerTools().parallel.is_main_process:
582
+ Logger.std_log(f"Fast forwarding {skip_batches} batches in {file_path}...")
583
+
584
+ data_iterator = iter(train_data_loader)
585
+ if skip_batches > 0:
586
+ data_iterator = islice(data_iterator, skip_batches, None)
587
+ last_ckpt_batch = skip_batches
588
+
589
+ for batch, batch_data in enumerate(data_iterator):
590
+ batch = skip_batches + batch
591
+
592
+ rollout_data = self._generate_rollout_data(batch_data)
593
+ torch.cuda.empty_cache()
594
+
595
+ try:
596
+ ppo_stats = self._ppo_learning_phase(rollout_data)
597
+
598
+ stats_tensor = torch.tensor([
599
+ ppo_stats['loss'],
600
+ ppo_stats['moe_aux_loss'],
601
+ ppo_stats['actor_loss'],
602
+ ppo_stats['value_loss'],
603
+ ppo_stats['approx_kl'],
604
+ ppo_stats['clip_frac'],
605
+ rollout_data['env_rewards'].item()
606
+ ], device=TrainerTools().parallel.device)
607
+
608
+ if TrainerTools().parallel.parallel_train:
609
+ dist.all_reduce(stats_tensor, op=dist.ReduceOp.AVG)
610
+
611
+ ppo_stats['loss'] = stats_tensor[0].item()
612
+ ppo_stats['moe_aux_loss'] = stats_tensor[1].item()
613
+ ppo_stats['actor_loss'] = stats_tensor[2].item()
614
+ ppo_stats['value_loss'] = stats_tensor[3].item()
615
+ ppo_stats['approx_kl'] = stats_tensor[4].item()
616
+ ppo_stats['clip_frac'] = stats_tensor[5].item()
617
+ reward_value = stats_tensor[6].item()
618
+
619
+ self._log(
620
+ keys={
621
+ 'epoch': epoch,
622
+ 'file': f'{file_idx + 1}/{file_count}',
623
+ 'batch': f'{batch + 1}/{batch_count_per_file}'
624
+ },
625
+ values={
626
+ 'loss': ppo_stats['loss'],
627
+ 'moe_aux_loss': ppo_stats['moe_aux_loss'],
628
+ 'actor_loss': ppo_stats['actor_loss'],
629
+ 'value_loss': ppo_stats['value_loss'],
630
+ 'approx_kl': ppo_stats['approx_kl'],
631
+ 'clip_frac': ppo_stats['clip_frac'],
632
+ 'rewards': reward_value
633
+ }
634
+ )
635
+
636
+ if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
637
+ save_checkpoint(
638
+ model=self.train_model,
639
+ optimizer=self.optimizer,
640
+ extra_module=self.reward_normalizer
641
+ )
642
+ save_steps(
643
+ epoch=epoch,
644
+ file_idx=file_idx,
645
+ batch_idx=batch + 1,
646
+ lr_scheduler=self.lr_scheduler
647
+ )
648
+
649
+ last_ckpt_batch = batch
650
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
651
+
652
+ del rollout_data
653
+ except Exception as e:
654
+ self._on_exception(e, epoch, batch)
655
+
656
+ # 一个文件训练结束后,清理内存
657
+ del train_data_loader
658
+ del dataset
659
+ if hasattr(TrainerTools().parallel, '_sampler'):
660
+ TrainerTools().parallel._sampler = None
661
+
662
+ gc.collect()
663
+ torch.cuda.empty_cache()
664
+
665
+ # end epoch
666
+
667
+ # reset resume state
668
+ self.resume_file_idx = 0
669
+ self.resume_batch_idx = 0
670
+
671
+ save_checkpoint(
672
+ model=self.train_model,
673
+ optimizer=self.optimizer,
674
+ extra_module=self.reward_normalizer
675
+ )
676
+ save_steps(
677
+ epoch=epoch + 1,
678
+ file_idx=0,
679
+ batch_idx=0,
680
+ lr_scheduler=self.lr_scheduler
681
+ )
682
+
683
+ TrainerTools().parallel.on_epoch_end(epoch)
684
+ self._on_epoch_end(tag=f'epoch:{epoch}')
685
+
686
+ TrainerTools().parallel.destroy()