project-llm-trainer 0.4.15__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

Files changed (29) hide show
  1. llm_trainer/checkpoint.py +0 -50
  2. llm_trainer/dpo_trainer.py +6 -3
  3. llm_trainer/eval.py +3 -30
  4. llm_trainer/generate_utils.py +2 -6
  5. llm_trainer/grpo_trainer.py +27 -28
  6. llm_trainer/loss.py +1 -1
  7. llm_trainer/partition_utils.py +146 -0
  8. llm_trainer/tools.py +0 -2
  9. llm_trainer/train_configs.py +5 -25
  10. llm_trainer/trainer.py +28 -67
  11. llm_trainer/utils.py +0 -1
  12. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.0.dist-info}/METADATA +1 -1
  13. project_llm_trainer-0.5.0.dist-info/RECORD +33 -0
  14. llm_trainer/dcp.py +0 -93
  15. llm_trainer/ds_model_params.py +0 -72
  16. llm_trainer/fsdp_checkpoint.py +0 -52
  17. llm_trainer/fsdp_model_params.py +0 -39
  18. llm_trainer/model_params.py +0 -28
  19. llm_trainer/parallel_fsdp.py +0 -121
  20. project_llm_trainer-0.4.15.dist-info/RECORD +0 -38
  21. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.0.data}/scripts/calc_intermediate_size +0 -0
  22. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.0.data}/scripts/ddp_train +0 -0
  23. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.0.data}/scripts/ds_train +0 -0
  24. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.0.data}/scripts/plot_loss +0 -0
  25. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.0.data}/scripts/plot_lr +0 -0
  26. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.0.data}/scripts/py_train +0 -0
  27. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.0.data}/scripts/smart_train +0 -0
  28. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.0.dist-info}/WHEEL +0 -0
  29. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.0.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py CHANGED
@@ -6,35 +6,11 @@ from torch.optim import Optimizer
6
6
  from torch.nn.parallel import DistributedDataParallel as DDP
7
7
 
8
8
  from .parallel_ds import DsParallel
9
- from .parallel_fsdp import FsdpParallel
10
- from .parallel_ddp import DdpParallel
11
9
  from .scheduler import LRScheduler
12
10
  from .tools import TrainerTools
13
11
 
14
- try:
15
- from .dcp import save_dcp, load_dcp, convert_dcp_to_pth
16
- except:
17
- os.environ['ENABLE_DCP'] = "0"
18
-
19
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
20
-
21
- # https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
22
-
23
12
  DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
24
13
 
