project-llm-trainer 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,13 @@
1
+ from .trainer import Trainer
2
+ from .sft_trainer import SFTTrainer
3
+ from .dpo_trainer import DPOTrainer
4
+ from .ppo_trainer import PPOTrainer
5
+ from .grpo_trainer import GRPOTrainer
6
+ from .tools import (
7
+ TrainerTools,
8
+ FileDataset,
9
+ estimate_data_size,
10
+ extract_policy_weights_from_ppo,
11
+ extract_value_weights_from_ppo
12
+ )
13
+ from .generate_utils import generate, streaming_generate
@@ -0,0 +1,683 @@
1
+ from typing import Optional, Tuple, List, Dict, Any
2
+ import copy
3
+ import gc
4
+ import importlib.metadata
5
+ from packaging import version
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.utils.data import Dataset
10
+ from llm_model import LlmModel
11
+
12
+ from .parallel import DsParallel
13
+ from .tools import TrainerTools
14
+ from .loss import LMLoss, KDLoss
15
+ from .eval import submit_gen_task
16
+ from .partition_utils import unwrap_model_for_generation
17
+
18
+ from .train_configs import (
19
+ TrainConfig,
20
+ DsZero2Config,
21
+ DsZero3Config,
22
+ KDConfig
23
+ )
24
+
25
+ from .scheduler import (
26
+ LRScheduler,
27
+ WarmupCosineAnnealingLRScheduler,
28
+ NoneLRScheduler
29
+ )
30
+
31
+ from .checkpoint import (
32
+ load_checkpoint,
33
+ save_checkpoint,
34
+ load_steps,
35
+ save_steps,
36
+ )
37
+
38
+ from .utils import (
39
+ set_seed,
40
+ autocast,
41
+ )
42
+
43
+ from .log import Logger
44
+
45
+ class BaseTrainer:
46
+ def __init__(
47
+ self,
48
+ *,
49
+ train_config: TrainConfig,
50
+ eval_prompts: List[str],
51
+ kd_config: Optional[KDConfig] = None,
52
+ gradient_accumulation_steps: int = 1
53
+ ):
54
+ set_seed()
55
+
56
+ self.train_config: TrainConfig = train_config
57
+ self.eval_prompts = eval_prompts
58
+ self.eval_idx = -1
59
+ self.last_global_steps = 0
60
+ self.kd_config = kd_config
61
+ self.gradient_accumulation_steps = gradient_accumulation_steps
62
+
63
+ self.logger = Logger('log.txt')
64
+
65
+ self.parallel_kwargs, self.data_loader_kwargs, self.sampler_kwargs = self._convert_train_args()
66
+ # initialize a GradScaler. If enabled=False scaler is a no-op
67
+ self.scaler = torch.GradScaler(enabled=TrainerTools().use_amp)
68
+
69
+ # 注意:学习率要根据GPU的数量进行倍增:
70
+ # 在训练的过程中,损失梯度决定下降的方向,学习率决定下降的步长。如果有两块gpu,前进的综合步长为:平均学习率*2
71
+ initial_lr = train_config.optim_config.initial_lr
72
+
73
+ self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr)
74
+ self.lr_scheduler = self._init_lr_scheduler(initial_lr, self.optimizer)
75
+
76
+ self.criterion, self.kd_loss = self._init_loss()
77
+
78
+ load_checkpoint(
79
+ self.train_model,
80
+ optimizer=self.optimizer,
81
+ device=TrainerTools().parallel.device
82
+ )
83
+
84
+ steps_dict = load_steps()
85
+ self._apply_restore_ckpt(steps_dict)
86
+
87
+ def _new_model(self, train_config: TrainConfig):
88
+ return LlmModel(train_config.model_config)
89
+
90
+ def _init_train_model_and_optim(self, initial_lr: float):
91
+ model = self._new_model(self.train_config)
92
+
93
+ if self.train_config.init_state_dict:
94
+ model.load_state_dict(self.train_config.init_state_dict, strict=False)
95
+ self.train_config.init_state_dict = None
96
+
97
+ self._check_freeze_llm_model(model)
98
+
99
+ if TrainerTools().parallel.is_main_process:
100
+ total_params = sum(p.numel() for p in model.parameters())
101
+ Logger.std_log(f"Total number of parameters: {total_params:,}")
102
+
103
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
104
+ Logger.std_log(f"Trainable number of parameters: {trainable_params:,}")
105
+
106
+ total_size_bytes = total_params * 4
107
+ total_size_mb = total_size_bytes / (1024 * 1024)
108
+ Logger.std_log(f"Total size of the model: {total_size_mb:.2f} MB")
109
+
110
+ model, optim = TrainerTools().parallel.process(
111
+ model=model,
112
+ optimizer=self._config_optim(model, initial_lr),
113
+ kwargs=self.parallel_kwargs
114
+ )
115
+
116
+ return model, optim
117
+
118
+ def _check_freeze_llm_model(self, model): ...
119
+
120
+ def _config_optim(self, model, initial_lr):
121
+ optimizer = None
122
+ use_lion_optim = self.train_config.optim_config.optim_type == 'lion'
123
+
124
+ if isinstance(TrainerTools().parallel, DsParallel) and self.parallel_kwargs:
125
+ import deepspeed
126
+ if ('zero_optimization' in self.parallel_kwargs
127
+ and 'offload_optimizer' in self.parallel_kwargs['zero_optimization']
128
+ and self.parallel_kwargs['zero_optimization']['offload_optimizer']['device'] == 'cpu'):
129
+ if self.train_config.optim_config.optim_type == 'lion':
130
+ if version.parse(importlib.metadata.version("deepspeed")) >= version.parse('0.17.6'):
131
+ optimizer = deepspeed.ops.lion.DeepSpeedCPULion
132
+ else:
133
+ optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
134
+ use_lion_optim = False
135
+ if TrainerTools().parallel.is_main_process:
136
+ Logger.std_log('When set offload_optimizer, lion optim is unsupported, so set optim to adam!!!!!')
137
+ else:
138
+ optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
139
+ else:
140
+ if self.train_config.optim_config.optim_type == 'lion':
141
+ optimizer = deepspeed.ops.lion.FusedLion
142
+ else:
143
+ optimizer = deepspeed.ops.adam.FusedAdam
144
+
145
+ if not optimizer:
146
+ if self.train_config.optim_config.optim_type == 'lion':
147
+ try:
148
+ import lion_pytorch
149
+ except:
150
+ raise Exception('lion is not detected, please use `pip3 install lion_pytorch` to install or set optim_type to adam')
151
+
152
+ optimizer = lion_pytorch.Lion
153
+ else:
154
+ optimizer = torch.optim.AdamW
155
+
156
+ betas = self.train_config.optim_config.betas
157
+ weight_decay = self.train_config.optim_config.weight_decay
158
+
159
+ if betas is None:
160
+ if use_lion_optim:
161
+ betas = (0.95, 0.98)
162
+ else:
163
+ betas = (0.9, 0.999)
164
+
165
+ if weight_decay is None:
166
+ if use_lion_optim:
167
+ weight_decay = 0.015
168
+ else:
169
+ weight_decay = 0.01
170
+
171
+ no_decay_name_list = ["bias", "norm.weight"]
172
+ decay_params = []
173
+ no_decay_params = []
174
+
175
+ for name, param in model.named_parameters():
176
+ if not param.requires_grad:
177
+ continue
178
+
179
+ if any(nd in name for nd in no_decay_name_list):
180
+ no_decay_params.append(param)
181
+ else:
182
+ decay_params.append(param)
183
+
184
+ optimizer_grouped_parameters = [
185
+ {
186
+ "params": decay_params,
187
+ "weight_decay": weight_decay,
188
+ },
189
+ {
190
+ "params": no_decay_params,
191
+ "weight_decay": 0.0,
192
+ },
193
+ ]
194
+
195
+ return optimizer(
196
+ optimizer_grouped_parameters,
197
+ lr=initial_lr,
198
+ betas=betas,
199
+ weight_decay=weight_decay
200
+ )
201
+
202
+ def _init_lr_scheduler(self, initial_lr: float, optimizer) -> LRScheduler:
203
+ if self.train_config.optim_config.enable_lr_scheduler:
204
+ warmup_iters = self.train_config.optim_config.warmup_iters
205
+ min_lr = self.train_config.optim_config.min_lr
206
+ max_lr = self.train_config.optim_config.max_lr
207
+ cosine_annealing_period = self.train_config.optim_config.cosine_annealing_period
208
+ cosine_annealing_period_mul = self.train_config.optim_config.cosine_annealing_period_mul
209
+
210
+ return WarmupCosineAnnealingLRScheduler(
211
+ optimizer=optimizer,
212
+ warmup_iters=warmup_iters,
213
+ initial_lr=initial_lr,
214
+ min_lr=min_lr,
215
+ max_lr=max_lr,
216
+ cosine_annealing_period=cosine_annealing_period,
217
+ cosine_annealing_period_mul=cosine_annealing_period_mul,
218
+ need_log=TrainerTools().parallel.is_main_process
219
+ )
220
+
221
+ return NoneLRScheduler(initial_lr)
222
+
223
+ def _init_loss(self):
224
+ critical_tokens: Optional[List[int]] = None
225
+ critical_alpha: float = 1.0
226
+ if self.train_config.loss_config.critical_tokens:
227
+ critical_tokens = self.train_config.loss_config.critical_tokens
228
+ critical_alpha = self.train_config.loss_config.critical_alpha
229
+
230
+ criterion = LMLoss(
231
+ critical_tokens=critical_tokens,
232
+ critical_alpha=critical_alpha,
233
+ vocab_size=TrainerTools().tokenizer.vocab_size
234
+ )
235
+
236
+ kd_loss = KDLoss() if self.kd_config else None
237
+
238
+ return criterion, kd_loss
239
+
240
+ def _apply_restore_ckpt(self, steps_dict):
241
+ if steps_dict:
242
+ self.last_global_steps = steps_dict['global_steps']
243
+ if not self.last_global_steps:
244
+ self.last_global_steps = 0
245
+
246
+ self.lr_scheduler.restore_ckpt_dict(steps_dict)
247
+
248
+ if TrainerTools().parallel.is_main_process:
249
+ Logger.std_log(f'restore steps_dict={steps_dict}')
250
+
251
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
252
+ parallel_kwargs: Optional[Dict[str, Any]] = None
253
+ if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
254
+ parallel_kwargs = {
255
+ 'gradient_accumulation_steps': 1,
256
+ 'gradient_clipping': self.train_config.ds_config.gradient_clipping,
257
+ 'train_micro_batch_size_per_gpu': self.train_config.batch_size
258
+ }
259
+
260
+ if self.train_config.ds_config.zero_config:
261
+ zero_config = self.train_config.ds_config.zero_config
262
+ zero_optimization: Dict[str, Any] = {'stage': zero_config.stage}
263
+
264
+ if zero_config.allgather_partitions is not None:
265
+ zero_optimization['allgather_partitions'] = zero_config.allgather_partitions
266
+ if zero_config.allgather_bucket_size is not None:
267
+ zero_optimization['allgather_bucket_size'] = zero_config.allgather_bucket_size
268
+ if zero_config.overlap_comm is not None:
269
+ zero_optimization['overlap_comm'] = zero_config.overlap_comm
270
+ if zero_config.reduce_scatter is not None:
271
+ zero_optimization['reduce_scatter'] = zero_config.reduce_scatter
272
+ if zero_config.reduce_bucket_size is not None:
273
+ zero_optimization['reduce_bucket_size'] = zero_config.reduce_bucket_size
274
+ if zero_config.contiguous_gradients is not None:
275
+ zero_optimization['contiguous_gradients'] = zero_config.contiguous_gradients
276
+
277
+ if isinstance(zero_config, DsZero2Config) or isinstance(zero_config, DsZero3Config):
278
+ if zero_config.offload_optimizer is not None:
279
+ zero_optimization['offload_optimizer'] = {
280
+ "device": zero_config.offload_optimizer.device,
281
+ "pin_memory": zero_config.offload_optimizer.pin_memory
282
+ }
283
+ if zero_config.offload_param is not None:
284
+ zero_optimization['offload_param'] = {
285
+ "device": zero_config.offload_param.device,
286
+ "pin_memory": zero_config.offload_param.pin_memory
287
+ }
288
+
289
+ if isinstance(zero_config, DsZero3Config):
290
+ if zero_config.sub_group_size is not None:
291
+ zero_optimization['sub_group_size'] = zero_config.sub_group_size
292
+ if zero_config.stage3_prefetch_bucket_size is not None:
293
+ zero_optimization['stage3_prefetch_bucket_size'] = zero_config.stage3_prefetch_bucket_size
294
+ if zero_config.stage3_param_persistence_threshold is not None:
295
+ zero_optimization['stage3_param_persistence_threshold'] = zero_config.stage3_param_persistence_threshold
296
+ if zero_config.stage3_max_live_parameters is not None:
297
+ zero_optimization['stage3_max_live_parameters'] = zero_config.stage3_max_live_parameters
298
+ if zero_config.stage3_max_reuse_distance is not None:
299
+ zero_optimization['stage3_max_reuse_distance'] = zero_config.stage3_max_reuse_distance
300
+ if zero_config.stage3_gather_16bit_weights_on_model_save is not None:
301
+ zero_optimization['stage3_gather_16bit_weights_on_model_save'] = zero_config.stage3_gather_16bit_weights_on_model_save
302
+
303
+ parallel_kwargs['zero_optimization'] = zero_optimization
304
+
305
+ if (self.train_config.ds_config.bf16_config is not None
306
+ and self.train_config.ds_config.bf16_config.enabled):
307
+ bf16_config = self.train_config.ds_config.bf16_config
308
+ bf16 = {
309
+ 'enabled': bf16_config.enabled
310
+ }
311
+ parallel_kwargs['bf16'] = bf16
312
+ elif self.train_config.ds_config.fp16_config:
313
+ fb16_config = self.train_config.ds_config.fp16_config
314
+ fp16 = {
315
+ 'enabled': fb16_config.enabled,
316
+ 'loss_scale': fb16_config.loss_scale,
317
+ 'loss_scale_window': fb16_config.loss_scale_window,
318
+ 'initial_scale_power': fb16_config.initial_scale_power,
319
+ 'hysteresis': fb16_config.hysteresis,
320
+ 'min_loss_scale': fb16_config.min_loss_scale
321
+ }
322
+
323
+ if fb16_config.fp16_opt_level is not None:
324
+ fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
325
+
326
+ parallel_kwargs['fp16'] = fp16
327
+
328
+ if self.train_config.ds_config.activation_checkpointing:
329
+ activation_checkpointing_config = self.train_config.ds_config.activation_checkpointing
330
+ activation_checkpointing: Dict[str, Any] = {
331
+ 'partition_activations': activation_checkpointing_config.partition_activations,
332
+ 'cpu_checkpointing': activation_checkpointing_config.cpu_checkpointing,
333
+ 'contiguous_memory_optimization': activation_checkpointing_config.contiguous_memory_optimization,
334
+ 'synchronize_checkpoint_boundary': activation_checkpointing_config.synchronize_checkpoint_boundary,
335
+ 'profile': activation_checkpointing_config.profile
336
+ }
337
+
338
+ if activation_checkpointing_config.number_checkpoints is not None:
339
+ activation_checkpointing['number_checkpoints'] = activation_checkpointing_config.number_checkpoints
340
+
341
+ parallel_kwargs['activation_checkpointing'] = activation_checkpointing
342
+
343
+ dataloader_args = self.train_config.data_loader_config
344
+ data_loader_kwargs = {
345
+ "batch_size": self.train_config.batch_size,
346
+ "pin_memory": dataloader_args.data_loader_pin_memory,
347
+ "num_workers": dataloader_args.data_loader_num_workers,
348
+ "shuffle": dataloader_args.data_loader_shuffle,
349
+ "drop_last": dataloader_args.data_loader_drop_last,
350
+ }
351
+ sampler_kwargs = {
352
+ "shuffle": dataloader_args.data_loader_shuffle,
353
+ "drop_last": dataloader_args.data_loader_drop_last,
354
+ }
355
+
356
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
357
+
358
+ def _init_ref_model_args(self) -> dict:
359
+ parallel_kwargs = copy.deepcopy(self.parallel_kwargs) if self.parallel_kwargs else None
360
+
361
+ if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
362
+ # reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
363
+ # if model is not None:
364
+ # hidden_size = (
365
+ # max(model.config.hidden_sizes)
366
+ # if getattr(model.config, "hidden_sizes", None)
367
+ # else getattr(model.config, "hidden_size", None)
368
+ # )
369
+ # if hidden_size is not None and stage == 3:
370
+ # # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
371
+ # # @ step 0: expected module 1, but got module 0`
372
+ # # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
373
+ # config_kwargs.update(
374
+ # {
375
+ # "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
376
+ # "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
377
+ # "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
378
+ # }
379
+ # )
380
+
381
+ parallel_kwargs.pop('activation_checkpointing', None)
382
+ parallel_kwargs.pop('gradient_clipping', None)
383
+
384
+ # ref_model暂时先使用stage 0, 解决训练卡住问题
385
+ parallel_kwargs["zero_optimization"] = {"stage": 0}
386
+ # if parallel_kwargs.get("zero_optimization", {}).get("stage", 0) != 3:
387
+ # parallel_kwargs["zero_optimization"] = {"stage": 0}
388
+
389
+ return parallel_kwargs
390
+
391
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]: ...
392
+
393
+ def _calc_loss(self, inputs, attention_mask, logits, labels):
394
+ # calc loss
395
+ if not self.kd_loss or self.kd_config.kd_coef == 0.0:
396
+ # 不用计算kd_loss
397
+ return self.criterion(logits, labels)
398
+
399
+ teacher_logits = self.kd_config.teacher_logits_provider(inputs, attention_mask)
400
+ loss = self.kd_loss(logits, teacher_logits, labels)
401
+
402
+ if self.kd_config.kd_coef == 1.0:
403
+ # 不用计算ce loss
404
+ return loss
405
+
406
+ ce_loss = self.criterion(logits, labels)
407
+ return (1 - self.kd_config.kd_coef) * ce_loss + self.kd_config.kd_coef * loss
408
+
409
+ def _backward_loss(self, loss):
410
+ if isinstance(TrainerTools().parallel, DsParallel):
411
+ self.train_model.backward(loss)
412
+ else:
413
+ self.scaler.scale(loss).backward()
414
+
415
+ def _apply_grad_clipping(self):
416
+ # ds模式已经集成gradient_clipping
417
+ if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
418
+ # clip grad
419
+ self.scaler.unscale_(self.optimizer)
420
+
421
+ trainable_params = filter(lambda p: p.requires_grad, self.train_model.parameters())
422
+ torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
423
+
424
+ def _apply_step(self):
425
+ self.lr_scheduler.step()
426
+ if isinstance(TrainerTools().parallel, DsParallel):
427
+ self.train_model.step()
428
+ else:
429
+ self.scaler.step(self.optimizer)
430
+ self.scaler.update()
431
+ self.optimizer.zero_grad(set_to_none=True)
432
+
433
+ TrainerTools().parallel.synchronize()
434
+
435
+ def _get_eval_data(self) -> Optional[str]:
436
+ if len(self.eval_prompts) == 0:
437
+ return None
438
+
439
+ self.eval_idx += 1
440
+ if self.eval_idx == len(self.eval_prompts):
441
+ self.eval_idx = 0
442
+
443
+ return self.eval_prompts[self.eval_idx]
444
+
445
+ def _get_eval_pixel_values_and_tokens_count(self, eval_idx):
446
+ return None, None
447
+
448
+ def _log(self, keys: Dict[str, any], values: Dict[str, any]):
449
+ """
450
+ 格式:keys_key1: keys_value1, keys_key2: keys_value2 -> values_key1: values_value1, values_key2: values_value2
451
+ """
452
+ if TrainerTools().parallel.is_main_process:
453
+ log_tags = ', '.join([f'{k}: {v}' for k, v in keys.items()])
454
+ log_values = ', '.join([f'{k}: {v}' for k, v in values.items()])
455
+
456
+ log_msg = f'{log_tags} -> {log_values}'
457
+ self.logger.log(log_msg)
458
+
459
+ def _on_exception(
460
+ self,
461
+ e: Exception,
462
+ epoch: int,
463
+ batch: int
464
+ ):
465
+ exception_file = e.__traceback__.tb_frame.f_globals["__file__"]
466
+ exception_line = e.__traceback__.tb_lineno
467
+ log_msg = f"epoch: {epoch}, batch: {batch} -> {e} at {exception_file} line {exception_line}"
468
+ Logger('exception.txt').log(log_msg, log_to_console=False).release()
469
+
470
+ raise e
471
+
472
+ def _get_model_dtype(self):
473
+ if isinstance(TrainerTools().parallel, DsParallel):
474
+ import deepspeed
475
+ assert isinstance(self.train_model, deepspeed.DeepSpeedEngine)
476
+ return self.train_model.get_data_types()[0]
477
+ else:
478
+ return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
479
+
480
+ def _eval(self, tag: str):
481
+ with unwrap_model_for_generation(self.train_model) as eval_model:
482
+ if TrainerTools().parallel.is_main_process:
483
+ eval_prompt = self._get_eval_data()
484
+
485
+ if eval_prompt:
486
+ eval_model = self._check_eval_model(eval_model)
487
+ eval_model.eval()
488
+
489
+ eval_pixel_values, tokens_per_image = self._get_eval_pixel_values_and_tokens_count(self.eval_idx)
490
+ submit_gen_task(
491
+ eval_model,
492
+ self.train_config,
493
+ tag=tag,
494
+ prompt=eval_prompt,
495
+ pixel_values=eval_pixel_values,
496
+ tokens_per_image=tokens_per_image
497
+ )
498
+
499
+ eval_model.train()
500
+
501
+ TrainerTools().parallel.wait('eval')
502
+
503
+ def _check_eval_model(self, eval_model):
504
+ return eval_model
505
+
506
+ def _on_batch_end(self, tag: str):
507
+ self._eval(f'sign:batch/{tag}')
508
+
509
+ def _on_epoch_end(self, tag: str):
510
+ self._eval(f'sign:epoch/{tag}')
511
+
512
+ def _on_file_start(
513
+ self,
514
+ epoch: int,
515
+ file_name: str
516
+ ):
517
+ if TrainerTools().parallel.is_main_process:
518
+ self.logger.log(f"====epoch: {epoch}, start train {file_name}====", log_to_console=False)
519
+
520
+ def _avg_loss(
521
+ self,
522
+ losses: List[float],
523
+ gradient_accumulation_steps,
524
+ batches_accumulated
525
+ ) -> List[float]:
526
+ avg_losses = []
527
+ for loss in losses:
528
+ avg_loss = torch.tensor(
529
+ loss * gradient_accumulation_steps / batches_accumulated,
530
+ device=TrainerTools().parallel.device)
531
+
532
+ if TrainerTools().parallel.parallel_train:
533
+ dist.all_reduce(avg_loss, dist.ReduceOp.AVG)
534
+
535
+ avg_losses.append(avg_loss.detach().item())
536
+
537
+ return avg_losses
538
+
539
+ def _get_pixel_values(self, batch_data):
540
+ return None
541
+
542
+ def train(self):
543
+ # 梯度累积步数
544
+ gradient_accumulation_steps = max(1, self.gradient_accumulation_steps)
545
+ global_steps = 0
546
+ skipping_train = False
547
+
548
+ loss_accumulation = 0.0
549
+ aux_loss_accumulation = 0.0
550
+ batches_accumulated = 0
551
+
552
+ for epoch in range(self.train_config.n_epochs):
553
+ self.train_model.train()
554
+ file_count = len(self.train_config.file_dataset)
555
+
556
+ for file_idx in range(file_count):
557
+ dataset, file_path = self._create_dataset(file_idx)
558
+ train_data_loader = TrainerTools().parallel.process_dataloader(
559
+ dataset=dataset,
560
+ data_loader_kwargs=self.data_loader_kwargs,
561
+ sampler_kwargs=self.sampler_kwargs
562
+ )
563
+
564
+ last_ckpt_batch = 0
565
+ batch_count_per_file = len(train_data_loader)
566
+
567
+ TrainerTools().parallel.on_epoch_start(epoch)
568
+ self._on_file_start(epoch, file_path)
569
+
570
+ for batch, batch_data in enumerate(train_data_loader):
571
+ global_steps += 1
572
+ if global_steps < self.last_global_steps:
573
+ skipping_train = True
574
+ continue
575
+
576
+ # 是否需要更新梯度
577
+ if skipping_train:
578
+ need_update_grad = False
579
+ elif gradient_accumulation_steps > 1:
580
+ need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
581
+ else:
582
+ need_update_grad = True
583
+
584
+ # 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
585
+ if skipping_train:
586
+ TrainerTools().parallel.wait('skip train')
587
+ skipping_train = False
588
+
589
+ inputs = batch_data['inputs']
590
+ labels = batch_data['labels']
591
+
592
+ try:
593
+ inputs, labels = inputs.to(TrainerTools().parallel.device), labels.to(TrainerTools().parallel.device)
594
+ attention_mask = inputs != TrainerTools().tokenizer.pad
595
+ pixel_values = self._get_pixel_values(batch_data)
596
+
597
+ if TrainerTools().parallel.parallel_train:
598
+ self.train_model.require_backward_grad_sync = need_update_grad
599
+
600
+ with autocast(TrainerTools().parallel.device_type):
601
+ result = self.train_model(
602
+ inputs,
603
+ attention_mask=attention_mask,
604
+ pixel_values=pixel_values
605
+ )
606
+
607
+ # calc loss
608
+ loss = self._calc_loss(inputs, attention_mask, result['logits'], labels)
609
+ if result['aux_loss'] and self.train_config.loss_config.aux_loss_coef:
610
+ aux_loss = self.train_config.loss_config.aux_loss_coef * result['aux_loss']
611
+ else:
612
+ aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
613
+
614
+ if gradient_accumulation_steps > 1:
615
+ loss = loss / gradient_accumulation_steps
616
+ aux_loss = aux_loss / gradient_accumulation_steps
617
+
618
+ total_loss = loss + aux_loss
619
+ self._backward_loss(total_loss)
620
+
621
+ loss_accumulation += total_loss.detach().item()
622
+ aux_loss_accumulation += aux_loss.detach().item()
623
+
624
+ batches_accumulated += 1
625
+
626
+ if need_update_grad:
627
+ self._apply_grad_clipping()
628
+ self._apply_step()
629
+
630
+ avg_loss, avg_aux_loss = self._avg_loss(
631
+ losses=[
632
+ loss_accumulation,
633
+ aux_loss_accumulation
634
+ ],
635
+ gradient_accumulation_steps=gradient_accumulation_steps,
636
+ batches_accumulated=batches_accumulated
637
+ )
638
+
639
+ self._log(
640
+ keys={
641
+ 'epoch': epoch,
642
+ 'file': f'{file_idx + 1}/{file_count}',
643
+ 'batch': f'{batch}/{batch_count_per_file}'
644
+ },
645
+ values={
646
+ 'loss': avg_loss,
647
+ 'moe_aux_loss': avg_aux_loss
648
+ }
649
+ )
650
+
651
+ # reset to default
652
+ loss_accumulation = 0.0
653
+ aux_loss_accumulation = 0.0
654
+ batches_accumulated = 0
655
+ except Exception as e:
656
+ self._on_exception(e, epoch, batch)
657
+ finally:
658
+ if need_update_grad:
659
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
660
+
661
+ if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
662
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
663
+ last_ckpt_batch = batch
664
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
665
+
666
+ # 一个文件训练结束后,清理内存
667
+ del train_data_loader
668
+ del dataset
669
+ if hasattr(TrainerTools().parallel, '_sampler'):
670
+ TrainerTools().parallel._sampler = None
671
+
672
+ gc.collect()
673
+ torch.cuda.empty_cache()
674
+
675
+ # end epoch
676
+ if not skipping_train:
677
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
678
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
679
+
680
+ TrainerTools().parallel.on_epoch_end(epoch)
681
+ self._on_epoch_end(tag=f'epoch:{epoch}')
682
+
683
+ TrainerTools().parallel.destroy()