project-llm-trainer 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,311 @@
1
+ from typing import Tuple, List, Optional
2
+ import gc
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from itertools import islice
6
+
7
+ from .base_trainer import BaseTrainer
8
+ from .train_configs import TrainConfig
9
+ from .dataset import DPODataset
10
+ from .loss import DPOLoss
11
+ from .tools import TrainerTools
12
+ from .utils import (
13
+ autocast,
14
+ get_dpo_collate_fn,
15
+ log_softmax,
16
+ disable_dropout_in_model
17
+ )
18
+
19
+ from .checkpoint import (
20
+ save_checkpoint,
21
+ save_steps,
22
+ )
23
+ from .log import Logger
24
+
25
+
26
+ class DPOTrainer(BaseTrainer):
27
+ def __init__(
28
+ self,
29
+ *,
30
+ train_config: TrainConfig,
31
+ eval_prompts: List[str]
32
+ ):
33
+ self.dpo_config = train_config.dpo_config
34
+ super().__init__(
35
+ train_config=train_config,
36
+ eval_prompts=eval_prompts,
37
+ gradient_accumulation_steps=self.dpo_config.gradient_accumulation_steps
38
+ )
39
+ self.ref_model = self._init_ref_model()
40
+
41
+ def _init_ref_model(self):
42
+ ref_model = self._new_model(self.train_config)
43
+
44
+ if self.dpo_config.ref_model_checkpoint:
45
+ ref_model.load_state_dict(self.dpo_config.ref_model_checkpoint)
46
+ self.dpo_config.ref_model_checkpoint = {}
47
+
48
+ ref_model.eval()
49
+ for param in ref_model.parameters():
50
+ param.requires_grad = False
51
+
52
+ ref_model, _ = TrainerTools().parallel.process(
53
+ model=ref_model,
54
+ optimizer=None,
55
+ kwargs=self._init_ref_model_args(),
56
+ save_instance=False
57
+ )
58
+
59
+ return ref_model
60
+
61
+ def _new_model(self, train_config: TrainConfig):
62
+ model = super()._new_model(train_config)
63
+ disable_dropout_in_model(model)
64
+ return model
65
+
66
+ def _init_loss(self):
67
+ criterion = DPOLoss(
68
+ beta=self.dpo_config.loss_beta,
69
+ label_smoothing=self.dpo_config.loss_label_smoothing,
70
+ ipo=self.dpo_config.loss_ipo
71
+ )
72
+
73
+ return criterion, None
74
+
75
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
76
+ dpo_collate_fn = get_dpo_collate_fn(self.dpo_config.mask_prompt)
77
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
78
+ data_loader_kwargs.update({"collate_fn": dpo_collate_fn})
79
+
80
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
81
+
82
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
83
+ file_path = self.train_config.file_dataset[file_idx]
84
+ block_size = self.train_config.dataset_block_size
85
+ return DPODataset(file_path, block_size), file_path
86
+
87
+ def _calc_loss(self, inputs, attention_mask, logits, labels): ...
88
+
89
+ def _logprobs(self, logits, labels):
90
+ """
91
+ Calculate the average log probabilities for a batch of sequences.
92
+
93
+ Args:
94
+ logits (torch.Tensor): Logits from the model with shape (B, T, V)
95
+ labels (torch.Tensor): Ground truth labels with shape (B, T).
96
+
97
+ Returns:
98
+ torch.Tensor: Average log probabilities for each sequence in the batch.
99
+ Shape is (B,) representing the mean log probability for each sequence.
100
+ """
101
+ loss_masks = (labels != -100)
102
+
103
+ logits = logits[:, :-1, :]
104
+ labels = labels[:, 1:].clone()
105
+ loss_masks = loss_masks[:, 1:]
106
+
107
+ # dummy token; we'll ignore the losses on these tokens later
108
+ labels[labels == -100] = 0
109
+
110
+ # Gather the log probabilities for the actual labels
111
+ per_token_logps = log_softmax(logits, labels)
112
+
113
+ # Apply the mask to set log-probs of padding tokens to 0
114
+ logprobs_sums = (per_token_logps * loss_masks).sum(-1)
115
+ logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1).clamp(min=1.0)
116
+
117
+ return logprobs_sums, logprobs_means
118
+
119
+ def train(self):
120
+ # 梯度累积步数
121
+ gradient_accumulation_steps = max(1, self.gradient_accumulation_steps)
122
+
123
+ loss_accumulation = 0.0
124
+ aux_loss_accumulation = 0.0
125
+ nll_loss_accumulation = 0.0
126
+ batches_accumulated = 0
127
+
128
+ aux_loss_coef = self.train_config.loss_config.aux_loss_coef
129
+ nll_loss_coef = self.dpo_config.nll_loss_coef
130
+
131
+ for epoch in range(self.resume_epoch, self.train_config.n_epochs):
132
+ self.train_model.train()
133
+ file_count = len(self.train_config.file_dataset)
134
+ start_file_idx = self.resume_file_idx if epoch == self.resume_epoch else 0
135
+
136
+ for file_idx in range(start_file_idx, file_count):
137
+ dataset, file_path = self._create_dataset(file_idx)
138
+ train_data_loader = TrainerTools().parallel.process_dataloader(
139
+ dataset=dataset,
140
+ data_loader_kwargs=self.data_loader_kwargs,
141
+ sampler_kwargs=self.sampler_kwargs
142
+ )
143
+
144
+ last_ckpt_batch = 0
145
+ batch_count_per_file = len(train_data_loader)
146
+
147
+ TrainerTools().parallel.on_epoch_start(epoch)
148
+ self._on_file_start(epoch, file_path)
149
+
150
+ skip_batches = 0
151
+ if epoch == self.resume_epoch and file_idx == self.resume_file_idx:
152
+ skip_batches = self.resume_batch_idx
153
+ if skip_batches > 0 and TrainerTools().parallel.is_main_process:
154
+ Logger.std_log(f"Fast forwarding {skip_batches} batches in {file_path}...")
155
+
156
+ data_iterator = iter(train_data_loader)
157
+ if skip_batches > 0:
158
+ data_iterator = islice(data_iterator, skip_batches, None)
159
+ last_ckpt_batch = skip_batches
160
+
161
+ for batch, batch_data in enumerate(data_iterator):
162
+ batch = skip_batches + batch
163
+
164
+ # 是否需要更新梯度
165
+ if gradient_accumulation_steps > 1:
166
+ need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
167
+ else:
168
+ need_update_grad = True
169
+
170
+ try:
171
+ chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
172
+ chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
173
+
174
+ rejected_inputs: torch.Tensor = batch_data['rejected_inputs'].to(TrainerTools().parallel.device)
175
+ rejected_labels: torch.Tensor = batch_data['rejected_labels'].to(TrainerTools().parallel.device)
176
+
177
+ chosen_attention_masks: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
178
+ rejected_attention_masks: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
179
+
180
+ # 在batch维度concat
181
+ # [chosen, chosen, reject, reject]
182
+ concat_inputs = torch.concat([chosen_inputs, rejected_inputs], dim=0)
183
+ concat_labels = torch.concat([chosen_labels, rejected_labels], dim=0)
184
+ concat_attention_masks = torch.concat([chosen_attention_masks, rejected_attention_masks], dim=0)
185
+
186
+ if TrainerTools().parallel.parallel_train:
187
+ self.train_model.require_backward_grad_sync = need_update_grad
188
+
189
+ with autocast(TrainerTools().parallel.device_type):
190
+ policy_outputs = self.train_model(concat_inputs, attention_mask=concat_attention_masks)
191
+ policy_logprobs_sums, policy_logprobs_means = self._logprobs(policy_outputs['logits'], concat_labels)
192
+
193
+ with torch.no_grad():
194
+ ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_attention_masks)
195
+ ref_logprobs_sums, _ = self._logprobs(ref_outputs['logits'], concat_labels)
196
+
197
+ policy_chosen_logps = policy_logprobs_sums[:chosen_inputs.shape[0]]
198
+ policy_rejected_logps = policy_logprobs_sums[chosen_inputs.shape[0]:]
199
+
200
+ ref_chosen_logps = ref_logprobs_sums[:chosen_inputs.shape[0]]
201
+ ref_rejected_logps = ref_logprobs_sums[chosen_inputs.shape[0]:]
202
+
203
+ nll_loss = -policy_logprobs_means[:chosen_inputs.shape[0]].mean()
204
+
205
+ # calc loss
206
+ loss = self.criterion(
207
+ policy_chosen_logps,
208
+ policy_rejected_logps,
209
+ ref_chosen_logps,
210
+ ref_rejected_logps
211
+ )
212
+
213
+ if aux_loss_coef and policy_outputs.get('aux_loss'):
214
+ aux_loss = aux_loss_coef * policy_outputs.get('aux_loss')
215
+ else:
216
+ aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
217
+
218
+ if nll_loss_coef and nll_loss:
219
+ nll_loss = nll_loss_coef * nll_loss
220
+ else:
221
+ nll_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
222
+
223
+ if gradient_accumulation_steps > 1:
224
+ loss = loss / gradient_accumulation_steps
225
+ aux_loss = aux_loss / gradient_accumulation_steps
226
+ nll_loss = nll_loss / gradient_accumulation_steps
227
+
228
+ total_loss = loss + aux_loss + nll_loss
229
+ self._backward_loss(total_loss)
230
+
231
+ loss_accumulation += total_loss.detach().item()
232
+ aux_loss_accumulation += aux_loss.detach().item()
233
+ nll_loss_accumulation += nll_loss.detach().item()
234
+
235
+ batches_accumulated += 1
236
+
237
+ if need_update_grad:
238
+ self._apply_grad_clipping()
239
+ self._apply_step()
240
+
241
+ avg_loss, avg_aux_loss, avg_nll_loss = self._avg_loss(
242
+ losses=[
243
+ loss_accumulation,
244
+ aux_loss_accumulation,
245
+ nll_loss_accumulation,
246
+ ],
247
+ gradient_accumulation_steps=gradient_accumulation_steps,
248
+ batches_accumulated=batches_accumulated
249
+ )
250
+
251
+ self._log(
252
+ keys={
253
+ 'epoch': epoch,
254
+ 'file': f'{file_idx + 1}/{file_count}',
255
+ 'batch': f'{batch + 1}/{batch_count_per_file}',
256
+ },
257
+ values={
258
+ 'loss': avg_loss,
259
+ 'moe_aux_loss': avg_aux_loss,
260
+ 'nll_loss': avg_nll_loss
261
+ }
262
+ )
263
+
264
+ # reset to default
265
+ loss_accumulation = 0.0
266
+ aux_loss_accumulation = 0.0
267
+ nll_loss_accumulation = 0.0
268
+ batches_accumulated = 0
269
+
270
+ if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
271
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
272
+ save_steps(
273
+ epoch=epoch,
274
+ file_idx=file_idx,
275
+ batch_idx=batch + 1,
276
+ lr_scheduler=self.lr_scheduler
277
+ )
278
+
279
+ last_ckpt_batch = batch
280
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
281
+ except Exception as e:
282
+ self._on_exception(e, epoch, batch)
283
+
284
+ # 一个文件训练结束后,清理内存
285
+ del train_data_loader
286
+ del dataset
287
+ if hasattr(TrainerTools().parallel, '_sampler'):
288
+ TrainerTools().parallel._sampler = None
289
+
290
+ gc.collect()
291
+ torch.cuda.empty_cache()
292
+
293
+ # end epoch
294
+
295
+ # reset resume state
296
+ self.resume_file_idx = 0
297
+ self.resume_batch_idx = 0
298
+
299
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
300
+ save_steps(
301
+ epoch=epoch + 1,
302
+ file_idx=0,
303
+ batch_idx=0,
304
+ lr_scheduler=self.lr_scheduler
305
+ )
306
+
307
+ TrainerTools().parallel.on_epoch_end(epoch)
308
+ self._on_epoch_end(tag=f'epoch:{epoch}')
309
+
310
+ TrainerTools().parallel.destroy()
311
+
@@ -0,0 +1,72 @@
1
+ import os
2
+ from glob import glob
3
+ from typing import Optional
4
+ import shutil
5
+ import torch
6
+ from torch import nn
7
+ from .tools import TrainerTools
8
+
9
+ try:
10
+ import deepspeed
11
+ from deepspeed import DeepSpeedEngine
12
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
13
+ except: ...
14
+
15
+ """
16
+ 函数 功能 是否加载模型到内存 是否保存到文件 主要用途
17
+ get_fp32_state_dict_from_zero_checkpoint 从 ZeRO 检查点提取 FP32 状态字典 否 否 获取模型权重,用于推理、迁移等
18
+ load_state_dict_from_zero_checkpoint 从 ZeRO 检查点加载模型和优化器状态 是 否 恢复训练状态,继续训练
19
+ convert_zero_checkpoint_to_fp32_state_dict 将 ZeRO 检查点转换为独立的 FP32 状态字典文件 否 是 创建可移植的 FP32 权重文件,用于部署、分享等
20
+ """
21
+
22
+ def save_ds_checkpoint(
23
+ model: nn.Module,
24
+ extra_module: Optional[nn.Module] = None
25
+ ):
26
+ assert isinstance(model, DeepSpeedEngine)
27
+ ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
28
+
29
+ try:
30
+ # 包括model、optimizer等状态
31
+ model.save_checkpoint(save_dir=ckpt_dir)
32
+ except: ...
33
+
34
+ # 只在main rank上执行
35
+ if TrainerTools().parallel.is_main_process:
36
+ if extra_module:
37
+ torch.save(extra_module.state_dict(), os.path.join(ckpt_dir, "extra_module_state_dict.pt"))
38
+
39
+ # 最多保存多少checkpoint,默认为2
40
+ max_to_keep = int(os.environ.get('CKPT_MAX_TO_KEEP', '2'))
41
+ # 删除历史checkpoint
42
+ ckpt_paths = glob(os.path.join(ckpt_dir, "global_*"))
43
+ if len(ckpt_paths) > max_to_keep:
44
+ # 按修改时间排序,找到最旧的目录
45
+ oldest_ckpt = sorted(ckpt_paths, key=os.path.getmtime)[0]
46
+ try:
47
+ shutil.rmtree(oldest_ckpt)
48
+ except: ...
49
+
50
+ TrainerTools().parallel.wait('remove old ds checkpoint')
51
+
52
+
53
+ def load_ds_checkpoint(
54
+ model: nn.Module,
55
+ load_module_only: bool = False,
56
+ extra_module: Optional[nn.Module] = None
57
+ ):
58
+ assert isinstance(model, DeepSpeedEngine)
59
+ ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
60
+
61
+ # 包括model、optimizer等状态
62
+ if os.path.exists(ckpt_dir):
63
+ model.load_checkpoint(
64
+ load_dir=ckpt_dir,
65
+ load_module_only=load_module_only
66
+ )
67
+
68
+ path = os.path.join(ckpt_dir, "extra_module_state_dict.pt")
69
+ if os.path.exists(path):
70
+ state = torch.load(path, map_location=TrainerTools().parallel.device, weights_only=True)
71
+ extra_module.load_state_dict(state)
72
+
llm_trainer/eval.py ADDED
@@ -0,0 +1,33 @@
1
+ import os
2
+ import torch
3
+
4
+ from .generate_utils import generate
5
+ from .tools import TrainerTools
6
+ from .train_configs import TrainConfig
7
+ from .log import _get_log_dir
8
+
9
+ def submit_gen_task(
10
+ eval_model: torch.nn.Module,
11
+ train_config: TrainConfig,
12
+ tag,
13
+ prompt,
14
+ pixel_values,
15
+ tokens_per_image
16
+ ):
17
+ tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True)
18
+ max_new_tokens = max(train_config.eval_config.max_seq_len - tokens.shape[1], 0)
19
+
20
+ gen_result = generate(
21
+ eval_model,
22
+ prompt=tokens,
23
+ max_new_tokens=max_new_tokens,
24
+ temperature=train_config.eval_config.temperature,
25
+ k=train_config.eval_config.top_k,
26
+ p=train_config.eval_config.top_p,
27
+ pixel_values=pixel_values,
28
+ tokens_per_image=tokens_per_image,
29
+ device=TrainerTools().parallel.device
30
+ )
31
+
32
+ with open(os.path.join(_get_log_dir(), 'gen.txt'), 'a') as f:
33
+ f.write(f"{tag}, gen->{gen_result}\n")