project-llm-trainer 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,300 @@
1
+ import time
2
+ from typing import Tuple, List, Optional
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torch.distributed as dist
6
+ import torch.nn.functional as F
7
+
8
+ from llm_model import LlmModel
9
+
10
+ from .parallel_ds import DsParallel
11
+ from .parallel_fsdp import FsdpParallel
12
+ from .trainer import Trainer
13
+ from .train_configs import TrainConfig
14
+ from .dataset import DPODataset
15
+ from .loss import DPOLoss
16
+ from .tools import TrainerTools
17
+ from .utils import get_dpo_collate_fn
18
+
19
+ from .checkpoint import (
20
+ save_checkpoint,
21
+ load_checkpoint_for_eval,
22
+ save_steps,
23
+ )
24
+
25
+ class DPOTrainer(Trainer):
26
+ def __init__(
27
+ self,
28
+ *,
29
+ train_config: TrainConfig,
30
+ eval_prompts: List[str],
31
+ eval_image_tags: Optional[List[int]] = None
32
+ ):
33
+ super().__init__(
34
+ train_config=train_config,
35
+ eval_prompts=eval_prompts,
36
+ eval_image_tags=eval_image_tags
37
+ )
38
+
39
+ self.reference_model = self._init_reference_model()
40
+
41
+ def _init_reference_model(self):
42
+ parallel = TrainerTools().new_parallel()
43
+
44
+ reference_model = LlmModel(self.train_config.model_config)
45
+ if self.train_config.init_state_dict:
46
+ reference_model.load_state_dict(self.train_config.init_state_dict, strict=False)
47
+ self.train_config.init_state_dict = None
48
+ else:
49
+ load_checkpoint_for_eval(model=reference_model, device=parallel.device)
50
+
51
+ reference_model, _ = parallel.process(
52
+ model=reference_model,
53
+ optimizer=None,
54
+ kwargs=self._init_reference_args()
55
+ )
56
+
57
+ parallel.raw_model.eval()
58
+ for param in parallel.raw_model.parameters():
59
+ param.requires_grad = False
60
+
61
+ return reference_model
62
+
63
+ def _init_reference_args(self):
64
+ if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
65
+ parallel_kwargs = {
66
+ 'gradient_accumulation_steps': 1,
67
+ 'train_micro_batch_size_per_gpu': 1
68
+ }
69
+
70
+ if self.train_config.ds_config.zero_config:
71
+ zero_optimization = {'stage': 0}
72
+ parallel_kwargs['zero_optimization'] = zero_optimization
73
+
74
+ if self.train_config.ds_config.fp16_config:
75
+ fb16_config = self.train_config.ds_config.fp16_config
76
+ fp16 = { 'enabled': fb16_config.enabled }
77
+
78
+ if fb16_config.fp16_opt_level is not None:
79
+ fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
80
+
81
+ parallel_kwargs['fp16'] = fp16
82
+
83
+ if self.train_config.ds_config.bf16_config:
84
+ bf16_config = self.train_config.ds_config.bf16_config
85
+ bf16 = { 'enabled': bf16_config.enabled }
86
+ parallel_kwargs['bf16'] = bf16
87
+ elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
88
+ parallel_kwargs = {
89
+ 'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
90
+ 'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
91
+ 'cpu_offload': self.train_config.fsdp_config.cpu_offload,
92
+ 'offload_params': self.train_config.fsdp_config.offload_params
93
+ }
94
+ else:
95
+ parallel_kwargs = None
96
+
97
+ return parallel_kwargs
98
+
99
+ def _init_loss(self):
100
+ criterion = DPOLoss(
101
+ beta=self.train_config.dpo_config.loss_beta,
102
+ label_smoothing=self.train_config.dpo_config.loss_label_smoothing,
103
+ ipo=self.train_config.dpo_config.loss_ipo
104
+ )
105
+
106
+ return criterion, None
107
+
108
+ def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
109
+ dpo_collate_fn = get_dpo_collate_fn(self.train_config.mask_prompt)
110
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = super()._convert_train_args()
111
+ data_loader_kwargs.update({"collate_fn": dpo_collate_fn})
112
+
113
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
114
+
115
+ def _create_dataset(self, file_path) -> Dataset:
116
+ max_position_embeddings = self.train_config.model_config.max_position_embeddings
117
+ return DPODataset(file_path, max_position_embeddings)
118
+
119
+ def _calc_loss(self, inputs, attention_mask, logits, labels): ...
120
+
121
+ def _log_probs_from_logits(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
122
+ # https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881
123
+ if logits.dtype in [torch.float32, torch.float64]:
124
+ logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
125
+ logsumexp_values = torch.stack(
126
+ [torch.logsumexp(l, dim=-1) for l in logits] # loop to reduce peak mem consumption
127
+ )
128
+ log_probs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
129
+ else:
130
+ log_probs_labels = []
131
+ for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption
132
+ row_log_probs = F.log_softmax(row_logits, dim=-1)
133
+ row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
134
+ log_probs_labels.append(row_log_probs_labels)
135
+ log_probs_labels = torch.stack(log_probs_labels)
136
+
137
+ return log_probs_labels
138
+
139
+
140
+ def _logprobs(self, logits, labels, mask):
141
+ """
142
+ Calculate the average log probabilities for a batch of sequences.
143
+
144
+ Args:
145
+ logits (torch.Tensor): Logits from the model with shape (B, T, V)
146
+ labels (torch.Tensor): Ground truth labels with shape (B, T).
147
+ mask (torch.Tensor): Mask tensor with shape (B, T) indicating
148
+ which tokens are not padding (1 for valid tokens, 0 for padding).
149
+
150
+ Returns:
151
+ torch.Tensor: Average log probabilities for each sequence in the batch.
152
+ Shape is (B,) representing the mean log probability for each sequence.
153
+ """
154
+ labels = labels[:, 1:].clone()
155
+ logits = logits[:, :-1, :]
156
+
157
+ # # Shift mask right by one to align with labels
158
+ mask = mask[:, 1:].clone()
159
+
160
+ # dummy token; we'll ignore the losses on these tokens later
161
+ labels[labels == -100] = 0
162
+
163
+ # Gather the log probabilities for the actual labels
164
+ per_token_logps = self._log_probs_from_logits(logits, labels)
165
+
166
+ # Apply the mask to set log-probs of padding tokens to 0
167
+ logprobs_sums = (per_token_logps * mask).sum(-1)
168
+
169
+ # logprobs_means = (per_token_logps * mask).sum(-1) / mask.sum(-1)
170
+
171
+ return logprobs_sums #, -logprobs_means.mean()
172
+
173
+ def train(self):
174
+ # 梯度累积步数
175
+ gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
176
+ global_steps = 0
177
+ loss_accumulation = 0.0
178
+ skipping_train = False
179
+
180
+ aux_loss_coef = self.train_config.loss_config.aux_loss_coef
181
+
182
+ for epoch in range(self.train_config.n_epochs):
183
+ self.train_model.train()
184
+ file_count = len(self.train_config.file_dataset)
185
+
186
+ for file_idx in range(file_count):
187
+ file_path = self.train_config.file_dataset[file_idx]
188
+
189
+ dataset = self._create_dataset(file_path)
190
+ train_data_loader = TrainerTools().parallel.process_dataloader(
191
+ dataset=dataset,
192
+ data_loader_kwargs=self.data_loader_kwargs,
193
+ sampler_kwargs=self.sampler_kwargs
194
+ )
195
+
196
+ last_ckpt_batch = 0
197
+ batch_count_per_file = len(train_data_loader)
198
+
199
+ TrainerTools().parallel.on_epoch_start(epoch)
200
+ self._on_file_start(epoch, file_path)
201
+
202
+ for batch, batch_data in enumerate(train_data_loader):
203
+ global_steps += 1
204
+ if global_steps < self.last_global_steps:
205
+ skipping_train = True
206
+ continue
207
+
208
+ skipping_train = False
209
+
210
+ # 是否需要更新梯度
211
+ if gradient_accumulation_steps > 1:
212
+ need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
213
+ else:
214
+ need_update_grad = True
215
+
216
+ try:
217
+ chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
218
+ chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
219
+ rejected_inputs: torch.Tensor = batch_data['rejected_inputs'].to(TrainerTools().parallel.device)
220
+ rejected_labels: torch.Tensor = batch_data['rejected_labels'].to(TrainerTools().parallel.device)
221
+
222
+ chosen_attention_mask: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
223
+ rejected_attention_mask: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
224
+
225
+ # 在batch维度concat
226
+ # [chosen, chosen, reject, reject]
227
+ concat_inputs = torch.concat([chosen_inputs, rejected_inputs], dim=0)
228
+ concat_labels = torch.concat([chosen_labels, rejected_labels], dim=0)
229
+ concat_mask = torch.concat([chosen_attention_mask, rejected_attention_mask], dim=0)
230
+
231
+ if TrainerTools().parallel.parallel_train:
232
+ self.train_model.require_backward_grad_sync = need_update_grad
233
+
234
+ with self.ctx:
235
+ policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
236
+ with torch.inference_mode():
237
+ ref_outputs = self.reference_model(concat_inputs, attention_mask=concat_mask)
238
+
239
+ policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
240
+ ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
241
+
242
+ # calc loss
243
+ loss = self.criterion(policy_probs, ref_probs)
244
+
245
+ if aux_loss_coef and policy_outputs['aux_loss']:
246
+ loss += aux_loss_coef * policy_outputs['aux_loss']
247
+
248
+ if gradient_accumulation_steps > 1:
249
+ loss = loss / gradient_accumulation_steps
250
+
251
+ loss_accumulation += loss.detach()
252
+ self._backward_loss(loss)
253
+
254
+ if need_update_grad:
255
+ # todo check all_reduce??
256
+ if TrainerTools().parallel.parallel_train:
257
+ dist.all_reduce(loss_accumulation, dist.ReduceOp.AVG)
258
+
259
+ # ds模式已经集成gradient_clipping
260
+ if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
261
+ # clip grad
262
+ self.scalar.unscale_(self.optimizer)
263
+ torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
264
+
265
+ self._step()
266
+
267
+ self._log_loss(
268
+ epoch_tag=f'epoch: {epoch}',
269
+ file_tag=f'file: {file_idx + 1}/{file_count}',
270
+ batch_tag=f'batch: {batch}/{batch_count_per_file}',
271
+ loss=loss_accumulation.item()
272
+ )
273
+ # reset to default
274
+ loss_accumulation = 0.0
275
+ except Exception as e:
276
+ self._on_exception(e, epoch, batch)
277
+ finally:
278
+ if need_update_grad:
279
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
280
+
281
+ if (batch - last_ckpt_batch) >= self.train_config.eval_batch_interval:
282
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
283
+ last_ckpt_batch = batch
284
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
285
+
286
+ try:
287
+ del loss
288
+ except UnboundLocalError: ...
289
+
290
+ # end epoch
291
+ if not skipping_train:
292
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
293
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
294
+ TrainerTools().parallel.on_epoch_end(epoch)
295
+ self._on_epoch_end(tag=f'epoch:{epoch}')
296
+
297
+ # 等待checkpoint保存完成
298
+ time.sleep(10)
299
+ TrainerTools().parallel.destroy()
300
+
@@ -0,0 +1,61 @@
1
+ import os
2
+ from typing import Optional
3
+ from glob import glob
4
+ import shutil
5
+ from torch import nn
6
+ try:
7
+ from deepspeed import DeepSpeedEngine
8
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
9
+ except: ...
10
+
11
+ """
12
+ 函数 功能 是否加载模型到内存 是否保存到文件 主要用途
13
+ get_fp32_state_dict_from_zero_checkpoint 从 ZeRO 检查点提取 FP32 状态字典 否 否 获取模型权重,用于推理、迁移等
14
+ load_state_dict_from_zero_checkpoint 从 ZeRO 检查点加载模型和优化器状态 是 否 恢复训练状态,继续训练
15
+ convert_zero_checkpoint_to_fp32_state_dict 将 ZeRO 检查点转换为独立的 FP32 状态字典文件 否 是 创建可移植的 FP32 权重文件,用于部署、分享等
16
+ """
17
+
18
+ def save_ds_checkpoint(
19
+ model: nn.Module,
20
+ suffix: Optional[str] = None
21
+ ):
22
+ assert isinstance(model, DeepSpeedEngine)
23
+ ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
24
+ if suffix:
25
+ ckpt_dir = f"{ckpt_dir}_{suffix}"
26
+
27
+ try:
28
+ # 包括model、optimizer等状态
29
+ model.save_checkpoint(save_dir=ckpt_dir)
30
+ except:
31
+ return
32
+
33
+ # 删除历史checkpoint
34
+ ckpt_paths = glob(os.path.join(ckpt_dir, "global_*"))
35
+ if len(ckpt_paths) > 2:
36
+ # 按修改时间排序,找到最旧的目录
37
+ oldest_ckpt = sorted(ckpt_paths, key=os.path.getmtime)[0]
38
+ try:
39
+ shutil.rmtree(oldest_ckpt)
40
+ except: ...
41
+
42
+
43
+ def load_ds_checkpoint(
44
+ model: nn.Module,
45
+ load_module_only: bool = False,
46
+ suffix: Optional[str] = None
47
+ ):
48
+ assert isinstance(model, DeepSpeedEngine)
49
+ ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
50
+ if suffix:
51
+ ckpt_dir = f"{ckpt_dir}_{suffix}"
52
+
53
+ # 包括model、optimizer等状态
54
+ if os.path.exists(ckpt_dir):
55
+ model.load_checkpoint(load_dir=ckpt_dir, load_module_only=load_module_only)
56
+
57
+
58
+ def load_ds_checkpoint_for_eval(model: nn.Module):
59
+ ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
60
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(ckpt_dir)
61
+ model.load_state_dict(state_dict)
llm_trainer/eval.py ADDED
@@ -0,0 +1,86 @@
1
+ import time
2
+
3
+ import torch
4
+
5
+ from .generate_utils import generate
6
+ from .checkpoint import load_checkpoint_for_eval
7
+ from .log import get_log_dir
8
+ from .tools import TrainerTools
9
+ from .train_configs import EvalConfig
10
+
11
+
12
+ def _eval_task(
13
+ eval_model: torch.nn.Module,
14
+ eval_config: EvalConfig,
15
+ tag,
16
+ prompt,
17
+ pixel_values,
18
+ max_position_embeddings,
19
+ tokens_per_image,
20
+ device
21
+ ):
22
+ log_dir = get_log_dir()
23
+
24
+ # 当eval_model不是独立model时可以尝试这个
25
+ # if isinstance(eval_model, FSDP):
26
+ # with FSDP.summon_full_params(module=eval_model, writeback=False, recurse=False):
27
+ # gen = generate(
28
+ # eval_model,
29
+ # prompt=prompt,
30
+ # max_position_embeddings=max_position_embeddings,
31
+ # max_new_tokens=max_new_tokens,
32
+ # # temperature=None,
33
+ # # k=None,
34
+ # # p=None,
35
+ # device='cpu',
36
+ # item_callback=lambda item: write_temp(item)
37
+ # )
38
+
39
+ # ---------
40
+ try:
41
+ load_checkpoint_for_eval(eval_model, device=device)
42
+ except:
43
+ return
44
+
45
+ gen_result = generate(
46
+ eval_model,
47
+ prompt=prompt,
48
+ max_position_embeddings=max_position_embeddings,
49
+ max_new_tokens=eval_config.max_new_tokens,
50
+ temperature=eval_config.temperature,
51
+ k=eval_config.top_k,
52
+ p=eval_config.top_p,
53
+ pixel_values=pixel_values,
54
+ tokens_per_image=tokens_per_image,
55
+ device=device
56
+ )
57
+
58
+ with open(f'{log_dir}gen.txt', 'a') as f:
59
+ f.write(f"{tag}, gen->{gen_result}\n")
60
+
61
+
62
+ def submit_gen_task(
63
+ eval_model: torch.nn.Module,
64
+ eval_config: EvalConfig,
65
+ tag,
66
+ prompt,
67
+ pixel_values,
68
+ max_position_embeddings,
69
+ tokens_per_image
70
+ ):
71
+ # 等待1s,防止deepspeed模式下,找不到checkpoint问题
72
+ time.sleep(1)
73
+ eval_model.to(TrainerTools().parallel.device)
74
+ _eval_task(
75
+ eval_model=eval_model,
76
+ eval_config=eval_config,
77
+ tag=tag,
78
+ prompt=prompt,
79
+ pixel_values=pixel_values,
80
+ max_position_embeddings=max_position_embeddings,
81
+ tokens_per_image=tokens_per_image,
82
+ device=TrainerTools().parallel.device
83
+ )
84
+ eval_model.to('cpu')
85
+
86
+ # threading.Thread(target=_eval_task, args=args).start()