project-llm-trainer 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,707 @@
1
+ from typing import Optional, Tuple, List, Dict, Any
2
+ import copy
3
+ import gc
4
+ import importlib.metadata
5
+ from packaging import version
6
+ from itertools import islice
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch.utils.data import Dataset
11
+ from llm_model import LlmModel
12
+
13
+ from .parallel import DsParallel
14
+ from .tools import TrainerTools
15
+ from .loss import LMLoss, KDLoss
16
+ from .eval import submit_gen_task
17
+ from .partition_utils import unwrap_model_for_generation
18
+
19
+ from .train_configs import (
20
+ TrainConfig,
21
+ DsZero2Config,
22
+ DsZero3Config,
23
+ KDConfig
24
+ )
25
+
26
+ from .scheduler import (
27
+ LRScheduler,
28
+ WarmupCosineAnnealingLRScheduler,
29
+ NoneLRScheduler
30
+ )
31
+
32
+ from .checkpoint import (
33
+ load_checkpoint,
34
+ save_checkpoint,
35
+ load_steps,
36
+ save_steps,
37
+ )
38
+
39
+ from .utils import (
40
+ set_seed,
41
+ autocast,
42
+ )
43
+
44
+ from .log import Logger
45
+
46
+ class BaseTrainer:
47
+ def __init__(
48
+ self,
49
+ *,
50
+ train_config: TrainConfig,
51
+ eval_prompts: List[str],
52
+ kd_config: Optional[KDConfig] = None,
53
+ gradient_accumulation_steps: int = 1
54
+ ):
55
+ set_seed()
56
+
57
+ self.train_config: TrainConfig = train_config
58
+ self.eval_prompts = eval_prompts
59
+ self.eval_idx = -1
60
+
61
+ self.resume_epoch = 0
62
+ self.resume_file_idx = 0
63
+ self.resume_batch_idx = 0
64
+
65
+ self.kd_config = kd_config
66
+ self.gradient_accumulation_steps = gradient_accumulation_steps
67
+
68
+ self.logger = Logger('log.txt')
69
+
70
+ self.parallel_kwargs, self.data_loader_kwargs, self.sampler_kwargs = self._convert_train_args()
71
+ # initialize a GradScaler. If enabled=False scaler is a no-op
72
+ self.scaler = torch.GradScaler(enabled=TrainerTools().use_amp)
73
+
74
+ # 注意:学习率要根据GPU的数量进行倍增:
75
+ # 在训练的过程中,损失梯度决定下降的方向,学习率决定下降的步长。如果有两块gpu,前进的综合步长为:平均学习率*2
76
+ initial_lr = train_config.optim_config.initial_lr
77
+
78
+ self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr)
79
+ self.lr_scheduler = self._init_lr_scheduler(initial_lr, self.optimizer)
80
+
81
+ self.criterion, self.kd_loss = self._init_loss()
82
+
83
+ self._load_train_model_checkpoint()
84
+ self._apply_restore_ckpt()
85
+
86
+ def _new_model(self, train_config: TrainConfig):
87
+ return LlmModel(train_config.model_config)
88
+
89
+ def _init_train_model_and_optim(self, initial_lr: float):
90
+ model = self._new_model(self.train_config)
91
+
92
+ if self.train_config.init_state_dict:
93
+ model.load_state_dict(self.train_config.init_state_dict, strict=False)
94
+ self.train_config.init_state_dict = None
95
+
96
+ self._check_freeze_llm_model(model)
97
+
98
+ if self.train_config.ds_config and self.train_config.ds_config.activation_checkpointing:
99
+ model.gradient_checkpointing_enable()
100
+
101
+ if TrainerTools().parallel.is_main_process:
102
+ total_params = sum(p.numel() for p in model.parameters())
103
+ Logger.std_log(f"Total number of parameters: {total_params:,}")
104
+
105
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
106
+ Logger.std_log(f"Trainable number of parameters: {trainable_params:,}")
107
+
108
+ total_size_bytes = total_params * 4
109
+ total_size_mb = total_size_bytes / (1024 * 1024)
110
+ Logger.std_log(f"Total size of the model: {total_size_mb:.2f} MB")
111
+
112
+ model, optim = TrainerTools().parallel.process(
113
+ model=model,
114
+ optimizer=self._config_optim(model, initial_lr),
115
+ kwargs=self.parallel_kwargs
116
+ )
117
+
118
+ return model, optim
119
+
120
+ def _check_freeze_llm_model(self, model): ...
121
+
122
+ def _config_optim(self, model, initial_lr):
123
+ optimizer_cls, use_lion_optim = self._get_optim_cls()
124
+
125
+ betas = self.train_config.optim_config.betas
126
+ weight_decay = self.train_config.optim_config.betas
127
+
128
+ if betas is None:
129
+ betas = (0.95, 0.98) if use_lion_optim else (0.9, 0.999)
130
+
131
+ if weight_decay is None:
132
+ weight_decay = 0.015 if use_lion_optim else 0.01
133
+
134
+ no_decay_name_list = ["bias", "norm.weight"]
135
+ decay_params = []
136
+ no_decay_params = []
137
+
138
+ for name, param in model.named_parameters():
139
+ if not param.requires_grad:
140
+ continue
141
+
142
+ if any(nd in name for nd in no_decay_name_list):
143
+ no_decay_params.append(param)
144
+ else:
145
+ decay_params.append(param)
146
+
147
+ optimizer_grouped_parameters = [
148
+ {
149
+ "params": decay_params,
150
+ "weight_decay": weight_decay,
151
+ },
152
+ {
153
+ "params": no_decay_params,
154
+ "weight_decay": 0.0,
155
+ },
156
+ ]
157
+
158
+ return optimizer_cls(
159
+ optimizer_grouped_parameters,
160
+ lr=initial_lr,
161
+ betas=betas,
162
+ weight_decay=weight_decay
163
+ )
164
+
165
+ def _get_optim_cls(self):
166
+ optimizer = None
167
+ use_lion_optim = self.train_config.optim_config.optim_type == 'lion'
168
+
169
+ if isinstance(TrainerTools().parallel, DsParallel) and self.parallel_kwargs:
170
+ import deepspeed
171
+ if ('zero_optimization' in self.parallel_kwargs
172
+ and 'offload_optimizer' in self.parallel_kwargs['zero_optimization']
173
+ and self.parallel_kwargs['zero_optimization']['offload_optimizer']['device'] == 'cpu'):
174
+ if self.train_config.optim_config.optim_type == 'lion':
175
+ if version.parse(importlib.metadata.version("deepspeed")) >= version.parse('0.17.6'):
176
+ optimizer = deepspeed.ops.lion.DeepSpeedCPULion
177
+ else:
178
+ optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
179
+ use_lion_optim = False
180
+ if TrainerTools().parallel.is_main_process:
181
+ Logger.std_log(
182
+ 'When set offload_optimizer, lion optim is unsupported, so set optim to adam!!!!!')
183
+ else:
184
+ optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
185
+ else:
186
+ if self.train_config.optim_config.optim_type == 'lion':
187
+ optimizer = deepspeed.ops.lion.FusedLion
188
+ else:
189
+ optimizer = deepspeed.ops.adam.FusedAdam
190
+
191
+ if not optimizer:
192
+ if self.train_config.optim_config.optim_type == 'lion':
193
+ try:
194
+ import lion_pytorch
195
+ except:
196
+ raise Exception(
197
+ 'lion is not detected, please use `pip3 install lion_pytorch` to install or set optim_type to adam')
198
+
199
+ optimizer = lion_pytorch.Lion
200
+ else:
201
+ optimizer = torch.optim.AdamW
202
+
203
+ return optimizer, use_lion_optim
204
+
205
+ def _init_lr_scheduler(self, initial_lr: float, optimizer) -> LRScheduler:
206
+ if self.train_config.optim_config.enable_lr_scheduler:
207
+ warmup_iters = self.train_config.optim_config.warmup_iters
208
+ min_lr = self.train_config.optim_config.min_lr
209
+ max_lr = self.train_config.optim_config.max_lr
210
+ cosine_annealing_period = self.train_config.optim_config.cosine_annealing_period
211
+ cosine_annealing_period_mul = self.train_config.optim_config.cosine_annealing_period_mul
212
+
213
+ return WarmupCosineAnnealingLRScheduler(
214
+ optimizer=optimizer,
215
+ warmup_iters=warmup_iters,
216
+ initial_lr=initial_lr,
217
+ min_lr=min_lr,
218
+ max_lr=max_lr,
219
+ cosine_annealing_period=cosine_annealing_period,
220
+ cosine_annealing_period_mul=cosine_annealing_period_mul,
221
+ need_log=TrainerTools().parallel.is_main_process
222
+ )
223
+
224
+ return NoneLRScheduler(initial_lr)
225
+
226
+ def _init_loss(self):
227
+ critical_tokens: Optional[List[int]] = None
228
+ critical_alpha: float = 1.0
229
+ if self.train_config.loss_config.critical_tokens:
230
+ critical_tokens = self.train_config.loss_config.critical_tokens
231
+ critical_alpha = self.train_config.loss_config.critical_alpha
232
+
233
+ criterion = LMLoss(
234
+ critical_tokens=critical_tokens,
235
+ critical_alpha=critical_alpha,
236
+ vocab_size=TrainerTools().tokenizer.vocab_size
237
+ )
238
+
239
+ kd_loss = KDLoss() if self.kd_config else None
240
+
241
+ return criterion, kd_loss
242
+
243
+ def _load_train_model_checkpoint(self):
244
+ load_checkpoint(
245
+ self.train_model,
246
+ optimizer=self.optimizer,
247
+ device=TrainerTools().parallel.device
248
+ )
249
+
250
+ def _apply_restore_ckpt(self):
251
+ steps_dict = load_steps()
252
+ if steps_dict:
253
+ self.resume_epoch = steps_dict.get('epoch', 0)
254
+ self.resume_file_idx = steps_dict.get('file_idx', 0)
255
+ self.resume_batch_idx = steps_dict.get('batch_idx', 0)
256
+
257
+ self.lr_scheduler.restore_ckpt_dict(steps_dict)
258
+
259
+ if TrainerTools().parallel.is_main_process:
260
+ Logger.std_log(f'restore steps_dict={steps_dict}')
261
+
262
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
263
+ parallel_kwargs: Optional[Dict[str, Any]] = None
264
+ if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
265
+ parallel_kwargs = {
266
+ 'gradient_accumulation_steps': 1,
267
+ 'gradient_clipping': self.train_config.ds_config.gradient_clipping,
268
+ 'train_micro_batch_size_per_gpu': self.train_config.batch_size
269
+ }
270
+
271
+ if self.train_config.ds_config.zero_config:
272
+ zero_config = self.train_config.ds_config.zero_config
273
+ zero_optimization: Dict[str, Any] = {'stage': zero_config.stage}
274
+
275
+ if zero_config.allgather_partitions is not None:
276
+ zero_optimization['allgather_partitions'] = zero_config.allgather_partitions
277
+ if zero_config.allgather_bucket_size is not None:
278
+ zero_optimization['allgather_bucket_size'] = zero_config.allgather_bucket_size
279
+ if zero_config.overlap_comm is not None:
280
+ zero_optimization['overlap_comm'] = zero_config.overlap_comm
281
+ if zero_config.reduce_scatter is not None:
282
+ zero_optimization['reduce_scatter'] = zero_config.reduce_scatter
283
+ if zero_config.reduce_bucket_size is not None:
284
+ zero_optimization['reduce_bucket_size'] = zero_config.reduce_bucket_size
285
+ if zero_config.contiguous_gradients is not None:
286
+ zero_optimization['contiguous_gradients'] = zero_config.contiguous_gradients
287
+
288
+ if isinstance(zero_config, DsZero2Config) or isinstance(zero_config, DsZero3Config):
289
+ if zero_config.offload_optimizer is not None:
290
+ zero_optimization['offload_optimizer'] = {
291
+ "device": zero_config.offload_optimizer.device,
292
+ "pin_memory": zero_config.offload_optimizer.pin_memory
293
+ }
294
+ if zero_config.offload_param is not None:
295
+ zero_optimization['offload_param'] = {
296
+ "device": zero_config.offload_param.device,
297
+ "pin_memory": zero_config.offload_param.pin_memory
298
+ }
299
+
300
+ if isinstance(zero_config, DsZero3Config):
301
+ if zero_config.sub_group_size is not None:
302
+ zero_optimization['sub_group_size'] = zero_config.sub_group_size
303
+ if zero_config.stage3_prefetch_bucket_size is not None:
304
+ zero_optimization['stage3_prefetch_bucket_size'] = zero_config.stage3_prefetch_bucket_size
305
+ if zero_config.stage3_param_persistence_threshold is not None:
306
+ zero_optimization['stage3_param_persistence_threshold'] = zero_config.stage3_param_persistence_threshold
307
+ if zero_config.stage3_max_live_parameters is not None:
308
+ zero_optimization['stage3_max_live_parameters'] = zero_config.stage3_max_live_parameters
309
+ if zero_config.stage3_max_reuse_distance is not None:
310
+ zero_optimization['stage3_max_reuse_distance'] = zero_config.stage3_max_reuse_distance
311
+ if zero_config.stage3_gather_16bit_weights_on_model_save is not None:
312
+ zero_optimization['stage3_gather_16bit_weights_on_model_save'] = zero_config.stage3_gather_16bit_weights_on_model_save
313
+
314
+ parallel_kwargs['zero_optimization'] = zero_optimization
315
+
316
+ if (self.train_config.ds_config.bf16_config is not None
317
+ and self.train_config.ds_config.bf16_config.enabled):
318
+ bf16_config = self.train_config.ds_config.bf16_config
319
+ bf16 = {
320
+ 'enabled': bf16_config.enabled
321
+ }
322
+ parallel_kwargs['bf16'] = bf16
323
+ elif self.train_config.ds_config.fp16_config:
324
+ fp16_config = self.train_config.ds_config.fp16_config
325
+ fp16 = {
326
+ 'enabled': fp16_config.enabled,
327
+ 'loss_scale': fp16_config.loss_scale,
328
+ 'loss_scale_window': fp16_config.loss_scale_window,
329
+ 'initial_scale_power': fp16_config.initial_scale_power,
330
+ 'hysteresis': fp16_config.hysteresis,
331
+ 'min_loss_scale': fp16_config.min_loss_scale
332
+ }
333
+
334
+ if fp16_config.fp16_opt_level is not None:
335
+ fp16['fp16_opt_level'] = fp16_config.fp16_opt_level
336
+
337
+ parallel_kwargs['fp16'] = fp16
338
+
339
+ if self.train_config.ds_config.activation_checkpointing:
340
+ activation_checkpointing_config = self.train_config.ds_config.activation_checkpointing
341
+ activation_checkpointing: Dict[str, Any] = {
342
+ 'partition_activations': activation_checkpointing_config.partition_activations,
343
+ 'cpu_checkpointing': activation_checkpointing_config.cpu_checkpointing,
344
+ 'contiguous_memory_optimization': activation_checkpointing_config.contiguous_memory_optimization,
345
+ 'synchronize_checkpoint_boundary': activation_checkpointing_config.synchronize_checkpoint_boundary,
346
+ 'profile': activation_checkpointing_config.profile
347
+ }
348
+
349
+ if activation_checkpointing_config.number_checkpoints is not None:
350
+ activation_checkpointing['number_checkpoints'] = activation_checkpointing_config.number_checkpoints
351
+
352
+ parallel_kwargs['activation_checkpointing'] = activation_checkpointing
353
+
354
+ dataloader_args = self.train_config.data_loader_config
355
+ data_loader_kwargs = {
356
+ "batch_size": self.train_config.batch_size,
357
+ "pin_memory": dataloader_args.data_loader_pin_memory,
358
+ "num_workers": dataloader_args.data_loader_num_workers,
359
+ "shuffle": dataloader_args.data_loader_shuffle,
360
+ "drop_last": dataloader_args.data_loader_drop_last,
361
+ }
362
+ sampler_kwargs = {
363
+ "shuffle": dataloader_args.data_loader_shuffle,
364
+ "drop_last": dataloader_args.data_loader_drop_last,
365
+ }
366
+
367
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
368
+
369
+ def _init_ref_model_args(self) -> dict:
370
+ parallel_kwargs = copy.deepcopy(self.parallel_kwargs) if self.parallel_kwargs else None
371
+
372
+ if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
373
+ # reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
374
+ # if model is not None:
375
+ # hidden_size = (
376
+ # max(model.config.hidden_sizes)
377
+ # if getattr(model.config, "hidden_sizes", None)
378
+ # else getattr(model.config, "hidden_size", None)
379
+ # )
380
+ # if hidden_size is not None and stage == 3:
381
+ # # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
382
+ # # @ step 0: expected module 1, but got module 0`
383
+ # # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
384
+ # config_kwargs.update(
385
+ # {
386
+ # "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
387
+ # "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
388
+ # "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
389
+ # }
390
+ # )
391
+
392
+ parallel_kwargs.pop('activation_checkpointing', None)
393
+ parallel_kwargs.pop('gradient_clipping', None)
394
+
395
+ # ref_model暂时先使用stage 0, 解决训练卡住问题
396
+ parallel_kwargs["zero_optimization"] = {"stage": 0}
397
+ # if parallel_kwargs.get("zero_optimization", {}).get("stage", 0) != 3:
398
+ # parallel_kwargs["zero_optimization"] = {"stage": 0}
399
+
400
+ return parallel_kwargs
401
+
402
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]: ...
403
+
404
+ def _calc_loss(self, inputs, attention_mask, logits, labels):
405
+ # calc loss
406
+ if not self.kd_loss or self.kd_config.kd_coef == 0.0:
407
+ # 不用计算kd_loss
408
+ return self.criterion(logits, labels)
409
+
410
+ teacher_logits = self.kd_config.teacher_logits_provider(inputs, attention_mask)
411
+ loss = self.kd_loss(logits, teacher_logits, labels)
412
+
413
+ if self.kd_config.kd_coef == 1.0:
414
+ # 不用计算ce loss
415
+ return loss
416
+
417
+ ce_loss = self.criterion(logits, labels)
418
+ return (1 - self.kd_config.kd_coef) * ce_loss + self.kd_config.kd_coef * loss
419
+
420
+ def _backward_loss(self, loss):
421
+ if isinstance(TrainerTools().parallel, DsParallel):
422
+ self.train_model.backward(loss)
423
+ else:
424
+ self.scaler.scale(loss).backward()
425
+
426
+ def _apply_grad_clipping(self):
427
+ # ds模式已经集成gradient_clipping
428
+ if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
429
+ # clip grad
430
+ self.scaler.unscale_(self.optimizer)
431
+
432
+ trainable_params = filter(lambda p: p.requires_grad, self.train_model.parameters())
433
+ torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
434
+
435
+ def _apply_step(self):
436
+ self.lr_scheduler.step()
437
+ if isinstance(TrainerTools().parallel, DsParallel):
438
+ self.train_model.step()
439
+ else:
440
+ self.scaler.step(self.optimizer)
441
+ self.scaler.update()
442
+ self.optimizer.zero_grad(set_to_none=True)
443
+
444
+ TrainerTools().parallel.synchronize()
445
+
446
+ def _get_eval_data(self) -> Optional[str]:
447
+ if len(self.eval_prompts) == 0:
448
+ return None
449
+
450
+ self.eval_idx += 1
451
+ if self.eval_idx == len(self.eval_prompts):
452
+ self.eval_idx = 0
453
+
454
+ return self.eval_prompts[self.eval_idx]
455
+
456
+ def _get_eval_pixel_values_and_tokens_count(self, eval_idx):
457
+ return None, None
458
+
459
+ def _log(self, keys: Dict[str, any], values: Dict[str, any]):
460
+ """
461
+ 格式:keys_key1: keys_value1, keys_key2: keys_value2 -> values_key1: values_value1, values_key2: values_value2
462
+ """
463
+ if TrainerTools().parallel.is_main_process:
464
+ log_tags = ', '.join([f'{k}: {v}' for k, v in keys.items()])
465
+ log_values = ', '.join([f'{k}: {v}' for k, v in values.items()])
466
+
467
+ log_msg = f'{log_tags} -> {log_values}'
468
+ self.logger.log(log_msg)
469
+
470
+ def _on_exception(
471
+ self,
472
+ e: Exception,
473
+ epoch: int,
474
+ batch: int
475
+ ):
476
+ exception_file = e.__traceback__.tb_frame.f_globals["__file__"]
477
+ exception_line = e.__traceback__.tb_lineno
478
+ log_msg = f"epoch: {epoch}, batch: {batch} -> {e} at {exception_file} line {exception_line}"
479
+ Logger('exception.txt').log(log_msg, log_to_console=False).release()
480
+
481
+ raise e
482
+
483
+ def _get_model_dtype(self):
484
+ if isinstance(TrainerTools().parallel, DsParallel):
485
+ import deepspeed
486
+ assert isinstance(self.train_model, deepspeed.DeepSpeedEngine)
487
+ return self.train_model.get_data_types()[0]
488
+ else:
489
+ return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
490
+
491
+ def _eval(self, tag: str):
492
+ with unwrap_model_for_generation(self.train_model) as eval_model:
493
+ if TrainerTools().parallel.is_main_process:
494
+ eval_prompt = self._get_eval_data()
495
+
496
+ if eval_prompt:
497
+ eval_model = self._check_eval_model(eval_model)
498
+ eval_model.eval()
499
+
500
+ eval_pixel_values, tokens_per_image = self._get_eval_pixel_values_and_tokens_count(self.eval_idx)
501
+ submit_gen_task(
502
+ eval_model,
503
+ self.train_config,
504
+ tag=tag,
505
+ prompt=eval_prompt,
506
+ pixel_values=eval_pixel_values,
507
+ tokens_per_image=tokens_per_image
508
+ )
509
+
510
+ eval_model.train()
511
+
512
+ TrainerTools().parallel.wait('eval')
513
+
514
+ def _check_eval_model(self, eval_model):
515
+ return eval_model
516
+
517
+ def _on_batch_end(self, tag: str):
518
+ self._eval(f'sign:batch/{tag}')
519
+
520
+ def _on_epoch_end(self, tag: str):
521
+ self._eval(f'sign:epoch/{tag}')
522
+
523
+ def _on_file_start(
524
+ self,
525
+ epoch: int,
526
+ file_name: str
527
+ ):
528
+ if TrainerTools().parallel.is_main_process:
529
+ self.logger.log(f"====epoch: {epoch}, start train {file_name}====", log_to_console=False)
530
+
531
+ def _avg_loss(
532
+ self,
533
+ losses: List[float],
534
+ gradient_accumulation_steps,
535
+ batches_accumulated
536
+ ) -> List[float]:
537
+ loss_tensors = [
538
+ torch.tensor(loss * gradient_accumulation_steps / batches_accumulated,
539
+ device=TrainerTools().parallel.device)
540
+ for loss in losses
541
+ ]
542
+
543
+ stacked_losses = torch.stack(loss_tensors)
544
+ if TrainerTools().parallel.parallel_train:
545
+ dist.all_reduce(stacked_losses, dist.ReduceOp.AVG)
546
+
547
+ return stacked_losses.detach().cpu().tolist()
548
+
549
+ def _get_pixel_values(self, batch_data):
550
+ return None
551
+
552
+ def train(self):
553
+ # 梯度累积步数
554
+ gradient_accumulation_steps = max(1, self.gradient_accumulation_steps)
555
+
556
+ loss_accumulation = 0.0
557
+ aux_loss_accumulation = 0.0
558
+ batches_accumulated = 0
559
+
560
+ for epoch in range(self.resume_epoch, self.train_config.n_epochs):
561
+ self.train_model.train()
562
+ file_count = len(self.train_config.file_dataset)
563
+ start_file_idx = self.resume_file_idx if epoch == self.resume_epoch else 0
564
+
565
+ for file_idx in range(start_file_idx, file_count):
566
+ dataset, file_path = self._create_dataset(file_idx)
567
+ train_data_loader = TrainerTools().parallel.process_dataloader(
568
+ dataset=dataset,
569
+ data_loader_kwargs=self.data_loader_kwargs,
570
+ sampler_kwargs=self.sampler_kwargs
571
+ )
572
+
573
+ last_ckpt_batch = 0
574
+ batch_count_per_file = len(train_data_loader)
575
+
576
+ TrainerTools().parallel.on_epoch_start(epoch)
577
+ self._on_file_start(epoch, file_path)
578
+
579
+ skip_batches = 0
580
+ if epoch == self.resume_epoch and file_idx == self.resume_file_idx:
581
+ skip_batches = self.resume_batch_idx
582
+ if skip_batches > 0 and TrainerTools().parallel.is_main_process:
583
+ Logger.std_log(f"Fast forwarding {skip_batches} batches in {file_path}...")
584
+
585
+ data_iterator = iter(train_data_loader)
586
+
587
+ if skip_batches > 0:
588
+ data_iterator = islice(data_iterator, skip_batches, None)
589
+ last_ckpt_batch = skip_batches
590
+
591
+ for batch, batch_data in enumerate(data_iterator):
592
+ batch = skip_batches + batch
593
+
594
+ # 是否需要更新梯度
595
+ if gradient_accumulation_steps > 1:
596
+ need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
597
+ else:
598
+ need_update_grad = True
599
+
600
+ inputs = batch_data['inputs']
601
+ labels = batch_data['labels']
602
+
603
+ try:
604
+ inputs, labels = inputs.to(TrainerTools().parallel.device), labels.to(TrainerTools().parallel.device)
605
+ attention_mask = inputs != TrainerTools().tokenizer.pad
606
+ pixel_values = self._get_pixel_values(batch_data)
607
+
608
+ if TrainerTools().parallel.parallel_train:
609
+ self.train_model.require_backward_grad_sync = need_update_grad
610
+
611
+ with autocast(TrainerTools().parallel.device_type):
612
+ result = self.train_model(
613
+ inputs,
614
+ attention_mask=attention_mask,
615
+ pixel_values=pixel_values
616
+ )
617
+
618
+ # calc loss
619
+ loss = self._calc_loss(inputs, attention_mask, result['logits'], labels)
620
+ if result['aux_loss'] and self.train_config.loss_config.aux_loss_coef:
621
+ aux_loss = self.train_config.loss_config.aux_loss_coef * result['aux_loss']
622
+ else:
623
+ aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
624
+
625
+ if gradient_accumulation_steps > 1:
626
+ loss = loss / gradient_accumulation_steps
627
+ aux_loss = aux_loss / gradient_accumulation_steps
628
+
629
+ total_loss = loss + aux_loss
630
+ self._backward_loss(total_loss)
631
+
632
+ loss_accumulation += total_loss.detach().item()
633
+ aux_loss_accumulation += aux_loss.detach().item()
634
+
635
+ batches_accumulated += 1
636
+
637
+ if need_update_grad:
638
+ self._apply_grad_clipping()
639
+ self._apply_step()
640
+
641
+ avg_loss, avg_aux_loss = self._avg_loss(
642
+ losses=[
643
+ loss_accumulation,
644
+ aux_loss_accumulation
645
+ ],
646
+ gradient_accumulation_steps=gradient_accumulation_steps,
647
+ batches_accumulated=batches_accumulated
648
+ )
649
+
650
+ self._log(
651
+ keys={
652
+ 'epoch': epoch,
653
+ 'file': f'{file_idx + 1}/{file_count}',
654
+ 'batch': f'{batch + 1}/{batch_count_per_file}'
655
+ },
656
+ values={
657
+ 'loss': avg_loss,
658
+ 'moe_aux_loss': avg_aux_loss
659
+ }
660
+ )
661
+
662
+ # reset to default
663
+ loss_accumulation = 0.0
664
+ aux_loss_accumulation = 0.0
665
+ batches_accumulated = 0
666
+
667
+ if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
668
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
669
+ save_steps(
670
+ epoch=epoch,
671
+ file_idx=file_idx,
672
+ batch_idx=batch + 1,
673
+ lr_scheduler=self.lr_scheduler
674
+ )
675
+
676
+ last_ckpt_batch = batch
677
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
678
+ except Exception as e:
679
+ self._on_exception(e, epoch, batch)
680
+
681
+ # 一个文件训练结束后,清理内存
682
+ del train_data_loader
683
+ del dataset
684
+ if hasattr(TrainerTools().parallel, '_sampler'):
685
+ TrainerTools().parallel._sampler = None
686
+
687
+ gc.collect()
688
+ torch.cuda.empty_cache()
689
+
690
+ # end epoch
691
+
692
+ # reset resume state
693
+ self.resume_file_idx = 0
694
+ self.resume_batch_idx = 0
695
+
696
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
697
+ save_steps(
698
+ epoch=epoch + 1,
699
+ file_idx=0,
700
+ batch_idx=0,
701
+ lr_scheduler=self.lr_scheduler
702
+ )
703
+
704
+ TrainerTools().parallel.on_epoch_end(epoch)
705
+ self._on_epoch_end(tag=f'epoch:{epoch}')
706
+
707
+ TrainerTools().parallel.destroy()