project-llm-trainer 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +6 -0
- llm_trainer/checkpoint.py +161 -0
- llm_trainer/dataset.py +140 -0
- llm_trainer/dcp.py +93 -0
- llm_trainer/dpo_trainer.py +300 -0
- llm_trainer/ds_checkpoint.py +61 -0
- llm_trainer/eval.py +86 -0
- llm_trainer/generate_utils.py +424 -0
- llm_trainer/grpo_trainer.py +393 -0
- llm_trainer/log.py +16 -0
- llm_trainer/loss.py +171 -0
- llm_trainer/parallel.py +146 -0
- llm_trainer/parallel_ddp.py +39 -0
- llm_trainer/parallel_ds.py +45 -0
- llm_trainer/parallel_fsdp.py +115 -0
- llm_trainer/parallel_none.py +28 -0
- llm_trainer/scheduler.py +138 -0
- llm_trainer/sft_trainer.py +39 -0
- llm_trainer/tokenizer.py +166 -0
- llm_trainer/tools.py +102 -0
- llm_trainer/train_configs.py +445 -0
- llm_trainer/trainer.py +569 -0
- llm_trainer/utils.py +262 -0
- project_llm_trainer-0.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.3.data/scripts/ddp_train +12 -0
- project_llm_trainer-0.3.data/scripts/ds_train +12 -0
- project_llm_trainer-0.3.data/scripts/plot_loss +39 -0
- project_llm_trainer-0.3.data/scripts/plot_lr +41 -0
- project_llm_trainer-0.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.3.data/scripts/smart_train +28 -0
- project_llm_trainer-0.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.3.dist-info/RECORD +34 -0
- project_llm_trainer-0.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import copy
|
|
3
|
+
from typing import Tuple, List, Union, Callable, Optional
|
|
4
|
+
import torch
|
|
5
|
+
from torch.utils.data import Dataset
|
|
6
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
7
|
+
import torch.distributed as dist
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
from llm_model import LlmModel
|
|
11
|
+
|
|
12
|
+
from .parallel_ds import DsParallel
|
|
13
|
+
from .trainer import Trainer
|
|
14
|
+
from .train_configs import TrainConfig
|
|
15
|
+
from .dataset import GRPORolloutDataset
|
|
16
|
+
from .loss import GRPOLoss
|
|
17
|
+
from .tools import TrainerTools
|
|
18
|
+
from .generate_utils import batch_generate
|
|
19
|
+
|
|
20
|
+
from .checkpoint import (
|
|
21
|
+
save_checkpoint,
|
|
22
|
+
load_checkpoint_for_eval,
|
|
23
|
+
save_steps,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
class GRPOTrainer(Trainer):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
train_config: TrainConfig,
|
|
31
|
+
reward_func: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], List[float]],
|
|
32
|
+
eval_prompts: List[str],
|
|
33
|
+
eval_image_tags: Optional[List[int]] = None
|
|
34
|
+
):
|
|
35
|
+
super().__init__(
|
|
36
|
+
train_config=train_config,
|
|
37
|
+
eval_prompts=eval_prompts,
|
|
38
|
+
eval_image_tags=eval_image_tags
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
self.reward_func = reward_func
|
|
42
|
+
self.reference_model = self._init_reference_model()
|
|
43
|
+
self.generate_model = self._init_generate_model()
|
|
44
|
+
|
|
45
|
+
# 默认使用torch提供的pad_sequence
|
|
46
|
+
# 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
|
|
47
|
+
self._use_origin_pad_sequence = True
|
|
48
|
+
|
|
49
|
+
# 保存一下train model的checkpoint,方便下面reference_model使用
|
|
50
|
+
save_checkpoint(self.train_model, self.optimizer)
|
|
51
|
+
|
|
52
|
+
def _init_reference_model(self):
|
|
53
|
+
reference_model = LlmModel(self.train_config.model_config)
|
|
54
|
+
|
|
55
|
+
device = 'cpu' # TrainerTools().parallel.device
|
|
56
|
+
reference_model.to(device)
|
|
57
|
+
# load_checkpoint_for_eval(model=reference_model, device=device)
|
|
58
|
+
|
|
59
|
+
reference_model.eval()
|
|
60
|
+
for param in reference_model.parameters():
|
|
61
|
+
param.requires_grad = False
|
|
62
|
+
|
|
63
|
+
return reference_model
|
|
64
|
+
|
|
65
|
+
def _init_generate_model(self):
|
|
66
|
+
return copy.deepcopy(self.reference_model)
|
|
67
|
+
# generate_model = LlmModel(self.train_config.model_config)
|
|
68
|
+
#
|
|
69
|
+
# device = 'cpu' #TrainerTools().parallel.device
|
|
70
|
+
# generate_model.to(device)
|
|
71
|
+
# # load_checkpoint_for_eval(model=generate_model, device=device)
|
|
72
|
+
#
|
|
73
|
+
# generate_model.eval()
|
|
74
|
+
# for param in generate_model.parameters():
|
|
75
|
+
# param.requires_grad = False
|
|
76
|
+
#
|
|
77
|
+
# return generate_model
|
|
78
|
+
|
|
79
|
+
def _init_loss(self):
|
|
80
|
+
criterion = GRPOLoss(
|
|
81
|
+
clip_eps=self.train_config.grpo_config.clip_eps,
|
|
82
|
+
kl_weight=self.train_config.grpo_config.kl_weight
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return criterion, None
|
|
86
|
+
|
|
87
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
|
|
88
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = super()._convert_train_args()
|
|
89
|
+
data_loader_kwargs.update({"collate_fn": lambda x: x})
|
|
90
|
+
|
|
91
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
|
|
92
|
+
|
|
93
|
+
def _create_dataset(self, file_path) -> Dataset:
|
|
94
|
+
return GRPORolloutDataset(file_path)
|
|
95
|
+
|
|
96
|
+
def _calc_loss(self, inputs, attention_mask, logits, labels): ...
|
|
97
|
+
|
|
98
|
+
def _left_pad_sequence(
|
|
99
|
+
self,
|
|
100
|
+
sequences: Union[torch.Tensor, List[torch.Tensor]],
|
|
101
|
+
padding_value: float,
|
|
102
|
+
) -> torch.Tensor:
|
|
103
|
+
if self._use_origin_pad_sequence:
|
|
104
|
+
try:
|
|
105
|
+
return pad_sequence(sequences, batch_first=True, padding_value=padding_value, padding_side='left')
|
|
106
|
+
except:
|
|
107
|
+
self._use_origin_pad_sequence = False
|
|
108
|
+
return self._left_pad_sequence(sequences, padding_value)
|
|
109
|
+
else:
|
|
110
|
+
# 反转每个序列的顺序(如 [1,2,3] → [3,2,1])
|
|
111
|
+
reversed_sequences = [seq.flip(dims=(0,)) for seq in sequences]
|
|
112
|
+
# 使用默认的右侧填充
|
|
113
|
+
padded_reversed = pad_sequence(reversed_sequences, batch_first=True, padding_value=padding_value)
|
|
114
|
+
# 再次反转序列顺序,恢复原始方向(填充在左侧)
|
|
115
|
+
return padded_reversed.flip(dims=(1,))
|
|
116
|
+
|
|
117
|
+
def _selective_log_softmax(self, logits, input_ids):
|
|
118
|
+
# Convert raw logits into log probabilities along the vocabulary axis.
|
|
119
|
+
# [batch_size, seq_len, vocab_size]
|
|
120
|
+
log_probs = F.log_softmax(logits, dim=-1)
|
|
121
|
+
|
|
122
|
+
# Reshape input_ids from (batch_size, seq_len) to (batch_size, seq_len, 1) for gathering.
|
|
123
|
+
# Then, gather the log probability for each token in input_ids.
|
|
124
|
+
selected_log_probs = log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1))
|
|
125
|
+
|
|
126
|
+
# Remove the extra last dimension to get back to shape (batch_size, seq_len).
|
|
127
|
+
return selected_log_probs.squeeze(-1)
|
|
128
|
+
|
|
129
|
+
def _compute_log_probabilities(
|
|
130
|
+
self,
|
|
131
|
+
model,
|
|
132
|
+
input_ids,
|
|
133
|
+
attention_mask,
|
|
134
|
+
logits_to_keep
|
|
135
|
+
):
|
|
136
|
+
# prompt部分[1, 2, 3]
|
|
137
|
+
# 生成模型生成的内容是[4, 5],logits_to_keep=2
|
|
138
|
+
# 则下面的输入 [1, 2, 3, 4, 5], 正常情况下输出是[2, 3, 4, 5, 6]
|
|
139
|
+
# logits_to_keep=2,时输出[5, 6]
|
|
140
|
+
# 但是我们想要的[4, 5]部分
|
|
141
|
+
# 所以需要logits_to_keep=2+1,输出[4, 5, 6]
|
|
142
|
+
|
|
143
|
+
# [batch_size, total_seq_len, vocab_size]
|
|
144
|
+
outputs = model(
|
|
145
|
+
input_ids=input_ids,
|
|
146
|
+
attention_mask=attention_mask,
|
|
147
|
+
logits_to_keep=logits_to_keep + 1
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# [batch_size, total_seq_len - 1, vocab_size]
|
|
151
|
+
logits = outputs['logits'][:, :-1, :]
|
|
152
|
+
|
|
153
|
+
input_ids = input_ids[:, -logits_to_keep:]
|
|
154
|
+
logits = logits[:, -logits_to_keep:, :]
|
|
155
|
+
|
|
156
|
+
# Compute and return the log probabilities for the selected tokens.
|
|
157
|
+
return self._selective_log_softmax(logits, input_ids), outputs['aux_loss']
|
|
158
|
+
|
|
159
|
+
def _compute_group_relative_advantages(self, rewards):
|
|
160
|
+
group_size = self.train_config.grpo_config.group_size
|
|
161
|
+
|
|
162
|
+
# Reshape rewards to group by prompt
|
|
163
|
+
# [batch, group_size]
|
|
164
|
+
rewards_by_group = rewards.view(-1, group_size)
|
|
165
|
+
|
|
166
|
+
# Compute mean and standard deviation for each prompt group
|
|
167
|
+
# [batch]
|
|
168
|
+
group_means = rewards_by_group.mean(dim=1)
|
|
169
|
+
group_stds = rewards_by_group.std(dim=1)
|
|
170
|
+
|
|
171
|
+
# Expand the means and stds to match the original flat rewards tensor shape
|
|
172
|
+
# [batch*group_size]
|
|
173
|
+
expanded_means = group_means.repeat_interleave(group_size)
|
|
174
|
+
expanded_stds = group_stds.repeat_interleave(group_size)
|
|
175
|
+
|
|
176
|
+
# Normalize rewards to get advantages
|
|
177
|
+
# [batch*group_size]
|
|
178
|
+
advantages = (rewards - expanded_means) / (expanded_stds + 1e-4)
|
|
179
|
+
|
|
180
|
+
# [batch*group_size, 1]
|
|
181
|
+
return advantages.unsqueeze(1) # Add dimension for token-wise operations
|
|
182
|
+
|
|
183
|
+
def _generate_completions(self, prompts, group_size: int):
|
|
184
|
+
pad_token_id = TrainerTools().tokenizer.pad
|
|
185
|
+
device = TrainerTools().parallel.device
|
|
186
|
+
|
|
187
|
+
# 左边添加pad,对齐prompt长度
|
|
188
|
+
# [batch, max_prompt_len]
|
|
189
|
+
prompt_ids = self._left_pad_sequence(prompts, padding_value=pad_token_id)
|
|
190
|
+
prompt_ids = prompt_ids.to(device)
|
|
191
|
+
|
|
192
|
+
prompt_len = prompt_ids.shape[1]
|
|
193
|
+
|
|
194
|
+
# [batch*group_size, max_prompt_len]
|
|
195
|
+
prompt_ids = prompt_ids.repeat_interleave(group_size, 0)
|
|
196
|
+
# [batch*group_size, max_prompt_len]
|
|
197
|
+
prompt_masks = prompt_ids != pad_token_id
|
|
198
|
+
|
|
199
|
+
# [batch*group_size, max_prompt_len+max_gen_len]
|
|
200
|
+
outputs: torch.Tensor = batch_generate(
|
|
201
|
+
# model=self.train_model,
|
|
202
|
+
model=self.generate_model,
|
|
203
|
+
tokens=prompt_ids,
|
|
204
|
+
pad_token_id=pad_token_id,
|
|
205
|
+
attention_mask=prompt_masks,
|
|
206
|
+
max_position_embeddings=self.train_config.model_config.max_position_embeddings,
|
|
207
|
+
max_new_tokens=self.train_config.grpo_config.gen_max_new_tokens,
|
|
208
|
+
temperature=self.train_config.grpo_config.gen_temperature,
|
|
209
|
+
k=self.train_config.grpo_config.gen_k,
|
|
210
|
+
p=self.train_config.grpo_config.gen_p,
|
|
211
|
+
device=device,
|
|
212
|
+
suppress_tokens=self.train_config.grpo_config.gen_suppress_tokens
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# [batch*group_size, max_gen_len]
|
|
216
|
+
completion_ids = outputs[:, prompt_len:]
|
|
217
|
+
# [batch*group_size, max_gen_len]
|
|
218
|
+
completion_masks = (completion_ids != pad_token_id).int()
|
|
219
|
+
|
|
220
|
+
return prompt_ids, prompt_masks, completion_ids, completion_masks
|
|
221
|
+
|
|
222
|
+
def _generate_rollout_data(self, batch_data: List[dict]):
|
|
223
|
+
prompts = [item["prompt"] for item in batch_data]
|
|
224
|
+
answers = [item["answer"] for item in batch_data]
|
|
225
|
+
group_size = self.train_config.grpo_config.group_size
|
|
226
|
+
|
|
227
|
+
# 使用no_grad替换inference_mode
|
|
228
|
+
# 修复问题:Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal
|
|
229
|
+
with torch.no_grad():
|
|
230
|
+
# with torch.inference_mode():
|
|
231
|
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(prompts, group_size)
|
|
232
|
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
233
|
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
234
|
+
logits_to_keep = completion_ids.shape[1]
|
|
235
|
+
|
|
236
|
+
# Compute old_log_probs from the current model, with gradients disabled.
|
|
237
|
+
old_log_probs, _ = self._compute_log_probabilities(self.generate_model, input_ids, attention_mask, logits_to_keep)
|
|
238
|
+
|
|
239
|
+
# Compute ref_log_probs from the reference model, which remains static.
|
|
240
|
+
ref_log_probs, _ = self._compute_log_probabilities(self.reference_model, input_ids, attention_mask, logits_to_keep)
|
|
241
|
+
|
|
242
|
+
repeated_prompts = [p for p in prompts for _ in range(group_size)]
|
|
243
|
+
repeated_answers = [a for a in answers for _ in range(group_size)]
|
|
244
|
+
|
|
245
|
+
return {
|
|
246
|
+
'input_ids': input_ids,
|
|
247
|
+
'attention_mask': attention_mask,
|
|
248
|
+
'completion_mask': completion_mask,
|
|
249
|
+
'old_log_probs': old_log_probs,
|
|
250
|
+
'ref_log_probs': ref_log_probs,
|
|
251
|
+
'completion_ids': completion_ids,
|
|
252
|
+
'repeated_prompts': repeated_prompts,
|
|
253
|
+
'repeated_answers': repeated_answers,
|
|
254
|
+
'logits_to_keep': logits_to_keep
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
def _maximize_grpo_objective(self, rollout_data):
|
|
258
|
+
device = TrainerTools().parallel.device
|
|
259
|
+
|
|
260
|
+
input_ids = rollout_data['input_ids']
|
|
261
|
+
attention_mask = rollout_data['attention_mask']
|
|
262
|
+
completion_mask = rollout_data['completion_mask']
|
|
263
|
+
old_log_probs = rollout_data['old_log_probs']
|
|
264
|
+
ref_log_probs = rollout_data['ref_log_probs']
|
|
265
|
+
logits_to_keep = rollout_data['logits_to_keep']
|
|
266
|
+
completion_ids = rollout_data['completion_ids']
|
|
267
|
+
repeated_prompts = rollout_data['repeated_prompts']
|
|
268
|
+
repeated_answers = rollout_data['repeated_answers']
|
|
269
|
+
|
|
270
|
+
# [batch*group_size]
|
|
271
|
+
rewards = torch.tensor(
|
|
272
|
+
self.reward_func(repeated_prompts, completion_ids, repeated_answers),
|
|
273
|
+
dtype=torch.float32,
|
|
274
|
+
device=device
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# [batch*group_size, 1]
|
|
278
|
+
advantages = self._compute_group_relative_advantages(rewards)
|
|
279
|
+
|
|
280
|
+
# Compute current log probabilities
|
|
281
|
+
log_probs, aux_loss = self._compute_log_probabilities(self.train_model, input_ids, attention_mask, logits_to_keep)
|
|
282
|
+
|
|
283
|
+
loss = self.criterion(
|
|
284
|
+
log_probs=log_probs,
|
|
285
|
+
old_log_probs=old_log_probs,
|
|
286
|
+
ref_log_probs=ref_log_probs,
|
|
287
|
+
completion_mask=completion_mask,
|
|
288
|
+
advantages=advantages
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
return loss, aux_loss
|
|
292
|
+
|
|
293
|
+
def train(self):
|
|
294
|
+
global_steps = 0
|
|
295
|
+
skipping_train = False
|
|
296
|
+
device = TrainerTools().parallel.device
|
|
297
|
+
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
298
|
+
|
|
299
|
+
for epoch in range(self.train_config.n_epochs):
|
|
300
|
+
load_checkpoint_for_eval(model=self.reference_model, device=device)
|
|
301
|
+
self.train_model.train()
|
|
302
|
+
file_count = len(self.train_config.file_dataset)
|
|
303
|
+
|
|
304
|
+
for file_idx in range(file_count):
|
|
305
|
+
file_path = self.train_config.file_dataset[file_idx]
|
|
306
|
+
dataset = self._create_dataset(file_path)
|
|
307
|
+
|
|
308
|
+
train_data_loader = TrainerTools().parallel.process_dataloader(
|
|
309
|
+
dataset=dataset,
|
|
310
|
+
data_loader_kwargs=self.data_loader_kwargs,
|
|
311
|
+
sampler_kwargs=self.sampler_kwargs
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
last_ckpt_batch = 0
|
|
315
|
+
batch_count_per_file = len(train_data_loader)
|
|
316
|
+
|
|
317
|
+
TrainerTools().parallel.on_epoch_start(epoch)
|
|
318
|
+
self._on_file_start(epoch, file_path)
|
|
319
|
+
|
|
320
|
+
for batch, batch_data in enumerate(train_data_loader):
|
|
321
|
+
global_steps += 1
|
|
322
|
+
if global_steps < self.last_global_steps:
|
|
323
|
+
skipping_train = True
|
|
324
|
+
continue
|
|
325
|
+
|
|
326
|
+
skipping_train = False
|
|
327
|
+
|
|
328
|
+
# start generate
|
|
329
|
+
# 使用单独的模型生成数据, 原因是在deepspeed并行训练时,使用train_model生成数据会卡死
|
|
330
|
+
self.generate_model.to(TrainerTools().parallel.device)
|
|
331
|
+
self.reference_model.to(TrainerTools().parallel.device)
|
|
332
|
+
|
|
333
|
+
# 保存了train_model checkpoint后,这里保证生成模型使用的参数是最新
|
|
334
|
+
load_checkpoint_for_eval(self.generate_model, TrainerTools().parallel.device)
|
|
335
|
+
# 生成数据
|
|
336
|
+
rollout_data = self._generate_rollout_data(batch_data)
|
|
337
|
+
|
|
338
|
+
# 卸载到cpu上,等待下次使用时再to gpu
|
|
339
|
+
self.generate_model.to('cpu')
|
|
340
|
+
self.reference_model.to('cpu')
|
|
341
|
+
torch.cuda.empty_cache()
|
|
342
|
+
# end generate
|
|
343
|
+
|
|
344
|
+
try:
|
|
345
|
+
for grpo_step in range(self.train_config.grpo_config.grpo_steps):
|
|
346
|
+
with self.ctx:
|
|
347
|
+
loss, aux_loss = self._maximize_grpo_objective(rollout_data)
|
|
348
|
+
if aux_loss_coef and aux_loss:
|
|
349
|
+
loss += aux_loss_coef * aux_loss
|
|
350
|
+
|
|
351
|
+
self._backward_loss(loss)
|
|
352
|
+
|
|
353
|
+
if TrainerTools().parallel.parallel_train:
|
|
354
|
+
dist.all_reduce(loss, dist.ReduceOp.AVG)
|
|
355
|
+
|
|
356
|
+
# ds模式已经集成gradient_clipping
|
|
357
|
+
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
358
|
+
# clip grad
|
|
359
|
+
self.scalar.unscale_(self.optimizer)
|
|
360
|
+
torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
|
|
361
|
+
|
|
362
|
+
self._step()
|
|
363
|
+
|
|
364
|
+
self._log_loss(
|
|
365
|
+
epoch_tag=f'epoch: {epoch}',
|
|
366
|
+
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
367
|
+
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
368
|
+
loss=loss.detach().item()
|
|
369
|
+
)
|
|
370
|
+
except Exception as e:
|
|
371
|
+
self._on_exception(e, epoch, batch)
|
|
372
|
+
finally:
|
|
373
|
+
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
374
|
+
|
|
375
|
+
if (batch - last_ckpt_batch) >= self.train_config.eval_batch_interval:
|
|
376
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
377
|
+
last_ckpt_batch = batch
|
|
378
|
+
self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
|
|
379
|
+
|
|
380
|
+
try:
|
|
381
|
+
del loss
|
|
382
|
+
except UnboundLocalError: ...
|
|
383
|
+
|
|
384
|
+
# end epoch
|
|
385
|
+
if not skipping_train:
|
|
386
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
387
|
+
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
388
|
+
TrainerTools().parallel.on_epoch_end(epoch)
|
|
389
|
+
self._on_epoch_end(tag=f'epoch:{epoch}')
|
|
390
|
+
|
|
391
|
+
# 等待checkpoint保存完成
|
|
392
|
+
time.sleep(10)
|
|
393
|
+
TrainerTools().parallel.destroy()
|
llm_trainer/log.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import time, os
|
|
2
|
+
|
|
3
|
+
def get_log_dir() -> str:
|
|
4
|
+
log_dir = os.environ['LOG_DIR']
|
|
5
|
+
if not os.path.exists(log_dir):
|
|
6
|
+
os.mkdir(log_dir)
|
|
7
|
+
|
|
8
|
+
return f'{log_dir}/' if not log_dir.endswith('/') else log_dir
|
|
9
|
+
|
|
10
|
+
def log(msg: str, log_file=None):
|
|
11
|
+
cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
|
12
|
+
if not log_file:
|
|
13
|
+
print(f'[{cur_time}] {msg}')
|
|
14
|
+
else:
|
|
15
|
+
with open(log_file, 'a') as f:
|
|
16
|
+
f.write(f"[{cur_time}] {msg}")
|
llm_trainer/loss.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from torch import nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LMLoss(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
llm loss
|
|
10
|
+
"""
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
ignore_index: int = -100,
|
|
14
|
+
*,
|
|
15
|
+
critical_tokens: Optional[List[int]] = None,
|
|
16
|
+
critical_alpha: float = 1.0,
|
|
17
|
+
vocab_size: int = 0
|
|
18
|
+
):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.ignore_index = ignore_index
|
|
21
|
+
self.critical_tokens = critical_tokens
|
|
22
|
+
self.critical_alpha = critical_alpha
|
|
23
|
+
|
|
24
|
+
if critical_tokens and vocab_size > 0:
|
|
25
|
+
self.register_buffer('weights', torch.ones(vocab_size))
|
|
26
|
+
# 为关键token设置权重
|
|
27
|
+
self.weights[self.critical_tokens] = critical_alpha
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
31
|
+
# logits shape (batch, seq_len, vocab_size)
|
|
32
|
+
# labels shape (batch, seq_len)
|
|
33
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
34
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
35
|
+
|
|
36
|
+
logits = shift_logits.reshape(-1, logits.shape[-1])
|
|
37
|
+
targets = shift_labels.reshape(-1)
|
|
38
|
+
|
|
39
|
+
ce_loss = F.cross_entropy(
|
|
40
|
+
logits,
|
|
41
|
+
targets,
|
|
42
|
+
ignore_index=self.ignore_index,
|
|
43
|
+
weight=self.weights.to(logits.device, dtype=logits.dtype) if self.critical_tokens else None
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# 添加额外惩罚项(可选)
|
|
47
|
+
# if self.critical_tokens:
|
|
48
|
+
# crit_mask = torch.isin(targets, torch.tensor(self.critical_tokens).to(targets.device))
|
|
49
|
+
# crit_logits = logits[crit_mask]
|
|
50
|
+
# crit_targets = targets[crit_mask]
|
|
51
|
+
# extra_loss = F.cross_entropy(crit_logits, crit_targets, ignore_index=self.ignore_index)
|
|
52
|
+
# return ce_loss + extra_loss * (self.critical_alpha - 1) # 增强惩罚
|
|
53
|
+
|
|
54
|
+
return ce_loss
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class KDLoss(nn.Module):
|
|
58
|
+
"""
|
|
59
|
+
Language Model Knowledge Distillation Loss
|
|
60
|
+
https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/loss.py#L266
|
|
61
|
+
"""
|
|
62
|
+
def __init__(self, ignore_index: int = -100):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.ignore_index = ignore_index
|
|
65
|
+
|
|
66
|
+
def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
|
|
68
|
+
inf_mask = torch.isinf(logits)
|
|
69
|
+
|
|
70
|
+
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
|
71
|
+
prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
|
|
72
|
+
|
|
73
|
+
x = torch.sum(prod_probs, dim=-1).view(-1)
|
|
74
|
+
mask = (labels != self.ignore_index).int()
|
|
75
|
+
|
|
76
|
+
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
|
|
77
|
+
|
|
78
|
+
return distil_loss
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class DPOLoss(nn.Module):
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
beta: float,
|
|
85
|
+
label_smoothing: float = 0.0,
|
|
86
|
+
ipo: bool = False
|
|
87
|
+
):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.beta = beta
|
|
90
|
+
self.label_smoothing = label_smoothing
|
|
91
|
+
self.ipo = ipo
|
|
92
|
+
|
|
93
|
+
def forward(
|
|
94
|
+
self,
|
|
95
|
+
policy_logps: torch.Tensor,
|
|
96
|
+
reference_logps: torch.Tensor,
|
|
97
|
+
) -> torch.Tensor:
|
|
98
|
+
batch_size = reference_logps.shape[0]
|
|
99
|
+
ref_chosen_probs = reference_logps[:batch_size//2]
|
|
100
|
+
ref_reject_probs = reference_logps[batch_size//2:]
|
|
101
|
+
policy_chosen_probs = policy_logps[:batch_size//2]
|
|
102
|
+
policy_reject_probs = policy_logps[batch_size//2:]
|
|
103
|
+
|
|
104
|
+
pi_logratios = policy_chosen_probs - policy_reject_probs
|
|
105
|
+
ref_logratios = ref_chosen_probs - ref_reject_probs
|
|
106
|
+
logits = pi_logratios - ref_logratios
|
|
107
|
+
|
|
108
|
+
if self.ipo:
|
|
109
|
+
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
|
|
110
|
+
else:
|
|
111
|
+
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
|
|
112
|
+
losses = (
|
|
113
|
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
|
114
|
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
loss = losses.mean()
|
|
118
|
+
# chosen_rewards = self.beta * (policy_chosen_probs - ref_chosen_probs).detach()
|
|
119
|
+
# rejected_rewards = self.beta * (policy_reject_probs - ref_reject_probs).detach()
|
|
120
|
+
|
|
121
|
+
return loss
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class GRPOLoss(nn.Module):
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
clip_eps: float,
|
|
128
|
+
kl_weight: float
|
|
129
|
+
):
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.clip_eps = clip_eps
|
|
132
|
+
self.kl_weight = kl_weight
|
|
133
|
+
|
|
134
|
+
def forward(
|
|
135
|
+
self,
|
|
136
|
+
log_probs: torch.Tensor,
|
|
137
|
+
old_log_probs: torch.Tensor,
|
|
138
|
+
ref_log_probs: torch.Tensor,
|
|
139
|
+
completion_mask: torch.Tensor,
|
|
140
|
+
advantages: torch.Tensor
|
|
141
|
+
) -> torch.Tensor:
|
|
142
|
+
# Compute policy ratio
|
|
143
|
+
ratio = torch.exp(log_probs - old_log_probs)
|
|
144
|
+
|
|
145
|
+
# Compute surrogate loss with clipping
|
|
146
|
+
surrogate1 = ratio * advantages
|
|
147
|
+
surrogate2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
|
148
|
+
surrogate_loss = torch.min(surrogate1, surrogate2)
|
|
149
|
+
|
|
150
|
+
# Compute KL divergence penalty
|
|
151
|
+
kl_div = torch.exp(ref_log_probs - log_probs) - (ref_log_probs - log_probs) - 1
|
|
152
|
+
|
|
153
|
+
# Combine losses
|
|
154
|
+
per_token_loss = surrogate_loss - self.kl_weight * kl_div
|
|
155
|
+
loss = -((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
|
156
|
+
|
|
157
|
+
return loss
|
|
158
|
+
|
|
159
|
+
# kl = self._approx_kl_divergence(
|
|
160
|
+
# log_probs=log_probs,
|
|
161
|
+
# ref_log_probs=ref_log_probs,
|
|
162
|
+
# mask=mask,
|
|
163
|
+
# )
|
|
164
|
+
#
|
|
165
|
+
# ratio = (log_probs - old_log_probs).exp()
|
|
166
|
+
# surr1 = ratio * advantages
|
|
167
|
+
# surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
|
168
|
+
# loss = -torch.min(surr1, surr2) + self.kl_weight * kl
|
|
169
|
+
#
|
|
170
|
+
# loss = self._masked_mean(loss, mask, dim=-1).mean()
|
|
171
|
+
# return loss, kl.mean()
|