project-llm-trainer 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,410 @@
1
+ from typing import Tuple, List, Callable, Optional
2
+ import gc
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torch.nn.functional as F
6
+ from itertools import islice
7
+
8
+ from .base_trainer import BaseTrainer
9
+ from .train_configs import TrainConfig
10
+ from .dataset import RLDataset
11
+ from .loss import GRPOLoss
12
+ from .tools import TrainerTools
13
+ from .generate_utils import batch_generate
14
+ from .log import Logger
15
+ from .utils import (
16
+ autocast,
17
+ left_pad_sequence,
18
+ log_softmax,
19
+ disable_dropout_in_model,
20
+ calc_position_ids
21
+ )
22
+
23
+ from .partition_utils import (
24
+ sync_model_params,
25
+ unwrap_model_for_generation
26
+ )
27
+
28
+ from .checkpoint import (
29
+ save_checkpoint,
30
+ save_steps,
31
+ )
32
+
33
+ class GRPOTrainer(BaseTrainer):
34
+ """
35
+ reward_func(prompt_ids, complete_ids, answer_ids) -> scores
36
+ """
37
+ def __init__(
38
+ self,
39
+ *,
40
+ train_config: TrainConfig,
41
+ reward_func: Callable[[List[torch.Tensor], torch.Tensor, List[Optional[torch.Tensor]]], List[float]],
42
+ eval_prompts: List[str]
43
+ ):
44
+ self.grpo_config = train_config.grpo_config
45
+ super().__init__(
46
+ train_config=train_config,
47
+ eval_prompts=eval_prompts
48
+ )
49
+
50
+ self.reward_func = reward_func
51
+ self.ref_model = self._init_ref_model()
52
+
53
+ def _init_ref_model(self):
54
+ # beta == 0,不需要ref_model
55
+ if self.grpo_config.loss_beta == 0.0:
56
+ return None
57
+
58
+ ref_model = self._new_model(self.train_config)
59
+
60
+ ref_model.eval()
61
+ for param in ref_model.parameters():
62
+ param.requires_grad = False
63
+
64
+ ref_model, _ = TrainerTools().parallel.process(
65
+ model=ref_model,
66
+ optimizer=None,
67
+ kwargs=self._init_ref_model_args(),
68
+ save_instance=False
69
+ )
70
+
71
+ return ref_model
72
+
73
+ def _new_model(self, train_config: TrainConfig):
74
+ model = super()._new_model(train_config)
75
+ disable_dropout_in_model(model)
76
+ return model
77
+
78
+ def _init_loss(self):
79
+ criterion = GRPOLoss(
80
+ beta=self.grpo_config.loss_beta,
81
+ clip_eps_low=self.grpo_config.loss_clip_eps,
82
+ clip_eps_high=self.grpo_config.loss_clip_eps_high,
83
+ delta=self.grpo_config.loss_delta,
84
+ importance_sampling_level=self.grpo_config.loss_importance_sampling_level,
85
+ loss_type=self.grpo_config.loss_type
86
+ )
87
+
88
+ return criterion, None
89
+
90
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
91
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
92
+ data_loader_kwargs.update({"collate_fn": lambda x: x})
93
+
94
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
95
+
96
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
97
+ file_path = self.train_config.file_dataset[file_idx]
98
+ return RLDataset(file_path), file_path
99
+
100
+ def _calc_loss(self, inputs, attention_mask, logits, labels): ...
101
+
102
+ def _compute_log_probs(
103
+ self,
104
+ model,
105
+ input_ids,
106
+ attention_mask
107
+ ):
108
+ position_ids = calc_position_ids(attention_mask)
109
+
110
+ # [batch_size, total_seq_len, vocab_size]
111
+ outputs = model(
112
+ input_ids=input_ids,
113
+ attention_mask=attention_mask,
114
+ position_ids=position_ids
115
+ )
116
+
117
+ # [batch_size, total_seq_len - 1, vocab_size]
118
+ logits = outputs['logits'][:, :-1, :]
119
+ input_ids = input_ids[:, 1:]
120
+
121
+ # Compute and return the log probabilities for the selected tokens.
122
+ return log_softmax(logits, input_ids), outputs['aux_loss']
123
+
124
+ def _compute_group_relative_advantages(self, rewards):
125
+ group_size = self.grpo_config.group_size
126
+
127
+ # Reshape rewards to group by prompt
128
+ # [batch, group_size]
129
+ rewards_by_group = rewards.view(-1, group_size)
130
+
131
+ # Compute mean and standard deviation for each prompt group
132
+ # [batch]
133
+ group_means = rewards_by_group.mean(dim=1)
134
+ group_stds = rewards_by_group.std(dim=1)
135
+
136
+ # Expand the means and stds to match the original flat rewards tensor shape
137
+ # [batch*group_size]
138
+ expanded_means = group_means.repeat_interleave(group_size)
139
+ expanded_stds = group_stds.repeat_interleave(group_size)
140
+
141
+ # Normalize rewards to get advantages
142
+ # [batch*group_size]
143
+ advantages = (rewards - expanded_means) / (expanded_stds + 1e-4)
144
+
145
+ # [batch*group_size, 1]
146
+ return advantages.unsqueeze(1) # Add dimension for token-wise operations
147
+
148
+ def _generate_completions(self, model, prompts, group_size: int):
149
+ pad_token_id = TrainerTools().tokenizer.pad
150
+ device = TrainerTools().parallel.device
151
+
152
+ # 左边添加pad,对齐prompt长度
153
+ # [batch, max_prompt_len]
154
+ prompt_ids = left_pad_sequence(prompts, padding_value=pad_token_id)
155
+ prompt_ids = prompt_ids.to(device)
156
+
157
+ prompt_len = prompt_ids.shape[1]
158
+
159
+ # [batch*group_size, max_prompt_len]
160
+ prompt_ids = prompt_ids.repeat_interleave(group_size, 0)
161
+ # [batch*group_size, max_prompt_len]
162
+ prompt_masks = prompt_ids != pad_token_id
163
+
164
+ max_new_tokens = self.grpo_config.gen_max_seq_len - prompt_len
165
+ if max_new_tokens <= 0:
166
+ raise ValueError(
167
+ f"Prompt length ({prompt_len}) >= gen_max_seq_len ({self.grpo_config.gen_max_seq_len}). "
168
+ f"Cannot generate any tokens. Please increase gen_max_seq_len or reduce dataset_block_size."
169
+ )
170
+
171
+ # [batch*group_size, max_prompt_len+max_gen_len]
172
+ outputs, _ = batch_generate(
173
+ model=model,
174
+ tokens=prompt_ids,
175
+ attention_mask=prompt_masks,
176
+ max_new_tokens=max_new_tokens,
177
+ temperature=self.grpo_config.gen_temperature,
178
+ k=self.grpo_config.gen_k,
179
+ p=self.grpo_config.gen_p,
180
+ device=device,
181
+ suppress_tokens=self.grpo_config.gen_suppress_tokens,
182
+ return_logits=False
183
+ )
184
+
185
+ # [batch*group_size, max_gen_len]
186
+ completion_ids = outputs[:, prompt_len:]
187
+ # [batch*group_size, max_gen_len]
188
+ completion_masks = (completion_ids != pad_token_id).int()
189
+
190
+ return prompt_ids, prompt_masks, completion_ids, completion_masks
191
+
192
+ def _generate_rollout_data(self, generate_model, batch_data: List[dict]):
193
+ prompts = [item["prompt"] for item in batch_data]
194
+ answers = [item["answer"] for item in batch_data]
195
+ group_size = self.grpo_config.group_size
196
+
197
+ # 使用no_grad替换inference_mode
198
+ # 修复问题:Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal
199
+ with torch.no_grad():
200
+ # with torch.inference_mode():
201
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(generate_model, prompts, group_size)
202
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
203
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
204
+
205
+ old_log_probs, _ = self._compute_log_probs(generate_model, input_ids, attention_mask)
206
+
207
+ if self.ref_model:
208
+ ref_log_probs, _ = self._compute_log_probs(self.ref_model, input_ids, attention_mask)
209
+ else:
210
+ ref_log_probs = None
211
+
212
+ repeated_prompts = [p for p in prompts for _ in range(group_size)]
213
+ repeated_answers = [a for a in answers for _ in range(group_size)]
214
+
215
+ return {
216
+ 'input_ids': input_ids,
217
+ 'attention_mask': attention_mask,
218
+ 'completion_mask': completion_mask,
219
+ 'old_log_probs': old_log_probs,
220
+ 'ref_log_probs': ref_log_probs,
221
+ 'completion_ids': completion_ids,
222
+ 'repeated_prompts': repeated_prompts,
223
+ 'repeated_answers': repeated_answers,
224
+ }
225
+
226
+ def _maximize_grpo_objective(self, rollout_data):
227
+ device = TrainerTools().parallel.device
228
+
229
+ input_ids = rollout_data['input_ids']
230
+ attention_mask = rollout_data['attention_mask']
231
+ completion_mask = rollout_data['completion_mask']
232
+ old_log_probs = rollout_data['old_log_probs']
233
+ ref_log_probs = rollout_data['ref_log_probs']
234
+ completion_ids = rollout_data['completion_ids']
235
+ repeated_prompts = rollout_data['repeated_prompts']
236
+ repeated_answers = rollout_data['repeated_answers']
237
+
238
+ prompt_len = input_ids.shape[1] - completion_ids.shape[1]
239
+
240
+ # [batch*group_size]
241
+ rewards = torch.tensor(
242
+ self.reward_func(repeated_prompts, completion_ids, repeated_answers),
243
+ dtype=torch.float32,
244
+ device=device
245
+ )
246
+
247
+ # [batch*group_size, 1]
248
+ advantages = self._compute_group_relative_advantages(rewards)
249
+
250
+ # Compute current log probabilities
251
+ log_probs, aux_loss = self._compute_log_probs(self.train_model, input_ids, attention_mask)
252
+
253
+ pad_len = prompt_len - 1
254
+ if pad_len > 0:
255
+ padded_completion_mask = F.pad(completion_mask, (pad_len, 0), 'constant', 0)
256
+ else:
257
+ padded_completion_mask = completion_mask
258
+
259
+ assert padded_completion_mask.shape == log_probs.shape, \
260
+ f"Shape mismatch! Padded completion mask: {padded_completion_mask.shape}, Log probs: {log_probs.shape}"
261
+
262
+ loss = self.criterion(
263
+ log_probs=log_probs,
264
+ old_log_probs=old_log_probs,
265
+ ref_log_probs=ref_log_probs,
266
+ completion_mask=padded_completion_mask,
267
+ advantages=advantages,
268
+ completion_len=completion_ids.shape[1]
269
+ )
270
+
271
+ return loss, aux_loss, rewards
272
+
273
+ def train(self):
274
+ aux_loss_coef = self.train_config.loss_config.aux_loss_coef
275
+
276
+ for epoch in range(self.resume_epoch, self.train_config.n_epochs):
277
+ if self.ref_model:
278
+ sync_model_params(
279
+ _from=self.train_model,
280
+ _to=self.ref_model,
281
+ mixup_alpha=self.grpo_config.mixup_alpha
282
+ )
283
+
284
+ file_count = len(self.train_config.file_dataset)
285
+ start_file_idx = self.resume_file_idx if epoch == self.resume_epoch else 0
286
+
287
+ for file_idx in range(start_file_idx, file_count):
288
+ dataset, file_path = self._create_dataset(file_idx)
289
+
290
+ train_data_loader = TrainerTools().parallel.process_dataloader(
291
+ dataset=dataset,
292
+ data_loader_kwargs=self.data_loader_kwargs,
293
+ sampler_kwargs=self.sampler_kwargs
294
+ )
295
+
296
+ last_ckpt_batch = 0
297
+ batch_count_per_file = len(train_data_loader)
298
+
299
+ TrainerTools().parallel.on_epoch_start(epoch)
300
+ self._on_file_start(epoch, file_path)
301
+
302
+ skip_batches = 0
303
+ if epoch == self.resume_epoch and file_idx == self.resume_file_idx:
304
+ skip_batches = self.resume_batch_idx
305
+ if skip_batches > 0 and TrainerTools().parallel.is_main_process:
306
+ Logger.std_log(f"Fast forwarding {skip_batches} batches in {file_path}...")
307
+
308
+ data_iterator = iter(train_data_loader)
309
+ if skip_batches > 0:
310
+ data_iterator = islice(data_iterator, skip_batches, None)
311
+ last_ckpt_batch = skip_batches
312
+
313
+ for batch, batch_data in enumerate(data_iterator):
314
+ batch = skip_batches + batch
315
+
316
+ # start generate
317
+ if TrainerTools().parallel.is_main_process:
318
+ Logger.std_log(f'start generate for batch {batch + 1}/{batch_count_per_file}')
319
+
320
+ # 生成数据
321
+ with unwrap_model_for_generation(self.train_model) as generate_model:
322
+ rollout_data = self._generate_rollout_data(generate_model, batch_data)
323
+ # end generate
324
+
325
+ torch.cuda.empty_cache()
326
+
327
+ try:
328
+ if TrainerTools().parallel.is_main_process:
329
+ Logger.std_log(f'start train for batch {batch + 1}/{batch_count_per_file}')
330
+
331
+ for grpo_step in range(self.grpo_config.grpo_steps):
332
+ with autocast(TrainerTools().parallel.device_type):
333
+ loss, aux_loss, rewards = self._maximize_grpo_objective(rollout_data)
334
+ if aux_loss_coef and aux_loss is not None:
335
+ aux_loss = aux_loss_coef * aux_loss
336
+ else:
337
+ aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
338
+
339
+ total_loss = loss + aux_loss
340
+ self._backward_loss(total_loss)
341
+ self._apply_grad_clipping()
342
+ self._apply_step()
343
+
344
+ loss_accumulation = total_loss.detach().item()
345
+ aux_loss_accumulation = aux_loss.detach().item()
346
+
347
+ avg_loss, avg_aux_loss = self._avg_loss(
348
+ losses=[
349
+ loss_accumulation,
350
+ aux_loss_accumulation
351
+ ],
352
+ gradient_accumulation_steps=1,
353
+ batches_accumulated=1
354
+ )
355
+
356
+ self._log(
357
+ keys={
358
+ 'epoch': epoch,
359
+ 'file': f'{file_idx + 1}/{file_count}',
360
+ 'batch': f'{batch + 1}/{batch_count_per_file}',
361
+ 'grpo_step': grpo_step
362
+ },
363
+ values={
364
+ 'loss': avg_loss,
365
+ 'moe_aux_loss': avg_aux_loss,
366
+ 'rewards': (rewards.sum() / rewards.size(0)).item(),
367
+ }
368
+ )
369
+
370
+ if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
371
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
372
+ save_steps(
373
+ epoch=epoch,
374
+ file_idx=file_idx,
375
+ batch_idx=batch + 1,
376
+ lr_scheduler=self.lr_scheduler
377
+ )
378
+
379
+ last_ckpt_batch = batch
380
+ self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
381
+ except Exception as e:
382
+ self._on_exception(e, epoch, batch)
383
+
384
+ # 一个文件训练结束后,清理内存
385
+ del train_data_loader
386
+ del dataset
387
+ if hasattr(TrainerTools().parallel, '_sampler'):
388
+ TrainerTools().parallel._sampler = None
389
+
390
+ gc.collect()
391
+ torch.cuda.empty_cache()
392
+
393
+ # end epoch
394
+
395
+ # reset resume state
396
+ self.resume_file_idx = 0
397
+ self.resume_batch_idx = 0
398
+
399
+ save_checkpoint(model=self.train_model, optimizer=self.optimizer)
400
+ save_steps(
401
+ epoch=epoch + 1,
402
+ file_idx=0,
403
+ batch_idx=0,
404
+ lr_scheduler=self.lr_scheduler
405
+ )
406
+
407
+ TrainerTools().parallel.on_epoch_end(epoch)
408
+ self._on_epoch_end(tag=f'epoch:{epoch}')
409
+
410
+ TrainerTools().parallel.destroy()
llm_trainer/log.py ADDED
@@ -0,0 +1,65 @@
1
+ import time, os, atexit
2
+ from io import TextIOWrapper
3
+ from typing import Optional
4
+
5
+
6
+ def _get_log_dir() -> str:
7
+ log_dir = os.environ.get('LOG_DIR', './log')
8
+ os.makedirs(log_dir, exist_ok=True)
9
+ return log_dir
10
+
11
+
12
+ class Logger:
13
+ def __init__(self, log_file_name = None, log_dir = None):
14
+ self.log_file_name = log_file_name
15
+ self.log_file: Optional[TextIOWrapper] = None
16
+
17
+ if not log_dir:
18
+ self.log_dir = _get_log_dir()
19
+ else:
20
+ os.makedirs(log_dir, exist_ok=True)
21
+ self.log_dir = log_dir
22
+
23
+ self.flush_interval = int(os.environ.get('LOG_FLUSH_INTERVAL', '1'))
24
+ self.log_steps = 0
25
+
26
+ @staticmethod
27
+ def std_log(msg: str):
28
+ log_content = Logger._build_log(msg)
29
+ print(log_content)
30
+
31
+ def log(self, msg: str, log_to_console = True):
32
+ log_content = Logger._build_log(msg)
33
+
34
+ if log_to_console:
35
+ print(log_content)
36
+
37
+ if self._open_file():
38
+ self.log_file.write(f'{log_content}\n')
39
+ if self.log_steps % self.flush_interval == 0:
40
+ self.log_file.flush()
41
+
42
+ self.log_steps += 1
43
+ return self
44
+
45
+ def release(self):
46
+ if self.log_file:
47
+ self.log_file.close()
48
+ self.log_file = None
49
+
50
+ @staticmethod
51
+ def _build_log(msg: str):
52
+ cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
53
+ return f'[{cur_time}] {msg}'
54
+
55
+ def _open_file(self) -> bool:
56
+ if not self.log_file_name:
57
+ return False
58
+
59
+ if self.log_file:
60
+ return True
61
+
62
+ self.log_file = open(os.path.join(self.log_dir, self.log_file_name), 'a', encoding='utf-8')
63
+ atexit.register(self.release)
64
+
65
+ return True