project-llm-trainer 0.3.5__py3-none-any.whl → 0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

Files changed (26) hide show
  1. llm_trainer/checkpoint.py +40 -43
  2. llm_trainer/dpo_trainer.py +8 -13
  3. llm_trainer/ds_checkpoint.py +52 -0
  4. llm_trainer/eval.py +9 -26
  5. llm_trainer/fsdp_checkpoint.py +91 -0
  6. llm_trainer/grpo_trainer.py +6 -9
  7. llm_trainer/parallel.py +2 -1
  8. llm_trainer/parallel_ddp.py +10 -5
  9. llm_trainer/parallel_ds.py +14 -8
  10. llm_trainer/parallel_fsdp.py +19 -13
  11. llm_trainer/parallel_none.py +6 -4
  12. llm_trainer/tools.py +2 -2
  13. llm_trainer/train_configs.py +3 -1
  14. llm_trainer/trainer.py +23 -3
  15. {project_llm_trainer-0.3.5.dist-info → project_llm_trainer-0.4.dist-info}/METADATA +1 -1
  16. project_llm_trainer-0.4.dist-info/RECORD +35 -0
  17. project_llm_trainer-0.3.5.dist-info/RECORD +0 -34
  18. {project_llm_trainer-0.3.5.data → project_llm_trainer-0.4.data}/scripts/calc_intermediate_size +0 -0
  19. {project_llm_trainer-0.3.5.data → project_llm_trainer-0.4.data}/scripts/ddp_train +0 -0
  20. {project_llm_trainer-0.3.5.data → project_llm_trainer-0.4.data}/scripts/ds_train +0 -0
  21. {project_llm_trainer-0.3.5.data → project_llm_trainer-0.4.data}/scripts/plot_loss +0 -0
  22. {project_llm_trainer-0.3.5.data → project_llm_trainer-0.4.data}/scripts/plot_lr +0 -0
  23. {project_llm_trainer-0.3.5.data → project_llm_trainer-0.4.data}/scripts/py_train +0 -0
  24. {project_llm_trainer-0.3.5.data → project_llm_trainer-0.4.data}/scripts/smart_train +0 -0
  25. {project_llm_trainer-0.3.5.dist-info → project_llm_trainer-0.4.dist-info}/WHEEL +0 -0
  26. {project_llm_trainer-0.3.5.dist-info → project_llm_trainer-0.4.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py CHANGED
@@ -3,6 +3,7 @@ from typing import Optional, Union, Tuple
3
3
  import torch
4
4
  from torch import nn
5
5
  from torch.optim import Optimizer
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
6
7
 
7
8
  from .parallel_ds import DsParallel
8
9
  from .parallel_fsdp import FsdpParallel
@@ -44,39 +45,22 @@ def save_checkpoint(
44
45
  save_ds_checkpoint(model, suffix)
45
46
  elif _can_use_dcp(model):
46
47
  save_dcp(model, optimizer, suffix)
48
+ elif isinstance(model, FSDP):
49
+ from .fsdp_checkpoint import save_fsdp_checkpoint
50
+ save_fsdp_checkpoint(model, optimizer, suffix)
47
51
  else:
48
- if isinstance(model, FSDP):
49
- # 未经过测试 参考:https://doc.hfai.high-flyer.cn/haiscale/haiscale_fsdp.html
50
- # 是否使用rank0_only=True?
51
- with FSDP.summon_full_params(
52
- module=model,
53
- rank0_only=True,
54
- writeback=False,
55
- offload_to_cpu=True
56
- ):
57
- if TrainerTools().parallel.is_main_process:
58
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
59
- if suffix:
60
- checkpoint_name = f"{checkpoint_name}_{suffix}"
52
+ if TrainerTools().parallel.is_main_process:
53
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
54
+ if suffix:
55
+ checkpoint_name = f"{checkpoint_name}_{suffix}"
61
56
 
62
- ckpt = {'model_state_dict': model.state_dict()}
57
+ raw_model = model if not isinstance(model, DDP) else model.module
58
+ ckpt = {'model_state_dict': raw_model.state_dict()}
63
59
 
64
- if optimizer:
65
- ckpt.update({'optim_state_dict': optimizer.state_dict()})
60
+ if optimizer:
61
+ ckpt.update({'optim_state_dict': optimizer.state_dict()})
66
62
 
67
- torch.save(ckpt, checkpoint_name)
68
- else:
69
- if TrainerTools().parallel.is_main_process:
70
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
71
- if suffix:
72
- checkpoint_name = f"{checkpoint_name}_{suffix}"
73
-
74
- ckpt = {'model_state_dict': TrainerTools().parallel.raw_model.state_dict()}
75
-
76
- if optimizer:
77
- ckpt.update({'optim_state_dict': optimizer.state_dict()})
78
-
79
- torch.save(ckpt, checkpoint_name)
63
+ torch.save(ckpt, checkpoint_name)
80
64
 
81
65
 
82
66
  def load_checkpoint(
@@ -91,26 +75,20 @@ def load_checkpoint(
91
75
  load_ds_checkpoint(model, load_module_only=load_module_only, suffix=suffix)
92
76
  elif _can_use_dcp(model):
93
77
  load_dcp(model, optimizer, suffix)
78
+ elif isinstance(model, FSDP):
79
+ from .fsdp_checkpoint import load_fsdp_checkpoint
80
+ load_fsdp_checkpoint(model, optimizer, device, suffix)
94
81
  else:
95
82
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
96
83
  if suffix:
97
84
  checkpoint_name = f"{checkpoint_name}_{suffix}"
98
85
 
99
- if os.path.exists(checkpoint_name):
100
- # 未经过测试,else的逻辑经过测试在fsdp下也没问题
101
- if isinstance(model, FSDP):
102
- with FSDP.summon_full_params(module=model):
103
- state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
104
- model.load_state_dict(state_dict['model_state_dict'])
86
+ state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
87
+ raw_model = model.module if isinstance(model, DDP) else model
88
+ raw_model.load_state_dict(state_dict['model_state_dict'])
105
89
 
106
- if optimizer:
107
- optimizer.load_state_dict(state_dict['optim_state_dict'])
108
- else:
109
- state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
110
- model.load_state_dict(state_dict['model_state_dict'])
111
-
112
- if optimizer:
113
- optimizer.load_state_dict(state_dict['optim_state_dict'])
90
+ if optimizer:
91
+ optimizer.load_state_dict(state_dict['optim_state_dict'])
114
92
 
115
93
 
116
94
  def load_checkpoint_for_eval(
@@ -141,6 +119,25 @@ def load_checkpoint_for_eval(
141
119
  load_checkpoint(model, None, device, suffix=suffix)
142
120
 
143
121
 
122
+ def copy_model_params(
123
+ _from: nn.Module,
124
+ _to: nn.Module
125
+ ):
126
+ if isinstance(TrainerTools().parallel, DsParallel):
127
+ from .ds_checkpoint import get_ds_model_params
128
+ state_dict = get_ds_model_params(_from)
129
+ elif isinstance(TrainerTools().parallel, FsdpParallel):
130
+ from .fsdp_checkpoint import get_fsdp_model_params
131
+ state_dict = get_fsdp_model_params(_from)
132
+ elif isinstance(_from, DDP):
133
+ state_dict = _from.module.state_dict()
134
+ else:
135
+ state_dict = _from.state_dict()
136
+
137
+ if state_dict:
138
+ _to.load_state_dict(state_dict)
139
+
140
+
144
141
  def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
145
142
  # 暂时只保存主进程的
146
143
  if TrainerTools().parallel.is_main_process:
@@ -16,7 +16,7 @@ from .utils import get_dpo_collate_fn
16
16
 
17
17
  from .checkpoint import (
18
18
  save_checkpoint,
19
- load_checkpoint_for_eval,
19
+ copy_model_params,
20
20
  save_steps,
21
21
  )
22
22
 
@@ -37,23 +37,18 @@ class DPOTrainer(Trainer):
37
37
  self.reference_model = self._init_reference_model()
38
38
 
39
39
  def _init_reference_model(self):
40
- parallel = TrainerTools().new_parallel()
41
-
42
40
  reference_model = self._new_model(self.train_config)
43
- if self.train_config.init_state_dict:
44
- reference_model.load_state_dict(self.train_config.init_state_dict, strict=False)
45
- self.train_config.init_state_dict = None
46
- else:
47
- load_checkpoint_for_eval(model=reference_model, device=parallel.device)
41
+ copy_model_params(_from=self.train_model, _to=reference_model)
48
42
 
49
- reference_model, _ = parallel.process(
43
+ reference_model, _ = TrainerTools().parallel.process(
50
44
  model=reference_model,
51
45
  optimizer=None,
52
- kwargs=self._init_reference_args()
46
+ kwargs=self._init_reference_args(),
47
+ save_instance=False
53
48
  )
54
49
 
55
- parallel.raw_model.eval()
56
- for param in parallel.raw_model.parameters():
50
+ reference_model.eval()
51
+ for param in reference_model.parameters():
57
52
  param.requires_grad = False
58
53
 
59
54
  return reference_model
@@ -267,7 +262,7 @@ class DPOTrainer(Trainer):
267
262
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
268
263
  # clip grad
269
264
  self.scalar.unscale_(self.optimizer)
270
- torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
265
+ torch.nn.utils.clip_grad_norm_(self._get_trainable_params(self.train_model), 1.0)
271
266
 
272
267
  self._step()
273
268
 
@@ -2,8 +2,14 @@ import os
2
2
  from typing import Optional
3
3
  from glob import glob
4
4
  import shutil
5
+ import torch
5
6
  from torch import nn
7
+ import torch.distributed as dist
8
+
9
+ from .tools import TrainerTools
10
+
6
11
  try:
12
+ import deepspeed
7
13
  from deepspeed import DeepSpeedEngine
8
14
  from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
9
15
  except: ...
@@ -59,3 +65,49 @@ def load_ds_checkpoint_for_eval(model: nn.Module):
59
65
  ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
60
66
  state_dict = get_fp32_state_dict_from_zero_checkpoint(ckpt_dir)
61
67
  model.load_state_dict(state_dict)
68
+
69
+
70
+ def get_ds_model_params(model: nn.Module):
71
+ """
72
+ 从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
73
+ 兼容 ZeRO Stages 0, 1, 2, 3。
74
+ 这个版本包含了对 ZeRO-3 中非分片参数的正确处理。
75
+ """
76
+
77
+ assert isinstance(model, DeepSpeedEngine)
78
+ zero_stage = model.zero_optimization_stage()
79
+ state_dict = None
80
+
81
+ if TrainerTools().parallel.is_main_process:
82
+ if zero_stage == 3:
83
+ # ZeRO-3: Rank 0 聚合参数来构建完整的 state_dict
84
+ state_dict = {}
85
+ for param in model.module.parameters():
86
+ # 关键检查:判断参数是否被 ZeRO-3 分片管理
87
+ if hasattr(param, 'ds_id'):
88
+ # 这是被分片的参数,使用 GatheredParameters 聚合
89
+ with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
90
+ # .clone() 创建一个独立副本, .to('cpu') 移动到CPU, .to(torch.float32) 确保类型
91
+ state_dict[param.ds_name] = param.data.to(torch.float32).cpu().clone()
92
+ else:
93
+ # 这是未被分片的参数 (e.g., tied weights, buffers), 直接从 Rank 0 复制
94
+ state_dict[param.ds_name] = param.data.to(torch.float32).cpu().clone()
95
+ else: # zero_stage in [0, 1, 2]
96
+ # 在这些 stage,rank 0 已经有完整的模型。
97
+ # 我们从 model_engine.module 获取原始模型状态。
98
+ state_dict = {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
99
+
100
+ # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
101
+ # 我们需要将其广播给所有进程。
102
+ if TrainerTools().parallel.world_size > 1:
103
+ # 准备一个列表,rank 0 有数据,其他 rank 是占位符
104
+ object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
105
+
106
+ # 执行广播,这个操作是阻塞的,会同步所有进程
107
+ dist.broadcast_object_list(object_list, src=0)
108
+
109
+ # 所有进程从列表中获取广播后的 state_dict 副本
110
+ state_dict = object_list[0]
111
+
112
+ return state_dict
113
+
llm_trainer/eval.py CHANGED
@@ -1,9 +1,7 @@
1
- import time
2
-
3
1
  import torch
4
2
 
5
3
  from .generate_utils import generate
6
- from .checkpoint import load_checkpoint_for_eval
4
+ from .checkpoint import copy_model_params
7
5
  from .log import get_log_dir
8
6
  from .tools import TrainerTools
9
7
  from .train_configs import EvalConfig
@@ -21,27 +19,6 @@ def _eval_task(
21
19
  ):
22
20
  log_dir = get_log_dir()
23
21
 
24
- # 当eval_model不是独立model时可以尝试这个
25
- # if isinstance(eval_model, FSDP):
26
- # with FSDP.summon_full_params(module=eval_model, writeback=False, recurse=False):
27
- # gen = generate(
28
- # eval_model,
29
- # prompt=prompt,
30
- # max_position_embeddings=max_position_embeddings,
31
- # max_new_tokens=max_new_tokens,
32
- # # temperature=None,
33
- # # k=None,
34
- # # p=None,
35
- # device='cpu',
36
- # item_callback=lambda item: write_temp(item)
37
- # )
38
-
39
- # ---------
40
- try:
41
- load_checkpoint_for_eval(eval_model, device=device)
42
- except:
43
- return
44
-
45
22
  gen_result = generate(
46
23
  eval_model,
47
24
  prompt=prompt,
@@ -60,6 +37,7 @@ def _eval_task(
60
37
 
61
38
 
62
39
  def submit_gen_task(
40
+ train_model: torch.nn.Module,
63
41
  eval_model: torch.nn.Module,
64
42
  eval_config: EvalConfig,
65
43
  tag,
@@ -68,8 +46,13 @@ def submit_gen_task(
68
46
  max_position_embeddings,
69
47
  tokens_per_image
70
48
  ):
71
- # 等待1s,防止deepspeed模式下,找不到checkpoint问题
72
- time.sleep(1)
49
+ try:
50
+ copy_model_params(_from=train_model, _to=eval_model)
51
+ except Exception as e:
52
+ if isinstance(e, KeyboardInterrupt):
53
+ raise e
54
+ return
55
+
73
56
  eval_model.to(TrainerTools().parallel.device)
74
57
  _eval_task(
75
58
  eval_model=eval_model,
@@ -0,0 +1,91 @@
1
+ import os
2
+ from typing import Optional, Union, Tuple
3
+ import torch
4
+ from torch import nn
5
+ from torch.optim import Optimizer
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ import torch.distributed as dist
8
+
9
+ from .tools import TrainerTools
10
+
11
+ DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
12
+
13
+ def save_fsdp_checkpoint(
14
+ model: nn.Module,
15
+ optimizer: Optional[Optimizer] = None,
16
+ suffix: Optional[str] = None
17
+ ):
18
+ # 未经过测试 参考:https://doc.hfai.high-flyer.cn/haiscale/haiscale_fsdp.html
19
+ # 是否使用rank0_only=True?
20
+ with FSDP.summon_full_params(
21
+ module=model,
22
+ rank0_only=True,
23
+ writeback=False,
24
+ offload_to_cpu=True
25
+ ):
26
+ if TrainerTools().parallel.is_main_process:
27
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
28
+ if suffix:
29
+ checkpoint_name = f"{checkpoint_name}_{suffix}"
30
+
31
+ ckpt = {'model_state_dict': model.state_dict()}
32
+ if optimizer:
33
+ ckpt.update({'optim_state_dict': optimizer.state_dict()})
34
+
35
+ torch.save(ckpt, checkpoint_name)
36
+
37
+
38
+ def load_fsdp_checkpoint(
39
+ model: nn.Module,
40
+ optimizer: Optional[Optimizer] = None,
41
+ device: Optional[Union[torch.device, str]] = None,
42
+ suffix: Optional[str] = None
43
+ ):
44
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
45
+ if suffix:
46
+ checkpoint_name = f"{checkpoint_name}_{suffix}"
47
+
48
+ with FSDP.summon_full_params(module=model):
49
+ state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
50
+ model.load_state_dict(state_dict['model_state_dict'])
51
+
52
+ if optimizer:
53
+ optimizer.load_state_dict(state_dict['optim_state_dict'])
54
+
55
+
56
+
57
+ def get_fsdp_model_params(model: nn.Module):
58
+ """
59
+ 从一个 FSDP 包装的模型中高效地提取完整的 FP32 state_dict。
60
+ 这个函数会聚合所有分片的参数,并确保所有 rank 都收到一个完整的副本。
61
+ """
62
+
63
+ # FSDP 要求在所有 rank 上都调用 summon_full_params,即使我们只在 rank 0 上操作。
64
+ # writeback=False: 我们只读取参数,不写回,可以节省开销。
65
+ # offload_to_cpu=True: 直接将聚合后的参数卸载到 CPU,避免在 GPU 上产生大的峰值内存,
66
+ # 并为我们省去了 .cpu() 的步骤。这是一个非常有用的优化。
67
+ # rank0_only=False: 为了让 offload_to_cpu 在所有 rank 上都生效,这里通常设为 False。
68
+ # 我们稍后通过 get_rank() 来确保只有 rank 0 实际构建字典。
69
+ with FSDP.summon_full_params(model, writeback=False, offload_to_cpu=True):
70
+
71
+ state_dict = None
72
+ if TrainerTools().parallel.is_main_process:
73
+ # 在这个 with 块内部, model.state_dict() 会返回一个在 CPU 上的、完整的状态字典。
74
+ # 因为我们设置了 offload_to_cpu=True。
75
+ # 我们使用 .clone() 来确保我们得到的是一个独立的副本,
76
+ # 尽管 offload_to_cpu 已经帮我们处理了大部分情况。
77
+ state_dict = {k: v.clone() for k, v in model.state_dict().items()}
78
+
79
+ # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
80
+ # 我们需要将其广播给所有进程。
81
+ if TrainerTools().parallel.world_size > 1:
82
+ # 准备一个列表,rank 0 有数据,其他 rank 是占位符
83
+ object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
84
+
85
+ # 执行广播,这个操作是阻塞的,会同步所有进程
86
+ dist.broadcast_object_list(object_list, src=0)
87
+
88
+ # 所有进程从列表中获取广播后的 state_dict 副本
89
+ state_dict = object_list[0]
90
+
91
+ return state_dict
@@ -17,7 +17,7 @@ from .generate_utils import batch_generate
17
17
 
18
18
  from .checkpoint import (
19
19
  save_checkpoint,
20
- load_checkpoint_for_eval,
20
+ copy_model_params,
21
21
  save_steps,
22
22
  )
23
23
 
@@ -44,9 +44,6 @@ class GRPOTrainer(Trainer):
44
44
  # 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
45
45
  self._use_origin_pad_sequence = True
46
46
 
47
- # 保存一下train model的checkpoint,方便下面reference_model使用
48
- save_checkpoint(self.train_model, self.optimizer)
49
-
50
47
  def _init_reference_model(self):
51
48
  reference_model = self._new_model(self.train_config)
52
49
 
@@ -296,7 +293,7 @@ class GRPOTrainer(Trainer):
296
293
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
297
294
 
298
295
  for epoch in range(self.train_config.n_epochs):
299
- load_checkpoint_for_eval(model=self.reference_model, device=device)
296
+ copy_model_params(_from=self.train_model, _to=self.reference_model)
300
297
  self.train_model.train()
301
298
  file_count = len(self.train_config.file_dataset)
302
299
 
@@ -325,11 +322,11 @@ class GRPOTrainer(Trainer):
325
322
 
326
323
  # start generate
327
324
  # 使用单独的模型生成数据, 原因是在deepspeed并行训练时,使用train_model生成数据会卡死
328
- self.generate_model.to(TrainerTools().parallel.device)
329
- self.reference_model.to(TrainerTools().parallel.device)
325
+ self.generate_model.to(device)
326
+ self.reference_model.to(device)
330
327
 
331
328
  # 保存了train_model checkpoint后,这里保证生成模型使用的参数是最新
332
- load_checkpoint_for_eval(self.generate_model, TrainerTools().parallel.device)
329
+ copy_model_params(_from=self.train_model, _to=self.generate_model)
333
330
  # 生成数据
334
331
  rollout_data = self._generate_rollout_data(batch_data)
335
332
 
@@ -355,7 +352,7 @@ class GRPOTrainer(Trainer):
355
352
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
356
353
  # clip grad
357
354
  self.scalar.unscale_(self.optimizer)
358
- torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
355
+ torch.nn.utils.clip_grad_norm_(self._get_trainable_params(self.train_model), 1.0)
359
356
 
360
357
  self._step()
361
358
 
llm_trainer/parallel.py CHANGED
@@ -64,7 +64,8 @@ class Parallel(ABC):
64
64
  self,
65
65
  model: nn.Module,
66
66
  optimizer: torch.optim.Optimizer,
67
- kwargs: Optional[dict] = None
67
+ kwargs: Optional[dict] = None,
68
+ save_instance: bool = True
68
69
  ) -> Tuple[nn.Module, torch.optim.Optimizer]: ...
69
70
 
70
71
  def process_dataloader(
@@ -21,7 +21,8 @@ class DdpParallel(Parallel):
21
21
  self,
22
22
  model: nn.Module,
23
23
  optimizer: torch.optim.Optimizer,
24
- kwargs: Optional[dict] = None
24
+ kwargs: Optional[dict] = None,
25
+ save_instance: bool = True
25
26
  ) -> Tuple[nn.Module, torch.optim.Optimizer]:
26
27
  model.to(self.device)
27
28
 
@@ -30,10 +31,14 @@ class DdpParallel(Parallel):
30
31
 
31
32
  if self._use_parallel:
32
33
  # self.model = DDP(module=model, broadcast_buffers=False, find_unused_parameters=True)
33
- self.model = DDP(module=model, device_ids=[self._local_rank], output_device=self._local_rank)
34
- self.raw_model = self.model.module
34
+ model = DDP(module=model, device_ids=[self._local_rank], output_device=self._local_rank)
35
+ raw_model = model.module
35
36
  else:
37
+ model = model
38
+ raw_model = model
39
+
40
+ if save_instance:
36
41
  self.model = model
37
- self.raw_model = model
42
+ self.raw_model = raw_model
38
43
 
39
- return self.model, optimizer
44
+ return model, optimizer
@@ -16,16 +16,20 @@ class DsParallel(Parallel):
16
16
  self,
17
17
  model: nn.Module,
18
18
  optimizer: torch.optim.Optimizer,
19
- kwargs: Optional[dict] = None
19
+ kwargs: Optional[dict] = None,
20
+ save_instance: bool = True
20
21
  ) -> Tuple[nn.Module, torch.optim.Optimizer]:
21
22
  """
22
- :param model:
23
- :param optimizer:
24
- :param kwargs:
25
- 参考deepspeed配置
26
- :return:
23
+ :param model:
24
+ :param optimizer:
25
+ :param kwargs:
26
+ 参考deepspeed配置
27
+ :param save_instance
28
+ :return:
27
29
  """
28
- self.raw_model = model
30
+
31
+ if save_instance:
32
+ self.raw_model = model
29
33
 
30
34
  model, optim, _, _ = deepspeed.initialize(
31
35
  model=model,
@@ -34,7 +38,9 @@ class DsParallel(Parallel):
34
38
  config_params=kwargs
35
39
  )
36
40
 
37
- self.model = model
41
+ if save_instance:
42
+ self.model = model
43
+
38
44
  return model, optim
39
45
 
40
46
  def synchronize(self): ...
@@ -28,16 +28,18 @@ class FsdpParallel(Parallel):
28
28
  self,
29
29
  model: nn.Module,
30
30
  optimizer: torch.optim.Optimizer,
31
- kwargs: Optional[dict] = None
31
+ kwargs: Optional[dict] = None,
32
+ save_instance: bool = True
32
33
  ) -> Tuple[nn.Module, torch.optim.Optimizer]:
33
34
  """
34
- :param model:
35
- :param optimizer:
36
- :param kwargs:
37
- "wrap_policy_num_params" int size_based_auto_wrap_policy的最小参数量
38
- "cpu_offload" bool 是否使用cpu卸载
39
- "offload_params" bool 是否卸载参数,在cpu_offload为True时生效
40
- :return:
35
+ :param model:
36
+ :param optimizer:
37
+ :param kwargs:
38
+ "wrap_policy_num_params" int size_based_auto_wrap_policy的最小参数量
39
+ "cpu_offload" bool 是否使用cpu卸载
40
+ "offload_params" bool 是否卸载参数,在cpu_offload为True时生效
41
+ :param save_instance
42
+ :return:
41
43
  """
42
44
 
43
45
  model.to(self.device)
@@ -81,10 +83,10 @@ class FsdpParallel(Parallel):
81
83
  else:
82
84
  mixed_precision = None
83
85
 
84
- self.raw_model = model
86
+ raw_model = model
85
87
 
86
88
  # device_mesh = init_device_mesh("cuda", (self.world_size,))
87
- # self.model = FSDP(
89
+ # model = FSDP(
88
90
  # model,
89
91
  # auto_wrap_policy=auto_wrap_policy,
90
92
  # mixed_precision=mixed_precision,
@@ -93,7 +95,7 @@ class FsdpParallel(Parallel):
93
95
  # device_mesh=device_mesh
94
96
  # )
95
97
 
96
- self.model = FSDP(
98
+ model = FSDP(
97
99
  model,
98
100
  sharding_strategy=ShardingStrategy.FULL_SHARD,
99
101
  auto_wrap_policy=auto_wrap_policy,
@@ -107,9 +109,13 @@ class FsdpParallel(Parallel):
107
109
  # forward_prefetch=True,
108
110
  )
109
111
  else:
112
+ model = model
113
+ raw_model = model
114
+
115
+ if save_instance:
116
+ self.raw_model = raw_model
110
117
  self.model = model
111
- self.raw_model = model
112
118
 
113
- return self.model, optimizer
119
+ return model, optimizer
114
120
 
115
121
 
@@ -12,17 +12,19 @@ class NoneParallel(Parallel):
12
12
  self,
13
13
  model: nn.Module,
14
14
  optimizer: torch.optim.Optimizer,
15
- kwargs: Optional[dict] = None
15
+ kwargs: Optional[dict] = None,
16
+ save_instance: bool = True
16
17
  ) -> Tuple[nn.Module, torch.optim.Optimizer]:
17
18
  model.to(self.device)
18
19
 
19
20
  if self._use_compile:
20
21
  model = torch.compile(model)
21
22
 
22
- self.raw_model = model
23
- self.model = model
23
+ if save_instance:
24
+ self.raw_model = model
25
+ self.model = model
24
26
 
25
- return self.model, optimizer
27
+ return model, optimizer
26
28
 
27
29
 
28
30
 
llm_trainer/tools.py CHANGED
@@ -28,7 +28,7 @@ class TrainerTools:
28
28
  if not hasattr(TrainerTools, "_first_init"):
29
29
  TrainerTools._first_init = True
30
30
 
31
- self.parallel = self.new_parallel()
31
+ self.parallel = self._new_parallel()
32
32
 
33
33
  self.tokenizer = Tokenizer(os.environ.get('TOKENIZERS_TYPE', 'zh_llama'))
34
34
  self.use_amp = 'cuda' in self.parallel.device and not isinstance(self.parallel, DsParallel)
@@ -43,7 +43,7 @@ class TrainerTools:
43
43
  f' use_amp={self.use_amp},'
44
44
  f' dtype={self.dtype}')
45
45
 
46
- def new_parallel(self):
46
+ def _new_parallel(self):
47
47
  parallel_type = os.environ.get('PARALLEL_TYPE', 'none')
48
48
  log(f'parallel_type={parallel_type}')
49
49
  return parallel_types[parallel_type]()
@@ -422,7 +422,8 @@ class TrainConfig:
422
422
  kd_config: Optional[KDConfig] = None,
423
423
  pixel_values_provider: Optional[Callable[[list[str]], torch.Tensor]] = None,
424
424
  init_state_dict: Optional[Mapping[str, Any]] = None,
425
- eval_config: EvalConfig = EvalConfig()
425
+ eval_config: EvalConfig = EvalConfig(),
426
+ freeze_llm_model: bool = False
426
427
  ):
427
428
  self.n_epochs = n_epochs
428
429
  self.batch_size = batch_size
@@ -443,5 +444,6 @@ class TrainConfig:
443
444
  self.pixel_values_provider = pixel_values_provider
444
445
  self.init_state_dict = init_state_dict
445
446
  self.eval_config = eval_config
447
+ self.freeze_llm_model = freeze_llm_model
446
448
 
447
449
 
llm_trainer/trainer.py CHANGED
@@ -116,6 +116,10 @@ class Trainer:
116
116
  else:
117
117
  return LlmModel(train_config.model_config)
118
118
 
119
+ def _get_trainable_params(self, model):
120
+ freeze_llm_model = self.train_config.freeze_llm_model
121
+ return model.parameters() if not freeze_llm_model else filter(lambda p: p.requires_grad, model.parameters())
122
+
119
123
  def _init_train_model_and_optim(
120
124
  self,
121
125
  initial_lr: float,
@@ -128,10 +132,24 @@ class Trainer:
128
132
  model.load_state_dict(self.train_config.init_state_dict, strict=False)
129
133
  self.train_config.init_state_dict = None
130
134
 
135
+ # freeze llm model for vlm training
136
+ if self.train_config.freeze_llm_model:
137
+ for name, param in model.named_parameters():
138
+ if not any(sub_module in name for sub_module in ['vision_tower', 'multi_modal_projector']):
139
+ param.requires_grad = False
140
+
141
+ model.embed_tokens.eval()
142
+ model.layers.eval()
143
+ model.head_norm.eval()
144
+ model.lm_head.eval()
145
+
131
146
  if TrainerTools().parallel.is_main_process:
132
147
  total_params = sum(p.numel() for p in model.parameters())
133
148
  log(f"Total number of parameters: {total_params:,}")
134
149
 
150
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
151
+ log(f"Trainable number of parameters: {trainable_params:,}")
152
+
135
153
  total_size_bytes = total_params * 4
136
154
  total_size_mb = total_size_bytes / (1024 * 1024)
137
155
  log(f"Total size of the model: {total_size_mb:.2f} MB")
@@ -139,13 +157,13 @@ class Trainer:
139
157
  if use_ds_optim:
140
158
  import deepspeed
141
159
  origin_optim = deepspeed.ops.adam.DeepSpeedCPUAdam(
142
- model.parameters(),
160
+ self._get_trainable_params(model),
143
161
  lr=initial_lr,
144
162
  weight_decay=self.train_config.lr_config.weight_decay
145
163
  )
146
164
  else:
147
165
  origin_optim = torch.optim.AdamW(
148
- model.parameters(),
166
+ self._get_trainable_params(model),
149
167
  lr=initial_lr,
150
168
  weight_decay=self.train_config.lr_config.weight_decay
151
169
  )
@@ -406,6 +424,7 @@ class Trainer:
406
424
  eval_pixel_values = None
407
425
 
408
426
  submit_gen_task(
427
+ self.train_model,
409
428
  self.eval_model,
410
429
  self.train_config.eval_config,
411
430
  tag=f'sign:batch/{tag}',
@@ -428,6 +447,7 @@ class Trainer:
428
447
  eval_pixel_values = None
429
448
 
430
449
  submit_gen_task(
450
+ self.train_model,
431
451
  self.eval_model,
432
452
  self.train_config.eval_config,
433
453
  tag=f'sign:epoch/{tag}',
@@ -529,7 +549,7 @@ class Trainer:
529
549
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
530
550
  # clip grad
531
551
  self.scalar.unscale_(self.optimizer)
532
- torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
552
+ torch.nn.utils.clip_grad_norm_(self._get_trainable_params(self.train_model), 1.0)
533
553
 
534
554
  self._step()
535
555
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.3.5
3
+ Version: 0.4
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -0,0 +1,35 @@
1
+ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
+ llm_trainer/checkpoint.py,sha256=GPaSGvnLCGMgsIA_vfjuw34tTQY26EuNwu7c08fhJHQ,5638
3
+ llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
+ llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
+ llm_trainer/dpo_trainer.py,sha256=rC_I5ipesSlP3gFK_SG2GB8NbgJAMu4K7KLxkAS-aRY,13406
6
+ llm_trainer/ds_checkpoint.py,sha256=H0BxYQixOWKRC20t55cFqNDTPzalD3AGTVt-owIB0_4,4488
7
+ llm_trainer/eval.py,sha256=CsB3TpSVwhYVS9SP4Kuj_JhFUUvLcZUkvd8hvEIkPDU,1782
8
+ llm_trainer/fsdp_checkpoint.py,sha256=xPQnAfXbx1SRKcVDLLgOtVrqjk0CjIRleVY0ZrwOAJU,3876
9
+ llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
10
+ llm_trainer/grpo_trainer.py,sha256=fqLT48ORSCece_e8dpyt8J7EarDuTnGoJ_eHk7Oy-1k,16177
11
+ llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
12
+ llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
13
+ llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
14
+ llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
15
+ llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
16
+ llm_trainer/parallel_fsdp.py,sha256=cQOdY8ou6m8OsR06PpFVn6GiyZlK9nefkcGyszUOIJk,4055
17
+ llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
18
+ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
19
+ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
20
+ llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
21
+ llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
22
+ llm_trainer/train_configs.py,sha256=arnet3tIzgVnwshod08F1jE7r4I7e-SIgMy55IagPnE,15971
23
+ llm_trainer/trainer.py,sha256=DujZR1KOHyP3EHR8uIQPEsnX_5b7YC9Cto_eH7zxWqc,25256
24
+ llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
25
+ project_llm_trainer-0.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
+ project_llm_trainer-0.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
+ project_llm_trainer-0.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
+ project_llm_trainer-0.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
+ project_llm_trainer-0.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
+ project_llm_trainer-0.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
+ project_llm_trainer-0.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
+ project_llm_trainer-0.4.dist-info/METADATA,sha256=-xxg-UyXn5MhW5OdYGFUcL5DtbIkgnQoUZS5b5bcEio,193
33
+ project_llm_trainer-0.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ project_llm_trainer-0.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
+ project_llm_trainer-0.4.dist-info/RECORD,,
@@ -1,34 +0,0 @@
1
- llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=Dlkcit0o7Gx6S9QUrIrVp2pTurP9X0zVA7w7ImSuVQU,6049
3
- llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
- llm_trainer/dpo_trainer.py,sha256=kDpzHrxP3qWbBxDGi9Rkus1kw8P3-bAtw9IyuYINgk0,13625
6
- llm_trainer/ds_checkpoint.py,sha256=_svpzqRaa43--DKPputoXAelc6X9vPM0gNQu-hlh6NI,2153
7
- llm_trainer/eval.py,sha256=sCvdYnqWWf5_nuDQN5BHb_YivXLOQW-V0ET9mPu0tPU,2389
8
- llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
9
- llm_trainer/grpo_trainer.py,sha256=M6vp6QjxhBQVaw3e_3BJ4earuezQNKQ3JeZfQLBaSLQ,16370
10
- llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
11
- llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
12
- llm_trainer/parallel.py,sha256=2VJtW3Gq2c1yS_LdcrNhk7B12prFwBmFnKhvV8FS2d8,4428
13
- llm_trainer/parallel_ddp.py,sha256=Gz-3LZ6LKmqlNwxrnGRC4uKoqoSxCvp9JHejIBSQp3c,1238
14
- llm_trainer/parallel_ds.py,sha256=W_PkczyAlgffCRcQadN-Pf7H7HM7TU26v5W63jKELFM,990
15
- llm_trainer/parallel_fsdp.py,sha256=u9XbbVTzcsMcaf-aQFrC_QwWsDRGoEpRmgvu1cKNtgk,3887
16
- llm_trainer/parallel_none.py,sha256=a6tt3aBmCq5rSP7n2I-sF-hsZ992BbLbpbxutDCFJfs,607
17
- llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
18
- llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
19
- llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
20
- llm_trainer/tools.py,sha256=AhfjN9oln5Pyif1SgCWwgQg-Q5acTCd9xpz4L26QUjA,3039
21
- llm_trainer/train_configs.py,sha256=cadfo8RgxNUR-L3ZLyjiRXTQvhjUl4A1qENaq-ol8h4,15878
22
- llm_trainer/trainer.py,sha256=5DgDzg0TReZrXsIaM6A4DzeJnzePNybGdfoVSDybQ2U,24308
23
- llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
24
- project_llm_trainer-0.3.5.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
25
- project_llm_trainer-0.3.5.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
26
- project_llm_trainer-0.3.5.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
27
- project_llm_trainer-0.3.5.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
28
- project_llm_trainer-0.3.5.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
29
- project_llm_trainer-0.3.5.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
30
- project_llm_trainer-0.3.5.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
31
- project_llm_trainer-0.3.5.dist-info/METADATA,sha256=jfnJI_XqE7U89-8tLEGPmLpuzwp-3qw-aERIgV8GJpk,195
32
- project_llm_trainer-0.3.5.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
33
- project_llm_trainer-0.3.5.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
34
- project_llm_trainer-0.3.5.dist-info/RECORD,,