project-llm-trainer 0.4.15__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

Files changed (30) hide show
  1. llm_trainer/checkpoint.py +0 -50
  2. llm_trainer/dpo_trainer.py +6 -3
  3. llm_trainer/eval.py +3 -30
  4. llm_trainer/generate_utils.py +9 -74
  5. llm_trainer/grpo_trainer.py +27 -28
  6. llm_trainer/loss.py +1 -1
  7. llm_trainer/partition_utils.py +146 -0
  8. llm_trainer/tokenizer.py +10 -10
  9. llm_trainer/tools.py +0 -2
  10. llm_trainer/train_configs.py +5 -25
  11. llm_trainer/trainer.py +28 -67
  12. llm_trainer/utils.py +0 -1
  13. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/METADATA +1 -1
  14. project_llm_trainer-0.5.1.dist-info/RECORD +33 -0
  15. llm_trainer/dcp.py +0 -93
  16. llm_trainer/ds_model_params.py +0 -72
  17. llm_trainer/fsdp_checkpoint.py +0 -52
  18. llm_trainer/fsdp_model_params.py +0 -39
  19. llm_trainer/model_params.py +0 -28
  20. llm_trainer/parallel_fsdp.py +0 -121
  21. project_llm_trainer-0.4.15.dist-info/RECORD +0 -38
  22. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/calc_intermediate_size +0 -0
  23. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ddp_train +0 -0
  24. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ds_train +0 -0
  25. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_loss +0 -0
  26. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_lr +0 -0
  27. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/py_train +0 -0
  28. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/smart_train +0 -0
  29. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/WHEEL +0 -0
  30. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py CHANGED
@@ -6,35 +6,11 @@ from torch.optim import Optimizer
6
6
  from torch.nn.parallel import DistributedDataParallel as DDP
7
7
 
8
8
  from .parallel_ds import DsParallel
9
- from .parallel_fsdp import FsdpParallel
10
- from .parallel_ddp import DdpParallel
11
9
  from .scheduler import LRScheduler
12
10
  from .tools import TrainerTools
13
11
 
14
- try:
15
- from .dcp import save_dcp, load_dcp, convert_dcp_to_pth
16
- except:
17
- os.environ['ENABLE_DCP'] = "0"
18
-
19
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
20
-
21
- # https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
22
-
23
12
  DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
24
13
 
