project-llm-trainer 0.13.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +707 -0
- llm_trainer/checkpoint.py +114 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +311 -0
- llm_trainer/ds_checkpoint.py +72 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +463 -0
- llm_trainer/grpo_trainer.py +410 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +266 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +686 -0
- llm_trainer/scheduler.py +220 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +327 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +630 -0
- project_llm_trainer-0.13.4.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.13.4.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.13.4.data/scripts/ds_train +17 -0
- project_llm_trainer-0.13.4.data/scripts/py_train +12 -0
- project_llm_trainer-0.13.4.data/scripts/smart_train +37 -0
- project_llm_trainer-0.13.4.data/scripts/vis_log +98 -0
- project_llm_trainer-0.13.4.data/scripts/vis_lr +46 -0
- project_llm_trainer-0.13.4.dist-info/METADATA +9 -0
- project_llm_trainer-0.13.4.dist-info/RECORD +32 -0
- project_llm_trainer-0.13.4.dist-info/WHEEL +5 -0
- project_llm_trainer-0.13.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,707 @@
|
|
|
1
|
+
from typing import Optional, Tuple, List, Dict, Any
|
|
2
|
+
import copy
|
|
3
|
+
import gc
|
|
4
|
+
import importlib.metadata
|
|
5
|
+
from packaging import version
|
|
6
|
+
from itertools import islice
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
from llm_model import LlmModel
|
|
12
|
+
|
|
13
|
+
from .parallel import DsParallel
|
|
14
|
+
from .tools import TrainerTools
|
|
15
|
+
from .loss import LMLoss, KDLoss
|
|
16
|
+
from .eval import submit_gen_task
|
|
17
|
+
from .partition_utils import unwrap_model_for_generation
|
|
18
|
+
|
|
19
|
+
from .train_configs import (
|
|
20
|
+
TrainConfig,
|
|
21
|
+
DsZero2Config,
|
|
22
|
+
DsZero3Config,
|
|
23
|
+
KDConfig
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from .scheduler import (
|
|
27
|
+
LRScheduler,
|
|
28
|
+
WarmupCosineAnnealingLRScheduler,
|
|
29
|
+
NoneLRScheduler
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
from .checkpoint import (
|
|
33
|
+
load_checkpoint,
|
|
34
|
+
save_checkpoint,
|
|
35
|
+
load_steps,
|
|
36
|
+
save_steps,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
from .utils import (
|
|
40
|
+
set_seed,
|
|
41
|
+
autocast,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
from .log import Logger
|
|
45
|
+
|
|
46
|
+
class BaseTrainer:
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
*,
|
|
50
|
+
train_config: TrainConfig,
|
|
51
|
+
eval_prompts: List[str],
|
|
52
|
+
kd_config: Optional[KDConfig] = None,
|
|
53
|
+
gradient_accumulation_steps: int = 1
|
|
54
|
+
):
|
|
55
|
+
set_seed()
|
|
56
|
+
|
|
57
|
+
self.train_config: TrainConfig = train_config
|
|
58
|
+
self.eval_prompts = eval_prompts
|
|
59
|
+
self.eval_idx = -1
|
|
60
|
+
|
|
61
|
+
self.resume_epoch = 0
|
|
62
|
+
self.resume_file_idx = 0
|
|
63
|
+
self.resume_batch_idx = 0
|
|
64
|
+
|
|
65
|
+
self.kd_config = kd_config
|
|
66
|
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
67
|
+
|
|
68
|
+
self.logger = Logger('log.txt')
|
|
69
|
+
|
|
70
|
+
self.parallel_kwargs, self.data_loader_kwargs, self.sampler_kwargs = self._convert_train_args()
|
|
71
|
+
# initialize a GradScaler. If enabled=False scaler is a no-op
|
|
72
|
+
self.scaler = torch.GradScaler(enabled=TrainerTools().use_amp)
|
|
73
|
+
|
|
74
|
+
# 注意:学习率要根据GPU的数量进行倍增:
|
|
75
|
+
# 在训练的过程中,损失梯度决定下降的方向,学习率决定下降的步长。如果有两块gpu,前进的综合步长为:平均学习率*2
|
|
76
|
+
initial_lr = train_config.optim_config.initial_lr
|
|
77
|
+
|
|
78
|
+
self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr)
|
|
79
|
+
self.lr_scheduler = self._init_lr_scheduler(initial_lr, self.optimizer)
|
|
80
|
+
|
|
81
|
+
self.criterion, self.kd_loss = self._init_loss()
|
|
82
|
+
|
|
83
|
+
self._load_train_model_checkpoint()
|
|
84
|
+
self._apply_restore_ckpt()
|
|
85
|
+
|
|
86
|
+
def _new_model(self, train_config: TrainConfig):
|
|
87
|
+
return LlmModel(train_config.model_config)
|
|
88
|
+
|
|
89
|
+
def _init_train_model_and_optim(self, initial_lr: float):
|
|
90
|
+
model = self._new_model(self.train_config)
|
|
91
|
+
|
|
92
|
+
if self.train_config.init_state_dict:
|
|
93
|
+
model.load_state_dict(self.train_config.init_state_dict, strict=False)
|
|
94
|
+
self.train_config.init_state_dict = None
|
|
95
|
+
|
|
96
|
+
self._check_freeze_llm_model(model)
|
|
97
|
+
|
|
98
|
+
if self.train_config.ds_config and self.train_config.ds_config.activation_checkpointing:
|
|
99
|
+
model.gradient_checkpointing_enable()
|
|
100
|
+
|
|
101
|
+
if TrainerTools().parallel.is_main_process:
|
|
102
|
+
total_params = sum(p.numel() for p in model.parameters())
|
|
103
|
+
Logger.std_log(f"Total number of parameters: {total_params:,}")
|
|
104
|
+
|
|
105
|
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
106
|
+
Logger.std_log(f"Trainable number of parameters: {trainable_params:,}")
|
|
107
|
+
|
|
108
|
+
total_size_bytes = total_params * 4
|
|
109
|
+
total_size_mb = total_size_bytes / (1024 * 1024)
|
|
110
|
+
Logger.std_log(f"Total size of the model: {total_size_mb:.2f} MB")
|
|
111
|
+
|
|
112
|
+
model, optim = TrainerTools().parallel.process(
|
|
113
|
+
model=model,
|
|
114
|
+
optimizer=self._config_optim(model, initial_lr),
|
|
115
|
+
kwargs=self.parallel_kwargs
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return model, optim
|
|
119
|
+
|
|
120
|
+
def _check_freeze_llm_model(self, model): ...
|
|
121
|
+
|
|
122
|
+
def _config_optim(self, model, initial_lr):
|
|
123
|
+
optimizer_cls, use_lion_optim = self._get_optim_cls()
|
|
124
|
+
|
|
125
|
+
betas = self.train_config.optim_config.betas
|
|
126
|
+
weight_decay = self.train_config.optim_config.betas
|
|
127
|
+
|
|
128
|
+
if betas is None:
|
|
129
|
+
betas = (0.95, 0.98) if use_lion_optim else (0.9, 0.999)
|
|
130
|
+
|
|
131
|
+
if weight_decay is None:
|
|
132
|
+
weight_decay = 0.015 if use_lion_optim else 0.01
|
|
133
|
+
|
|
134
|
+
no_decay_name_list = ["bias", "norm.weight"]
|
|
135
|
+
decay_params = []
|
|
136
|
+
no_decay_params = []
|
|
137
|
+
|
|
138
|
+
for name, param in model.named_parameters():
|
|
139
|
+
if not param.requires_grad:
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
if any(nd in name for nd in no_decay_name_list):
|
|
143
|
+
no_decay_params.append(param)
|
|
144
|
+
else:
|
|
145
|
+
decay_params.append(param)
|
|
146
|
+
|
|
147
|
+
optimizer_grouped_parameters = [
|
|
148
|
+
{
|
|
149
|
+
"params": decay_params,
|
|
150
|
+
"weight_decay": weight_decay,
|
|
151
|
+
},
|
|
152
|
+
{
|
|
153
|
+
"params": no_decay_params,
|
|
154
|
+
"weight_decay": 0.0,
|
|
155
|
+
},
|
|
156
|
+
]
|
|
157
|
+
|
|
158
|
+
return optimizer_cls(
|
|
159
|
+
optimizer_grouped_parameters,
|
|
160
|
+
lr=initial_lr,
|
|
161
|
+
betas=betas,
|
|
162
|
+
weight_decay=weight_decay
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def _get_optim_cls(self):
|
|
166
|
+
optimizer = None
|
|
167
|
+
use_lion_optim = self.train_config.optim_config.optim_type == 'lion'
|
|
168
|
+
|
|
169
|
+
if isinstance(TrainerTools().parallel, DsParallel) and self.parallel_kwargs:
|
|
170
|
+
import deepspeed
|
|
171
|
+
if ('zero_optimization' in self.parallel_kwargs
|
|
172
|
+
and 'offload_optimizer' in self.parallel_kwargs['zero_optimization']
|
|
173
|
+
and self.parallel_kwargs['zero_optimization']['offload_optimizer']['device'] == 'cpu'):
|
|
174
|
+
if self.train_config.optim_config.optim_type == 'lion':
|
|
175
|
+
if version.parse(importlib.metadata.version("deepspeed")) >= version.parse('0.17.6'):
|
|
176
|
+
optimizer = deepspeed.ops.lion.DeepSpeedCPULion
|
|
177
|
+
else:
|
|
178
|
+
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
|
|
179
|
+
use_lion_optim = False
|
|
180
|
+
if TrainerTools().parallel.is_main_process:
|
|
181
|
+
Logger.std_log(
|
|
182
|
+
'When set offload_optimizer, lion optim is unsupported, so set optim to adam!!!!!')
|
|
183
|
+
else:
|
|
184
|
+
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
|
|
185
|
+
else:
|
|
186
|
+
if self.train_config.optim_config.optim_type == 'lion':
|
|
187
|
+
optimizer = deepspeed.ops.lion.FusedLion
|
|
188
|
+
else:
|
|
189
|
+
optimizer = deepspeed.ops.adam.FusedAdam
|
|
190
|
+
|
|
191
|
+
if not optimizer:
|
|
192
|
+
if self.train_config.optim_config.optim_type == 'lion':
|
|
193
|
+
try:
|
|
194
|
+
import lion_pytorch
|
|
195
|
+
except:
|
|
196
|
+
raise Exception(
|
|
197
|
+
'lion is not detected, please use `pip3 install lion_pytorch` to install or set optim_type to adam')
|
|
198
|
+
|
|
199
|
+
optimizer = lion_pytorch.Lion
|
|
200
|
+
else:
|
|
201
|
+
optimizer = torch.optim.AdamW
|
|
202
|
+
|
|
203
|
+
return optimizer, use_lion_optim
|
|
204
|
+
|
|
205
|
+
def _init_lr_scheduler(self, initial_lr: float, optimizer) -> LRScheduler:
|
|
206
|
+
if self.train_config.optim_config.enable_lr_scheduler:
|
|
207
|
+
warmup_iters = self.train_config.optim_config.warmup_iters
|
|
208
|
+
min_lr = self.train_config.optim_config.min_lr
|
|
209
|
+
max_lr = self.train_config.optim_config.max_lr
|
|
210
|
+
cosine_annealing_period = self.train_config.optim_config.cosine_annealing_period
|
|
211
|
+
cosine_annealing_period_mul = self.train_config.optim_config.cosine_annealing_period_mul
|
|
212
|
+
|
|
213
|
+
return WarmupCosineAnnealingLRScheduler(
|
|
214
|
+
optimizer=optimizer,
|
|
215
|
+
warmup_iters=warmup_iters,
|
|
216
|
+
initial_lr=initial_lr,
|
|
217
|
+
min_lr=min_lr,
|
|
218
|
+
max_lr=max_lr,
|
|
219
|
+
cosine_annealing_period=cosine_annealing_period,
|
|
220
|
+
cosine_annealing_period_mul=cosine_annealing_period_mul,
|
|
221
|
+
need_log=TrainerTools().parallel.is_main_process
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
return NoneLRScheduler(initial_lr)
|
|
225
|
+
|
|
226
|
+
def _init_loss(self):
|
|
227
|
+
critical_tokens: Optional[List[int]] = None
|
|
228
|
+
critical_alpha: float = 1.0
|
|
229
|
+
if self.train_config.loss_config.critical_tokens:
|
|
230
|
+
critical_tokens = self.train_config.loss_config.critical_tokens
|
|
231
|
+
critical_alpha = self.train_config.loss_config.critical_alpha
|
|
232
|
+
|
|
233
|
+
criterion = LMLoss(
|
|
234
|
+
critical_tokens=critical_tokens,
|
|
235
|
+
critical_alpha=critical_alpha,
|
|
236
|
+
vocab_size=TrainerTools().tokenizer.vocab_size
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
kd_loss = KDLoss() if self.kd_config else None
|
|
240
|
+
|
|
241
|
+
return criterion, kd_loss
|
|
242
|
+
|
|
243
|
+
def _load_train_model_checkpoint(self):
|
|
244
|
+
load_checkpoint(
|
|
245
|
+
self.train_model,
|
|
246
|
+
optimizer=self.optimizer,
|
|
247
|
+
device=TrainerTools().parallel.device
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def _apply_restore_ckpt(self):
|
|
251
|
+
steps_dict = load_steps()
|
|
252
|
+
if steps_dict:
|
|
253
|
+
self.resume_epoch = steps_dict.get('epoch', 0)
|
|
254
|
+
self.resume_file_idx = steps_dict.get('file_idx', 0)
|
|
255
|
+
self.resume_batch_idx = steps_dict.get('batch_idx', 0)
|
|
256
|
+
|
|
257
|
+
self.lr_scheduler.restore_ckpt_dict(steps_dict)
|
|
258
|
+
|
|
259
|
+
if TrainerTools().parallel.is_main_process:
|
|
260
|
+
Logger.std_log(f'restore steps_dict={steps_dict}')
|
|
261
|
+
|
|
262
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict]:
|
|
263
|
+
parallel_kwargs: Optional[Dict[str, Any]] = None
|
|
264
|
+
if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
|
|
265
|
+
parallel_kwargs = {
|
|
266
|
+
'gradient_accumulation_steps': 1,
|
|
267
|
+
'gradient_clipping': self.train_config.ds_config.gradient_clipping,
|
|
268
|
+
'train_micro_batch_size_per_gpu': self.train_config.batch_size
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
if self.train_config.ds_config.zero_config:
|
|
272
|
+
zero_config = self.train_config.ds_config.zero_config
|
|
273
|
+
zero_optimization: Dict[str, Any] = {'stage': zero_config.stage}
|
|
274
|
+
|
|
275
|
+
if zero_config.allgather_partitions is not None:
|
|
276
|
+
zero_optimization['allgather_partitions'] = zero_config.allgather_partitions
|
|
277
|
+
if zero_config.allgather_bucket_size is not None:
|
|
278
|
+
zero_optimization['allgather_bucket_size'] = zero_config.allgather_bucket_size
|
|
279
|
+
if zero_config.overlap_comm is not None:
|
|
280
|
+
zero_optimization['overlap_comm'] = zero_config.overlap_comm
|
|
281
|
+
if zero_config.reduce_scatter is not None:
|
|
282
|
+
zero_optimization['reduce_scatter'] = zero_config.reduce_scatter
|
|
283
|
+
if zero_config.reduce_bucket_size is not None:
|
|
284
|
+
zero_optimization['reduce_bucket_size'] = zero_config.reduce_bucket_size
|
|
285
|
+
if zero_config.contiguous_gradients is not None:
|
|
286
|
+
zero_optimization['contiguous_gradients'] = zero_config.contiguous_gradients
|
|
287
|
+
|
|
288
|
+
if isinstance(zero_config, DsZero2Config) or isinstance(zero_config, DsZero3Config):
|
|
289
|
+
if zero_config.offload_optimizer is not None:
|
|
290
|
+
zero_optimization['offload_optimizer'] = {
|
|
291
|
+
"device": zero_config.offload_optimizer.device,
|
|
292
|
+
"pin_memory": zero_config.offload_optimizer.pin_memory
|
|
293
|
+
}
|
|
294
|
+
if zero_config.offload_param is not None:
|
|
295
|
+
zero_optimization['offload_param'] = {
|
|
296
|
+
"device": zero_config.offload_param.device,
|
|
297
|
+
"pin_memory": zero_config.offload_param.pin_memory
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
if isinstance(zero_config, DsZero3Config):
|
|
301
|
+
if zero_config.sub_group_size is not None:
|
|
302
|
+
zero_optimization['sub_group_size'] = zero_config.sub_group_size
|
|
303
|
+
if zero_config.stage3_prefetch_bucket_size is not None:
|
|
304
|
+
zero_optimization['stage3_prefetch_bucket_size'] = zero_config.stage3_prefetch_bucket_size
|
|
305
|
+
if zero_config.stage3_param_persistence_threshold is not None:
|
|
306
|
+
zero_optimization['stage3_param_persistence_threshold'] = zero_config.stage3_param_persistence_threshold
|
|
307
|
+
if zero_config.stage3_max_live_parameters is not None:
|
|
308
|
+
zero_optimization['stage3_max_live_parameters'] = zero_config.stage3_max_live_parameters
|
|
309
|
+
if zero_config.stage3_max_reuse_distance is not None:
|
|
310
|
+
zero_optimization['stage3_max_reuse_distance'] = zero_config.stage3_max_reuse_distance
|
|
311
|
+
if zero_config.stage3_gather_16bit_weights_on_model_save is not None:
|
|
312
|
+
zero_optimization['stage3_gather_16bit_weights_on_model_save'] = zero_config.stage3_gather_16bit_weights_on_model_save
|
|
313
|
+
|
|
314
|
+
parallel_kwargs['zero_optimization'] = zero_optimization
|
|
315
|
+
|
|
316
|
+
if (self.train_config.ds_config.bf16_config is not None
|
|
317
|
+
and self.train_config.ds_config.bf16_config.enabled):
|
|
318
|
+
bf16_config = self.train_config.ds_config.bf16_config
|
|
319
|
+
bf16 = {
|
|
320
|
+
'enabled': bf16_config.enabled
|
|
321
|
+
}
|
|
322
|
+
parallel_kwargs['bf16'] = bf16
|
|
323
|
+
elif self.train_config.ds_config.fp16_config:
|
|
324
|
+
fp16_config = self.train_config.ds_config.fp16_config
|
|
325
|
+
fp16 = {
|
|
326
|
+
'enabled': fp16_config.enabled,
|
|
327
|
+
'loss_scale': fp16_config.loss_scale,
|
|
328
|
+
'loss_scale_window': fp16_config.loss_scale_window,
|
|
329
|
+
'initial_scale_power': fp16_config.initial_scale_power,
|
|
330
|
+
'hysteresis': fp16_config.hysteresis,
|
|
331
|
+
'min_loss_scale': fp16_config.min_loss_scale
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
if fp16_config.fp16_opt_level is not None:
|
|
335
|
+
fp16['fp16_opt_level'] = fp16_config.fp16_opt_level
|
|
336
|
+
|
|
337
|
+
parallel_kwargs['fp16'] = fp16
|
|
338
|
+
|
|
339
|
+
if self.train_config.ds_config.activation_checkpointing:
|
|
340
|
+
activation_checkpointing_config = self.train_config.ds_config.activation_checkpointing
|
|
341
|
+
activation_checkpointing: Dict[str, Any] = {
|
|
342
|
+
'partition_activations': activation_checkpointing_config.partition_activations,
|
|
343
|
+
'cpu_checkpointing': activation_checkpointing_config.cpu_checkpointing,
|
|
344
|
+
'contiguous_memory_optimization': activation_checkpointing_config.contiguous_memory_optimization,
|
|
345
|
+
'synchronize_checkpoint_boundary': activation_checkpointing_config.synchronize_checkpoint_boundary,
|
|
346
|
+
'profile': activation_checkpointing_config.profile
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
if activation_checkpointing_config.number_checkpoints is not None:
|
|
350
|
+
activation_checkpointing['number_checkpoints'] = activation_checkpointing_config.number_checkpoints
|
|
351
|
+
|
|
352
|
+
parallel_kwargs['activation_checkpointing'] = activation_checkpointing
|
|
353
|
+
|
|
354
|
+
dataloader_args = self.train_config.data_loader_config
|
|
355
|
+
data_loader_kwargs = {
|
|
356
|
+
"batch_size": self.train_config.batch_size,
|
|
357
|
+
"pin_memory": dataloader_args.data_loader_pin_memory,
|
|
358
|
+
"num_workers": dataloader_args.data_loader_num_workers,
|
|
359
|
+
"shuffle": dataloader_args.data_loader_shuffle,
|
|
360
|
+
"drop_last": dataloader_args.data_loader_drop_last,
|
|
361
|
+
}
|
|
362
|
+
sampler_kwargs = {
|
|
363
|
+
"shuffle": dataloader_args.data_loader_shuffle,
|
|
364
|
+
"drop_last": dataloader_args.data_loader_drop_last,
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
368
|
+
|
|
369
|
+
def _init_ref_model_args(self) -> dict:
|
|
370
|
+
parallel_kwargs = copy.deepcopy(self.parallel_kwargs) if self.parallel_kwargs else None
|
|
371
|
+
|
|
372
|
+
if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
|
|
373
|
+
# reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
|
|
374
|
+
# if model is not None:
|
|
375
|
+
# hidden_size = (
|
|
376
|
+
# max(model.config.hidden_sizes)
|
|
377
|
+
# if getattr(model.config, "hidden_sizes", None)
|
|
378
|
+
# else getattr(model.config, "hidden_size", None)
|
|
379
|
+
# )
|
|
380
|
+
# if hidden_size is not None and stage == 3:
|
|
381
|
+
# # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
|
|
382
|
+
# # @ step 0: expected module 1, but got module 0`
|
|
383
|
+
# # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
|
384
|
+
# config_kwargs.update(
|
|
385
|
+
# {
|
|
386
|
+
# "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
|
387
|
+
# "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
|
388
|
+
# "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
|
389
|
+
# }
|
|
390
|
+
# )
|
|
391
|
+
|
|
392
|
+
parallel_kwargs.pop('activation_checkpointing', None)
|
|
393
|
+
parallel_kwargs.pop('gradient_clipping', None)
|
|
394
|
+
|
|
395
|
+
# ref_model暂时先使用stage 0, 解决训练卡住问题
|
|
396
|
+
parallel_kwargs["zero_optimization"] = {"stage": 0}
|
|
397
|
+
# if parallel_kwargs.get("zero_optimization", {}).get("stage", 0) != 3:
|
|
398
|
+
# parallel_kwargs["zero_optimization"] = {"stage": 0}
|
|
399
|
+
|
|
400
|
+
return parallel_kwargs
|
|
401
|
+
|
|
402
|
+
def _create_dataset(self, file_idx) -> Tuple[Dataset, str]: ...
|
|
403
|
+
|
|
404
|
+
def _calc_loss(self, inputs, attention_mask, logits, labels):
|
|
405
|
+
# calc loss
|
|
406
|
+
if not self.kd_loss or self.kd_config.kd_coef == 0.0:
|
|
407
|
+
# 不用计算kd_loss
|
|
408
|
+
return self.criterion(logits, labels)
|
|
409
|
+
|
|
410
|
+
teacher_logits = self.kd_config.teacher_logits_provider(inputs, attention_mask)
|
|
411
|
+
loss = self.kd_loss(logits, teacher_logits, labels)
|
|
412
|
+
|
|
413
|
+
if self.kd_config.kd_coef == 1.0:
|
|
414
|
+
# 不用计算ce loss
|
|
415
|
+
return loss
|
|
416
|
+
|
|
417
|
+
ce_loss = self.criterion(logits, labels)
|
|
418
|
+
return (1 - self.kd_config.kd_coef) * ce_loss + self.kd_config.kd_coef * loss
|
|
419
|
+
|
|
420
|
+
def _backward_loss(self, loss):
|
|
421
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
422
|
+
self.train_model.backward(loss)
|
|
423
|
+
else:
|
|
424
|
+
self.scaler.scale(loss).backward()
|
|
425
|
+
|
|
426
|
+
def _apply_grad_clipping(self):
|
|
427
|
+
# ds模式已经集成gradient_clipping
|
|
428
|
+
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
429
|
+
# clip grad
|
|
430
|
+
self.scaler.unscale_(self.optimizer)
|
|
431
|
+
|
|
432
|
+
trainable_params = filter(lambda p: p.requires_grad, self.train_model.parameters())
|
|
433
|
+
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
|
434
|
+
|
|
435
|
+
def _apply_step(self):
|
|
436
|
+
self.lr_scheduler.step()
|
|
437
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
438
|
+
self.train_model.step()
|
|
439
|
+
else:
|
|
440
|
+
self.scaler.step(self.optimizer)
|
|
441
|
+
self.scaler.update()
|
|
442
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
443
|
+
|
|
444
|
+
TrainerTools().parallel.synchronize()
|
|
445
|
+
|
|
446
|
+
def _get_eval_data(self) -> Optional[str]:
|
|
447
|
+
if len(self.eval_prompts) == 0:
|
|
448
|
+
return None
|
|
449
|
+
|
|
450
|
+
self.eval_idx += 1
|
|
451
|
+
if self.eval_idx == len(self.eval_prompts):
|
|
452
|
+
self.eval_idx = 0
|
|
453
|
+
|
|
454
|
+
return self.eval_prompts[self.eval_idx]
|
|
455
|
+
|
|
456
|
+
def _get_eval_pixel_values_and_tokens_count(self, eval_idx):
|
|
457
|
+
return None, None
|
|
458
|
+
|
|
459
|
+
def _log(self, keys: Dict[str, any], values: Dict[str, any]):
|
|
460
|
+
"""
|
|
461
|
+
格式:keys_key1: keys_value1, keys_key2: keys_value2 -> values_key1: values_value1, values_key2: values_value2
|
|
462
|
+
"""
|
|
463
|
+
if TrainerTools().parallel.is_main_process:
|
|
464
|
+
log_tags = ', '.join([f'{k}: {v}' for k, v in keys.items()])
|
|
465
|
+
log_values = ', '.join([f'{k}: {v}' for k, v in values.items()])
|
|
466
|
+
|
|
467
|
+
log_msg = f'{log_tags} -> {log_values}'
|
|
468
|
+
self.logger.log(log_msg)
|
|
469
|
+
|
|
470
|
+
def _on_exception(
|
|
471
|
+
self,
|
|
472
|
+
e: Exception,
|
|
473
|
+
epoch: int,
|
|
474
|
+
batch: int
|
|
475
|
+
):
|
|
476
|
+
exception_file = e.__traceback__.tb_frame.f_globals["__file__"]
|
|
477
|
+
exception_line = e.__traceback__.tb_lineno
|
|
478
|
+
log_msg = f"epoch: {epoch}, batch: {batch} -> {e} at {exception_file} line {exception_line}"
|
|
479
|
+
Logger('exception.txt').log(log_msg, log_to_console=False).release()
|
|
480
|
+
|
|
481
|
+
raise e
|
|
482
|
+
|
|
483
|
+
def _get_model_dtype(self):
|
|
484
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
485
|
+
import deepspeed
|
|
486
|
+
assert isinstance(self.train_model, deepspeed.DeepSpeedEngine)
|
|
487
|
+
return self.train_model.get_data_types()[0]
|
|
488
|
+
else:
|
|
489
|
+
return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
|
490
|
+
|
|
491
|
+
def _eval(self, tag: str):
|
|
492
|
+
with unwrap_model_for_generation(self.train_model) as eval_model:
|
|
493
|
+
if TrainerTools().parallel.is_main_process:
|
|
494
|
+
eval_prompt = self._get_eval_data()
|
|
495
|
+
|
|
496
|
+
if eval_prompt:
|
|
497
|
+
eval_model = self._check_eval_model(eval_model)
|
|
498
|
+
eval_model.eval()
|
|
499
|
+
|
|
500
|
+
eval_pixel_values, tokens_per_image = self._get_eval_pixel_values_and_tokens_count(self.eval_idx)
|
|
501
|
+
submit_gen_task(
|
|
502
|
+
eval_model,
|
|
503
|
+
self.train_config,
|
|
504
|
+
tag=tag,
|
|
505
|
+
prompt=eval_prompt,
|
|
506
|
+
pixel_values=eval_pixel_values,
|
|
507
|
+
tokens_per_image=tokens_per_image
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
eval_model.train()
|
|
511
|
+
|
|
512
|
+
TrainerTools().parallel.wait('eval')
|
|
513
|
+
|
|
514
|
+
def _check_eval_model(self, eval_model):
|
|
515
|
+
return eval_model
|
|
516
|
+
|
|
517
|
+
def _on_batch_end(self, tag: str):
|
|
518
|
+
self._eval(f'sign:batch/{tag}')
|
|
519
|
+
|
|
520
|
+
def _on_epoch_end(self, tag: str):
|
|
521
|
+
self._eval(f'sign:epoch/{tag}')
|
|
522
|
+
|
|
523
|
+
def _on_file_start(
|
|
524
|
+
self,
|
|
525
|
+
epoch: int,
|
|
526
|
+
file_name: str
|
|
527
|
+
):
|
|
528
|
+
if TrainerTools().parallel.is_main_process:
|
|
529
|
+
self.logger.log(f"====epoch: {epoch}, start train {file_name}====", log_to_console=False)
|
|
530
|
+
|
|
531
|
+
def _avg_loss(
|
|
532
|
+
self,
|
|
533
|
+
losses: List[float],
|
|
534
|
+
gradient_accumulation_steps,
|
|
535
|
+
batches_accumulated
|
|
536
|
+
) -> List[float]:
|
|
537
|
+
loss_tensors = [
|
|
538
|
+
torch.tensor(loss * gradient_accumulation_steps / batches_accumulated,
|
|
539
|
+
device=TrainerTools().parallel.device)
|
|
540
|
+
for loss in losses
|
|
541
|
+
]
|
|
542
|
+
|
|
543
|
+
stacked_losses = torch.stack(loss_tensors)
|
|
544
|
+
if TrainerTools().parallel.parallel_train:
|
|
545
|
+
dist.all_reduce(stacked_losses, dist.ReduceOp.AVG)
|
|
546
|
+
|
|
547
|
+
return stacked_losses.detach().cpu().tolist()
|
|
548
|
+
|
|
549
|
+
def _get_pixel_values(self, batch_data):
|
|
550
|
+
return None
|
|
551
|
+
|
|
552
|
+
def train(self):
|
|
553
|
+
# 梯度累积步数
|
|
554
|
+
gradient_accumulation_steps = max(1, self.gradient_accumulation_steps)
|
|
555
|
+
|
|
556
|
+
loss_accumulation = 0.0
|
|
557
|
+
aux_loss_accumulation = 0.0
|
|
558
|
+
batches_accumulated = 0
|
|
559
|
+
|
|
560
|
+
for epoch in range(self.resume_epoch, self.train_config.n_epochs):
|
|
561
|
+
self.train_model.train()
|
|
562
|
+
file_count = len(self.train_config.file_dataset)
|
|
563
|
+
start_file_idx = self.resume_file_idx if epoch == self.resume_epoch else 0
|
|
564
|
+
|
|
565
|
+
for file_idx in range(start_file_idx, file_count):
|
|
566
|
+
dataset, file_path = self._create_dataset(file_idx)
|
|
567
|
+
train_data_loader = TrainerTools().parallel.process_dataloader(
|
|
568
|
+
dataset=dataset,
|
|
569
|
+
data_loader_kwargs=self.data_loader_kwargs,
|
|
570
|
+
sampler_kwargs=self.sampler_kwargs
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
last_ckpt_batch = 0
|
|
574
|
+
batch_count_per_file = len(train_data_loader)
|
|
575
|
+
|
|
576
|
+
TrainerTools().parallel.on_epoch_start(epoch)
|
|
577
|
+
self._on_file_start(epoch, file_path)
|
|
578
|
+
|
|
579
|
+
skip_batches = 0
|
|
580
|
+
if epoch == self.resume_epoch and file_idx == self.resume_file_idx:
|
|
581
|
+
skip_batches = self.resume_batch_idx
|
|
582
|
+
if skip_batches > 0 and TrainerTools().parallel.is_main_process:
|
|
583
|
+
Logger.std_log(f"Fast forwarding {skip_batches} batches in {file_path}...")
|
|
584
|
+
|
|
585
|
+
data_iterator = iter(train_data_loader)
|
|
586
|
+
|
|
587
|
+
if skip_batches > 0:
|
|
588
|
+
data_iterator = islice(data_iterator, skip_batches, None)
|
|
589
|
+
last_ckpt_batch = skip_batches
|
|
590
|
+
|
|
591
|
+
for batch, batch_data in enumerate(data_iterator):
|
|
592
|
+
batch = skip_batches + batch
|
|
593
|
+
|
|
594
|
+
# 是否需要更新梯度
|
|
595
|
+
if gradient_accumulation_steps > 1:
|
|
596
|
+
need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
|
|
597
|
+
else:
|
|
598
|
+
need_update_grad = True
|
|
599
|
+
|
|
600
|
+
inputs = batch_data['inputs']
|
|
601
|
+
labels = batch_data['labels']
|
|
602
|
+
|
|
603
|
+
try:
|
|
604
|
+
inputs, labels = inputs.to(TrainerTools().parallel.device), labels.to(TrainerTools().parallel.device)
|
|
605
|
+
attention_mask = inputs != TrainerTools().tokenizer.pad
|
|
606
|
+
pixel_values = self._get_pixel_values(batch_data)
|
|
607
|
+
|
|
608
|
+
if TrainerTools().parallel.parallel_train:
|
|
609
|
+
self.train_model.require_backward_grad_sync = need_update_grad
|
|
610
|
+
|
|
611
|
+
with autocast(TrainerTools().parallel.device_type):
|
|
612
|
+
result = self.train_model(
|
|
613
|
+
inputs,
|
|
614
|
+
attention_mask=attention_mask,
|
|
615
|
+
pixel_values=pixel_values
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# calc loss
|
|
619
|
+
loss = self._calc_loss(inputs, attention_mask, result['logits'], labels)
|
|
620
|
+
if result['aux_loss'] and self.train_config.loss_config.aux_loss_coef:
|
|
621
|
+
aux_loss = self.train_config.loss_config.aux_loss_coef * result['aux_loss']
|
|
622
|
+
else:
|
|
623
|
+
aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
|
|
624
|
+
|
|
625
|
+
if gradient_accumulation_steps > 1:
|
|
626
|
+
loss = loss / gradient_accumulation_steps
|
|
627
|
+
aux_loss = aux_loss / gradient_accumulation_steps
|
|
628
|
+
|
|
629
|
+
total_loss = loss + aux_loss
|
|
630
|
+
self._backward_loss(total_loss)
|
|
631
|
+
|
|
632
|
+
loss_accumulation += total_loss.detach().item()
|
|
633
|
+
aux_loss_accumulation += aux_loss.detach().item()
|
|
634
|
+
|
|
635
|
+
batches_accumulated += 1
|
|
636
|
+
|
|
637
|
+
if need_update_grad:
|
|
638
|
+
self._apply_grad_clipping()
|
|
639
|
+
self._apply_step()
|
|
640
|
+
|
|
641
|
+
avg_loss, avg_aux_loss = self._avg_loss(
|
|
642
|
+
losses=[
|
|
643
|
+
loss_accumulation,
|
|
644
|
+
aux_loss_accumulation
|
|
645
|
+
],
|
|
646
|
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
647
|
+
batches_accumulated=batches_accumulated
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
self._log(
|
|
651
|
+
keys={
|
|
652
|
+
'epoch': epoch,
|
|
653
|
+
'file': f'{file_idx + 1}/{file_count}',
|
|
654
|
+
'batch': f'{batch + 1}/{batch_count_per_file}'
|
|
655
|
+
},
|
|
656
|
+
values={
|
|
657
|
+
'loss': avg_loss,
|
|
658
|
+
'moe_aux_loss': avg_aux_loss
|
|
659
|
+
}
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
# reset to default
|
|
663
|
+
loss_accumulation = 0.0
|
|
664
|
+
aux_loss_accumulation = 0.0
|
|
665
|
+
batches_accumulated = 0
|
|
666
|
+
|
|
667
|
+
if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
|
|
668
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
669
|
+
save_steps(
|
|
670
|
+
epoch=epoch,
|
|
671
|
+
file_idx=file_idx,
|
|
672
|
+
batch_idx=batch + 1,
|
|
673
|
+
lr_scheduler=self.lr_scheduler
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
last_ckpt_batch = batch
|
|
677
|
+
self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
|
|
678
|
+
except Exception as e:
|
|
679
|
+
self._on_exception(e, epoch, batch)
|
|
680
|
+
|
|
681
|
+
# 一个文件训练结束后,清理内存
|
|
682
|
+
del train_data_loader
|
|
683
|
+
del dataset
|
|
684
|
+
if hasattr(TrainerTools().parallel, '_sampler'):
|
|
685
|
+
TrainerTools().parallel._sampler = None
|
|
686
|
+
|
|
687
|
+
gc.collect()
|
|
688
|
+
torch.cuda.empty_cache()
|
|
689
|
+
|
|
690
|
+
# end epoch
|
|
691
|
+
|
|
692
|
+
# reset resume state
|
|
693
|
+
self.resume_file_idx = 0
|
|
694
|
+
self.resume_batch_idx = 0
|
|
695
|
+
|
|
696
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
697
|
+
save_steps(
|
|
698
|
+
epoch=epoch + 1,
|
|
699
|
+
file_idx=0,
|
|
700
|
+
batch_idx=0,
|
|
701
|
+
lr_scheduler=self.lr_scheduler
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
TrainerTools().parallel.on_epoch_end(epoch)
|
|
705
|
+
self._on_epoch_end(tag=f'epoch:{epoch}')
|
|
706
|
+
|
|
707
|
+
TrainerTools().parallel.destroy()
|