project-llm-trainer 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,297 @@
1
+ from typing import Tuple, List, Optional
2
+ import gc
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+
6
+ from .base_trainer import BaseTrainer
7
+ from .train_configs import TrainConfig
8
+ from .dataset import DPODataset
9
+ from .loss import DPOLoss
10
+ from .tools import TrainerTools
11
+ from .utils import (
12
+ autocast,
13
+ get_dpo_collate_fn,
14
+ log_softmax,
15
+ disable_dropout_in_model
16
+ )
17
+
18
+ from .checkpoint import (
19
+ save_checkpoint,
20
+ save_steps,
21
+ )
22
+
23
+
24
+ class DPOTrainer(BaseTrainer):
25
+ def __init__(
26
+ self,
27
+ *,
28
+ train_config: TrainConfig,
29
+ eval_prompts: List[str]
30
+ ):
31
+ self.dpo_config = train_config.dpo_config
32
+ super().__init__(
33
+ train_config=train_config,
34
+ eval_prompts=eval_prompts,
35
+ gradient_accumulation_steps=self.dpo_config.gradient_accumulation_steps
36
+ )
37
+ self.ref_model = self._init_ref_model()
38
+
39
+ def _init_ref_model(self):
40
+ ref_model = self._new_model(self.train_config)
41
+
42
+ if self.dpo_config.ref_model_checkpoint:
43
+ ref_model.load_state_dict(self.dpo_config.ref_model_checkpoint)
44
+ self.dpo_config.ref_model_checkpoint = {}
45
+
46
+ ref_model.eval()
47
+ for param in ref_model.parameters():
48
+ param.requires_grad = False
49
+
50
+ ref_model, _ = TrainerTools().parallel.process(
51
+ model=ref_model,
52
+ optimizer=None,
53
+ kwargs=self._init_ref_model_args(),
54
+ save_instance=False
55
+ )
56
+
57
+ return ref_model
58
+
59
+ def _new_model(self, train_config: TrainConfig):
60
+ model = super()._new_model(train_config)
61
+ disable_dropout_in_model(model)
62
+ return model
63
+
64
+ def _init_loss(self):
65
+ criterion = DPOLoss(
66
+ beta=self.dpo_config.loss_beta,
67
+ label_smoothing=self.dpo_config.loss_label_smoothing,
68
+ ipo=self.dpo_config.loss_ipo
69
+ )
70
+
71
+ return criterion, None
72
+
73
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
74
+ dpo_collate_fn = get_dpo_collate_fn(self.dpo_config.mask_prompt)
75
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
76
+ data_loader_kwargs.update({"collate_fn": dpo_collate_fn})
77
+
78
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
79
+
80
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
81
+ file_path = self.train_config.file_dataset[file_idx]
82
+ max_seq_len = self.train_config.max_seq_len
83
+ return DPODataset(file_path, max_seq_len), file_path
84
+
85
+ def _calc_loss(self, inputs, attention_mask, logits, labels): ...
86
+
87
+ def _logprobs(self, logits, labels):
88
+ """
89
+ Calculate the average log probabilities for a batch of sequences.
90
+
91
+ Args:
92
+ logits (torch.Tensor): Logits from the model with shape (B, T, V)
93
+ labels (torch.Tensor): Ground truth labels with shape (B, T).
94
+
95
+ Returns:
96
+ torch.Tensor: Average log probabilities for each sequence in the batch.
97
+ Shape is (B,) representing the mean log probability for each sequence.
98
+ """
99
+ loss_masks = (labels != -100)
100
+
101
+ logits = logits[:, :-1, :]
102
+ labels = labels[:, 1:].clone()
103
+ loss_masks = loss_masks[:, 1:]
104
+
105
+ # dummy token; we'll ignore the losses on these tokens later
106
+ labels[labels == -100] = 0
107
+
108
+ # Gather the log probabilities for the actual labels
109
+ per_token_logps = log_softmax(logits, labels)
110
+
111
+ # Apply the mask to set log-probs of padding tokens to 0
112
+ logprobs_sums = (per_token_logps * loss_masks).sum(-1)
113
+ logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1).clamp(min=1.0)
114
+
115
+ return logprobs_sums, logprobs_means
116
+
117
+ def train(self):
118
+ # 梯度累积步数
119
+ gradient_accumulation_steps = max(1, self.gradient_accumulation_steps)
120
+ global_steps = 0
121
+ skipping_train = False
122
+
123
+ loss_accumulation = 0.0
124
+ aux_loss_accumulation = 0.0
125
+ nll_loss_accumulation = 0.0
126
+ batches_accumulated = 0
127
+
128
+ aux_loss_coef = self.train_config.loss_config.aux_loss_coef
129
+ nll_loss_coef = self.dpo_config.nll_loss_coef
130
+
131
+ for epoch in range(self.train_config.n_epochs):
132
+ self.train_model.train()
133
+ file_count = len(self.train_config.file_dataset)
134
+
135
+ for file_idx in range(file_count):
136
+ dataset, file_path = self._create_dataset(file_idx)
137
+ train_data_loader = TrainerTools().parallel.process_dataloader(
138
+ dataset=dataset,
139
+ data_loader_kwargs=self.data_loader_kwargs,
140
+ sampler_kwargs=self.sampler_kwargs
141
+ )
142
+
143
+ last_ckpt_batch = 0
144
+ batch_count_per_file = len(train_data_loader)
145
+
146
+ TrainerTools().parallel.on_epoch_start(epoch)
147
+ self._on_file_start(epoch, file_path)
148
+
149
+ for batch, batch_data in enumerate(train_data_loader):
150
+ global_steps += 1
151
+ if global_steps < self.last_global_steps:
152
+ skipping_train = True
153
+ continue
154
+
155
+ # 是否需要更新梯度
156
+ if skipping_train:
157
+ need_update_grad = False
158
+ elif gradient_accumulation_steps > 1:
159
+ need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
160
+ else:
161
+ need_update_grad = True
162
+
163
+ # 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
164
+ if skipping_train:
165
+ TrainerTools().parallel.wait('skip train')
166
+ skipping_train = False
167
+
168
+ try:
169
+ chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
170
+ chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
171
+
172
+ rejected_inputs: torch.Tensor = batch_data['rejected_inputs'].to(TrainerTools().parallel.device)
173
+ rejected_labels: torch.Tensor = batch_data['rejected_labels'].to(TrainerTools().parallel.device)
174
+
175
+ chosen_attention_masks: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
176
+ rejected_attention_masks: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
177
+
178
+ # 在batch维度concat
179
+ # [chosen, chosen, reject, reject]
180
+ concat_inputs = torch.concat([chosen_inputs, rejected_inputs], dim=0)
181
+ concat_labels = torch.concat([chosen_labels, rejected_labels], dim=0)
182
+ concat_attention_masks = torch.concat([chosen_attention_masks, rejected_attention_masks], dim=0)
183
+
184
+ if TrainerTools().parallel.parallel_train:
185
+ self.train_model.require_backward_grad_sync = need_update_grad
186
+
187
+ with autocast(TrainerTools().parallel.device_type):
188
+ policy_outputs = self.train_model(concat_inputs, attention_mask=concat_attention_masks)
189
+ policy_logprobs_sums, policy_logprobs_means = self._logprobs(policy_outputs['logits'], concat_labels)
190
+
191
+ with torch.no_grad():
192
+ ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_attention_masks)
193
+ ref_logprobs_sums, _ = self._logprobs(ref_outputs['logits'], concat_labels)
194
+
195
+ policy_chosen_logps = policy_logprobs_sums[:chosen_inputs.shape[0]]
196
+ policy_rejected_logps = policy_logprobs_sums[chosen_inputs.shape[0]:]
197
+
198
+ ref_chosen_logps = ref_logprobs_sums[:chosen_inputs.shape[0]]
199
+ ref_rejected_logps = ref_logprobs_sums[chosen_inputs.shape[0]:]
200
+
201
+ nll_loss = -policy_logprobs_means[:chosen_inputs.shape[0]].mean()
202
+
203
+ # calc loss
204
+ loss = self.criterion(
205
+ policy_chosen_logps,
206
+ policy_rejected_logps,
207
+ ref_chosen_logps,
208
+ ref_rejected_logps
209
+ )
210
+
211
+ if aux_loss_coef and policy_outputs.get('aux_loss'):
212
+ aux_loss = aux_loss_coef * policy_outputs.get('aux_loss')
213
+ else:
214
+ aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
215
+
216
+ if nll_loss_coef and nll_loss:
217
+ nll_loss = nll_loss_coef * nll_loss
218
+ else:
219
+ nll_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
220
+
221
+ if gradient_accumulation_steps > 1:
222
+ loss = loss / gradient_accumulation_steps
223
+ aux_loss = aux_loss / gradient_accumulation_steps
224
+ nll_loss = nll_loss / gradient_accumulation_steps
225
+
226
+ total_loss = loss + aux_loss + nll_loss
227
+ self._backward_loss(total_loss)
228
+
229
+ loss_accumulation += total_loss.detach().item()
230
+ aux_loss_accumulation += aux_loss.detach().item()
231
+ nll_loss_accumulation += nll_loss.detach().item()
232
+
233
+ batches_accumulated += 1
234
+
235
+ if need_update_grad:
236
+ self._apply_grad_clipping()
237
+ self._apply_step()
238
+
239
+ avg_loss, avg_aux_loss, avg_nll_loss = self._avg_loss(
240
+ losses=[
241
+ loss_accumulation,
242
+ aux_loss_accumulation,
243
+ nll_loss_accumulation,
244
+ ],
245
+ gradient_accumulation_steps=gradient_accumulation_steps,
246
+ batches_accumulated=batches_accumulated
247
+ )
248
+
249
+ self._log(
250
+ keys={
251
+ 'epoch': epoch,
252
+ 'file': f'{file_idx + 1}/{file_count}',
253
+ 'batch': f'{batch}/{batch_count_per_file}',
254
+ },
255
+ values={
256
+ 'loss': avg_loss,
257
+ 'moe_aux_loss': avg_aux_loss,
258
+ 'nll_loss': avg_nll_loss
259
+ }
260
+ )
261
+
262
+ # reset to default
263
+ loss_accumulation = 0.0
264
+ aux_loss_accumulation = 0.0
265
+ nll_loss_accumulation = 0.0
266
+ batches_accumulated = 0
267
+ except Exception as e:
268
+ self._on_exception(e, epoch, batch)
269
+ finally:
270
+ if need_update_grad:
271
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
272
+
273
+ if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
274
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
275
+
276
+ last_ckpt_batch = batch
277
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
278
+
279
+ # 一个文件训练结束后,清理内存
280
+ del train_data_loader
281
+ del dataset
282
+ if hasattr(TrainerTools().parallel, '_sampler'):
283
+ TrainerTools().parallel._sampler = None
284
+
285
+ gc.collect()
286
+ torch.cuda.empty_cache()
287
+
288
+ # end epoch
289
+ if not skipping_train:
290
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
291
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
292
+
293
+ TrainerTools().parallel.on_epoch_end(epoch)
294
+ self._on_epoch_end(tag=f'epoch:{epoch}')
295
+
296
+ TrainerTools().parallel.destroy()
297
+
@@ -0,0 +1,63 @@
1
+ import os
2
+ from glob import glob
3
+ import shutil
4
+ from torch import nn
5
+ from .tools import TrainerTools
6
+
7
+ try:
8
+ import deepspeed
9
+ from deepspeed import DeepSpeedEngine
10
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
11
+ except: ...
12
+
13
+ """
14
+ 函数 功能 是否加载模型到内存 是否保存到文件 主要用途
15
+ get_fp32_state_dict_from_zero_checkpoint 从 ZeRO 检查点提取 FP32 状态字典 否 否 获取模型权重,用于推理、迁移等
16
+ load_state_dict_from_zero_checkpoint 从 ZeRO 检查点加载模型和优化器状态 是 否 恢复训练状态,继续训练
17
+ convert_zero_checkpoint_to_fp32_state_dict 将 ZeRO 检查点转换为独立的 FP32 状态字典文件 否 是 创建可移植的 FP32 权重文件,用于部署、分享等
18
+ """
19
+
20
+ def save_ds_checkpoint(model: nn.Module):
21
+ assert isinstance(model, DeepSpeedEngine)
22
+ ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
23
+
24
+ try:
25
+ # 包括model、optimizer等状态
26
+ model.save_checkpoint(save_dir=ckpt_dir)
27
+ except: ...
28
+
29
+ # 只在main rank上执行
30
+ if TrainerTools().parallel.is_main_process:
31
+ # 最多保存多少checkpoint,默认为2
32
+ max_to_keep = int(os.environ.get('CKPT_MAX_TO_KEEP', '2'))
33
+ # 删除历史checkpoint
34
+ ckpt_paths = glob(os.path.join(ckpt_dir, "global_*"))
35
+ if len(ckpt_paths) > max_to_keep:
36
+ # 按修改时间排序,找到最旧的目录
37
+ oldest_ckpt = sorted(ckpt_paths, key=os.path.getmtime)[0]
38
+ try:
39
+ shutil.rmtree(oldest_ckpt)
40
+ except: ...
41
+
42
+ TrainerTools().parallel.wait('remove old ds checkpoint')
43
+
44
+
45
+ def load_ds_checkpoint(
46
+ model: nn.Module,
47
+ load_module_only: bool = False
48
+ ):
49
+ assert isinstance(model, DeepSpeedEngine)
50
+ ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
51
+
52
+ # 包括model、optimizer等状态
53
+ if os.path.exists(ckpt_dir):
54
+ model.load_checkpoint(
55
+ load_dir=ckpt_dir,
56
+ load_module_only=load_module_only
57
+ )
58
+
59
+
60
+ def load_ds_checkpoint_for_eval(model: nn.Module):
61
+ ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
62
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(ckpt_dir)
63
+ model.load_state_dict(state_dict)
llm_trainer/eval.py ADDED
@@ -0,0 +1,33 @@
1
+ import os
2
+ import torch
3
+
4
+ from .generate_utils import generate
5
+ from .tools import TrainerTools
6
+ from .train_configs import TrainConfig
7
+ from .log import _get_log_dir
8
+
9
+ def submit_gen_task(
10
+ eval_model: torch.nn.Module,
11
+ train_config: TrainConfig,
12
+ tag,
13
+ prompt,
14
+ pixel_values,
15
+ tokens_per_image
16
+ ):
17
+ tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True)
18
+ max_new_tokens = train_config.eval_config.max_new_tokens
19
+
20
+ gen_result = generate(
21
+ eval_model,
22
+ prompt=tokens,
23
+ max_new_tokens=max_new_tokens,
24
+ temperature=train_config.eval_config.temperature,
25
+ k=train_config.eval_config.top_k,
26
+ p=train_config.eval_config.top_p,
27
+ pixel_values=pixel_values,
28
+ tokens_per_image=tokens_per_image,
29
+ device=TrainerTools().parallel.device
30
+ )
31
+
32
+ with open(os.path.join(_get_log_dir(), 'gen.txt'), 'a') as f:
33
+ f.write(f"{tag}, gen->{gen_result}\n")