project-llm-trainer 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,521 @@
1
+ from typing import Tuple, List, Union, Callable, Optional
2
+ import gc
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.utils.data import Dataset
6
+ import torch.nn as nn
7
+
8
+ from llm_model import LlmModel, VlmModel
9
+
10
+ from .base_trainer import BaseTrainer
11
+ from .train_configs import TrainConfig
12
+ from .dataset import RLDataset
13
+ from .loss import PPOLoss
14
+ from .tools import TrainerTools
15
+ from .generate_utils import batch_generate
16
+ from .utils import (
17
+ autocast,
18
+ left_pad_sequence,
19
+ log_softmax,
20
+ masked_whiten,
21
+ disable_dropout_in_model,
22
+ calc_position_ids
23
+ )
24
+ from .partition_utils import unwrap_model_for_generation
25
+ from .log import Logger
26
+ from .checkpoint import (
27
+ save_checkpoint,
28
+ save_steps,
29
+ )
30
+
31
+
32
+ class ValueModel(nn.Module):
33
+ def __init__(self, base_model: Union[LlmModel, VlmModel]):
34
+ super().__init__()
35
+ self.base_model = base_model
36
+ self.value_head = nn.Linear(base_model.config.hidden_size, 1, bias=True)
37
+ self.value_head.weight.data.normal_(mean=0.0, std=0.01)
38
+ self.value_head.bias.data.zero_()
39
+
40
+ def forward(self, *args, **kwargs) -> torch.Tensor:
41
+ outputs = self.base_model(*args, **kwargs)
42
+ # [batch_size, seq_len, hidden_size]
43
+ last_hidden_state = outputs['hidden_states']
44
+ # [batch_size, seq_len, 1]
45
+ values = self.value_head(last_hidden_state)
46
+ # [batch_size, seq_len]
47
+ return values.squeeze(-1)
48
+
49
+
50
+ class PolicyAndValueModelWrapper(nn.Module):
51
+ def __init__(self, policy_model: nn.Module, value_model: nn.Module):
52
+ super().__init__()
53
+ self.policy_model = policy_model
54
+ self.value_model = value_model
55
+
56
+ def forward(self, *args, **kwargs):
57
+ return self.policy_model(*args, **kwargs), self.value_model(*args, **kwargs)
58
+
59
+
60
+ class PPOTrainer(BaseTrainer):
61
+ """
62
+ reward_func(prompt_ids, complete_ids, answer_ids) -> scores
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ *,
68
+ train_config: TrainConfig,
69
+ reward_func: Callable[[List[torch.Tensor], torch.Tensor, List[Optional[torch.Tensor]]], List[float]],
70
+ eval_prompts: List[str]
71
+ ):
72
+ self.ppo_config = train_config.ppo_config
73
+
74
+ super().__init__(
75
+ train_config=train_config,
76
+ eval_prompts=eval_prompts,
77
+ gradient_accumulation_steps=self.ppo_config.gradient_accumulation_steps
78
+ )
79
+ self.reward_func = reward_func
80
+
81
+ self.ref_model = self._init_ref_model()
82
+
83
+ if self.train_config.ppo_config.normalize_rewards and self.train_config.ppo_config.whiten_rewards:
84
+ self.train_config.ppo_config.whiten_rewards = False
85
+ if TrainerTools().parallel.is_main_process:
86
+ Logger.std_log('WARN: ppo_config.normalize_rewards is enabled, ppo_config.whiten_rewards must be disabled.')
87
+
88
+ def _init_train_model_and_optim(self, initial_lr: float):
89
+ policy_model = self._new_model(self.train_config)
90
+ value_model = ValueModel(self._new_model(self.train_config))
91
+ train_model = PolicyAndValueModelWrapper(policy_model, value_model)
92
+
93
+ if self.train_config.init_state_dict:
94
+ policy_model.load_state_dict(self.train_config.init_state_dict)
95
+ value_model.base_model.load_state_dict(self.train_config.init_state_dict)
96
+ self.train_config.init_state_dict = None
97
+
98
+ if self.train_config.ppo_config.value_model_checkpoint:
99
+ value_model.load_state_dict(self.train_config.ppo_config.value_model_checkpoint)
100
+ self.train_config.ppo_config.value_model_checkpoint = {}
101
+
102
+ if TrainerTools().parallel.is_main_process:
103
+ for name, model in zip(['policy', 'value'], [policy_model, value_model]):
104
+ total_params = sum(p.numel() for p in model.parameters())
105
+ Logger.std_log(f"Total number of {name} model parameters: {total_params:,}")
106
+
107
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
108
+ Logger.std_log(f"Trainable number of {name} model parameters: {trainable_params:,}")
109
+
110
+ total_size_bytes = total_params * 4
111
+ total_size_mb = total_size_bytes / (1024 * 1024)
112
+ Logger.std_log(f"Total size of {name} model model: {total_size_mb:.2f} MB")
113
+
114
+ model, optim = TrainerTools().parallel.process(
115
+ model=train_model,
116
+ optimizer=self._config_optim(train_model, initial_lr),
117
+ kwargs=self.parallel_kwargs
118
+ )
119
+
120
+ return model, optim
121
+
122
+ def _init_ref_model(self):
123
+ ref_model = self._new_model(self.train_config)
124
+
125
+ if self.train_config.ppo_config.ref_model_checkpoint:
126
+ ref_model.load_state_dict(self.train_config.ppo_config.ref_model_checkpoint)
127
+ self.train_config.ppo_config.ref_model_checkpoint = {}
128
+
129
+ ref_model.eval()
130
+ for param in ref_model.parameters():
131
+ param.requires_grad = False
132
+
133
+ ref_model, _ = TrainerTools().parallel.process(
134
+ model=ref_model,
135
+ optimizer=None,
136
+ kwargs=self._init_ref_model_args(),
137
+ save_instance=False
138
+ )
139
+
140
+ return ref_model
141
+
142
+ def _new_model(self, train_config: TrainConfig):
143
+ model = super()._new_model(train_config)
144
+ disable_dropout_in_model(model)
145
+ return model
146
+
147
+ def _init_loss(self):
148
+ ppo_config = self.train_config.ppo_config
149
+ criterion = PPOLoss(
150
+ clip_eps=ppo_config.clip_eps,
151
+ vf_coef=ppo_config.vf_coef
152
+ )
153
+ return criterion, None
154
+
155
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
156
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
157
+ data_loader_kwargs.update({"collate_fn": lambda x: x})
158
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
159
+
160
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
161
+ file_path = self.train_config.file_dataset[file_idx]
162
+ return RLDataset(file_path), file_path
163
+
164
+ def _calc_loss(self, inputs, attention_mask, logits, labels): ...
165
+
166
+ def _check_eval_model(self, eval_model):
167
+ return eval_model.policy_model
168
+
169
+ def _compute_advantages_and_returns(
170
+ self,
171
+ rewards: torch.Tensor,
172
+ values: torch.Tensor,
173
+ last_values: torch.Tensor,
174
+ completion_mask: torch.Tensor,
175
+ dones: torch.Tensor,
176
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
177
+ gamma, lam = self.train_config.ppo_config.gamma, self.train_config.ppo_config.lam
178
+ advantages_reversed = []
179
+ last_gae_lam = 0
180
+ seq_len = rewards.size(1)
181
+
182
+ values = values * completion_mask
183
+ for t in reversed(range(seq_len)):
184
+ if t == seq_len - 1:
185
+ next_values = torch.where(dones, 0.0, last_values)
186
+ else:
187
+ next_values = values[:, t + 1]
188
+
189
+ delta = rewards[:, t] + gamma * next_values - values[:, t]
190
+ last_gae_lam = delta + gamma * lam * last_gae_lam * completion_mask[:, t]
191
+ advantages_reversed.append(last_gae_lam)
192
+
193
+ advantages = torch.stack(advantages_reversed[::-1], dim=1)
194
+ returns = advantages + values
195
+
196
+ return advantages * completion_mask, returns * completion_mask
197
+
198
+ def _generate_rollout_data(self, batch_data: List[dict]) -> dict:
199
+ ppo_config = self.train_config.ppo_config
200
+ device = TrainerTools().parallel.device
201
+ pad_token_id = TrainerTools().tokenizer.pad
202
+ eos_token_id = TrainerTools().tokenizer.end
203
+
204
+ prompts = [item["prompt"] for item in batch_data]
205
+ answers = [item["answer"] for item in batch_data]
206
+
207
+ prompt_ids = left_pad_sequence(prompts, padding_value=pad_token_id)
208
+ prompt_ids = prompt_ids.to(device)
209
+ prompt_masks = (prompt_ids != pad_token_id)
210
+ prompt_len = prompt_ids.shape[1]
211
+
212
+ with torch.no_grad():
213
+ with unwrap_model_for_generation(self.train_model) as unwrapped_model:
214
+ full_ids, logitss = batch_generate(
215
+ model=unwrapped_model.policy_model,
216
+ tokens=prompt_ids,
217
+ attention_mask=prompt_masks,
218
+ max_new_tokens=ppo_config.gen_max_new_tokens,
219
+ temperature=ppo_config.gen_temperature,
220
+ k=ppo_config.gen_k,
221
+ p=ppo_config.gen_p,
222
+ suppress_tokens=ppo_config.gen_suppress_tokens,
223
+ device=device
224
+ )
225
+ completion_ids = full_ids[:, prompt_len:]
226
+ full_attention_mask = (full_ids != pad_token_id)
227
+ full_position_ids = calc_position_ids(full_attention_mask)
228
+
229
+ with autocast(TrainerTools().parallel.device_type):
230
+ value_output = unwrapped_model.value_model(
231
+ full_ids,
232
+ attention_mask=full_attention_mask,
233
+ position_ids=full_position_ids
234
+ )
235
+
236
+ old_log_probs = log_softmax(logitss.float(), completion_ids)
237
+
238
+ with unwrap_model_for_generation(self.ref_model) as unwrapped_ref_model:
239
+ ref_outputs = unwrapped_ref_model(
240
+ full_ids,
241
+ attention_mask=full_attention_mask,
242
+ position_ids=full_position_ids
243
+ )
244
+ ref_logits_full = ref_outputs['logits']
245
+
246
+ ref_logits_completion = ref_logits_full[:, prompt_len - 1: -1]
247
+ ref_log_probs_completion = log_softmax(ref_logits_completion.float(), completion_ids)
248
+
249
+ dones = torch.any(completion_ids == eos_token_id, dim=1)
250
+ rewards = torch.zeros_like(completion_ids, dtype=torch.float32, device=device)
251
+ completion_mask = (completion_ids != pad_token_id)
252
+
253
+ if ppo_config.kl_beta > 0.0:
254
+ logr = ref_log_probs_completion - old_log_probs
255
+ kl = -logr if ppo_config.kl_estimator == "k1" else (logr.exp() - 1) - logr
256
+ kl_rewards = -ppo_config.kl_beta * kl
257
+ rewards += kl_rewards * completion_mask
258
+
259
+ env_rewards_tensor = torch.tensor(
260
+ self.reward_func(prompts, completion_ids, answers),
261
+ dtype=torch.float32,
262
+ device=device
263
+ )
264
+
265
+ if ppo_config.missing_eos_penalty is not None:
266
+ env_rewards_tensor[~dones] -= ppo_config.missing_eos_penalty
267
+
268
+ raw_reward_mean = env_rewards_tensor.mean()
269
+ if self.train_config.ppo_config.normalize_rewards:
270
+ batch_std = env_rewards_tensor.std()
271
+ if torch.isnan(batch_std) or batch_std < 1e-8:
272
+ batch_std = 1.0
273
+
274
+ env_rewards_tensor = (env_rewards_tensor - raw_reward_mean) / batch_std
275
+
276
+ last_token_indices = completion_mask.sum(dim=1) - 1
277
+ valid_indices_mask = last_token_indices >= 0
278
+
279
+ if valid_indices_mask.any():
280
+ valid_batch_indices = torch.arange(prompt_ids.size(0), device=device)[valid_indices_mask]
281
+ valid_last_token_indices = last_token_indices[valid_indices_mask]
282
+ valid_env_rewards = env_rewards_tensor[valid_indices_mask]
283
+ rewards[valid_batch_indices, valid_last_token_indices] += valid_env_rewards
284
+
285
+ return {
286
+ 'prompt_ids': prompt_ids.detach(),
287
+ 'completion_ids': completion_ids.detach(),
288
+ 'old_log_probs': old_log_probs.detach(),
289
+ 'values': value_output.detach(),
290
+ 'rewards': rewards.detach(),
291
+ 'env_rewards': raw_reward_mean.detach(),
292
+ 'dones': dones.detach(),
293
+ }
294
+
295
+ def _ppo_learning_phase(self, rollout_data: dict):
296
+ ppo_config = self.train_config.ppo_config
297
+
298
+ prompt_ids: torch.Tensor = rollout_data['prompt_ids']
299
+ completion_ids: torch.Tensor = rollout_data['completion_ids']
300
+ old_log_probs: torch.Tensor = rollout_data['old_log_probs']
301
+ old_values: torch.Tensor = rollout_data['values']
302
+ rewards: torch.Tensor = rollout_data['rewards']
303
+ dones: torch.Tensor = rollout_data['dones']
304
+
305
+ prompt_len = prompt_ids.shape[1]
306
+ batch_size = prompt_ids.shape[0]
307
+
308
+ values_for_gae = old_values[:, prompt_len - 1: -1]
309
+ last_values = old_values[:, -1]
310
+ assert values_for_gae.shape[1] == completion_ids.shape[1]
311
+
312
+ completion_mask: torch.Tensor = (completion_ids != TrainerTools().tokenizer.pad)
313
+
314
+ if ppo_config.whiten_rewards:
315
+ rewards = masked_whiten(rewards, completion_mask, shift_mean=False)
316
+ rewards = torch.masked_fill(rewards, ~completion_mask, 0.0)
317
+
318
+ advantages, returns = self._compute_advantages_and_returns(
319
+ rewards, values_for_gae, last_values, completion_mask, dones
320
+ )
321
+
322
+ advantages_whitened = masked_whiten(advantages, completion_mask, shift_mean=True)
323
+ advantages_whitened = torch.masked_fill(advantages_whitened, ~completion_mask, 0.0)
324
+
325
+ input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
326
+ attention_mask = (input_ids != TrainerTools().tokenizer.pad)
327
+
328
+ ppo_stats = {
329
+ "loss": 0.0, "moe_aux_loss": 0.0, "actor_loss": 0.0,
330
+ "value_loss": 0.0, "approx_kl": 0.0, "clip_frac": 0.0
331
+ }
332
+
333
+ grad_acc_steps = max(1, self.gradient_accumulation_steps)
334
+ ppo_batch_size = ppo_config.ppo_batch_size
335
+ num_micro_batches = (batch_size + ppo_batch_size - 1) // ppo_batch_size
336
+ total_micro_batches_processed = 0
337
+
338
+ for ppo_epoch in range(ppo_config.ppo_epochs):
339
+ indices = torch.randperm(batch_size, device=TrainerTools().parallel.device)
340
+
341
+ for i in range(0, batch_size, ppo_batch_size):
342
+ mini_batch_indices = indices[i:i + ppo_batch_size]
343
+ micro_batch_idx = i // ppo_batch_size
344
+ is_last_micro_batch = (micro_batch_idx == num_micro_batches - 1)
345
+ need_update_grad = ((micro_batch_idx + 1) % grad_acc_steps == 0) or is_last_micro_batch
346
+
347
+ if is_last_micro_batch:
348
+ remainder = (micro_batch_idx + 1) % grad_acc_steps
349
+ actual_acc_steps = remainder if remainder > 0 else grad_acc_steps
350
+ else:
351
+ actual_acc_steps = grad_acc_steps
352
+
353
+ if TrainerTools().parallel.parallel_train:
354
+ self.train_model.require_backward_grad_sync = need_update_grad
355
+
356
+ mb_input_ids = input_ids[mini_batch_indices]
357
+ mb_attention_mask = attention_mask[mini_batch_indices]
358
+ mb_completion_ids = completion_ids[mini_batch_indices]
359
+ mb_completion_mask = completion_mask[mini_batch_indices]
360
+ mb_old_log_probs = old_log_probs[mini_batch_indices]
361
+ mb_values = values_for_gae[mini_batch_indices]
362
+ mb_returns = returns[mini_batch_indices]
363
+ mb_advantages = advantages_whitened[mini_batch_indices]
364
+ mb_position_ids = calc_position_ids(mb_attention_mask)
365
+
366
+ with autocast(TrainerTools().parallel.device_type):
367
+ policy_output, value_output = self.train_model(
368
+ mb_input_ids,
369
+ attention_mask=mb_attention_mask,
370
+ position_ids=mb_position_ids
371
+ )
372
+
373
+ target_dtype = policy_output['logits'].dtype
374
+ mb_old_log_probs = mb_old_log_probs.to(target_dtype)
375
+ mb_values = mb_values.to(target_dtype)
376
+ mb_returns = mb_returns.to(target_dtype)
377
+ mb_advantages = mb_advantages.to(target_dtype)
378
+
379
+ logits_completion = policy_output['logits'][:, prompt_len - 1: -1]
380
+ current_log_probs = log_softmax(logits_completion, mb_completion_ids)
381
+ current_values = value_output[:, prompt_len - 1: -1]
382
+
383
+ loss, actor_loss, value_loss, approx_kl, clip_frac = self.criterion(
384
+ log_probs=current_log_probs,
385
+ old_log_probs=mb_old_log_probs,
386
+ values=current_values,
387
+ old_values=mb_values,
388
+ returns=mb_returns,
389
+ advantages=mb_advantages,
390
+ mask=mb_completion_mask
391
+ )
392
+
393
+ aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
394
+ if policy_output.get('aux_loss') and self.train_config.loss_config.aux_loss_coef:
395
+ aux_loss = self.train_config.loss_config.aux_loss_coef * policy_output['aux_loss']
396
+
397
+ total_loss = loss + aux_loss
398
+ scaled_total_loss = total_loss / actual_acc_steps
399
+ self._backward_loss(scaled_total_loss)
400
+
401
+ ppo_stats["loss"] += total_loss.detach().item()
402
+ ppo_stats["moe_aux_loss"] += aux_loss.detach().item()
403
+ ppo_stats["actor_loss"] += actor_loss.detach().item()
404
+ ppo_stats["value_loss"] += value_loss.detach().item()
405
+ ppo_stats["approx_kl"] += approx_kl.detach().item()
406
+ ppo_stats["clip_frac"] += clip_frac.detach().item()
407
+ total_micro_batches_processed += 1
408
+
409
+ if need_update_grad:
410
+ self._apply_grad_clipping()
411
+ self._apply_step()
412
+
413
+ if total_micro_batches_processed > 0:
414
+ for key in ppo_stats:
415
+ ppo_stats[key] /= total_micro_batches_processed
416
+
417
+ return ppo_stats
418
+
419
+ def train(self):
420
+ global_steps = 0
421
+ skipping_train = False
422
+
423
+ for epoch in range(self.train_config.n_epochs):
424
+ file_count = len(self.train_config.file_dataset)
425
+ for file_idx in range(file_count):
426
+ dataset, file_path = self._create_dataset(file_idx)
427
+ train_data_loader = TrainerTools().parallel.process_dataloader(
428
+ dataset=dataset,
429
+ data_loader_kwargs=self.data_loader_kwargs,
430
+ sampler_kwargs=self.sampler_kwargs
431
+ )
432
+
433
+ last_ckpt_batch = 0
434
+ batch_count_per_file = len(train_data_loader)
435
+
436
+ TrainerTools().parallel.on_epoch_start(epoch)
437
+ self._on_file_start(epoch, file_path)
438
+
439
+ for batch, batch_data in enumerate(train_data_loader):
440
+ global_steps += 1
441
+ if global_steps < self.last_global_steps:
442
+ skipping_train = True
443
+ continue
444
+
445
+ if skipping_train:
446
+ TrainerTools().parallel.wait('skip train')
447
+ skipping_train = False
448
+
449
+ rollout_data = self._generate_rollout_data(batch_data)
450
+ torch.cuda.empty_cache()
451
+
452
+ try:
453
+ ppo_stats = self._ppo_learning_phase(rollout_data)
454
+
455
+ stats_tensor = torch.tensor([
456
+ ppo_stats['loss'],
457
+ ppo_stats['moe_aux_loss'],
458
+ ppo_stats['actor_loss'],
459
+ ppo_stats['value_loss'],
460
+ ppo_stats['approx_kl'],
461
+ ppo_stats['clip_frac'],
462
+ rollout_data['env_rewards'].item()
463
+ ], device=TrainerTools().parallel.device)
464
+
465
+ if TrainerTools().parallel.parallel_train:
466
+ dist.all_reduce(stats_tensor, op=dist.ReduceOp.AVG)
467
+
468
+ ppo_stats['loss'] = stats_tensor[0].item()
469
+ ppo_stats['moe_aux_loss'] = stats_tensor[1].item()
470
+ ppo_stats['actor_loss'] = stats_tensor[2].item()
471
+ ppo_stats['value_loss'] = stats_tensor[3].item()
472
+ ppo_stats['approx_kl'] = stats_tensor[4].item()
473
+ ppo_stats['clip_frac'] = stats_tensor[5].item()
474
+ reward_value = stats_tensor[6].item()
475
+
476
+ self._log(
477
+ keys={
478
+ 'epoch': epoch,
479
+ 'file': f'{file_idx + 1}/{file_count}',
480
+ 'batch': f'{batch}/{batch_count_per_file}'
481
+ },
482
+ values={
483
+ 'loss': ppo_stats['loss'],
484
+ 'moe_aux_loss': ppo_stats['moe_aux_loss'],
485
+ 'actor_loss': ppo_stats['actor_loss'],
486
+ 'value_loss': ppo_stats['value_loss'],
487
+ 'approx_kl': ppo_stats['approx_kl'],
488
+ 'clip_frac': ppo_stats['clip_frac'],
489
+ 'rewards': reward_value
490
+ }
491
+ )
492
+ except Exception as e:
493
+ self._on_exception(e, epoch, batch)
494
+ finally:
495
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
496
+
497
+ if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
498
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
499
+ last_ckpt_batch = batch
500
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
501
+
502
+ torch.cuda.empty_cache()
503
+
504
+ # 一个文件训练结束后,清理内存
505
+ del train_data_loader
506
+ del dataset
507
+ if hasattr(TrainerTools().parallel, '_sampler'):
508
+ TrainerTools().parallel._sampler = None
509
+
510
+ gc.collect()
511
+ torch.cuda.empty_cache()
512
+
513
+ # end epoch
514
+ if not skipping_train:
515
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
516
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
517
+
518
+ TrainerTools().parallel.on_epoch_end(epoch)
519
+ self._on_epoch_end(tag=f'epoch:{epoch}')
520
+
521
+ TrainerTools().parallel.destroy()
@@ -0,0 +1,179 @@
1
+ from abc import ABC, abstractmethod
2
+ import math
3
+ import torch
4
+ from .log import Logger
5
+
6
+ class LRScheduler(ABC):
7
+ @property
8
+ @abstractmethod
9
+ def cur_steps(self): ...
10
+
11
+ @property
12
+ @abstractmethod
13
+ def cur_lr(self): ...
14
+
15
+ @abstractmethod
16
+ def step(self): ...
17
+
18
+ @abstractmethod
19
+ def can_clip_grad(self): ...
20
+
21
+ @abstractmethod
22
+ def get_ckpt_dict(self) -> dict: ...
23
+
24
+ @abstractmethod
25
+ def restore_ckpt_dict(self, ckpt: dict): ...
26
+
27
+
28
+ class WarmupCosineAnnealingLRScheduler(LRScheduler):
29
+ def __init__(
30
+ self,
31
+ *,
32
+ optimizer: torch.optim.Optimizer,
33
+ warmup_iters: int,
34
+ initial_lr: float,
35
+ min_lr: float,
36
+ max_lr: float,
37
+ cosine_annealing_period: int, # 每个周期的步数
38
+ cosine_annealing_period_mul: int = 0, # 周期长度的倍数
39
+ need_log: bool = False
40
+ ):
41
+ super().__init__()
42
+
43
+ self._optimizer = optimizer
44
+ self._initial_lr = initial_lr
45
+ self._min_lr = min_lr
46
+ self._max_lr = max_lr
47
+ self._warmup_iters = warmup_iters
48
+
49
+ self._cosine_annealing_period = cosine_annealing_period
50
+ self._cosine_annealing_period_mul = cosine_annealing_period_mul
51
+
52
+ self.T_cur = 0 # 当前周期内已走过的步数
53
+ self.cycle = 0 # 当前周期编号
54
+
55
+ if warmup_iters != 0:
56
+ self._lr_increment = (max_lr - initial_lr) / warmup_iters
57
+ else:
58
+ self._lr_increment = 0
59
+
60
+ self._steps = -1
61
+ self._current_lr = initial_lr
62
+ self._cosine_annealing_base_lr = None
63
+
64
+ if need_log:
65
+ self.logger = Logger('lr.txt')
66
+ else:
67
+ self.logger = None
68
+
69
+ @property
70
+ def cur_steps(self):
71
+ return self._steps
72
+
73
+ @property
74
+ def cur_lr(self):
75
+ return self._current_lr
76
+
77
+ def step(self):
78
+ self._steps += 1
79
+ self._update_lr()
80
+
81
+ def can_clip_grad(self):
82
+ return self._steps > self._warmup_iters
83
+
84
+ def _update_lr(self):
85
+ # 如果period_mul是0,则认为没有周期,超过余弦退火总步数,则一直保持最小lr
86
+ if self._cosine_annealing_period_mul == 0 and self._steps >= self._cosine_annealing_period + self._warmup_iters:
87
+ lr = self._min_lr
88
+ for param_group in self._optimizer.param_groups:
89
+ param_group['lr'] = lr
90
+ elif self._steps <= self._warmup_iters:
91
+ # Warmup: adjust learning rate linearly
92
+ # (max_lr - initial_lr) / warmup_iters
93
+ lr = self._initial_lr + self._steps * self._lr_increment
94
+ for param_group in self._optimizer.param_groups:
95
+ param_group['lr'] = lr
96
+ else:
97
+ if not self._cosine_annealing_base_lr:
98
+ self._cosine_annealing_base_lr = self.cur_lr
99
+
100
+ """每步更新学习率"""
101
+ # 计算当前周期的最大步数
102
+ T_max = self._cosine_annealing_period * (max(self._cosine_annealing_period_mul, 1) ** self.cycle)
103
+
104
+ # 更新周期状态
105
+ self.T_cur += 1
106
+ calc_t = self.T_cur
107
+
108
+ if self.T_cur >= T_max:
109
+ if self._cosine_annealing_period_mul == 0:
110
+ self.T_cur = T_max
111
+ calc_t = T_max
112
+ else:
113
+ self.cycle += 1
114
+ self.T_cur = 0
115
+ calc_t = T_max
116
+
117
+ # 计算并设置新学习率
118
+ cos_factor = (1 + math.cos(math.pi * calc_t / T_max)) / 2
119
+ lr = self._min_lr + (self._cosine_annealing_base_lr - self._min_lr) * cos_factor
120
+
121
+ for param_group in self._optimizer.param_groups:
122
+ param_group['lr'] = lr
123
+
124
+ self._current_lr = lr
125
+
126
+ if self.logger:
127
+ self.logger.log(f"step: {self.cur_steps}, lr: {lr}", log_to_console=False)
128
+
129
+ def get_ckpt_dict(self) -> dict:
130
+ return {
131
+ 'cur_lr': self._current_lr,
132
+ 'lr_steps': self.cur_steps,
133
+ 'cosine_annealing_base_lr': self._cosine_annealing_base_lr,
134
+ 't_cur': self.T_cur,
135
+ 'cycle': self.cycle,
136
+ }
137
+
138
+ def restore_ckpt_dict(self, ckpt: dict):
139
+ if 'cur_lr' in ckpt:
140
+ self._current_lr = ckpt['cur_lr']
141
+
142
+ if 'lr_steps' in ckpt:
143
+ self._steps = ckpt['lr_steps']
144
+
145
+ if 'cosine_annealing_base_lr' in ckpt:
146
+ self._cosine_annealing_base_lr = ckpt['cosine_annealing_base_lr']
147
+
148
+ if 't_cur' in ckpt:
149
+ self.T_cur = ckpt['t_cur']
150
+
151
+ if 'cycle' in ckpt:
152
+ self.cycle = ckpt['cycle']
153
+
154
+ self._update_lr()
155
+
156
+
157
+ class NoneLRScheduler(LRScheduler):
158
+ def __init__(self, initial_lr):
159
+ self._current_lr = initial_lr
160
+
161
+ @property
162
+ def cur_steps(self):
163
+ return -1
164
+
165
+ @property
166
+ def cur_lr(self):
167
+ return self._current_lr
168
+
169
+ def step(self): ...
170
+
171
+ def can_clip_grad(self):
172
+ return True
173
+
174
+ def get_ckpt_dict(self) -> dict:
175
+ return {'cur_lr': self._current_lr}
176
+
177
+ def restore_ckpt_dict(self, ckpt: dict):
178
+ if 'cur_lr' in ckpt:
179
+ self._current_lr = ckpt['cur_lr']