project-llm-trainer 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,385 @@
1
+ from typing import Tuple, List, Callable, Optional
2
+ import gc
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torch.nn.functional as F
6
+
7
+ from .base_trainer import BaseTrainer
8
+ from .train_configs import TrainConfig
9
+ from .dataset import RLDataset
10
+ from .loss import GRPOLoss
11
+ from .tools import TrainerTools
12
+ from .generate_utils import batch_generate
13
+ from .log import Logger
14
+ from .utils import (
15
+ autocast,
16
+ left_pad_sequence,
17
+ log_softmax,
18
+ disable_dropout_in_model,
19
+ calc_position_ids
20
+ )
21
+
22
+ from .partition_utils import (
23
+ sync_model_params,
24
+ unwrap_model_for_generation
25
+ )
26
+
27
+ from .checkpoint import (
28
+ save_checkpoint,
29
+ save_steps,
30
+ )
31
+
32
+ class GRPOTrainer(BaseTrainer):
33
+ """
34
+ reward_func(prompt_ids, complete_ids, answer_ids) -> scores
35
+ """
36
+ def __init__(
37
+ self,
38
+ *,
39
+ train_config: TrainConfig,
40
+ reward_func: Callable[[List[torch.Tensor], torch.Tensor, List[Optional[torch.Tensor]]], List[float]],
41
+ eval_prompts: List[str]
42
+ ):
43
+ self.grpo_config = train_config.grpo_config
44
+ super().__init__(
45
+ train_config=train_config,
46
+ eval_prompts=eval_prompts
47
+ )
48
+
49
+ self.reward_func = reward_func
50
+ self.ref_model = self._init_ref_model()
51
+
52
+ def _init_ref_model(self):
53
+ # beta == 0,不需要ref_model
54
+ if self.grpo_config.loss_beta == 0.0:
55
+ return None
56
+
57
+ ref_model = self._new_model(self.train_config)
58
+
59
+ ref_model.eval()
60
+ for param in ref_model.parameters():
61
+ param.requires_grad = False
62
+
63
+ ref_model, _ = TrainerTools().parallel.process(
64
+ model=ref_model,
65
+ optimizer=None,
66
+ kwargs=self._init_ref_model_args(),
67
+ save_instance=False
68
+ )
69
+
70
+ return ref_model
71
+
72
+ def _new_model(self, train_config: TrainConfig):
73
+ model = super()._new_model(train_config)
74
+ disable_dropout_in_model(model)
75
+ return model
76
+
77
+ def _init_loss(self):
78
+ criterion = GRPOLoss(
79
+ beta=self.grpo_config.loss_beta,
80
+ clip_eps_low=self.grpo_config.loss_clip_eps,
81
+ clip_eps_high=self.grpo_config.loss_clip_eps_high,
82
+ delta=self.grpo_config.loss_delta,
83
+ importance_sampling_level=self.grpo_config.loss_importance_sampling_level,
84
+ loss_type=self.grpo_config.loss_type,
85
+ gen_max_new_tokens=self.grpo_config.gen_max_new_tokens
86
+ )
87
+
88
+ return criterion, None
89
+
90
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
91
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
92
+ data_loader_kwargs.update({"collate_fn": lambda x: x})
93
+
94
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
95
+
96
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
97
+ file_path = self.train_config.file_dataset[file_idx]
98
+ return RLDataset(file_path), file_path
99
+
100
+ def _calc_loss(self, inputs, attention_mask, logits, labels): ...
101
+
102
+ def _compute_log_probs(
103
+ self,
104
+ model,
105
+ input_ids,
106
+ attention_mask
107
+ ):
108
+ position_ids = calc_position_ids(attention_mask)
109
+
110
+ # [batch_size, total_seq_len, vocab_size]
111
+ outputs = model(
112
+ input_ids=input_ids,
113
+ attention_mask=attention_mask,
114
+ position_ids=position_ids
115
+ )
116
+
117
+ # [batch_size, total_seq_len - 1, vocab_size]
118
+ logits = outputs['logits'][:, :-1, :]
119
+ input_ids = input_ids[:, 1:]
120
+
121
+ # Compute and return the log probabilities for the selected tokens.
122
+ return log_softmax(logits, input_ids), outputs['aux_loss']
123
+
124
+ def _compute_group_relative_advantages(self, rewards):
125
+ group_size = self.grpo_config.group_size
126
+
127
+ # Reshape rewards to group by prompt
128
+ # [batch, group_size]
129
+ rewards_by_group = rewards.view(-1, group_size)
130
+
131
+ # Compute mean and standard deviation for each prompt group
132
+ # [batch]
133
+ group_means = rewards_by_group.mean(dim=1)
134
+ group_stds = rewards_by_group.std(dim=1)
135
+
136
+ # Expand the means and stds to match the original flat rewards tensor shape
137
+ # [batch*group_size]
138
+ expanded_means = group_means.repeat_interleave(group_size)
139
+ expanded_stds = group_stds.repeat_interleave(group_size)
140
+
141
+ # Normalize rewards to get advantages
142
+ # [batch*group_size]
143
+ advantages = (rewards - expanded_means) / (expanded_stds + 1e-4)
144
+
145
+ # [batch*group_size, 1]
146
+ return advantages.unsqueeze(1) # Add dimension for token-wise operations
147
+
148
+ def _generate_completions(self, model, prompts, group_size: int):
149
+ pad_token_id = TrainerTools().tokenizer.pad
150
+ device = TrainerTools().parallel.device
151
+
152
+ # 左边添加pad,对齐prompt长度
153
+ # [batch, max_prompt_len]
154
+ prompt_ids = left_pad_sequence(prompts, padding_value=pad_token_id)
155
+ prompt_ids = prompt_ids.to(device)
156
+
157
+ prompt_len = prompt_ids.shape[1]
158
+
159
+ # [batch*group_size, max_prompt_len]
160
+ prompt_ids = prompt_ids.repeat_interleave(group_size, 0)
161
+ # [batch*group_size, max_prompt_len]
162
+ prompt_masks = prompt_ids != pad_token_id
163
+
164
+ # [batch*group_size, max_prompt_len+max_gen_len]
165
+ outputs, _ = batch_generate(
166
+ model=model,
167
+ tokens=prompt_ids,
168
+ attention_mask=prompt_masks,
169
+ max_new_tokens=self.grpo_config.gen_max_new_tokens,
170
+ temperature=self.grpo_config.gen_temperature,
171
+ k=self.grpo_config.gen_k,
172
+ p=self.grpo_config.gen_p,
173
+ device=device,
174
+ suppress_tokens=self.grpo_config.gen_suppress_tokens,
175
+ return_logits=False
176
+ )
177
+
178
+ # [batch*group_size, max_gen_len]
179
+ completion_ids = outputs[:, prompt_len:]
180
+ # [batch*group_size, max_gen_len]
181
+ completion_masks = (completion_ids != pad_token_id).int()
182
+
183
+ return prompt_ids, prompt_masks, completion_ids, completion_masks
184
+
185
+ def _generate_rollout_data(self, generate_model, batch_data: List[dict]):
186
+ prompts = [item["prompt"] for item in batch_data]
187
+ answers = [item["answer"] for item in batch_data]
188
+ group_size = self.grpo_config.group_size
189
+
190
+ # 使用no_grad替换inference_mode
191
+ # 修复问题:Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal
192
+ with torch.no_grad():
193
+ # with torch.inference_mode():
194
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(generate_model, prompts, group_size)
195
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
196
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
197
+
198
+ old_log_probs, _ = self._compute_log_probs(generate_model, input_ids, attention_mask)
199
+
200
+ if self.ref_model:
201
+ ref_log_probs, _ = self._compute_log_probs(self.ref_model, input_ids, attention_mask)
202
+ else:
203
+ ref_log_probs = None
204
+
205
+ repeated_prompts = [p for p in prompts for _ in range(group_size)]
206
+ repeated_answers = [a for a in answers for _ in range(group_size)]
207
+
208
+ return {
209
+ 'input_ids': input_ids,
210
+ 'attention_mask': attention_mask,
211
+ 'completion_mask': completion_mask,
212
+ 'old_log_probs': old_log_probs,
213
+ 'ref_log_probs': ref_log_probs,
214
+ 'completion_ids': completion_ids,
215
+ 'repeated_prompts': repeated_prompts,
216
+ 'repeated_answers': repeated_answers,
217
+ }
218
+
219
+ def _maximize_grpo_objective(self, rollout_data):
220
+ device = TrainerTools().parallel.device
221
+
222
+ input_ids = rollout_data['input_ids']
223
+ attention_mask = rollout_data['attention_mask']
224
+ completion_mask = rollout_data['completion_mask']
225
+ old_log_probs = rollout_data['old_log_probs']
226
+ ref_log_probs = rollout_data['ref_log_probs']
227
+ completion_ids = rollout_data['completion_ids']
228
+ repeated_prompts = rollout_data['repeated_prompts']
229
+ repeated_answers = rollout_data['repeated_answers']
230
+
231
+ prompt_len = input_ids.shape[1] - completion_ids.shape[1]
232
+
233
+ # [batch*group_size]
234
+ rewards = torch.tensor(
235
+ self.reward_func(repeated_prompts, completion_ids, repeated_answers),
236
+ dtype=torch.float32,
237
+ device=device
238
+ )
239
+
240
+ # [batch*group_size, 1]
241
+ advantages = self._compute_group_relative_advantages(rewards)
242
+
243
+ # Compute current log probabilities
244
+ log_probs, aux_loss = self._compute_log_probs(self.train_model, input_ids, attention_mask)
245
+
246
+ pad_len = prompt_len - 1
247
+ if pad_len > 0:
248
+ padded_completion_mask = F.pad(completion_mask, (pad_len, 0), 'constant', 0)
249
+ else:
250
+ padded_completion_mask = completion_mask
251
+
252
+ assert padded_completion_mask.shape == log_probs.shape, \
253
+ f"Shape mismatch! Padded completion mask: {padded_completion_mask.shape}, Log probs: {log_probs.shape}"
254
+
255
+ loss = self.criterion(
256
+ log_probs=log_probs,
257
+ old_log_probs=old_log_probs,
258
+ ref_log_probs=ref_log_probs,
259
+ completion_mask=padded_completion_mask,
260
+ advantages=advantages
261
+ )
262
+
263
+ return loss, aux_loss, rewards
264
+
265
+ def train(self):
266
+ global_steps = 0
267
+ skipping_train = False
268
+ aux_loss_coef = self.train_config.loss_config.aux_loss_coef
269
+
270
+ for epoch in range(self.train_config.n_epochs):
271
+ if self.ref_model:
272
+ sync_model_params(
273
+ _from=self.train_model,
274
+ _to=self.ref_model,
275
+ mixup_alpha=self.grpo_config.mixup_alpha
276
+ )
277
+
278
+ file_count = len(self.train_config.file_dataset)
279
+
280
+ for file_idx in range(file_count):
281
+ dataset, file_path = self._create_dataset(file_idx)
282
+
283
+ train_data_loader = TrainerTools().parallel.process_dataloader(
284
+ dataset=dataset,
285
+ data_loader_kwargs=self.data_loader_kwargs,
286
+ sampler_kwargs=self.sampler_kwargs
287
+ )
288
+
289
+ last_ckpt_batch = 0
290
+ batch_count_per_file = len(train_data_loader)
291
+
292
+ TrainerTools().parallel.on_epoch_start(epoch)
293
+ self._on_file_start(epoch, file_path)
294
+
295
+ for batch, batch_data in enumerate(train_data_loader):
296
+ global_steps += 1
297
+ if global_steps < self.last_global_steps:
298
+ skipping_train = True
299
+ continue
300
+
301
+ if skipping_train:
302
+ TrainerTools().parallel.wait('skip train')
303
+ skipping_train = False
304
+
305
+ # start generate
306
+ if TrainerTools().parallel.is_main_process:
307
+ Logger.std_log(f'start generate for batch {batch}/{batch_count_per_file}')
308
+
309
+ # 生成数据
310
+ with unwrap_model_for_generation(self.train_model) as generate_model:
311
+ rollout_data = self._generate_rollout_data(generate_model, batch_data)
312
+ # end generate
313
+
314
+ torch.cuda.empty_cache()
315
+
316
+ try:
317
+ if TrainerTools().parallel.is_main_process:
318
+ Logger.std_log(f'start train for batch {batch}/{batch_count_per_file}')
319
+
320
+ for grpo_step in range(self.grpo_config.grpo_steps):
321
+ with autocast(TrainerTools().parallel.device_type):
322
+ loss, aux_loss, rewards = self._maximize_grpo_objective(rollout_data)
323
+ if aux_loss_coef and aux_loss is not None:
324
+ aux_loss = aux_loss_coef * aux_loss
325
+ else:
326
+ aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
327
+
328
+ total_loss = loss + aux_loss
329
+ self._backward_loss(total_loss)
330
+ self._apply_grad_clipping()
331
+ self._apply_step()
332
+
333
+ loss_accumulation = total_loss.detach().item()
334
+ aux_loss_accumulation = aux_loss.detach().item()
335
+
336
+ avg_loss, avg_aux_loss = self._avg_loss(
337
+ losses=[
338
+ loss_accumulation,
339
+ aux_loss_accumulation
340
+ ],
341
+ gradient_accumulation_steps=1,
342
+ batches_accumulated=1
343
+ )
344
+
345
+ self._log(
346
+ keys={
347
+ 'epoch': epoch,
348
+ 'file': f'{file_idx + 1}/{file_count}',
349
+ 'batch': f'{batch}/{batch_count_per_file}',
350
+ 'grpo_step': grpo_step
351
+ },
352
+ values={
353
+ 'loss': avg_loss,
354
+ 'moe_aux_loss': avg_aux_loss,
355
+ 'rewards': (rewards.sum() / rewards.size(0)).item(),
356
+ }
357
+ )
358
+ except Exception as e:
359
+ self._on_exception(e, epoch, batch)
360
+ finally:
361
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
362
+
363
+ if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
364
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
365
+ last_ckpt_batch = batch
366
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
367
+
368
+ # 一个文件训练结束后,清理内存
369
+ del train_data_loader
370
+ del dataset
371
+ if hasattr(TrainerTools().parallel, '_sampler'):
372
+ TrainerTools().parallel._sampler = None
373
+
374
+ gc.collect()
375
+ torch.cuda.empty_cache()
376
+
377
+ # end epoch
378
+ if not skipping_train:
379
+ save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
380
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
381
+
382
+ TrainerTools().parallel.on_epoch_end(epoch)
383
+ self._on_epoch_end(tag=f'epoch:{epoch}')
384
+
385
+ TrainerTools().parallel.destroy()
llm_trainer/log.py ADDED
@@ -0,0 +1,65 @@
1
+ import time, os, atexit
2
+ from io import TextIOWrapper
3
+ from typing import Optional
4
+
5
+
6
+ def _get_log_dir() -> str:
7
+ log_dir = os.environ.get('LOG_DIR', './log')
8
+ os.makedirs(log_dir, exist_ok=True)
9
+ return log_dir
10
+
11
+
12
+ class Logger:
13
+ def __init__(self, log_file_name = None, log_dir = None):
14
+ self.log_file_name = log_file_name
15
+ self.log_file: Optional[TextIOWrapper] = None
16
+
17
+ if not log_dir:
18
+ self.log_dir = _get_log_dir()
19
+ else:
20
+ os.makedirs(log_dir, exist_ok=True)
21
+ self.log_dir = log_dir
22
+
23
+ self.flush_interval = int(os.environ.get('LOG_FLUSH_INTERVAL', '1'))
24
+ self.log_steps = 0
25
+
26
+ @staticmethod
27
+ def std_log(msg: str):
28
+ log_content = Logger._build_log(msg)
29
+ print(log_content)
30
+
31
+ def log(self, msg: str, log_to_console = True):
32
+ log_content = Logger._build_log(msg)
33
+
34
+ if log_to_console:
35
+ print(log_content)
36
+
37
+ if self._open_file():
38
+ self.log_file.write(f'{log_content}\n')
39
+ if self.log_steps % self.flush_interval == 0:
40
+ self.log_file.flush()
41
+
42
+ self.log_steps += 1
43
+ return self
44
+
45
+ def release(self):
46
+ if self.log_file:
47
+ self.log_file.close()
48
+ self.log_file = None
49
+
50
+ @staticmethod
51
+ def _build_log(msg: str):
52
+ cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
53
+ return f'[{cur_time}] {msg}'
54
+
55
+ def _open_file(self) -> bool:
56
+ if not self.log_file_name:
57
+ return False
58
+
59
+ if self.log_file:
60
+ return True
61
+
62
+ self.log_file = open(os.path.join(self.log_dir, self.log_file_name), 'a', encoding='utf-8')
63
+ atexit.register(self.release)
64
+
65
+ return True