25
-
26
- def _can_use_dcp(model: nn.Module) -> bool:
27
- if os.environ.get('ENABLE_DCP', '1') != '1':
28
- return False
29
-
30
- # 如果是fsdp或者ddp,才能使用dcp保存
31
- if (isinstance(TrainerTools().parallel, FsdpParallel)
32
- or isinstance(TrainerTools().parallel, DdpParallel)):
33
- return True
34
-
35
- return False
36
-
37
-
38
14
  def save_checkpoint(
39
15
  model: nn.Module,
40
16
  optimizer: Optional[Optimizer] = None,
@@ -43,11 +19,6 @@ def save_checkpoint(
43
19
  if isinstance(TrainerTools().parallel, DsParallel):
44
20
  from .ds_checkpoint import save_ds_checkpoint
45
21
  save_ds_checkpoint(model, suffix)
46
- elif _can_use_dcp(model):
47
- save_dcp(model, optimizer, suffix)
48
- elif isinstance(model, FSDP):
49
- from .fsdp_checkpoint import save_fsdp_checkpoint
50
- save_fsdp_checkpoint(model, optimizer, suffix)
51
22
  else:
52
23
  if TrainerTools().parallel.is_main_process:
53
24
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
@@ -73,11 +44,6 @@ def load_checkpoint(
73
44
  if isinstance(TrainerTools().parallel, DsParallel):
74
45
  from .ds_checkpoint import load_ds_checkpoint
75
46
  load_ds_checkpoint(model, load_module_only=load_module_only, suffix=suffix)
76
- elif _can_use_dcp(model):
77
- load_dcp(model, optimizer, suffix)
78
- elif isinstance(model, FSDP):
79
- from .fsdp_checkpoint import load_fsdp_checkpoint
80
- load_fsdp_checkpoint(model, optimizer, device, suffix)
81
47
  else:
82
48
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
83
49
  if suffix:
@@ -99,22 +65,6 @@ def load_checkpoint_for_eval(
99
65
  if isinstance(TrainerTools().parallel, DsParallel):
100
66
  from .ds_checkpoint import load_ds_checkpoint_for_eval
101
67
  load_ds_checkpoint_for_eval(model)
102
- elif _can_use_dcp(model):
103
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
104
-
105
- # load_dcp方式在cpu上会报错,所以改为先将ckpt转换为pth,然后再加载pth
106
- # load_dcp(model, optimizer)
107
- pth_name = os.environ.get('EVAL_CHECKPOINT_NAME', checkpoint_name)
108
- if suffix:
109
- pth_name = f'{pth_name}_{suffix}'
110
-
111
- convert_dcp_to_pth(pth_name)
112
-
113
- if os.path.exists(pth_name):
114
- ckpt = torch.load(pth_name, map_location=device, weights_only=True)
115
- model.load_state_dict(ckpt['app']['model_state_dict'])
116
- # 使用完删除
117
- os.remove(pth_name)
118
68
  else:
119
69
  load_checkpoint(model, None, device, suffix=suffix)
120
70
 
@@ -12,7 +12,7 @@ from .dataset import DPODataset
12
12
  from .loss import DPOLoss
13
13
  from .tools import TrainerTools
14
14
  from .utils import get_dpo_collate_fn
15
- from .model_params import copy_model_params
15
+ from .partition_utils import sync_model_params
16
16
 
17
17
  from .checkpoint import (
18
18
  save_checkpoint,
@@ -38,7 +38,6 @@ class DPOTrainer(Trainer):
38
38
 
39
39
  def _init_reference_model(self):
40
40
  reference_model = self._new_model(self.train_config)
41
- copy_model_params(_from=self.train_model, _to=reference_model)
42
41
 
43
42
  reference_model, _ = TrainerTools().parallel.process(
44
43
  model=reference_model,
@@ -51,6 +50,11 @@ class DPOTrainer(Trainer):
51
50
  for param in reference_model.parameters():
52
51
  param.requires_grad = False
53
52
 
53
+ sync_model_params(
54
+ _from=self.train_model,
55
+ _to=reference_model
56
+ )
57
+
54
58
  return reference_model
55
59
 
56
60
  def _init_loss(self):
@@ -210,7 +214,6 @@ class DPOTrainer(Trainer):
210
214
  if need_update_grad:
211
215
  loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
212
216
 
213
- # todo check all_reduce??
214
217
  if TrainerTools().parallel.parallel_train:
215
218
  dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
216
219
 
llm_trainer/eval.py CHANGED
@@ -5,16 +5,14 @@ from .log import get_log_dir
5
5
  from .tools import TrainerTools
6
6
  from .train_configs import EvalConfig
7
7
 
8
-
9
- def _eval_task(
8
+ def submit_gen_task(
10
9
  eval_model: torch.nn.Module,
11
10
  eval_config: EvalConfig,
12
11
  tag,
13
12
  prompt,
14
13
  pixel_values,
15
14
  max_position_embeddings,
16
- tokens_per_image,
17
- device
15
+ tokens_per_image
18
16
  ):
19
17
  log_dir = get_log_dir()
20
18
 
@@ -28,33 +26,8 @@ def _eval_task(
28
26
  p=eval_config.top_p,
29
27
  pixel_values=pixel_values,
30
28
  tokens_per_image=tokens_per_image,
31
- device=device
29
+ device=TrainerTools().parallel.device
32
30
  )
33
31
 
34
32
  with open(f'{log_dir}gen.txt', 'a') as f:
35
33
  f.write(f"{tag}, gen->{gen_result}\n")
36
-
37
-
38
- def submit_gen_task(
39
- eval_model: torch.nn.Module,
40
- eval_config: EvalConfig,
41
- tag,
42
- prompt,
43
- pixel_values,
44
- max_position_embeddings,
45
- tokens_per_image
46
- ):
47
- eval_model.to(TrainerTools().parallel.device)
48
- _eval_task(
49
- eval_model=eval_model,
50
- eval_config=eval_config,
51
- tag=tag,
52
- prompt=prompt,
53
- pixel_values=pixel_values,
54
- max_position_embeddings=max_position_embeddings,
55
- tokens_per_image=tokens_per_image,
56
- device=TrainerTools().parallel.device
57
- )
58
- eval_model.to('cpu')
59
-
60
- # threading.Thread(target=_eval_task, args=args).start()
@@ -1,7 +1,6 @@
1
1
  from typing import Union, Optional, List
2
2
  from contextlib import nullcontext
3
3
  import torch
4
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
5
4
  from llm_model import VlmModel, KVCache
6
5
  from .tools import TrainerTools
7
6
  from .utils import batch_repeat_image_tok
@@ -108,8 +107,7 @@ def _generate(
108
107
  pixel_values: Optional[torch.Tensor] = None,
109
108
  tokens_per_image: int = -1,
110
109
  suppress_tokens: Optional[List[int]] = None,
111
- device: Union[str, torch.device, int],
112
- reasoning_budget: Optional[int] = None
110
+ device: Union[str, torch.device, int]
113
111
  ):
114
112
  """
115
113
  :param model:
@@ -131,8 +129,7 @@ def _generate(
131
129
  device_type=device,
132
130
  dtype=TrainerTools().dtype,
133
131
  enabled=True,
134
- # fsdp模式,需要将cache_enabled设置为false
135
- cache_enabled=False if isinstance(model, FSDP) else None
132
+ cache_enabled=None
136
133
  ) if TrainerTools().use_amp else nullcontext()
137
134
 
138
135
  if isinstance(model, VlmModel):
@@ -144,28 +141,6 @@ def _generate(
144
141
  kv_cache: Optional[KVCache] = None
145
142
  generate_tokens = tokens.clone()
146
143
 
147
- reasoning_start = TrainerTools().tokenizer.reasoning_start
148
- reasoning_end = TrainerTools().tokenizer.reasoning_end
149
-
150
- # --- 状态初始化 ---
151
- in_reasoning_block = False
152
- reasoning_step_count = 0
153
- # “冷静期”标志位。当强制结束思考后,在下一步抑制<reasoning>的生成。
154
- suppress_reasoning_start_next = False
155
-
156
- if reasoning_budget is not None:
157
- prompt_tokens = tokens[0]
158
- start_indices = (prompt_tokens == reasoning_start).nonzero(as_tuple=True)[0]
159
- end_indices = (prompt_tokens == reasoning_end).nonzero(as_tuple=True)[0]
160
-
161
- last_start_idx = start_indices[-1].item() if len(start_indices) > 0 else -1
162
- last_end_idx = end_indices[-1].item() if len(end_indices) > 0 else -1
163
-
164
- if last_start_idx > last_end_idx:
165
- in_reasoning_block = True
166
- reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
167
-
168
- model.eval()
169
144
  with torch.inference_mode():
170
145
  for _ in range(max_new_tokens):
171
146
  # 是否需要截取??
@@ -185,23 +160,6 @@ def _generate(
185
160
  # (batch, vocab_size)
186
161
  logits = logits[:, -1, :]
187
162
 
188
- # --- 推理预算逻辑 ---
189
- force_end_reasoning_token = False
190
- if reasoning_budget is not None:
191
- # 检查是否需要在此步抑制 <reasoning>
192
- should_suppress_this_step = suppress_reasoning_start_next
193
- suppress_reasoning_start_next = False # 立即重置标志位
194
-
195
- # 修改: 检查是否超出预算
196
- if in_reasoning_block and reasoning_step_count >= reasoning_budget:
197
- force_end_reasoning_token = True
198
- # 设置标志位,在下一步抑制 <reasoning>
199
- suppress_reasoning_start_next = True
200
-
201
- # 如果上一轮设置了抑制标志,则在此轮执行抑制
202
- if should_suppress_this_step:
203
- logits[:, reasoning_start] = -float("inf")
204
-
205
163
  # 抑制特殊token输出
206
164
  if suppress_tokens and len(suppress_tokens) != 0:
207
165
  logits = _suppress_warper(logits, suppress_tokens)
@@ -217,10 +175,6 @@ def _generate(
217
175
  if p and 0 < p <= 1:
218
176
  logits = _top_p_warper(logits, p)
219
177
 
220
- if force_end_reasoning_token:
221
- logits[:] = -float("inf")
222
- logits[:, reasoning_end] = 0.0
223
-
224
178
  if multinomial:
225
179
  prob = logits.softmax(dim=-1)
226
180
  # 返回下标
@@ -229,18 +183,6 @@ def _generate(
229
183
  # 返回下标
230
184
  next_token = logits.argmax(dim=-1, keepdim=True)
231
185
 
232
- if reasoning_budget is not None:
233
- current_token_id = next_token.item()
234
- if not in_reasoning_block and current_token_id == reasoning_start:
235
- in_reasoning_block = True
236
- reasoning_step_count = 0
237
- elif in_reasoning_block:
238
- if current_token_id == reasoning_end:
239
- in_reasoning_block = False
240
- reasoning_step_count = 0
241
- else:
242
- reasoning_step_count += 1
243
-
244
186
  # token, is_full_result
245
187
  yield next_token, False
246
188
 
@@ -269,8 +211,7 @@ def _streaming_generate(
269
211
  pixel_values: Optional[torch.Tensor] = None,
270
212
  tokens_per_image: int = -1,
271
213
  suppress_tokens: Optional[List[int]] = None,
272
- device: Union[str, torch.device, int] = None,
273
- reasoning_budget: Optional[int] = None
214
+ device: Union[str, torch.device, int] = None
274
215
  ):
275
216
  device = TrainerTools().parallel.device if not device else device
276
217
  encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
@@ -286,8 +227,7 @@ def _streaming_generate(
286
227
  pixel_values=pixel_values,
287
228
  tokens_per_image=tokens_per_image,
288
229
  suppress_tokens=suppress_tokens,
289
- device=device,
290
- reasoning_budget=reasoning_budget
230
+ device=device
291
231
  )
292
232
 
293
233
  for (token, is_full_result) in generate_text_iterator:
@@ -306,8 +246,7 @@ def streaming_generate(
306
246
  pixel_values: Optional[torch.Tensor] = None,
307
247
  tokens_per_image: int = -1,
308
248
  suppress_tokens: Optional[List[int]] = None,
309
- device: Union[str, torch.device, int] = None,
310
- reasoning_budget: Optional[int] = None
249
+ device: Union[str, torch.device, int] = None
311
250
  ):
312
251
  text_iterator = _streaming_generate(
313
252
  model=model,
@@ -320,8 +259,7 @@ def streaming_generate(
320
259
  pixel_values=pixel_values,
321
260
  tokens_per_image=tokens_per_image,
322
261
  suppress_tokens=suppress_tokens,
323
- device=device,
324
- reasoning_budget=reasoning_budget
262
+ device=device
325
263
  )
326
264
 
327
265
  for (token, is_full_result) in text_iterator:
@@ -341,8 +279,7 @@ def generate(
341
279
  pixel_values: Optional[torch.Tensor] = None,
342
280
  tokens_per_image: int = -1,
343
281
  suppress_tokens: Optional[List[int]] = None,
344
- device: Union[str, torch.device, int] = None,
345
- reasoning_budget: Optional[int] = None
282
+ device: Union[str, torch.device, int] = None
346
283
  ):
347
284
  text_iterator = _streaming_generate(
348
285
  model=model,
@@ -355,8 +292,7 @@ def generate(
355
292
  suppress_tokens=suppress_tokens,
356
293
  pixel_values=pixel_values,
357
294
  tokens_per_image=tokens_per_image,
358
- device=device,
359
- reasoning_budget=reasoning_budget
295
+ device=device
360
296
  )
361
297
 
362
298
  for (token, is_full_result) in text_iterator:
@@ -386,7 +322,7 @@ def batch_generate(
386
322
  device_type=device,
387
323
  dtype=TrainerTools().dtype,
388
324
  enabled=True,
389
- cache_enabled=False if isinstance(model, FSDP) else None
325
+ cache_enabled=None
390
326
  ) if TrainerTools().use_amp else nullcontext()
391
327
 
392
328
  if isinstance(model, VlmModel):
@@ -403,7 +339,6 @@ def batch_generate(
403
339
  end_token = TrainerTools().tokenizer.end
404
340
  done = torch.zeros(batch_size, dtype=torch.bool, device=device)
405
341
 
406
- model.eval()
407
342
  with torch.inference_mode():
408
343
  for _ in range(max_new_tokens):
409
344
  # 只处理未完成的样本
@@ -1,5 +1,4 @@
1
1
  import time
2
- import copy
3
2
  from typing import Tuple, List, Union, Callable, Optional
4
3
  import torch
5
4
  from torch.utils.data import Dataset
@@ -15,7 +14,11 @@ from .loss import GRPOLoss
15
14
  from .tools import TrainerTools
16
15
  from .generate_utils import batch_generate
17
16
  from .log import log
18
- from .model_params import copy_model_params
17
+
18
+ from .partition_utils import (
19
+ sync_model_params,
20
+ unwrap_model_for_generation
21
+ )
19
22
 
20
23
  from .checkpoint import (
21
24
  save_checkpoint,
@@ -39,7 +42,6 @@ class GRPOTrainer(Trainer):
39
42
 
40
43
  self.reward_func = reward_func
41
44
  self.reference_model = self._init_reference_model()
42
- self.generate_model = self._init_generate_model()
43
45
 
44
46
  # 默认使用torch提供的pad_sequence
45
47
  # 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
@@ -47,17 +49,20 @@ class GRPOTrainer(Trainer):
47
49
 
48
50
  def _init_reference_model(self):
49
51
  reference_model = self._new_model(self.train_config)
50
- reference_model.to('cpu')
51
- reference_model.eval()
52
52
 
53
+ reference_model, _ = TrainerTools().parallel.process(
54
+ model=reference_model,
55
+ optimizer=None,
56
+ kwargs=self._init_reference_args(),
57
+ save_instance=False
58
+ )
59
+
60
+ reference_model.eval()
53
61
  for param in reference_model.parameters():
54
62
  param.requires_grad = False
55
63
 
56
64
  return reference_model
57
65
 
58
- def _init_generate_model(self):
59
- return copy.deepcopy(self.reference_model)
60
-
61
66
  def _init_loss(self):
62
67
  criterion = GRPOLoss(
63
68
  clip_eps=self.train_config.grpo_config.clip_eps,
@@ -163,7 +168,7 @@ class GRPOTrainer(Trainer):
163
168
  # [batch*group_size, 1]
164
169
  return advantages.unsqueeze(1) # Add dimension for token-wise operations
165
170
 
166
- def _generate_completions(self, prompts, group_size: int):
171
+ def _generate_completions(self, model, prompts, group_size: int):
167
172
  pad_token_id = TrainerTools().tokenizer.pad
168
173
  device = TrainerTools().parallel.device
169
174
 
@@ -181,7 +186,7 @@ class GRPOTrainer(Trainer):
181
186
 
182
187
  # [batch*group_size, max_prompt_len+max_gen_len]
183
188
  outputs: torch.Tensor = batch_generate(
184
- model=self.generate_model,
189
+ model=model,
185
190
  tokens=prompt_ids,
186
191
  pad_token_id=pad_token_id,
187
192
  attention_mask=prompt_masks,
@@ -201,7 +206,7 @@ class GRPOTrainer(Trainer):
201
206
 
202
207
  return prompt_ids, prompt_masks, completion_ids, completion_masks
203
208
 
204
- def _generate_rollout_data(self, batch_data: List[dict]):
209
+ def _generate_rollout_data(self, generate_model, batch_data: List[dict]):
205
210
  prompts = [item["prompt"] for item in batch_data]
206
211
  answers = [item["answer"] for item in batch_data]
207
212
  group_size = self.train_config.grpo_config.group_size
@@ -210,13 +215,13 @@ class GRPOTrainer(Trainer):
210
215
  # 修复问题:Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal
211
216
  with torch.no_grad():
212
217
  # with torch.inference_mode():
213
- prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(prompts, group_size)
218
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(generate_model, prompts, group_size)
214
219
  input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
215
220
  attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
216
221
  logits_to_keep = completion_ids.shape[1]
217
222
 
218
223
  # Compute old_log_probs from the current model, with gradients disabled.
219
- old_log_probs, _ = self._compute_log_probabilities(self.generate_model, input_ids, attention_mask, logits_to_keep)
224
+ old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
220
225
 
221
226
  # Compute ref_log_probs from the reference model, which remains static.
222
227
  ref_log_probs, _ = self._compute_log_probabilities(self.reference_model, input_ids, attention_mask, logits_to_keep)
@@ -275,12 +280,15 @@ class GRPOTrainer(Trainer):
275
280
  def train(self):
276
281
  global_steps = 0
277
282
  skipping_train = False
278
- device = TrainerTools().parallel.device
279
283
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
280
284
 
281
285
  for epoch in range(self.train_config.n_epochs):
282
- copy_model_params(_from=self.train_model, _to=self.reference_model)
283
- self.train_model.train()
286
+ sync_model_params(
287
+ _from=self.train_model,
288
+ _to=self.reference_model,
289
+ mixup_alpha=self.train_config.grpo_config.mixup_alpha
290
+ )
291
+
284
292
  file_count = len(self.train_config.file_dataset)
285
293
 
286
294
  for file_idx in range(file_count):
@@ -307,22 +315,13 @@ class GRPOTrainer(Trainer):
307
315
  skipping_train = False
308
316
 
309
317
  # start generate
310
- # 使用单独的模型生成数据, 原因是在deepspeed并行训练时,使用train_model生成数据会卡死
311
- self.generate_model.to(device)
312
- self.reference_model.to(device)
313
-
314
318
  if TrainerTools().parallel.is_main_process:
315
319
  log(f'start generate for batch {batch}/{batch_count_per_file}')
316
320
 
317
321
  # 生成数据
318
- with torch.no_grad():
319
- # 保存了train_model checkpoint后,这里保证生成模型使用的参数是最新
320
- copy_model_params(_from=self.train_model, _to=self.generate_model)
321
- rollout_data = self._generate_rollout_data(batch_data)
322
-
323
- # 卸载到cpu上,等待下次使用时再to gpu
324
- self.generate_model.to('cpu')
325
- self.reference_model.to('cpu')
322
+ with unwrap_model_for_generation(self.train_model) as generate_model:
323
+ rollout_data = self._generate_rollout_data(generate_model, batch_data)
324
+
326
325
  torch.cuda.empty_cache()
327
326
  # end generate
328
327
 
llm_trainer/loss.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple
1
+ from typing import List, Optional
2
2
  import torch
3
3
  from torch import nn
4
4
  import torch.nn.functional as F
@@ -0,0 +1,146 @@
1
+ from typing import Optional
2
+ from contextlib import contextmanager
3
+ import itertools
4
+ from packaging import version
5
+ from torch import nn
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+
8
+ from .tools import TrainerTools
9
+ from .parallel_ds import DsParallel
10
+ from .parallel_ddp import DdpParallel
11
+
12
+
13
+ @contextmanager
14
+ def unwrap_model_for_generation(model: nn.Module):
15
+ """
16
+ Context manager to unwrap distributed or accelerated models for generation tasks.
17
+
18
+ Args:
19
+ model:
20
+ Model to be unwrapped.
21
+ Yields:
22
+ Unwrapped model.
23
+
24
+ Example:
25
+ ```python
26
+ with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
27
+ generated_outputs = unwrapped_model.generate(input_ids)
28
+ ```
29
+ """
30
+ if isinstance(TrainerTools().parallel, DsParallel):
31
+ import deepspeed
32
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
33
+
34
+ if model.zero_optimization_stage() == 3:
35
+ with deepspeed.zero.GatheredParameters(model.parameters()):
36
+ _remove_hooks(model)
37
+ yield unwrap_model(model)
38
+ _add_hooks(model)
39
+ else:
40
+ yield unwrap_model(model)
41
+ elif isinstance(TrainerTools().parallel, DdpParallel):
42
+ yield unwrap_model(model)
43
+ else:
44
+ yield model
45
+
46
+
47
+ def sync_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
48
+ if isinstance(TrainerTools().parallel, DsParallel):
49
+ _sync_ds_model_params(_from, _to, mixup_alpha)
50
+ elif isinstance(TrainerTools().parallel, DdpParallel):
51
+ _sync_ddp_model_params(_from, _to, mixup_alpha)
52
+ else:
53
+ _copy_params(_from, _to, mixup_alpha)
54
+
55
+
56
+ def unwrap_model(model) -> nn.Module:
57
+ try:
58
+ import deepspeed
59
+ if isinstance(model, deepspeed.DeepSpeedEngine):
60
+ return model.module
61
+ except: ...
62
+
63
+ if isinstance(model, DDP):
64
+ return model.module
65
+
66
+ return model
67
+
68
+
69
+ def _copy_params(model, target_model, mixup_alpha):
70
+ for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
71
+ target_param.data.mul_(1.0 - mixup_alpha).add_(copy_param.data, alpha=mixup_alpha)
72
+
73
+
74
+ def _sync_ds_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
75
+ import deepspeed
76
+ assert isinstance(_from, deepspeed.DeepSpeedEngine)
77
+
78
+ origin_from = unwrap_model(_from)
79
+
80
+ if _from.zero_optimization_stage() == 3:
81
+ with deepspeed.zero.GatheredParameters(list(origin_from.parameters()) + list(_to.parameters()), modifier_rank=0):
82
+ if TrainerTools().parallel.is_main_process:
83
+ _copy_params(origin_from, _to, mixup_alpha)
84
+ else:
85
+ _copy_params(origin_from, _to, mixup_alpha)
86
+
87
+
88
+ def _sync_ddp_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
89
+ assert isinstance(_from, DDP)
90
+
91
+ origin_from = unwrap_model(_from)
92
+ _copy_params(origin_from, _to, mixup_alpha)
93
+
94
+
95
+ def _add_hooks(model: nn.Module) -> None:
96
+ """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
97
+ import deepspeed
98
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
99
+
100
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
101
+ return
102
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
103
+ optimizer_offload = model.optimizer.parameter_offload
104
+ elif model.optimizer is not None:
105
+ optimizer_offload = model.optimizer
106
+ else:
107
+ raise RuntimeError("The model optimizer is None, which is not yet supported.")
108
+ if version.parse(deepspeed.__version__) >= version.parse("0.16.4"):
109
+ # Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
110
+ optimizer_offload._register_deepspeed_module(optimizer_offload.module)
111
+ else:
112
+ optimizer_offload._register_hooks_recursively(optimizer_offload.module)
113
+
114
+
115
+ def _remove_hooks(model: nn.Module) -> None:
116
+ """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
117
+ import deepspeed
118
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
119
+
120
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
121
+ return
122
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
123
+ optimizer_offload = model.optimizer.parameter_offload
124
+ elif model.optimizer is not None:
125
+ optimizer_offload = model.optimizer
126
+ else:
127
+ raise RuntimeError("The model optimizer is None, which is not yet supported.")
128
+
129
+ for param in _iter_params(optimizer_offload.module, recurse=True):
130
+ param.ds_active_sub_modules.clear()
131
+
132
+ for hook in optimizer_offload.forward_hooks:
133
+ hook.remove()
134
+ for hook in optimizer_offload.backward_hooks:
135
+ hook.remove()
136
+
137
+ optimizer_offload.forward_hooks = []
138
+ optimizer_offload.backward_hooks = []
139
+
140
+
141
+ def _iter_params(module, recurse=False):
142
+ return [param for _, param in _get_all_parameters(module, recurse)]
143
+
144
+
145
+ def _get_all_parameters(sub_module, recurse=False):
146
+ return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())