project-llm-trainer 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/trainer.py ADDED
@@ -0,0 +1,569 @@
1
+ import time
2
+ from contextlib import nullcontext
3
+ from typing import Optional, Tuple, List, Dict, Any
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.distributed as dist
8
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9
+ from torch.utils.data import Dataset
10
+ from llm_model import LlmModel, VlmModel
11
+
12
+ from .parallel_ds import DsParallel
13
+ from .parallel_fsdp import FsdpParallel
14
+ from .tools import TrainerTools
15
+ from .loss import LMLoss, KDLoss
16
+ from .dataset import TextDataset
17
+
18
+ from .train_configs import (
19
+ TrainConfig,
20
+ VLMConfig,
21
+ DsZero2Config,
22
+ DsZero3Config
23
+ )
24
+
25
+ from .scheduler import (
26
+ LRScheduler,
27
+ WarmupCosineAnnealingLRScheduler,
28
+ NoneLRScheduler
29
+ )
30
+
31
+ from .checkpoint import (
32
+ load_checkpoint,
33
+ save_checkpoint,
34
+ load_steps,
35
+ save_steps,
36
+ )
37
+ from .utils import (
38
+ set_seed,
39
+ pretrain_collate_fn,
40
+ )
41
+
42
+ from .log import(
43
+ log,
44
+ get_log_dir
45
+ )
46
+
47
+ from .eval import submit_gen_task
48
+
49
+ class Trainer:
50
+ def __init__(
51
+ self,
52
+ *,
53
+ train_config: TrainConfig,
54
+ eval_prompts: List[str],
55
+ eval_image_tags: Optional[List[int]] = None
56
+ ):
57
+ set_seed()
58
+
59
+ self.train_config: TrainConfig = train_config
60
+ self.eval_prompts = eval_prompts
61
+ self.eval_image_tags = eval_image_tags
62
+ self.eval_idx = -1
63
+
64
+ if self.eval_image_tags:
65
+ assert len(self.eval_prompts) == len(self.eval_image_tags)
66
+
67
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = self._convert_train_args()
68
+ self.data_loader_kwargs: dict[str, Any] = data_loader_kwargs
69
+ self.sampler_kwargs: dict[str, Any] = sampler_kwargs
70
+
71
+ # initialize a GradScaler. If enabled=False scaler is a no-op
72
+ self.scalar = torch.GradScaler(enabled=TrainerTools().use_amp)
73
+
74
+ # 注意:学习率要根据GPU的数量进行倍增:
75
+ # 在训练的过程中,损失梯度决定下降的方向,学习率决定下降的步长。如果有两块gpu,前进的综合步长为:平均学习率*2
76
+ initial_lr = train_config.lr_config.initial_lr
77
+
78
+ self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr, parallel_kwargs, use_ds_optim)
79
+ self.lr_scheduler = self._init_lr_scheduler(initial_lr)
80
+ self.eval_model: Optional[nn.Module] = self._init_eval_model()
81
+
82
+ self.criterion, self.kd_loss = self._init_loss()
83
+
84
+ self.ctx = torch.autocast(
85
+ device_type=TrainerTools().parallel.device_type,
86
+ dtype=TrainerTools().dtype,
87
+ enabled=True,
88
+ # fsdp模式,需要将cache_enabled设置为false
89
+ # https://www.zhihu.com/question/642793891
90
+ cache_enabled=False if isinstance(self.train_model, FSDP) else None
91
+ ) if TrainerTools().use_amp else nullcontext()
92
+
93
+ load_checkpoint(
94
+ self.train_model,
95
+ optimizer=self.optimizer,
96
+ device=TrainerTools().parallel.device
97
+ )
98
+
99
+ last_global_steps, last_lr_steps = load_steps(0, -1)
100
+ self.last_global_steps = last_global_steps
101
+ log(f'last_global_steps={last_global_steps}, last_lr_steps={last_lr_steps}')
102
+
103
+ if last_lr_steps != -1:
104
+ self.lr_scheduler.update_steps(last_lr_steps)
105
+
106
+ if isinstance(train_config.model_config, VLMConfig):
107
+ self.pixel_values_provider = train_config.pixel_values_provider
108
+ self.tokens_per_image = train_config.model_config.tokens_per_image
109
+ else:
110
+ self.pixel_values_provider = None
111
+ self.tokens_per_image = -1
112
+
113
+ def _init_train_model_and_optim(
114
+ self,
115
+ initial_lr: float,
116
+ parallel_kwargs: dict,
117
+ use_ds_optim: bool
118
+ ):
119
+ if isinstance(self.train_config.model_config, VLMConfig):
120
+ model = VlmModel(self.train_config.model_config)
121
+ else:
122
+ model = LlmModel(self.train_config.model_config)
123
+
124
+ if self.train_config.init_state_dict:
125
+ model.load_state_dict(self.train_config.init_state_dict, strict=False)
126
+ self.train_config.init_state_dict = None
127
+
128
+ if TrainerTools().parallel.is_main_process:
129
+ total_params = sum(p.numel() for p in model.parameters())
130
+ log(f"Total number of parameters: {total_params:,}")
131
+
132
+ total_size_bytes = total_params * 4
133
+ total_size_mb = total_size_bytes / (1024 * 1024)
134
+ log(f"Total size of the model: {total_size_mb:.2f} MB")
135
+
136
+ if use_ds_optim:
137
+ import deepspeed
138
+ origin_optim = deepspeed.ops.adam.DeepSpeedCPUAdam(
139
+ model.parameters(),
140
+ lr=initial_lr,
141
+ weight_decay=self.train_config.lr_config.weight_decay
142
+ )
143
+ else:
144
+ origin_optim = torch.optim.AdamW(
145
+ model.parameters(),
146
+ lr=initial_lr,
147
+ weight_decay=self.train_config.lr_config.weight_decay
148
+ )
149
+ model, optim = TrainerTools().parallel.process(
150
+ model=model,
151
+ optimizer=origin_optim,
152
+ kwargs=parallel_kwargs
153
+ )
154
+
155
+ return model, optim
156
+
157
+ def _init_eval_model(self) -> Optional[nn.Module]:
158
+ if TrainerTools().parallel.is_main_process:
159
+ if isinstance(self.train_config.model_config, VLMConfig):
160
+ return VlmModel(self.train_config.model_config).to('cpu')
161
+ else:
162
+ return LlmModel(self.train_config.model_config).to('cpu')
163
+
164
+ return None
165
+
166
+ def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
167
+ if self.train_config.lr_config.enable_lr_scheduler:
168
+ min_lr = self.train_config.lr_config.min_lr
169
+ max_lr = self.train_config.lr_config.max_lr
170
+ warmup_iters = self.train_config.lr_config.warmup_iters
171
+ period = self.train_config.lr_config.period
172
+ period_mul = self.train_config.lr_config.period_mul
173
+
174
+ return WarmupCosineAnnealingLRScheduler(
175
+ optimizer=self.optimizer,
176
+ initial_lr=initial_lr,
177
+ min_lr=min_lr,
178
+ max_lr=max_lr,
179
+ warmup_iters=warmup_iters,
180
+ period=period,
181
+ period_mul=period_mul,
182
+ need_log=TrainerTools().parallel.is_main_process
183
+ )
184
+
185
+ return NoneLRScheduler(initial_lr)
186
+
187
+ def _init_loss(self):
188
+ critical_tokens: Optional[List[int]] = None
189
+ critical_alpha: float = 1.0
190
+ if self.train_config.loss_config.critical_tokens:
191
+ critical_tokens = self.train_config.loss_config.critical_tokens
192
+ critical_alpha = self.train_config.loss_config.critical_alpha
193
+
194
+ criterion = LMLoss(
195
+ critical_tokens=critical_tokens,
196
+ critical_alpha=critical_alpha,
197
+ vocab_size=TrainerTools().tokenizer.vocab_size
198
+ )
199
+
200
+ kd_loss = KDLoss() if self.train_config.kd_config else None
201
+
202
+ return criterion, kd_loss
203
+
204
+ def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
205
+ parallel_kwargs: Optional[Dict[str, Any]] = None
206
+ use_ds_optim: bool = False
207
+ if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
208
+ parallel_kwargs = {
209
+ 'gradient_accumulation_steps': 1,
210
+ 'gradient_clipping': self.train_config.ds_config.gradient_clipping,
211
+ 'train_micro_batch_size_per_gpu': self.train_config.batch_size
212
+ }
213
+
214
+ if self.train_config.ds_config.zero_config:
215
+ zero_config = self.train_config.ds_config.zero_config
216
+ zero_optimization: Dict[str, Any] = {'stage': zero_config.stage}
217
+
218
+ if zero_config.allgather_partitions is not None:
219
+ zero_optimization['allgather_partitions'] = zero_config.allgather_partitions
220
+ if zero_config.allgather_bucket_size is not None:
221
+ zero_optimization['allgather_bucket_size'] = zero_config.allgather_bucket_size
222
+ if zero_config.overlap_comm is not None:
223
+ zero_optimization['overlap_comm'] = zero_config.overlap_comm
224
+ if zero_config.reduce_scatter is not None:
225
+ zero_optimization['reduce_scatter'] = zero_config.reduce_scatter
226
+ if zero_config.reduce_bucket_size is not None:
227
+ zero_optimization['reduce_bucket_size'] = zero_config.reduce_bucket_size
228
+ if zero_config.contiguous_gradients is not None:
229
+ zero_optimization['contiguous_gradients'] = zero_config.contiguous_gradients
230
+
231
+ if isinstance(zero_config, DsZero2Config) or isinstance(zero_config, DsZero3Config):
232
+ if zero_config.offload_optimizer is not None:
233
+ zero_optimization['offload_optimizer'] = {
234
+ "device": zero_config.offload_optimizer.device,
235
+ "pin_memory": zero_config.offload_optimizer.pin_memory
236
+ }
237
+ use_ds_optim = True
238
+ if zero_config.offload_param is not None:
239
+ zero_optimization['offload_param'] = {
240
+ "device": zero_config.offload_param.device,
241
+ "pin_memory": zero_config.offload_param.pin_memory
242
+ }
243
+
244
+ if isinstance(zero_config, DsZero3Config):
245
+ if zero_config.sub_group_size is not None:
246
+ zero_optimization['sub_group_size'] = zero_config.sub_group_size
247
+ if zero_config.stage3_prefetch_bucket_size is not None:
248
+ zero_optimization['stage3_prefetch_bucket_size'] = zero_config.stage3_prefetch_bucket_size
249
+ if zero_config.stage3_param_persistence_threshold is not None:
250
+ zero_optimization['stage3_param_persistence_threshold'] = zero_config.stage3_param_persistence_threshold
251
+ if zero_config.stage3_max_live_parameters is not None:
252
+ zero_optimization['stage3_max_live_parameters'] = zero_config.stage3_max_live_parameters
253
+ if zero_config.stage3_max_reuse_distance is not None:
254
+ zero_optimization['stage3_max_reuse_distance'] = zero_config.stage3_max_reuse_distance
255
+ if zero_config.stage3_gather_16bit_weights_on_model_save is not None:
256
+ zero_optimization['stage3_gather_16bit_weights_on_model_save'] = zero_config.stage3_gather_16bit_weights_on_model_save
257
+
258
+ parallel_kwargs['zero_optimization'] = zero_optimization
259
+
260
+ if (self.train_config.ds_config.bf16_config is not None
261
+ and self.train_config.ds_config.bf16_config.enabled):
262
+ bf16_config = self.train_config.ds_config.bf16_config
263
+ bf16 = {
264
+ 'enabled': bf16_config.enabled
265
+ }
266
+ parallel_kwargs['bf16'] = bf16
267
+ elif self.train_config.ds_config.fp16_config:
268
+ fb16_config = self.train_config.ds_config.fp16_config
269
+ fp16 = {
270
+ 'enabled': fb16_config.enabled,
271
+ 'loss_scale': fb16_config.loss_scale,
272
+ 'loss_scale_window': fb16_config.loss_scale_window,
273
+ 'initial_scale_power': fb16_config.initial_scale_power,
274
+ 'hysteresis': fb16_config.hysteresis,
275
+ 'min_loss_scale': fb16_config.min_loss_scale
276
+ }
277
+
278
+ if fb16_config.fp16_opt_level is not None:
279
+ fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
280
+
281
+ parallel_kwargs['fp16'] = fp16
282
+
283
+ if self.train_config.ds_config.activation_checkpointing:
284
+ activation_checkpointing_config = self.train_config.ds_config.activation_checkpointing
285
+ activation_checkpointing: Dict[str, Any] = {
286
+ 'partition_activations': activation_checkpointing_config.partition_activations,
287
+ 'cpu_checkpointing': activation_checkpointing_config.cpu_checkpointing,
288
+ 'contiguous_memory_optimization': activation_checkpointing_config.contiguous_memory_optimization,
289
+ 'synchronize_checkpoint_boundary': activation_checkpointing_config.synchronize_checkpoint_boundary,
290
+ 'profile': activation_checkpointing_config.profile
291
+ }
292
+
293
+ if activation_checkpointing_config.number_checkpoints is not None:
294
+ activation_checkpointing['number_checkpoints'] = activation_checkpointing_config.number_checkpoints
295
+
296
+ parallel_kwargs['activation_checkpointing'] = activation_checkpointing
297
+ elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
298
+ parallel_kwargs = {
299
+ 'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
300
+ 'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
301
+ 'cpu_offload': self.train_config.fsdp_config.cpu_offload,
302
+ 'offload_params': self.train_config.fsdp_config.offload_params
303
+ }
304
+
305
+ dataloader_args = self.train_config.data_loader_config
306
+ data_loader_kwargs = {
307
+ "batch_size": self.train_config.batch_size,
308
+ "pin_memory": dataloader_args.data_loader_pin_memory,
309
+ "collate_fn": pretrain_collate_fn,
310
+ "num_workers": dataloader_args.data_loader_num_workers,
311
+ "shuffle": dataloader_args.data_loader_shuffle,
312
+ "drop_last": dataloader_args.data_loader_drop_last,
313
+ }
314
+ sampler_kwargs = {
315
+ "shuffle": dataloader_args.data_loader_shuffle,
316
+ "drop_last": dataloader_args.data_loader_drop_last,
317
+ }
318
+
319
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
320
+
321
+ def _create_dataset(self, file_path) -> Dataset:
322
+ max_position_embeddings = self.train_config.model_config.max_position_embeddings
323
+ return TextDataset(file_path, max_position_embeddings, max_position_embeddings)
324
+
325
+ def _calc_loss(self, inputs, attention_mask, logits, labels):
326
+ # calc loss
327
+ loss = self.criterion(logits, labels)
328
+
329
+ # 知识蒸馏loss
330
+ if self.kd_loss:
331
+ teacher_logits = self.train_config.kd_config.teacher_logits_provider(inputs, attention_mask)
332
+ distil_loss = self.kd_loss(logits, teacher_logits, labels)
333
+ loss = (1 - self.train_config.kd_config.kd_coef) * loss + self.train_config.kd_config.kd_coef * distil_loss
334
+
335
+ return loss
336
+
337
+ def _backward_loss(self, loss):
338
+ if isinstance(TrainerTools().parallel, DsParallel):
339
+ self.train_model.backward(loss)
340
+ else:
341
+ self.scalar.scale(loss).backward()
342
+
343
+ def _step(self):
344
+ self.lr_scheduler.step()
345
+ if isinstance(TrainerTools().parallel, DsParallel):
346
+ self.train_model.step()
347
+ else:
348
+ self.scalar.step(self.optimizer)
349
+ # optimizer.step()
350
+ self.scalar.update()
351
+ # flush the gradients as soon as we can, no need for this memory anymore
352
+ self.optimizer.zero_grad(set_to_none=True)
353
+
354
+ TrainerTools().parallel.synchronize()
355
+
356
+ def _get_eval_data(self) -> Tuple[str, Optional[int]]:
357
+ if len(self.eval_prompts) == 0:
358
+ return '', None
359
+
360
+ self.eval_idx += 1
361
+ if self.eval_idx == len(self.eval_prompts):
362
+ self.eval_idx = 0
363
+
364
+ if not self.eval_image_tags:
365
+ return self.eval_prompts[self.eval_idx], None
366
+
367
+ return self.eval_prompts[self.eval_idx], self.eval_image_tags[self.eval_idx]
368
+
369
+ def _log_loss(
370
+ self,
371
+ epoch_tag: str,
372
+ file_tag: str,
373
+ batch_tag: str,
374
+ loss
375
+ ):
376
+ if TrainerTools().parallel.is_main_process:
377
+ log_dir = get_log_dir()
378
+ log_msg = f"{epoch_tag}, {file_tag}, {batch_tag}, loss: {loss}"
379
+ log(log_msg)
380
+ log(f"{log_msg}\n", f'{log_dir}log.txt')
381
+
382
+ def _on_exception(
383
+ self,
384
+ e: Exception,
385
+ epoch: int,
386
+ batch: int
387
+ ):
388
+ log_dir = get_log_dir()
389
+ exception_file = e.__traceback__.tb_frame.f_globals["__file__"]
390
+ exception_line = e.__traceback__.tb_lineno
391
+ log_msg = f"epoch: {epoch}, batch: {batch}, {e} at {exception_file} line {exception_line}\n"
392
+ log(log_msg, f'{log_dir}log.txt')
393
+
394
+ raise e
395
+
396
+ def _on_batch_end(
397
+ self,
398
+ tag: str
399
+ ):
400
+ if TrainerTools().parallel.is_main_process:
401
+ eval_prompt, eval_image_tag = self._get_eval_data()
402
+ if isinstance(self.train_config.model_config, VLMConfig) and eval_image_tag:
403
+ eval_pixel_values = self.pixel_values_provider([eval_image_tag])
404
+ else:
405
+ eval_pixel_values = None
406
+
407
+ submit_gen_task(
408
+ self.eval_model,
409
+ self.train_config.eval_config,
410
+ tag=f'sign:batch/{tag}',
411
+ prompt=eval_prompt,
412
+ pixel_values=eval_pixel_values,
413
+ max_position_embeddings=self.train_config.model_config.max_position_embeddings,
414
+ tokens_per_image=self.tokens_per_image
415
+ )
416
+ TrainerTools().parallel.wait()
417
+
418
+ def _on_epoch_end(
419
+ self,
420
+ tag: str
421
+ ):
422
+ if TrainerTools().parallel.is_main_process:
423
+ eval_prompt, eval_image_tag = self._get_eval_data()
424
+ if isinstance(self.train_config.model_config, VLMConfig) and eval_image_tag:
425
+ eval_pixel_values = self.pixel_values_provider([eval_image_tag])
426
+ else:
427
+ eval_pixel_values = None
428
+
429
+ submit_gen_task(
430
+ self.eval_model,
431
+ self.train_config.eval_config,
432
+ tag=f'sign:epoch/{tag}',
433
+ prompt=eval_prompt,
434
+ pixel_values=eval_pixel_values,
435
+ max_position_embeddings=self.train_config.model_config.max_position_embeddings,
436
+ tokens_per_image=self.tokens_per_image
437
+ )
438
+
439
+ TrainerTools().parallel.wait()
440
+
441
+ def _on_file_start(
442
+ self,
443
+ epoch: int,
444
+ file_name: str
445
+ ):
446
+ if TrainerTools().parallel.is_main_process:
447
+ log(f"epoch: {epoch}, start train {file_name}\n", f'{get_log_dir()}log.txt')
448
+
449
+ def train(self):
450
+ # 梯度累积步数
451
+ gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
452
+ global_steps = 0
453
+ loss_accumulation = 0.0
454
+ skipping_train = False
455
+
456
+ for epoch in range(self.train_config.n_epochs):
457
+ self.train_model.train()
458
+ file_count = len(self.train_config.file_dataset)
459
+
460
+ for file_idx in range(file_count):
461
+ file_path = self.train_config.file_dataset[file_idx]
462
+
463
+ dataset = self._create_dataset(file_path)
464
+ train_data_loader = TrainerTools().parallel.process_dataloader(
465
+ dataset=dataset,
466
+ data_loader_kwargs=self.data_loader_kwargs,
467
+ sampler_kwargs=self.sampler_kwargs
468
+ )
469
+
470
+ last_ckpt_batch = 0
471
+ batch_count_per_file = len(train_data_loader)
472
+
473
+ TrainerTools().parallel.on_epoch_start(epoch)
474
+ self._on_file_start(epoch, file_path)
475
+
476
+ for batch, batch_data in enumerate(train_data_loader):
477
+ global_steps += 1
478
+ if global_steps < self.last_global_steps:
479
+ skipping_train = True
480
+ continue
481
+
482
+ skipping_train = False
483
+
484
+ # 是否需要更新梯度
485
+ if gradient_accumulation_steps > 1:
486
+ need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
487
+ else:
488
+ need_update_grad = True
489
+
490
+ inputs = batch_data['inputs']
491
+ labels = batch_data['labels']
492
+
493
+ try:
494
+ inputs, labels = inputs.to(TrainerTools().parallel.device), labels.to(TrainerTools().parallel.device)
495
+ attention_mask = inputs != TrainerTools().tokenizer.pad
496
+
497
+ if TrainerTools().parallel.parallel_train:
498
+ self.train_model.require_backward_grad_sync = need_update_grad
499
+
500
+ if self.pixel_values_provider and 'image_tags' in batch_data:
501
+ image_tags = batch_data['image_tags']
502
+ pixel_values = self.pixel_values_provider(image_tags).to(TrainerTools().parallel.device)
503
+ else:
504
+ pixel_values = None
505
+
506
+ with self.ctx:
507
+ result = self.train_model(
508
+ inputs,
509
+ attention_mask=attention_mask,
510
+ pixel_values=pixel_values
511
+ )
512
+
513
+ # calc loss
514
+ loss = self._calc_loss(inputs, attention_mask, result['logits'], labels)
515
+ if result['aux_loss'] and self.train_config.loss_config.aux_loss_coef:
516
+ loss += self.train_config.loss_config.aux_loss_coef * result['aux_loss']
517
+
518
+ if gradient_accumulation_steps > 1:
519
+ loss = loss / gradient_accumulation_steps
520
+
521
+ loss_accumulation += loss.detach()
522
+ self._backward_loss(loss)
523
+
524
+ if need_update_grad:
525
+ # todo check all_reduce??
526
+ if TrainerTools().parallel.parallel_train:
527
+ dist.all_reduce(loss_accumulation, dist.ReduceOp.AVG)
528
+
529
+ # ds模式已经集成gradient_clipping
530
+ if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
531
+ # clip grad
532
+ self.scalar.unscale_(self.optimizer)
533
+ torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
534
+
535
+ self._step()
536
+
537
+ self._log_loss(
538
+ epoch_tag=f'epoch: {epoch}',
539
+ file_tag=f'file: {file_idx + 1}/{file_count}',
540
+ batch_tag=f'batch: {batch}/{batch_count_per_file}',
541
+ loss=loss_accumulation.item()
542
+ )
543
+ # reset to default
544
+ loss_accumulation = 0.0
545
+ except Exception as e:
546
+ self._on_exception(e, epoch, batch)
547
+ finally:
548
+ if need_update_grad:
549
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
550
+
551
+ if (batch - last_ckpt_batch) >= self.train_config.eval_batch_interval:
552
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
553
+ last_ckpt_batch = batch
554
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
555
+
556
+ try:
557
+ del loss
558
+ except UnboundLocalError: ...
559
+
560
+ # end epoch
561
+ if not skipping_train:
562
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
563
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
564
+ TrainerTools().parallel.on_epoch_end(epoch)
565
+ self._on_epoch_end(tag=f'epoch:{epoch}')
566
+
567
+ # 等待checkpoint保存完成
568
+ time.sleep(10)
569
+ TrainerTools().parallel.destroy()