project-llm-trainer 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +6 -0
- llm_trainer/checkpoint.py +161 -0
- llm_trainer/dataset.py +140 -0
- llm_trainer/dcp.py +93 -0
- llm_trainer/dpo_trainer.py +300 -0
- llm_trainer/ds_checkpoint.py +61 -0
- llm_trainer/eval.py +86 -0
- llm_trainer/generate_utils.py +424 -0
- llm_trainer/grpo_trainer.py +393 -0
- llm_trainer/log.py +16 -0
- llm_trainer/loss.py +171 -0
- llm_trainer/parallel.py +146 -0
- llm_trainer/parallel_ddp.py +39 -0
- llm_trainer/parallel_ds.py +45 -0
- llm_trainer/parallel_fsdp.py +115 -0
- llm_trainer/parallel_none.py +28 -0
- llm_trainer/scheduler.py +138 -0
- llm_trainer/sft_trainer.py +39 -0
- llm_trainer/tokenizer.py +166 -0
- llm_trainer/tools.py +102 -0
- llm_trainer/train_configs.py +445 -0
- llm_trainer/trainer.py +569 -0
- llm_trainer/utils.py +262 -0
- project_llm_trainer-0.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.3.data/scripts/ddp_train +12 -0
- project_llm_trainer-0.3.data/scripts/ds_train +12 -0
- project_llm_trainer-0.3.data/scripts/plot_loss +39 -0
- project_llm_trainer-0.3.data/scripts/plot_lr +41 -0
- project_llm_trainer-0.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.3.data/scripts/smart_train +28 -0
- project_llm_trainer-0.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.3.dist-info/RECORD +34 -0
- project_llm_trainer-0.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.3.dist-info/top_level.txt +1 -0
llm_trainer/trainer.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
from typing import Optional, Tuple, List, Dict, Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
import torch.distributed as dist
|
|
8
|
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
from llm_model import LlmModel, VlmModel
|
|
11
|
+
|
|
12
|
+
from .parallel_ds import DsParallel
|
|
13
|
+
from .parallel_fsdp import FsdpParallel
|
|
14
|
+
from .tools import TrainerTools
|
|
15
|
+
from .loss import LMLoss, KDLoss
|
|
16
|
+
from .dataset import TextDataset
|
|
17
|
+
|
|
18
|
+
from .train_configs import (
|
|
19
|
+
TrainConfig,
|
|
20
|
+
VLMConfig,
|
|
21
|
+
DsZero2Config,
|
|
22
|
+
DsZero3Config
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from .scheduler import (
|
|
26
|
+
LRScheduler,
|
|
27
|
+
WarmupCosineAnnealingLRScheduler,
|
|
28
|
+
NoneLRScheduler
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from .checkpoint import (
|
|
32
|
+
load_checkpoint,
|
|
33
|
+
save_checkpoint,
|
|
34
|
+
load_steps,
|
|
35
|
+
save_steps,
|
|
36
|
+
)
|
|
37
|
+
from .utils import (
|
|
38
|
+
set_seed,
|
|
39
|
+
pretrain_collate_fn,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
from .log import(
|
|
43
|
+
log,
|
|
44
|
+
get_log_dir
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
from .eval import submit_gen_task
|
|
48
|
+
|
|
49
|
+
class Trainer:
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
*,
|
|
53
|
+
train_config: TrainConfig,
|
|
54
|
+
eval_prompts: List[str],
|
|
55
|
+
eval_image_tags: Optional[List[int]] = None
|
|
56
|
+
):
|
|
57
|
+
set_seed()
|
|
58
|
+
|
|
59
|
+
self.train_config: TrainConfig = train_config
|
|
60
|
+
self.eval_prompts = eval_prompts
|
|
61
|
+
self.eval_image_tags = eval_image_tags
|
|
62
|
+
self.eval_idx = -1
|
|
63
|
+
|
|
64
|
+
if self.eval_image_tags:
|
|
65
|
+
assert len(self.eval_prompts) == len(self.eval_image_tags)
|
|
66
|
+
|
|
67
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = self._convert_train_args()
|
|
68
|
+
self.data_loader_kwargs: dict[str, Any] = data_loader_kwargs
|
|
69
|
+
self.sampler_kwargs: dict[str, Any] = sampler_kwargs
|
|
70
|
+
|
|
71
|
+
# initialize a GradScaler. If enabled=False scaler is a no-op
|
|
72
|
+
self.scalar = torch.GradScaler(enabled=TrainerTools().use_amp)
|
|
73
|
+
|
|
74
|
+
# 注意:学习率要根据GPU的数量进行倍增:
|
|
75
|
+
# 在训练的过程中,损失梯度决定下降的方向,学习率决定下降的步长。如果有两块gpu,前进的综合步长为:平均学习率*2
|
|
76
|
+
initial_lr = train_config.lr_config.initial_lr
|
|
77
|
+
|
|
78
|
+
self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr, parallel_kwargs, use_ds_optim)
|
|
79
|
+
self.lr_scheduler = self._init_lr_scheduler(initial_lr)
|
|
80
|
+
self.eval_model: Optional[nn.Module] = self._init_eval_model()
|
|
81
|
+
|
|
82
|
+
self.criterion, self.kd_loss = self._init_loss()
|
|
83
|
+
|
|
84
|
+
self.ctx = torch.autocast(
|
|
85
|
+
device_type=TrainerTools().parallel.device_type,
|
|
86
|
+
dtype=TrainerTools().dtype,
|
|
87
|
+
enabled=True,
|
|
88
|
+
# fsdp模式,需要将cache_enabled设置为false
|
|
89
|
+
# https://www.zhihu.com/question/642793891
|
|
90
|
+
cache_enabled=False if isinstance(self.train_model, FSDP) else None
|
|
91
|
+
) if TrainerTools().use_amp else nullcontext()
|
|
92
|
+
|
|
93
|
+
load_checkpoint(
|
|
94
|
+
self.train_model,
|
|
95
|
+
optimizer=self.optimizer,
|
|
96
|
+
device=TrainerTools().parallel.device
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
last_global_steps, last_lr_steps = load_steps(0, -1)
|
|
100
|
+
self.last_global_steps = last_global_steps
|
|
101
|
+
log(f'last_global_steps={last_global_steps}, last_lr_steps={last_lr_steps}')
|
|
102
|
+
|
|
103
|
+
if last_lr_steps != -1:
|
|
104
|
+
self.lr_scheduler.update_steps(last_lr_steps)
|
|
105
|
+
|
|
106
|
+
if isinstance(train_config.model_config, VLMConfig):
|
|
107
|
+
self.pixel_values_provider = train_config.pixel_values_provider
|
|
108
|
+
self.tokens_per_image = train_config.model_config.tokens_per_image
|
|
109
|
+
else:
|
|
110
|
+
self.pixel_values_provider = None
|
|
111
|
+
self.tokens_per_image = -1
|
|
112
|
+
|
|
113
|
+
def _init_train_model_and_optim(
|
|
114
|
+
self,
|
|
115
|
+
initial_lr: float,
|
|
116
|
+
parallel_kwargs: dict,
|
|
117
|
+
use_ds_optim: bool
|
|
118
|
+
):
|
|
119
|
+
if isinstance(self.train_config.model_config, VLMConfig):
|
|
120
|
+
model = VlmModel(self.train_config.model_config)
|
|
121
|
+
else:
|
|
122
|
+
model = LlmModel(self.train_config.model_config)
|
|
123
|
+
|
|
124
|
+
if self.train_config.init_state_dict:
|
|
125
|
+
model.load_state_dict(self.train_config.init_state_dict, strict=False)
|
|
126
|
+
self.train_config.init_state_dict = None
|
|
127
|
+
|
|
128
|
+
if TrainerTools().parallel.is_main_process:
|
|
129
|
+
total_params = sum(p.numel() for p in model.parameters())
|
|
130
|
+
log(f"Total number of parameters: {total_params:,}")
|
|
131
|
+
|
|
132
|
+
total_size_bytes = total_params * 4
|
|
133
|
+
total_size_mb = total_size_bytes / (1024 * 1024)
|
|
134
|
+
log(f"Total size of the model: {total_size_mb:.2f} MB")
|
|
135
|
+
|
|
136
|
+
if use_ds_optim:
|
|
137
|
+
import deepspeed
|
|
138
|
+
origin_optim = deepspeed.ops.adam.DeepSpeedCPUAdam(
|
|
139
|
+
model.parameters(),
|
|
140
|
+
lr=initial_lr,
|
|
141
|
+
weight_decay=self.train_config.lr_config.weight_decay
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
origin_optim = torch.optim.AdamW(
|
|
145
|
+
model.parameters(),
|
|
146
|
+
lr=initial_lr,
|
|
147
|
+
weight_decay=self.train_config.lr_config.weight_decay
|
|
148
|
+
)
|
|
149
|
+
model, optim = TrainerTools().parallel.process(
|
|
150
|
+
model=model,
|
|
151
|
+
optimizer=origin_optim,
|
|
152
|
+
kwargs=parallel_kwargs
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return model, optim
|
|
156
|
+
|
|
157
|
+
def _init_eval_model(self) -> Optional[nn.Module]:
|
|
158
|
+
if TrainerTools().parallel.is_main_process:
|
|
159
|
+
if isinstance(self.train_config.model_config, VLMConfig):
|
|
160
|
+
return VlmModel(self.train_config.model_config).to('cpu')
|
|
161
|
+
else:
|
|
162
|
+
return LlmModel(self.train_config.model_config).to('cpu')
|
|
163
|
+
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
|
|
167
|
+
if self.train_config.lr_config.enable_lr_scheduler:
|
|
168
|
+
min_lr = self.train_config.lr_config.min_lr
|
|
169
|
+
max_lr = self.train_config.lr_config.max_lr
|
|
170
|
+
warmup_iters = self.train_config.lr_config.warmup_iters
|
|
171
|
+
period = self.train_config.lr_config.period
|
|
172
|
+
period_mul = self.train_config.lr_config.period_mul
|
|
173
|
+
|
|
174
|
+
return WarmupCosineAnnealingLRScheduler(
|
|
175
|
+
optimizer=self.optimizer,
|
|
176
|
+
initial_lr=initial_lr,
|
|
177
|
+
min_lr=min_lr,
|
|
178
|
+
max_lr=max_lr,
|
|
179
|
+
warmup_iters=warmup_iters,
|
|
180
|
+
period=period,
|
|
181
|
+
period_mul=period_mul,
|
|
182
|
+
need_log=TrainerTools().parallel.is_main_process
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return NoneLRScheduler(initial_lr)
|
|
186
|
+
|
|
187
|
+
def _init_loss(self):
|
|
188
|
+
critical_tokens: Optional[List[int]] = None
|
|
189
|
+
critical_alpha: float = 1.0
|
|
190
|
+
if self.train_config.loss_config.critical_tokens:
|
|
191
|
+
critical_tokens = self.train_config.loss_config.critical_tokens
|
|
192
|
+
critical_alpha = self.train_config.loss_config.critical_alpha
|
|
193
|
+
|
|
194
|
+
criterion = LMLoss(
|
|
195
|
+
critical_tokens=critical_tokens,
|
|
196
|
+
critical_alpha=critical_alpha,
|
|
197
|
+
vocab_size=TrainerTools().tokenizer.vocab_size
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
kd_loss = KDLoss() if self.train_config.kd_config else None
|
|
201
|
+
|
|
202
|
+
return criterion, kd_loss
|
|
203
|
+
|
|
204
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
|
|
205
|
+
parallel_kwargs: Optional[Dict[str, Any]] = None
|
|
206
|
+
use_ds_optim: bool = False
|
|
207
|
+
if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
|
|
208
|
+
parallel_kwargs = {
|
|
209
|
+
'gradient_accumulation_steps': 1,
|
|
210
|
+
'gradient_clipping': self.train_config.ds_config.gradient_clipping,
|
|
211
|
+
'train_micro_batch_size_per_gpu': self.train_config.batch_size
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
if self.train_config.ds_config.zero_config:
|
|
215
|
+
zero_config = self.train_config.ds_config.zero_config
|
|
216
|
+
zero_optimization: Dict[str, Any] = {'stage': zero_config.stage}
|
|
217
|
+
|
|
218
|
+
if zero_config.allgather_partitions is not None:
|
|
219
|
+
zero_optimization['allgather_partitions'] = zero_config.allgather_partitions
|
|
220
|
+
if zero_config.allgather_bucket_size is not None:
|
|
221
|
+
zero_optimization['allgather_bucket_size'] = zero_config.allgather_bucket_size
|
|
222
|
+
if zero_config.overlap_comm is not None:
|
|
223
|
+
zero_optimization['overlap_comm'] = zero_config.overlap_comm
|
|
224
|
+
if zero_config.reduce_scatter is not None:
|
|
225
|
+
zero_optimization['reduce_scatter'] = zero_config.reduce_scatter
|
|
226
|
+
if zero_config.reduce_bucket_size is not None:
|
|
227
|
+
zero_optimization['reduce_bucket_size'] = zero_config.reduce_bucket_size
|
|
228
|
+
if zero_config.contiguous_gradients is not None:
|
|
229
|
+
zero_optimization['contiguous_gradients'] = zero_config.contiguous_gradients
|
|
230
|
+
|
|
231
|
+
if isinstance(zero_config, DsZero2Config) or isinstance(zero_config, DsZero3Config):
|
|
232
|
+
if zero_config.offload_optimizer is not None:
|
|
233
|
+
zero_optimization['offload_optimizer'] = {
|
|
234
|
+
"device": zero_config.offload_optimizer.device,
|
|
235
|
+
"pin_memory": zero_config.offload_optimizer.pin_memory
|
|
236
|
+
}
|
|
237
|
+
use_ds_optim = True
|
|
238
|
+
if zero_config.offload_param is not None:
|
|
239
|
+
zero_optimization['offload_param'] = {
|
|
240
|
+
"device": zero_config.offload_param.device,
|
|
241
|
+
"pin_memory": zero_config.offload_param.pin_memory
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
if isinstance(zero_config, DsZero3Config):
|
|
245
|
+
if zero_config.sub_group_size is not None:
|
|
246
|
+
zero_optimization['sub_group_size'] = zero_config.sub_group_size
|
|
247
|
+
if zero_config.stage3_prefetch_bucket_size is not None:
|
|
248
|
+
zero_optimization['stage3_prefetch_bucket_size'] = zero_config.stage3_prefetch_bucket_size
|
|
249
|
+
if zero_config.stage3_param_persistence_threshold is not None:
|
|
250
|
+
zero_optimization['stage3_param_persistence_threshold'] = zero_config.stage3_param_persistence_threshold
|
|
251
|
+
if zero_config.stage3_max_live_parameters is not None:
|
|
252
|
+
zero_optimization['stage3_max_live_parameters'] = zero_config.stage3_max_live_parameters
|
|
253
|
+
if zero_config.stage3_max_reuse_distance is not None:
|
|
254
|
+
zero_optimization['stage3_max_reuse_distance'] = zero_config.stage3_max_reuse_distance
|
|
255
|
+
if zero_config.stage3_gather_16bit_weights_on_model_save is not None:
|
|
256
|
+
zero_optimization['stage3_gather_16bit_weights_on_model_save'] = zero_config.stage3_gather_16bit_weights_on_model_save
|
|
257
|
+
|
|
258
|
+
parallel_kwargs['zero_optimization'] = zero_optimization
|
|
259
|
+
|
|
260
|
+
if (self.train_config.ds_config.bf16_config is not None
|
|
261
|
+
and self.train_config.ds_config.bf16_config.enabled):
|
|
262
|
+
bf16_config = self.train_config.ds_config.bf16_config
|
|
263
|
+
bf16 = {
|
|
264
|
+
'enabled': bf16_config.enabled
|
|
265
|
+
}
|
|
266
|
+
parallel_kwargs['bf16'] = bf16
|
|
267
|
+
elif self.train_config.ds_config.fp16_config:
|
|
268
|
+
fb16_config = self.train_config.ds_config.fp16_config
|
|
269
|
+
fp16 = {
|
|
270
|
+
'enabled': fb16_config.enabled,
|
|
271
|
+
'loss_scale': fb16_config.loss_scale,
|
|
272
|
+
'loss_scale_window': fb16_config.loss_scale_window,
|
|
273
|
+
'initial_scale_power': fb16_config.initial_scale_power,
|
|
274
|
+
'hysteresis': fb16_config.hysteresis,
|
|
275
|
+
'min_loss_scale': fb16_config.min_loss_scale
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
if fb16_config.fp16_opt_level is not None:
|
|
279
|
+
fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
|
|
280
|
+
|
|
281
|
+
parallel_kwargs['fp16'] = fp16
|
|
282
|
+
|
|
283
|
+
if self.train_config.ds_config.activation_checkpointing:
|
|
284
|
+
activation_checkpointing_config = self.train_config.ds_config.activation_checkpointing
|
|
285
|
+
activation_checkpointing: Dict[str, Any] = {
|
|
286
|
+
'partition_activations': activation_checkpointing_config.partition_activations,
|
|
287
|
+
'cpu_checkpointing': activation_checkpointing_config.cpu_checkpointing,
|
|
288
|
+
'contiguous_memory_optimization': activation_checkpointing_config.contiguous_memory_optimization,
|
|
289
|
+
'synchronize_checkpoint_boundary': activation_checkpointing_config.synchronize_checkpoint_boundary,
|
|
290
|
+
'profile': activation_checkpointing_config.profile
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
if activation_checkpointing_config.number_checkpoints is not None:
|
|
294
|
+
activation_checkpointing['number_checkpoints'] = activation_checkpointing_config.number_checkpoints
|
|
295
|
+
|
|
296
|
+
parallel_kwargs['activation_checkpointing'] = activation_checkpointing
|
|
297
|
+
elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
|
|
298
|
+
parallel_kwargs = {
|
|
299
|
+
'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
|
|
300
|
+
'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
|
|
301
|
+
'cpu_offload': self.train_config.fsdp_config.cpu_offload,
|
|
302
|
+
'offload_params': self.train_config.fsdp_config.offload_params
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
dataloader_args = self.train_config.data_loader_config
|
|
306
|
+
data_loader_kwargs = {
|
|
307
|
+
"batch_size": self.train_config.batch_size,
|
|
308
|
+
"pin_memory": dataloader_args.data_loader_pin_memory,
|
|
309
|
+
"collate_fn": pretrain_collate_fn,
|
|
310
|
+
"num_workers": dataloader_args.data_loader_num_workers,
|
|
311
|
+
"shuffle": dataloader_args.data_loader_shuffle,
|
|
312
|
+
"drop_last": dataloader_args.data_loader_drop_last,
|
|
313
|
+
}
|
|
314
|
+
sampler_kwargs = {
|
|
315
|
+
"shuffle": dataloader_args.data_loader_shuffle,
|
|
316
|
+
"drop_last": dataloader_args.data_loader_drop_last,
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
|
|
320
|
+
|
|
321
|
+
def _create_dataset(self, file_path) -> Dataset:
|
|
322
|
+
max_position_embeddings = self.train_config.model_config.max_position_embeddings
|
|
323
|
+
return TextDataset(file_path, max_position_embeddings, max_position_embeddings)
|
|
324
|
+
|
|
325
|
+
def _calc_loss(self, inputs, attention_mask, logits, labels):
|
|
326
|
+
# calc loss
|
|
327
|
+
loss = self.criterion(logits, labels)
|
|
328
|
+
|
|
329
|
+
# 知识蒸馏loss
|
|
330
|
+
if self.kd_loss:
|
|
331
|
+
teacher_logits = self.train_config.kd_config.teacher_logits_provider(inputs, attention_mask)
|
|
332
|
+
distil_loss = self.kd_loss(logits, teacher_logits, labels)
|
|
333
|
+
loss = (1 - self.train_config.kd_config.kd_coef) * loss + self.train_config.kd_config.kd_coef * distil_loss
|
|
334
|
+
|
|
335
|
+
return loss
|
|
336
|
+
|
|
337
|
+
def _backward_loss(self, loss):
|
|
338
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
339
|
+
self.train_model.backward(loss)
|
|
340
|
+
else:
|
|
341
|
+
self.scalar.scale(loss).backward()
|
|
342
|
+
|
|
343
|
+
def _step(self):
|
|
344
|
+
self.lr_scheduler.step()
|
|
345
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
346
|
+
self.train_model.step()
|
|
347
|
+
else:
|
|
348
|
+
self.scalar.step(self.optimizer)
|
|
349
|
+
# optimizer.step()
|
|
350
|
+
self.scalar.update()
|
|
351
|
+
# flush the gradients as soon as we can, no need for this memory anymore
|
|
352
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
353
|
+
|
|
354
|
+
TrainerTools().parallel.synchronize()
|
|
355
|
+
|
|
356
|
+
def _get_eval_data(self) -> Tuple[str, Optional[int]]:
|
|
357
|
+
if len(self.eval_prompts) == 0:
|
|
358
|
+
return '', None
|
|
359
|
+
|
|
360
|
+
self.eval_idx += 1
|
|
361
|
+
if self.eval_idx == len(self.eval_prompts):
|
|
362
|
+
self.eval_idx = 0
|
|
363
|
+
|
|
364
|
+
if not self.eval_image_tags:
|
|
365
|
+
return self.eval_prompts[self.eval_idx], None
|
|
366
|
+
|
|
367
|
+
return self.eval_prompts[self.eval_idx], self.eval_image_tags[self.eval_idx]
|
|
368
|
+
|
|
369
|
+
def _log_loss(
|
|
370
|
+
self,
|
|
371
|
+
epoch_tag: str,
|
|
372
|
+
file_tag: str,
|
|
373
|
+
batch_tag: str,
|
|
374
|
+
loss
|
|
375
|
+
):
|
|
376
|
+
if TrainerTools().parallel.is_main_process:
|
|
377
|
+
log_dir = get_log_dir()
|
|
378
|
+
log_msg = f"{epoch_tag}, {file_tag}, {batch_tag}, loss: {loss}"
|
|
379
|
+
log(log_msg)
|
|
380
|
+
log(f"{log_msg}\n", f'{log_dir}log.txt')
|
|
381
|
+
|
|
382
|
+
def _on_exception(
|
|
383
|
+
self,
|
|
384
|
+
e: Exception,
|
|
385
|
+
epoch: int,
|
|
386
|
+
batch: int
|
|
387
|
+
):
|
|
388
|
+
log_dir = get_log_dir()
|
|
389
|
+
exception_file = e.__traceback__.tb_frame.f_globals["__file__"]
|
|
390
|
+
exception_line = e.__traceback__.tb_lineno
|
|
391
|
+
log_msg = f"epoch: {epoch}, batch: {batch}, {e} at {exception_file} line {exception_line}\n"
|
|
392
|
+
log(log_msg, f'{log_dir}log.txt')
|
|
393
|
+
|
|
394
|
+
raise e
|
|
395
|
+
|
|
396
|
+
def _on_batch_end(
|
|
397
|
+
self,
|
|
398
|
+
tag: str
|
|
399
|
+
):
|
|
400
|
+
if TrainerTools().parallel.is_main_process:
|
|
401
|
+
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
402
|
+
if isinstance(self.train_config.model_config, VLMConfig) and eval_image_tag:
|
|
403
|
+
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
404
|
+
else:
|
|
405
|
+
eval_pixel_values = None
|
|
406
|
+
|
|
407
|
+
submit_gen_task(
|
|
408
|
+
self.eval_model,
|
|
409
|
+
self.train_config.eval_config,
|
|
410
|
+
tag=f'sign:batch/{tag}',
|
|
411
|
+
prompt=eval_prompt,
|
|
412
|
+
pixel_values=eval_pixel_values,
|
|
413
|
+
max_position_embeddings=self.train_config.model_config.max_position_embeddings,
|
|
414
|
+
tokens_per_image=self.tokens_per_image
|
|
415
|
+
)
|
|
416
|
+
TrainerTools().parallel.wait()
|
|
417
|
+
|
|
418
|
+
def _on_epoch_end(
|
|
419
|
+
self,
|
|
420
|
+
tag: str
|
|
421
|
+
):
|
|
422
|
+
if TrainerTools().parallel.is_main_process:
|
|
423
|
+
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
424
|
+
if isinstance(self.train_config.model_config, VLMConfig) and eval_image_tag:
|
|
425
|
+
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
426
|
+
else:
|
|
427
|
+
eval_pixel_values = None
|
|
428
|
+
|
|
429
|
+
submit_gen_task(
|
|
430
|
+
self.eval_model,
|
|
431
|
+
self.train_config.eval_config,
|
|
432
|
+
tag=f'sign:epoch/{tag}',
|
|
433
|
+
prompt=eval_prompt,
|
|
434
|
+
pixel_values=eval_pixel_values,
|
|
435
|
+
max_position_embeddings=self.train_config.model_config.max_position_embeddings,
|
|
436
|
+
tokens_per_image=self.tokens_per_image
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
TrainerTools().parallel.wait()
|
|
440
|
+
|
|
441
|
+
def _on_file_start(
|
|
442
|
+
self,
|
|
443
|
+
epoch: int,
|
|
444
|
+
file_name: str
|
|
445
|
+
):
|
|
446
|
+
if TrainerTools().parallel.is_main_process:
|
|
447
|
+
log(f"epoch: {epoch}, start train {file_name}\n", f'{get_log_dir()}log.txt')
|
|
448
|
+
|
|
449
|
+
def train(self):
|
|
450
|
+
# 梯度累积步数
|
|
451
|
+
gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
|
|
452
|
+
global_steps = 0
|
|
453
|
+
loss_accumulation = 0.0
|
|
454
|
+
skipping_train = False
|
|
455
|
+
|
|
456
|
+
for epoch in range(self.train_config.n_epochs):
|
|
457
|
+
self.train_model.train()
|
|
458
|
+
file_count = len(self.train_config.file_dataset)
|
|
459
|
+
|
|
460
|
+
for file_idx in range(file_count):
|
|
461
|
+
file_path = self.train_config.file_dataset[file_idx]
|
|
462
|
+
|
|
463
|
+
dataset = self._create_dataset(file_path)
|
|
464
|
+
train_data_loader = TrainerTools().parallel.process_dataloader(
|
|
465
|
+
dataset=dataset,
|
|
466
|
+
data_loader_kwargs=self.data_loader_kwargs,
|
|
467
|
+
sampler_kwargs=self.sampler_kwargs
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
last_ckpt_batch = 0
|
|
471
|
+
batch_count_per_file = len(train_data_loader)
|
|
472
|
+
|
|
473
|
+
TrainerTools().parallel.on_epoch_start(epoch)
|
|
474
|
+
self._on_file_start(epoch, file_path)
|
|
475
|
+
|
|
476
|
+
for batch, batch_data in enumerate(train_data_loader):
|
|
477
|
+
global_steps += 1
|
|
478
|
+
if global_steps < self.last_global_steps:
|
|
479
|
+
skipping_train = True
|
|
480
|
+
continue
|
|
481
|
+
|
|
482
|
+
skipping_train = False
|
|
483
|
+
|
|
484
|
+
# 是否需要更新梯度
|
|
485
|
+
if gradient_accumulation_steps > 1:
|
|
486
|
+
need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
|
|
487
|
+
else:
|
|
488
|
+
need_update_grad = True
|
|
489
|
+
|
|
490
|
+
inputs = batch_data['inputs']
|
|
491
|
+
labels = batch_data['labels']
|
|
492
|
+
|
|
493
|
+
try:
|
|
494
|
+
inputs, labels = inputs.to(TrainerTools().parallel.device), labels.to(TrainerTools().parallel.device)
|
|
495
|
+
attention_mask = inputs != TrainerTools().tokenizer.pad
|
|
496
|
+
|
|
497
|
+
if TrainerTools().parallel.parallel_train:
|
|
498
|
+
self.train_model.require_backward_grad_sync = need_update_grad
|
|
499
|
+
|
|
500
|
+
if self.pixel_values_provider and 'image_tags' in batch_data:
|
|
501
|
+
image_tags = batch_data['image_tags']
|
|
502
|
+
pixel_values = self.pixel_values_provider(image_tags).to(TrainerTools().parallel.device)
|
|
503
|
+
else:
|
|
504
|
+
pixel_values = None
|
|
505
|
+
|
|
506
|
+
with self.ctx:
|
|
507
|
+
result = self.train_model(
|
|
508
|
+
inputs,
|
|
509
|
+
attention_mask=attention_mask,
|
|
510
|
+
pixel_values=pixel_values
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# calc loss
|
|
514
|
+
loss = self._calc_loss(inputs, attention_mask, result['logits'], labels)
|
|
515
|
+
if result['aux_loss'] and self.train_config.loss_config.aux_loss_coef:
|
|
516
|
+
loss += self.train_config.loss_config.aux_loss_coef * result['aux_loss']
|
|
517
|
+
|
|
518
|
+
if gradient_accumulation_steps > 1:
|
|
519
|
+
loss = loss / gradient_accumulation_steps
|
|
520
|
+
|
|
521
|
+
loss_accumulation += loss.detach()
|
|
522
|
+
self._backward_loss(loss)
|
|
523
|
+
|
|
524
|
+
if need_update_grad:
|
|
525
|
+
# todo check all_reduce??
|
|
526
|
+
if TrainerTools().parallel.parallel_train:
|
|
527
|
+
dist.all_reduce(loss_accumulation, dist.ReduceOp.AVG)
|
|
528
|
+
|
|
529
|
+
# ds模式已经集成gradient_clipping
|
|
530
|
+
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
531
|
+
# clip grad
|
|
532
|
+
self.scalar.unscale_(self.optimizer)
|
|
533
|
+
torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
|
|
534
|
+
|
|
535
|
+
self._step()
|
|
536
|
+
|
|
537
|
+
self._log_loss(
|
|
538
|
+
epoch_tag=f'epoch: {epoch}',
|
|
539
|
+
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
540
|
+
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
541
|
+
loss=loss_accumulation.item()
|
|
542
|
+
)
|
|
543
|
+
# reset to default
|
|
544
|
+
loss_accumulation = 0.0
|
|
545
|
+
except Exception as e:
|
|
546
|
+
self._on_exception(e, epoch, batch)
|
|
547
|
+
finally:
|
|
548
|
+
if need_update_grad:
|
|
549
|
+
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
550
|
+
|
|
551
|
+
if (batch - last_ckpt_batch) >= self.train_config.eval_batch_interval:
|
|
552
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
553
|
+
last_ckpt_batch = batch
|
|
554
|
+
self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
|
|
555
|
+
|
|
556
|
+
try:
|
|
557
|
+
del loss
|
|
558
|
+
except UnboundLocalError: ...
|
|
559
|
+
|
|
560
|
+
# end epoch
|
|
561
|
+
if not skipping_train:
|
|
562
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
563
|
+
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
564
|
+
TrainerTools().parallel.on_epoch_end(epoch)
|
|
565
|
+
self._on_epoch_end(tag=f'epoch:{epoch}')
|
|
566
|
+
|
|
567
|
+
# 等待checkpoint保存完成
|
|
568
|
+
time.sleep(10)
|
|
569
|
+
TrainerTools().parallel.destroy()
|