25
-
26
- def _can_use_dcp(model: nn.Module) -> bool:
27
- if os.environ.get('ENABLE_DCP', '1') != '1':
28
- return False
29
-
30
- # 如果是fsdp或者ddp,才能使用dcp保存
31
- if (isinstance(TrainerTools().parallel, FsdpParallel)
32
- or isinstance(TrainerTools().parallel, DdpParallel)):
33
- return True
34
-
35
- return False
36
-
37
-
38
14
  def save_checkpoint(
39
15
  model: nn.Module,
40
16
  optimizer: Optional[Optimizer] = None,
@@ -43,11 +19,6 @@ def save_checkpoint(
43
19
  if isinstance(TrainerTools().parallel, DsParallel):
44
20
  from .ds_checkpoint import save_ds_checkpoint
45
21
  save_ds_checkpoint(model, suffix)
46
- elif _can_use_dcp(model):
47
- save_dcp(model, optimizer, suffix)
48
- elif isinstance(model, FSDP):
49
- from .fsdp_checkpoint import save_fsdp_checkpoint
50
- save_fsdp_checkpoint(model, optimizer, suffix)
51
22
  else:
52
23
  if TrainerTools().parallel.is_main_process:
53
24
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
@@ -73,11 +44,6 @@ def load_checkpoint(
73
44
  if isinstance(TrainerTools().parallel, DsParallel):
74
45
  from .ds_checkpoint import load_ds_checkpoint
75
46
  load_ds_checkpoint(model, load_module_only=load_module_only, suffix=suffix)
76
- elif _can_use_dcp(model):
77
- load_dcp(model, optimizer, suffix)
78
- elif isinstance(model, FSDP):
79
- from .fsdp_checkpoint import load_fsdp_checkpoint
80
- load_fsdp_checkpoint(model, optimizer, device, suffix)
81
47
  else:
82
48
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
83
49
  if suffix:
@@ -99,22 +65,6 @@ def load_checkpoint_for_eval(
99
65
  if isinstance(TrainerTools().parallel, DsParallel):
100
66
  from .ds_checkpoint import load_ds_checkpoint_for_eval
101
67
  load_ds_checkpoint_for_eval(model)
102
- elif _can_use_dcp(model):
103
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
104
-
105
- # load_dcp方式在cpu上会报错,所以改为先将ckpt转换为pth,然后再加载pth
106
- # load_dcp(model, optimizer)
107
- pth_name = os.environ.get('EVAL_CHECKPOINT_NAME', checkpoint_name)
108
- if suffix:
109
- pth_name = f'{pth_name}_{suffix}'
110
-
111
- convert_dcp_to_pth(pth_name)
112
-
113
- if os.path.exists(pth_name):
114
- ckpt = torch.load(pth_name, map_location=device, weights_only=True)
115
- model.load_state_dict(ckpt['app']['model_state_dict'])
116
- # 使用完删除
117
- os.remove(pth_name)
118
68
  else:
119
69
  load_checkpoint(model, None, device, suffix=suffix)
120
70
 
@@ -12,7 +12,7 @@ from .dataset import DPODataset
12
12
  from .loss import DPOLoss
13
13
  from .tools import TrainerTools
14
14
  from .utils import get_dpo_collate_fn
15
- from .model_params import copy_model_params
15
+ from .partition_utils import sync_model_params
16
16
 
17
17
  from .checkpoint import (
18
18
  save_checkpoint,
@@ -38,7 +38,6 @@ class DPOTrainer(Trainer):
38
38
 
39
39
  def _init_reference_model(self):
40
40
  reference_model = self._new_model(self.train_config)
41
- copy_model_params(_from=self.train_model, _to=reference_model)
42
41
 
43
42
  reference_model, _ = TrainerTools().parallel.process(
44
43
  model=reference_model,
@@ -51,6 +50,11 @@ class DPOTrainer(Trainer):
51
50
  for param in reference_model.parameters():
52
51
  param.requires_grad = False
53
52
 
53
+ sync_model_params(
54
+ _from=self.train_model,
55
+ _to=reference_model
56
+ )
57
+
54
58
  return reference_model
55
59
 
56
60
  def _init_loss(self):
@@ -210,7 +214,6 @@ class DPOTrainer(Trainer):
210
214
  if need_update_grad:
211
215
  loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
212
216
 
213
- # todo check all_reduce??
214
217
  if TrainerTools().parallel.parallel_train:
215
218
  dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
216
219
 
llm_trainer/eval.py CHANGED
@@ -5,16 +5,14 @@ from .log import get_log_dir
5
5
  from .tools import TrainerTools
6
6
  from .train_configs import EvalConfig
7
7
 
8
-
9
- def _eval_task(
8
+ def submit_gen_task(
10
9
  eval_model: torch.nn.Module,
11
10
  eval_config: EvalConfig,
12
11
  tag,
13
12
  prompt,
14
13
  pixel_values,
15
14
  max_position_embeddings,
16
- tokens_per_image,
17
- device
15
+ tokens_per_image
18
16
  ):
19
17
  log_dir = get_log_dir()
20
18
 
@@ -28,33 +26,8 @@ def _eval_task(
28
26
  p=eval_config.top_p,
29
27
  pixel_values=pixel_values,
30
28
  tokens_per_image=tokens_per_image,
31
- device=device
29
+ device=TrainerTools().parallel.device
32
30
  )
33
31
 
34
32
  with open(f'{log_dir}gen.txt', 'a') as f:
35
33
  f.write(f"{tag}, gen->{gen_result}\n")
36
-
37
-
38
- def submit_gen_task(
39
- eval_model: torch.nn.Module,
40
- eval_config: EvalConfig,
41
- tag,
42
- prompt,
43
- pixel_values,
44
- max_position_embeddings,
45
- tokens_per_image
46
- ):
47
- eval_model.to(TrainerTools().parallel.device)
48
- _eval_task(
49
- eval_model=eval_model,
50
- eval_config=eval_config,
51
- tag=tag,
52
- prompt=prompt,
53
- pixel_values=pixel_values,
54
- max_position_embeddings=max_position_embeddings,
55
- tokens_per_image=tokens_per_image,
56
- device=TrainerTools().parallel.device
57
- )
58
- eval_model.to('cpu')
59
-
60
- # threading.Thread(target=_eval_task, args=args).start()
@@ -1,7 +1,6 @@
1
1
  from typing import Union, Optional, List
2
2
  from contextlib import nullcontext
3
3
  import torch
4
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
5
4
  from llm_model import VlmModel, KVCache
6
5
  from .tools import TrainerTools
7
6
  from .utils import batch_repeat_image_tok
@@ -131,8 +130,7 @@ def _generate(
131
130
  device_type=device,
132
131
  dtype=TrainerTools().dtype,
133
132
  enabled=True,
134
- # fsdp模式,需要将cache_enabled设置为false
135
- cache_enabled=False if isinstance(model, FSDP) else None
133
+ cache_enabled=None
136
134
  ) if TrainerTools().use_amp else nullcontext()
137
135
 
138
136
  if isinstance(model, VlmModel):
@@ -165,7 +163,6 @@ def _generate(
165
163
  in_reasoning_block = True
166
164
  reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
167
165
 
168
- model.eval()
169
166
  with torch.inference_mode():
170
167
  for _ in range(max_new_tokens):
171
168
  # 是否需要截取??
@@ -386,7 +383,7 @@ def batch_generate(
386
383
  device_type=device,
387
384
  dtype=TrainerTools().dtype,
388
385
  enabled=True,
389
- cache_enabled=False if isinstance(model, FSDP) else None
386
+ cache_enabled=None
390
387
  ) if TrainerTools().use_amp else nullcontext()
391
388
 
392
389
  if isinstance(model, VlmModel):
@@ -403,7 +400,6 @@ def batch_generate(
403
400
  end_token = TrainerTools().tokenizer.end
404
401
  done = torch.zeros(batch_size, dtype=torch.bool, device=device)
405
402
 
406
- model.eval()
407
403
  with torch.inference_mode():
408
404
  for _ in range(max_new_tokens):
409
405
  # 只处理未完成的样本
@@ -1,5 +1,4 @@
1
1
  import time
2
- import copy
3
2
  from typing import Tuple, List, Union, Callable, Optional
4
3
  import torch
5
4
  from torch.utils.data import Dataset
@@ -15,7 +14,11 @@ from .loss import GRPOLoss
15
14
  from .tools import TrainerTools
16
15
  from .generate_utils import batch_generate
17
16
  from .log import log
18
- from .model_params import copy_model_params
17
+
18
+ from .partition_utils import (
19
+ sync_model_params,
20
+ unwrap_model_for_generation
21
+ )
19
22
 
20
23
  from .checkpoint import (
21
24
  save_checkpoint,
@@ -39,7 +42,6 @@ class GRPOTrainer(Trainer):
39
42
 
40
43
  self.reward_func = reward_func
41
44
  self.reference_model = self._init_reference_model()
42
- self.generate_model = self._init_generate_model()
43
45
 
44
46
  # 默认使用torch提供的pad_sequence
45
47
  # 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
@@ -47,17 +49,20 @@ class GRPOTrainer(Trainer):
47
49
 
48
50
  def _init_reference_model(self):
49
51
  reference_model = self._new_model(self.train_config)
50
- reference_model.to('cpu')
51
- reference_model.eval()
52
52
 
53
+ reference_model, _ = TrainerTools().parallel.process(
54
+ model=reference_model,
55
+ optimizer=None,
56
+ kwargs=self._init_reference_args(),
57
+ save_instance=False
58
+ )
59
+
60
+ reference_model.eval()
53
61
  for param in reference_model.parameters():
54
62
  param.requires_grad = False
55
63
 
56
64
  return reference_model
57
65
 
58
- def _init_generate_model(self):
59
- return copy.deepcopy(self.reference_model)
60
-
61
66
  def _init_loss(self):
62
67
  criterion = GRPOLoss(
63
68
  clip_eps=self.train_config.grpo_config.clip_eps,
@@ -163,7 +168,7 @@ class GRPOTrainer(Trainer):
163
168
  # [batch*group_size, 1]
164
169
  return advantages.unsqueeze(1) # Add dimension for token-wise operations
165
170
 
166
- def _generate_completions(self, prompts, group_size: int):
171
+ def _generate_completions(self, model, prompts, group_size: int):
167
172
  pad_token_id = TrainerTools().tokenizer.pad
168
173
  device = TrainerTools().parallel.device
169
174
 
@@ -181,7 +186,7 @@ class GRPOTrainer(Trainer):
181
186
 
182
187
  # [batch*group_size, max_prompt_len+max_gen_len]
183
188
  outputs: torch.Tensor = batch_generate(
184
- model=self.generate_model,
189
+ model=model,
185
190
  tokens=prompt_ids,
186
191
  pad_token_id=pad_token_id,
187
192
  attention_mask=prompt_masks,
@@ -201,7 +206,7 @@ class GRPOTrainer(Trainer):
201
206
 
202
207
  return prompt_ids, prompt_masks, completion_ids, completion_masks
203
208
 
204
- def _generate_rollout_data(self, batch_data: List[dict]):
209
+ def _generate_rollout_data(self, generate_model, batch_data: List[dict]):
205
210
  prompts = [item["prompt"] for item in batch_data]
206
211
  answers = [item["answer"] for item in batch_data]
207
212
  group_size = self.train_config.grpo_config.group_size
@@ -210,13 +215,13 @@ class GRPOTrainer(Trainer):
210
215
  # 修复问题:Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal
211
216
  with torch.no_grad():
212
217
  # with torch.inference_mode():
213
- prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(prompts, group_size)
218
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(generate_model, prompts, group_size)
214
219
  input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
215
220
  attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
216
221
  logits_to_keep = completion_ids.shape[1]
217
222
 
218
223
  # Compute old_log_probs from the current model, with gradients disabled.
219
- old_log_probs, _ = self._compute_log_probabilities(self.generate_model, input_ids, attention_mask, logits_to_keep)
224
+ old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
220
225
 
221
226
  # Compute ref_log_probs from the reference model, which remains static.
222
227
  ref_log_probs, _ = self._compute_log_probabilities(self.reference_model, input_ids, attention_mask, logits_to_keep)
@@ -275,12 +280,15 @@ class GRPOTrainer(Trainer):
275
280
  def train(self):
276
281
  global_steps = 0
277
282
  skipping_train = False
278
- device = TrainerTools().parallel.device
279
283
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
280
284
 
281
285
  for epoch in range(self.train_config.n_epochs):
282
- copy_model_params(_from=self.train_model, _to=self.reference_model)
283
- self.train_model.train()
286
+ sync_model_params(
287
+ _from=self.train_model,
288
+ _to=self.reference_model,
289
+ mixup_alpha=self.train_config.grpo_config.mixup_alpha
290
+ )
291
+
284
292
  file_count = len(self.train_config.file_dataset)
285
293
 
286
294
  for file_idx in range(file_count):
@@ -307,22 +315,13 @@ class GRPOTrainer(Trainer):
307
315
  skipping_train = False
308
316
 
309
317
  # start generate
310
- # 使用单独的模型生成数据, 原因是在deepspeed并行训练时,使用train_model生成数据会卡死
311
- self.generate_model.to(device)
312
- self.reference_model.to(device)
313
-
314
318
  if TrainerTools().parallel.is_main_process:
315
319
  log(f'start generate for batch {batch}/{batch_count_per_file}')
316
320
 
317
321
  # 生成数据
318
- with torch.no_grad():
319
- # 保存了train_model checkpoint后,这里保证生成模型使用的参数是最新
320
- copy_model_params(_from=self.train_model, _to=self.generate_model)
321
- rollout_data = self._generate_rollout_data(batch_data)
322
-
323
- # 卸载到cpu上,等待下次使用时再to gpu
324
- self.generate_model.to('cpu')
325
- self.reference_model.to('cpu')
322
+ with unwrap_model_for_generation(self.train_model) as generate_model:
323
+ rollout_data = self._generate_rollout_data(generate_model, batch_data)
324
+
326
325
  torch.cuda.empty_cache()
327
326
  # end generate
328
327
 
llm_trainer/loss.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple
1
+ from typing import List, Optional
2
2
  import torch
3
3
  from torch import nn
4
4
  import torch.nn.functional as F
@@ -0,0 +1,146 @@
1
+ from typing import Optional
2
+ from contextlib import contextmanager
3
+ import itertools
4
+ from packaging import version
5
+ from torch import nn
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+
8
+ from .tools import TrainerTools
9
+ from .parallel_ds import DsParallel
10
+ from .parallel_ddp import DdpParallel
11
+
12
+
13
+ @contextmanager
14
+ def unwrap_model_for_generation(model: nn.Module):
15
+ """
16
+ Context manager to unwrap distributed or accelerated models for generation tasks.
17
+
18
+ Args:
19
+ model:
20
+ Model to be unwrapped.
21
+ Yields:
22
+ Unwrapped model.
23
+
24
+ Example:
25
+ ```python
26
+ with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
27
+ generated_outputs = unwrapped_model.generate(input_ids)
28
+ ```
29
+ """
30
+ if isinstance(TrainerTools().parallel, DsParallel):
31
+ import deepspeed
32
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
33
+
34
+ if model.zero_optimization_stage() == 3:
35
+ with deepspeed.zero.GatheredParameters(model.parameters()):
36
+ _remove_hooks(model)
37
+ yield unwrap_model(model)
38
+ _add_hooks(model)
39
+ else:
40
+ yield unwrap_model(model)
41
+ elif isinstance(TrainerTools().parallel, DdpParallel):
42
+ yield unwrap_model(model)
43
+ else:
44
+ yield model
45
+
46
+
47
+ def sync_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
48
+ if isinstance(TrainerTools().parallel, DsParallel):
49
+ _sync_ds_model_params(_from, _to, mixup_alpha)
50
+ elif isinstance(TrainerTools().parallel, DdpParallel):
51
+ _sync_ddp_model_params(_from, _to, mixup_alpha)
52
+ else:
53
+ _copy_params(_from, _to, mixup_alpha)
54
+
55
+
56
+ def unwrap_model(model) -> nn.Module:
57
+ try:
58
+ import deepspeed
59
+ if isinstance(model, deepspeed.DeepSpeedEngine):
60
+ return model.module
61
+ except: ...
62
+
63
+ if isinstance(model, DDP):
64
+ return model.module
65
+
66
+ return model
67
+
68
+
69
+ def _copy_params(model, target_model, mixup_alpha):
70
+ for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
71
+ target_param.data.mul_(1.0 - mixup_alpha).add_(copy_param.data, alpha=mixup_alpha)
72
+
73
+
74
+ def _sync_ds_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
75
+ import deepspeed
76
+ assert isinstance(_from, deepspeed.DeepSpeedEngine)
77
+
78
+ origin_from = unwrap_model(_from)
79
+
80
+ if _from.zero_optimization_stage() == 3:
81
+ with deepspeed.zero.GatheredParameters(list(origin_from.parameters()) + list(_to.parameters()), modifier_rank=0):
82
+ if TrainerTools().parallel.is_main_process:
83
+ _copy_params(origin_from, _to, mixup_alpha)
84
+ else:
85
+ _copy_params(origin_from, _to, mixup_alpha)
86
+
87
+
88
+ def _sync_ddp_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
89
+ assert isinstance(_from, DDP)
90
+
91
+ origin_from = unwrap_model(_from)
92
+ _copy_params(origin_from, _to, mixup_alpha)
93
+
94
+
95
+ def _add_hooks(model: nn.Module) -> None:
96
+ """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
97
+ import deepspeed
98
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
99
+
100
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
101
+ return
102
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
103
+ optimizer_offload = model.optimizer.parameter_offload
104
+ elif model.optimizer is not None:
105
+ optimizer_offload = model.optimizer
106
+ else:
107
+ raise RuntimeError("The model optimizer is None, which is not yet supported.")
108
+ if version.parse(deepspeed.__version__) >= version.parse("0.16.4"):
109
+ # Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
110
+ optimizer_offload._register_deepspeed_module(optimizer_offload.module)
111
+ else:
112
+ optimizer_offload._register_hooks_recursively(optimizer_offload.module)
113
+
114
+
115
+ def _remove_hooks(model: nn.Module) -> None:
116
+ """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
117
+ import deepspeed
118
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
119
+
120
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
121
+ return
122
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
123
+ optimizer_offload = model.optimizer.parameter_offload
124
+ elif model.optimizer is not None:
125
+ optimizer_offload = model.optimizer
126
+ else:
127
+ raise RuntimeError("The model optimizer is None, which is not yet supported.")
128
+
129
+ for param in _iter_params(optimizer_offload.module, recurse=True):
130
+ param.ds_active_sub_modules.clear()
131
+
132
+ for hook in optimizer_offload.forward_hooks:
133
+ hook.remove()
134
+ for hook in optimizer_offload.backward_hooks:
135
+ hook.remove()
136
+
137
+ optimizer_offload.forward_hooks = []
138
+ optimizer_offload.backward_hooks = []
139
+
140
+
141
+ def _iter_params(module, recurse=False):
142
+ return [param for _, param in _get_all_parameters(module, recurse)]
143
+
144
+
145
+ def _get_all_parameters(sub_module, recurse=False):
146
+ return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
llm_trainer/tools.py CHANGED
@@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
3
3
  import torch
4
4
  from .tokenizer import Tokenizer
5
5
  from .parallel_ds import DsParallel
6
- from .parallel_fsdp import FsdpParallel
7
6
  from .parallel_ddp import DdpParallel
8
7
  from .parallel_none import NoneParallel
9
8
  from .log import log
@@ -11,7 +10,6 @@ from .log import log
11
10
 
12
11
  parallel_types = {
13
12
  'ds': DsParallel,
14
- 'fsdp': FsdpParallel,
15
13
  'ddp': DdpParallel,
16
14
  'none': NoneParallel
17
15
  }
@@ -1,8 +1,7 @@
1
- from typing import Optional, Union, Set, Type, Callable, List, Mapping, Any
1
+ from typing import Optional, Union, Callable, List, Mapping, Any
2
2
  from dataclasses import dataclass, field
3
3
 
4
4
  import torch
5
- from torch import nn
6
5
  from llm_model import ModelConfig, VLMConfig
7
6
  from .tools import FileDataset
8
7
 
@@ -33,6 +32,9 @@ class DsZeROConfig:
33
32
  reduce_bucket_size: Optional[Union[str, int]] = 5e8
34
33
  contiguous_gradients: Optional[bool] = True
35
34
 
35
+ @dataclass(kw_only=True)
36
+ class DsZero0Config(DsZeROConfig):
37
+ stage: int = field(default=0, init=False)
36
38
 
37
39
  @dataclass(kw_only=True)
38
40
  class DsZero1Config(DsZeROConfig):
@@ -84,26 +86,6 @@ class DsConfig:
84
86
  activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
85
87
 
86
88
 
87
- @dataclass(kw_only=True)
88
- class FsdpConfig:
89
- """
90
- fsdp训练模式配置项
91
- Args:
92
- transformer_layer_cls (`Set[Type[nn.Module]]`, *optional*, default is None):
93
- 提供transformer层的类
94
- wrap_policy_num_params (`int`, *optional*, default is -1):
95
- size_based_auto_wrap_policy的min_num_params参数,-1不生效该策略
96
- cpu_offload (`bool`, *optional*, default is False):
97
- 是否使用cpu卸载
98
- offload_params (`bool`, default is False):
99
- 是否卸载参数,在cpu_offload为True时生效
100
- """
101
- transformer_layer_cls: Optional[Set[Type[nn.Module]]] = None
102
- wrap_policy_num_params: int = -1
103
- cpu_offload: bool = False
104
- offload_params: bool = False
105
-
106
-
107
89
  @dataclass(kw_only=True)
108
90
  class DataLoaderConfig:
109
91
  """
@@ -157,6 +139,7 @@ class GRPOConfig:
157
139
  clip_eps: float = 0.2
158
140
  kl_weight: float = 0.01
159
141
  group_size: int = 12
142
+ mixup_alpha: float = 1.0
160
143
  gen_max_new_tokens: Optional[int] = None
161
144
  gen_temperature: Optional[float] = None
162
145
  gen_k: Optional[int] = None
@@ -210,8 +193,6 @@ class TrainConfig:
210
193
  每隔多少个batch进行模型eval
211
194
  lr_config (`LrConfig`):
212
195
  lr配置项
213
- fsdp_config: (`FsdpConfig`):
214
- fsdp训练模式配置项
215
196
  data_loader_config: (`DataLoaderConfig`):
216
197
  data loader配置项
217
198
  kd_config: (`KDConfig`, *Optional*, default is None):
@@ -231,7 +212,6 @@ class TrainConfig:
231
212
  lr_config: LrConfig = field(default_factory=LrConfig)
232
213
 
233
214
  ds_config: DsConfig = field(default_factory=DsConfig)
234
- fsdp_config: FsdpConfig = field(default_factory=FsdpConfig)
235
215
 
236
216
  kd_config: Optional[KDConfig] = None
237
217
  dpo_config: Optional[DPOConfig] = None
llm_trainer/trainer.py CHANGED
@@ -1,21 +1,18 @@
1
1
  import time
2
2
  from contextlib import nullcontext
3
- from typing import Optional, Tuple, List, Dict, Any, Union
3
+ from typing import Optional, Tuple, List, Dict, Any
4
4
 
5
5
  import torch
6
- from torch import nn
7
6
  import torch.distributed as dist
8
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9
7
  from torch.utils.data import Dataset
10
8
  from llm_model import LlmModel, VlmModel
11
9
 
12
10
  from .parallel_ds import DsParallel
13
- from .parallel_fsdp import FsdpParallel
14
11
  from .tools import TrainerTools
15
12
  from .loss import LMLoss, KDLoss
16
13
  from .dataset import TextDataset
17
- from .model_params import copy_model_params
18
14
  from .eval import submit_gen_task
15
+ from .partition_utils import unwrap_model_for_generation
19
16
 
20
17
  from .train_configs import (
21
18
  TrainConfig,
@@ -78,7 +75,6 @@ class Trainer:
78
75
 
79
76
  self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr, parallel_kwargs, use_ds_optim)
80
77
  self.lr_scheduler = self._init_lr_scheduler(initial_lr)
81
- self.eval_model: Optional[nn.Module] = self._init_eval_model()
82
78
 
83
79
  self.criterion, self.kd_loss = self._init_loss()
84
80
 
@@ -86,9 +82,7 @@ class Trainer:
86
82
  device_type=TrainerTools().parallel.device_type,
87
83
  dtype=TrainerTools().dtype,
88
84
  enabled=True,
89
- # fsdp模式,需要将cache_enabled设置为false
90
- # https://www.zhihu.com/question/642793891
91
- cache_enabled=False if isinstance(self.train_model, FSDP) else None
85
+ cache_enabled=None
92
86
  ) if TrainerTools().use_amp else nullcontext()
93
87
 
94
88
  load_checkpoint(
@@ -176,12 +170,6 @@ class Trainer:
176
170
 
177
171
  return model, optim
178
172
 
179
- def _init_eval_model(self) -> Optional[nn.Module]:
180
- if TrainerTools().parallel.is_main_process:
181
- return self._new_model(self.train_config).to(device='cpu', dtype=TrainerTools().dtype)
182
-
183
- return None
184
-
185
173
  def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
186
174
  if self.train_config.lr_config.enable_lr_scheduler:
187
175
  min_lr = self.train_config.lr_config.min_lr
@@ -313,13 +301,6 @@ class Trainer:
313
301
  activation_checkpointing['number_checkpoints'] = activation_checkpointing_config.number_checkpoints
314
302
 
315
303
  parallel_kwargs['activation_checkpointing'] = activation_checkpointing
316
- elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
317
- parallel_kwargs = {
318
- 'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
319
- 'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
320
- 'cpu_offload': self.train_config.fsdp_config.cpu_offload,
321
- 'offload_params': self.train_config.fsdp_config.offload_params
322
- }
323
304
 
324
305
  dataloader_args = self.train_config.data_loader_config
325
306
  data_loader_kwargs = {
@@ -441,54 +422,35 @@ class Trainer:
441
422
 
442
423
  raise e
443
424
 
444
- def _on_batch_end(
445
- self,
446
- tag: str
447
- ):
448
- copy_model_params(_from=self.train_model, _to=self.eval_model)
425
+ def _eval(self, tag: str):
426
+ with unwrap_model_for_generation(self.train_model) as generate_model:
427
+ if TrainerTools().parallel.is_main_process:
428
+ generate_model.eval()
429
+ eval_prompt, eval_image_tag = self._get_eval_data()
430
+
431
+ if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
432
+ eval_pixel_values = self.pixel_values_provider([eval_image_tag])
433
+ else:
434
+ eval_pixel_values = None
435
+
436
+ submit_gen_task(
437
+ generate_model,
438
+ self.train_config.eval_config,
439
+ tag=tag,
440
+ prompt=eval_prompt,
441
+ pixel_values=eval_pixel_values,
442
+ max_position_embeddings=self.train_config.model_config.max_position_embeddings,
443
+ tokens_per_image=self.tokens_per_image
444
+ )
445
+ generate_model.train()
449
446
 
450
- if TrainerTools().parallel.is_main_process:
451
- eval_prompt, eval_image_tag = self._get_eval_data()
452
- if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
453
- eval_pixel_values = self.pixel_values_provider([eval_image_tag])
454
- else:
455
- eval_pixel_values = None
456
-
457
- submit_gen_task(
458
- self.eval_model,
459
- self.train_config.eval_config,
460
- tag=f'sign:batch/{tag}',
461
- prompt=eval_prompt,
462
- pixel_values=eval_pixel_values,
463
- max_position_embeddings=self.train_config.model_config.max_position_embeddings,
464
- tokens_per_image=self.tokens_per_image
465
- )
466
447
  TrainerTools().parallel.wait()
467
448
 
468
- def _on_epoch_end(
469
- self,
470
- tag: str
471
- ):
472
- copy_model_params(_from=self.train_model, _to=self.eval_model)
473
-
474
- if TrainerTools().parallel.is_main_process:
475
- eval_prompt, eval_image_tag = self._get_eval_data()
476
- if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
477
- eval_pixel_values = self.pixel_values_provider([eval_image_tag])
478
- else:
479
- eval_pixel_values = None
480
-
481
- submit_gen_task(
482
- self.eval_model,
483
- self.train_config.eval_config,
484
- tag=f'sign:epoch/{tag}',
485
- prompt=eval_prompt,
486
- pixel_values=eval_pixel_values,
487
- max_position_embeddings=self.train_config.model_config.max_position_embeddings,
488
- tokens_per_image=self.tokens_per_image
489
- )
449
+ def _on_batch_end(self, tag: str):
450
+ self._eval(f'sign:batch/{tag}')
490
451
 
491
- TrainerTools().parallel.wait()
452
+ def _on_epoch_end(self, tag: str):
453
+ self._eval(f'sign:epoch/{tag}')
492
454
 
493
455
  def _on_file_start(
494
456
  self,
@@ -574,7 +536,6 @@ class Trainer:
574
536
  if need_update_grad:
575
537
  loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
576
538
 
577
- # todo check all_reduce??
578
539
  if TrainerTools().parallel.parallel_train:
579
540
  dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
580
541
 
llm_trainer/utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import random
2
- from typing import Tuple, Optional
3
2
  import torch
4
3
  from torch.nn.utils.rnn import pad_sequence
5
4
  import torch.nn.functional as F
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.4.15
3
+ Version: 0.5.0
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -0,0 +1,33 @@
1
+ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
+ llm_trainer/checkpoint.py,sha256=xTmmQSJ_jQDVSTT3km1p_8eRrc7yE_dEsi92z9OX5ec,3251
3
+ llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
+ llm_trainer/dpo_trainer.py,sha256=wMREatLt0I8Ajdm_sI2U8Zj-IN1L6txP9s_tH1oI3-s,11431
5
+ llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
6
+ llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
7
+ llm_trainer/generate_utils.py,sha256=2MoEGEpoTzx7khO3dPcC2akFLyjtbFFpdJtuB_QQ3OY,17708
8
+ llm_trainer/grpo_trainer.py,sha256=qiC3KwxYPSB9UKqyk4eSRvORP3b6GM-2ozqI8u3QvI0,15568
9
+ llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
+ llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
+ llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
12
+ llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
+ llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
+ llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
15
+ llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
16
+ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
17
+ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
18
+ llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
19
+ llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
+ llm_trainer/train_configs.py,sha256=m57W71SI5VCCU9aJ_nJkB-3AJrSGiNXmV28rdpuYmLg,7332
21
+ llm_trainer/trainer.py,sha256=zTJVyY1cAjJdTkyXCOy2ZPVP18SOMLdWhD54Mz2JRe4,25314
22
+ llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
+ project_llm_trainer-0.5.0.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.0.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.0.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.0.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.0.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.0.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.0.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.0.dist-info/METADATA,sha256=YDj-N4VL8O_AqNanwfU6Yt38J97p3RgtUSzmwl0Y-GM,195
31
+ project_llm_trainer-0.5.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.0.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.0.dist-info/RECORD,,
llm_trainer/dcp.py DELETED
@@ -1,93 +0,0 @@
1
- import os
2
- from typing import Optional, Dict, Any
3
- from torch import nn
4
- from torch.optim import Optimizer
5
- import torch.distributed.checkpoint as dcp
6
- from torch.distributed.checkpoint.stateful import Stateful
7
- from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
8
- from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
9
-
10
- DEFAULT_CHECKPOINT_DIR = "checkpoint"
11
-
12
- class AppState(Stateful):
13
- def __init__(self, model: nn.Module, optimizer: Optimizer):
14
- self.model = model
15
- self.optimizer = optimizer
16
-
17
- def state_dict(self) -> Dict[str, Any]:
18
- model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
19
- return {
20
- 'model_state_dict': model_state_dict,
21
- 'optim_state_dict': optimizer_state_dict
22
- }
23
-
24
- def load_state_dict(self, state_dict: Dict[str, Any]):
25
- set_state_dict(
26
- model=self.model,
27
- optimizers=self.optimizer,
28
- model_state_dict=state_dict['model_state_dict'],
29
- optim_state_dict=state_dict['optim_state_dict']
30
- )
31
-
32
-
33
- def save_dcp(
34
- model: nn.Module,
35
- optimizer: Optimizer,
36
- suffix: Optional[str] = None
37
- ):
38
- checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
39
- if suffix:
40
- checkpoint_id = f"{checkpoint_id}_{suffix}"
41
-
42
- state_dict = {'app': AppState(model, optimizer)}
43
-
44
- # fs_storage_writer = dcp.FileSystemWriter(checkpoint_id, overwrite=True)
45
- # dcp.save(state_dict=state_dict, storage_writer=fs_storage_writer)
46
- dcp.save(state_dict=state_dict, checkpoint_id=checkpoint_id)
47
-
48
-
49
- def load_dcp(
50
- model: nn.Module,
51
- optimizer: Optional[Optimizer] = None,
52
- suffix: Optional[str] = None
53
- ):
54
- checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
55
- if suffix:
56
- checkpoint_id = f"{checkpoint_id}_{suffix}"
57
-
58
- if os.path.exists(checkpoint_id):
59
- state_dict = {'app': AppState(model, optimizer)}
60
- # AppState帮助加载到state_dict中, 然后加载到model中
61
- dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
62
-
63
- # if isinstance(model, FSDP):
64
- # state_dict = {'app': AppState(model, optimizer)}
65
- # # AppState帮助加载到state_dict中, 然后加载到model中
66
- # dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
67
- # else:
68
- # state_dict = {"model_state_dict": model.state_dict()}
69
- #
70
- # if optimizer:
71
- # state_dict.update({'optim_state_dict': optimizer.state_dict()})
72
- #
73
- # # since no progress group is initialized, DCP will disable any collectives.
74
- # # 加载到state_dict中,然后通过model.load_state_dict加载到model中
75
- # dcp.load(
76
- # state_dict=state_dict,
77
- # checkpoint_id=checkpoint_id,
78
- # )
79
- #
80
- # model.load_state_dict(state_dict["model_state_dict"])
81
- # if optimizer:
82
- # optimizer.load_state_dict(state_dict["optim_state_dict"])
83
-
84
- def convert_dcp_to_pth(pth_path: str):
85
- dcp_path = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
86
- if os.path.exists(dcp_path):
87
- # convert dcp model to torch.save (assumes checkpoint was generated as above)
88
- dcp_to_torch_save(dcp_path, pth_path)
89
-
90
- def convert_pth_to_dcp(pth_path: str):
91
- if os.path.exists(pth_path):
92
- # converts the torch.save model back to DCP
93
- torch_save_to_dcp(pth_path, os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR))
@@ -1,72 +0,0 @@
1
- from typing import Optional
2
- from torch import nn
3
- import torch.distributed as dist
4
-
5
- from .tools import TrainerTools
6
-
7
- try:
8
- import deepspeed
9
- from deepspeed import DeepSpeedEngine
10
- from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
11
- except: ...
12
-
13
-
14
- def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
15
- """
16
- 需要在所有rank上调用,然后只有rank0有值
17
- """
18
-
19
- if model.zero_optimization_stage() != 3:
20
- if TrainerTools().parallel.is_main_process:
21
- return {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
22
- return None
23
-
24
- # --- ZeRO-3 ---
25
- # 只调用一次 GatheredParameters,传入所有参数
26
- with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
27
- if TrainerTools().parallel.is_main_process:
28
- # 在这个 'with' 代码块内,rank 0 上的 model.module 拥有完整的参数
29
- # 所以我们可以像操作普通模型一样直接调用 state_dict()
30
- full_state_dict = model.module.state_dict()
31
-
32
- # 将其克隆到 CPU 并返回
33
- return {k: v.cpu().clone() for k, v in full_state_dict.items()}
34
-
35
- # 其他 rank 执行到这里时,上下文结束,直接返回 None
36
- return None
37
-
38
- # # ZeRO-3
39
- # state_dict_on_rank_0 = {}
40
- # for param_name, param in model.module.named_parameters():
41
- # if hasattr(param, 'ds_id'):
42
- # with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
43
- # if TrainerTools().parallel.is_main_process:
44
- # state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
45
- # else:
46
- # if TrainerTools().parallel.is_main_process:
47
- # state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
48
- #
49
- # return state_dict_on_rank_0 if TrainerTools().parallel.is_main_process else None
50
-
51
-
52
- def get_ds_model_params(model: nn.Module, only_rank0=False):
53
- """
54
- 从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
55
- 兼容 ZeRO Stages 0, 1, 2, 3。
56
- 包含了对 ZeRO-3 中分片参数的正确处理。
57
- """
58
-
59
- assert isinstance(model, DeepSpeedEngine)
60
- state_dict = _get_ds_full_state_dict_on_rank0(model)
61
-
62
- # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
63
- # 我们需要将其广播给所有进程。
64
- if not only_rank0 and TrainerTools().parallel.world_size > 1:
65
- # 准备一个列表,rank 0 有数据,其他 rank 是占位符
66
- object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
67
- # 执行广播,这个操作是阻塞的,会同步所有进程
68
- dist.broadcast_object_list(object_list, src=0)
69
- # 所有进程从列表中获取广播后的 state_dict 副本
70
- state_dict = object_list[0]
71
-
72
- return state_dict
@@ -1,52 +0,0 @@
1
- import os
2
- from typing import Optional, Union, Tuple
3
- import torch
4
- from torch import nn
5
- from torch.optim import Optimizer
6
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
-
8
- from .tools import TrainerTools
9
-
10
- DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
11
-
12
- def save_fsdp_checkpoint(
13
- model: nn.Module,
14
- optimizer: Optional[Optimizer] = None,
15
- suffix: Optional[str] = None
16
- ):
17
- # 未经过测试 参考:https://doc.hfai.high-flyer.cn/haiscale/haiscale_fsdp.html
18
- # 是否使用rank0_only=True?
19
- with FSDP.summon_full_params(
20
- module=model,
21
- rank0_only=True,
22
- writeback=False,
23
- offload_to_cpu=True
24
- ):
25
- if TrainerTools().parallel.is_main_process:
26
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
27
- if suffix:
28
- checkpoint_name = f"{checkpoint_name}_{suffix}"
29
-
30
- ckpt = {'model_state_dict': model.state_dict()}
31
- if optimizer:
32
- ckpt.update({'optim_state_dict': optimizer.state_dict()})
33
-
34
- torch.save(ckpt, checkpoint_name)
35
-
36
-
37
- def load_fsdp_checkpoint(
38
- model: nn.Module,
39
- optimizer: Optional[Optimizer] = None,
40
- device: Optional[Union[torch.device, str]] = None,
41
- suffix: Optional[str] = None
42
- ):
43
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
44
- if suffix:
45
- checkpoint_name = f"{checkpoint_name}_{suffix}"
46
-
47
- with FSDP.summon_full_params(module=model):
48
- state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
49
- model.load_state_dict(state_dict['model_state_dict'])
50
-
51
- if optimizer:
52
- optimizer.load_state_dict(state_dict['optim_state_dict'])
@@ -1,39 +0,0 @@
1
- from typing import Optional
2
- from torch import nn
3
- import torch.distributed as dist
4
-
5
- from .tools import TrainerTools
6
-
7
-
8
- def _get_fsdp_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
9
- """
10
- 可以在任意rank上调用,然后只有rank0有值
11
- """
12
-
13
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
14
- with FSDP.summon_full_params(model, writeback=False, offload_to_cpu=True):
15
- if TrainerTools().parallel.is_main_process:
16
- return {k: v.clone() for k, v in model.state_dict().items()}
17
-
18
- return None
19
-
20
-
21
- def get_fsdp_model_params(model: nn.Module, only_rank0=False):
22
- """
23
- 从一个 FSDP 包装的模型中高效地提取完整的 FP32 state_dict。
24
- 这个函数会聚合所有分片的参数,并确保所有 rank 都收到一个完整的副本。
25
- """
26
-
27
- state_dict = _get_fsdp_full_state_dict_on_rank0(model)
28
-
29
- # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
30
- # 我们需要将其广播给所有进程。
31
- if not only_rank0 and TrainerTools().parallel.world_size > 1:
32
- # 准备一个列表,rank 0 有数据,其他 rank 是占位符
33
- object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
34
- # 执行广播,这个操作是阻塞的,会同步所有进程
35
- dist.broadcast_object_list(object_list, src=0)
36
- # 所有进程从列表中获取广播后的 state_dict 副本
37
- state_dict = object_list[0]
38
-
39
- return state_dict
@@ -1,28 +0,0 @@
1
- from typing import Optional
2
- from torch import nn
3
- from torch.nn.parallel import DistributedDataParallel as DDP
4
-
5
- from .tools import TrainerTools
6
- from .parallel_ds import DsParallel
7
- from .parallel_fsdp import FsdpParallel
8
-
9
- def copy_model_params(
10
- _from: nn.Module,
11
- _to: Optional[nn.Module]
12
- ):
13
- """
14
- 必须在所有rank上调用,非rank0, _to可以设置为None
15
- """
16
- if isinstance(TrainerTools().parallel, DsParallel):
17
- from .ds_model_params import get_ds_model_params
18
- state_dict = get_ds_model_params(_from, only_rank0=_to is None)
19
- elif isinstance(TrainerTools().parallel, FsdpParallel):
20
- from .fsdp_model_params import get_fsdp_model_params
21
- state_dict = get_fsdp_model_params(_from, only_rank0=_to is None)
22
- elif isinstance(_from, DDP):
23
- state_dict = _from.module.state_dict()
24
- else:
25
- state_dict = _from.state_dict()
26
-
27
- if _to and state_dict:
28
- _to.load_state_dict(state_dict)
@@ -1,121 +0,0 @@
1
- from typing import Optional, Tuple
2
- import functools
3
- import torch
4
- from torch import nn
5
- from torch.distributed.fsdp import (
6
- FullyShardedDataParallel as FSDP,
7
- MixedPrecision,
8
- ShardingStrategy,
9
- BackwardPrefetch,
10
- CPUOffload,
11
- )
12
-
13
- from torch.distributed.fsdp.wrap import (
14
- size_based_auto_wrap_policy,
15
- transformer_auto_wrap_policy,
16
- always_wrap_policy,
17
- enable_wrap,
18
- wrap,
19
- )
20
-
21
- from .parallel import Parallel
22
-
23
- class FsdpParallel(Parallel):
24
- def __init__(self):
25
- super().__init__()
26
-
27
- def process(
28
- self,
29
- model: nn.Module,
30
- optimizer: torch.optim.Optimizer,
31
- kwargs: Optional[dict] = None,
32
- save_instance: bool = True
33
- ) -> Tuple[nn.Module, torch.optim.Optimizer]:
34
- """
35
- :param model:
36
- :param optimizer:
37
- :param kwargs:
38
- "wrap_policy_num_params" int size_based_auto_wrap_policy的最小参数量
39
- "cpu_offload" bool 是否使用cpu卸载
40
- "offload_params" bool 是否卸载参数,在cpu_offload为True时生效
41
- :param save_instance
42
- :return:
43
- """
44
-
45
- model.to(self.device)
46
-
47
- if self._use_compile:
48
- model = torch.compile(model)
49
-
50
- if self._use_parallel:
51
- if 'transformer_layer_cls' in kwargs:
52
- auto_wrap_policy = functools.partial(
53
- transformer_auto_wrap_policy,
54
- transformer_layer_cls=kwargs['transformer_layer_cls']
55
- )
56
- elif 'wrap_policy_num_params' in kwargs:
57
- auto_wrap_policy = functools.partial(
58
- size_based_auto_wrap_policy,
59
- min_num_params=kwargs['wrap_policy_num_params']
60
- )
61
- else:
62
- auto_wrap_policy = None
63
-
64
- if 'cpu_offload' in kwargs:
65
- offload_params = False
66
- if 'offload_params' in kwargs:
67
- offload_params = kwargs['offload_params']
68
-
69
- # 选择配置 cpu_offload,以便在计算中不使用包装参数时将这些参数卸载到 CPU。
70
- # 这可以进一步提高内存效率,但代价是主机和设备之间的数据传输开销。
71
- cpu_offload = CPUOffload(offload_params=offload_params)
72
- else:
73
- cpu_offload = None
74
-
75
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
76
- mixed_precision = MixedPrecision(
77
- param_dtype=torch.bfloat16,
78
- # Gradient communication precision.
79
- reduce_dtype=torch.bfloat16,
80
- # Buffer precision.
81
- buffer_dtype=torch.bfloat16,
82
- )
83
- else:
84
- mixed_precision = None
85
-
86
- raw_model = model
87
-
88
- # device_mesh = init_device_mesh("cuda", (self.world_size,))
89
- # model = FSDP(
90
- # model,
91
- # auto_wrap_policy=auto_wrap_policy,
92
- # mixed_precision=mixed_precision,
93
- # cpu_offload=cpu_offload,
94
- # device_id=torch.cuda.current_device(),
95
- # device_mesh=device_mesh
96
- # )
97
-
98
- model = FSDP(
99
- model,
100
- sharding_strategy=ShardingStrategy.FULL_SHARD,
101
- auto_wrap_policy=auto_wrap_policy,
102
- mixed_precision=mixed_precision,
103
- cpu_offload=cpu_offload,
104
- device_id=torch.cuda.current_device(),
105
- process_group=None,
106
- # use_orig_params=True,
107
- # backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # bit faster async comms, bit higher memory
108
- # limit_all_gathers=False,
109
- # forward_prefetch=True,
110
- )
111
- else:
112
- model = model
113
- raw_model = model
114
-
115
- if save_instance:
116
- self.raw_model = raw_model
117
- self.model = model
118
-
119
- return model, optimizer
120
-
121
-
@@ -1,38 +0,0 @@
1
- llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=AvUC1JLxuahKtg3VNW20VHIE3iIjpaMHIi_pyyDYVJ0,5043
3
- llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
- llm_trainer/dpo_trainer.py,sha256=o5lYxt6yVMCvoBqW_yTu9l6Ff-xjEu-CwdPVttu3H8E,11447
6
- llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
7
- llm_trainer/ds_model_params.py,sha256=Nwmv0YcBtO6ynC0dXallAD1rWkN22-elGfVjLaWp2Yg,2988
8
- llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
9
- llm_trainer/fsdp_checkpoint.py,sha256=xsm71s9WeTaBvBvv6CbuGpwkmX3V6i3xmBcMTDfGxKc,1770
10
- llm_trainer/fsdp_model_params.py,sha256=MRjrs9zmMl-61a1l6188Ij5PSalzztOSp8E4evDvJXo,1541
11
- llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
12
- llm_trainer/grpo_trainer.py,sha256=1gZXiL1pogLFecFQUGj9zCU_k66ryVjZciYyd8J5ph4,15998
13
- llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
14
- llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
15
- llm_trainer/model_params.py,sha256=2f2W9KRCjyqSfEwxI3w5f6TPZaqq25WzY-nEc7aJxcs,970
16
- llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
17
- llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
18
- llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
19
- llm_trainer/parallel_fsdp.py,sha256=cQOdY8ou6m8OsR06PpFVn6GiyZlK9nefkcGyszUOIJk,4055
20
- llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
21
- llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
22
- llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
23
- llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
24
- llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
25
- llm_trainer/train_configs.py,sha256=HKzH3nfMT1-SW4Htwa0KqYtMd6FAJcthR5IEo6di8us,8168
26
- llm_trainer/trainer.py,sha256=95ARdNDfalhZ7Ug-fDj3qIhWEiZQeX9n5WANhijIRLE,27140
27
- llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
28
- project_llm_trainer-0.4.15.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
29
- project_llm_trainer-0.4.15.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
30
- project_llm_trainer-0.4.15.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
31
- project_llm_trainer-0.4.15.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
32
- project_llm_trainer-0.4.15.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
33
- project_llm_trainer-0.4.15.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
34
- project_llm_trainer-0.4.15.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
35
- project_llm_trainer-0.4.15.dist-info/METADATA,sha256=5sveZ3kkRMVCz9dI5_NI64o9tFBVsJhHhun9vwzzL9Q,196
36
- project_llm_trainer-0.4.15.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
37
- project_llm_trainer-0.4.15.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
38
- project_llm_trainer-0.4.15.dist-info/RECORD,,