project-llm-trainer 0.4.12__py3-none-any.whl → 0.4.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -129,10 +129,10 @@ def copy_model_params(
129
129
 
130
130
  if isinstance(TrainerTools().parallel, DsParallel):
131
131
  from .ds_checkpoint import get_ds_model_params
132
- state_dict = get_ds_model_params(_from)
132
+ state_dict = get_ds_model_params(_from, only_rank0=_to is None)
133
133
  elif isinstance(TrainerTools().parallel, FsdpParallel):
134
134
  from .fsdp_checkpoint import get_fsdp_model_params
135
- state_dict = get_fsdp_model_params(_from)
135
+ state_dict = get_fsdp_model_params(_from, only_rank0=_to is None)
136
136
  elif isinstance(_from, DDP):
137
137
  state_dict = _from.module.state_dict()
138
138
  else:
@@ -6,7 +6,6 @@ import torch.distributed as dist
6
6
  import torch.nn.functional as F
7
7
 
8
8
  from .parallel_ds import DsParallel
9
- from .parallel_fsdp import FsdpParallel
10
9
  from .trainer import Trainer
11
10
  from .train_configs import TrainConfig
12
11
  from .dataset import DPODataset
@@ -53,52 +52,6 @@ class DPOTrainer(Trainer):
53
52
 
54
53
  return reference_model
55
54
 
56
- def _init_reference_args(self):
57
- if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
58
- parallel_kwargs = {
59
- 'gradient_accumulation_steps': 1,
60
- 'train_micro_batch_size_per_gpu': 1
61
- }
62
-
63
- if self.train_config.ds_config.zero_config:
64
- zero_optimization = {'stage': 0}
65
- parallel_kwargs['zero_optimization'] = zero_optimization
66
-
67
-
68
- if (self.train_config.ds_config.bf16_config is not None
69
- and self.train_config.ds_config.bf16_config.enabled):
70
- bf16_config = self.train_config.ds_config.bf16_config
71
- bf16 = {
72
- 'enabled': bf16_config.enabled
73
- }
74
- parallel_kwargs['bf16'] = bf16
75
- elif self.train_config.ds_config.fp16_config:
76
- fb16_config = self.train_config.ds_config.fp16_config
77
- fp16 = {
78
- 'enabled': fb16_config.enabled,
79
- 'loss_scale': fb16_config.loss_scale,
80
- 'loss_scale_window': fb16_config.loss_scale_window,
81
- 'initial_scale_power': fb16_config.initial_scale_power,
82
- 'hysteresis': fb16_config.hysteresis,
83
- 'min_loss_scale': fb16_config.min_loss_scale
84
- }
85
-
86
- if fb16_config.fp16_opt_level is not None:
87
- fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
88
-
89
- parallel_kwargs['fp16'] = fp16
90
- elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
91
- parallel_kwargs = {
92
- 'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
93
- 'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
94
- 'cpu_offload': self.train_config.fsdp_config.cpu_offload,
95
- 'offload_params': self.train_config.fsdp_config.offload_params
96
- }
97
- else:
98
- parallel_kwargs = None
99
-
100
- return parallel_kwargs
101
-
102
55
  def _init_loss(self):
103
56
  criterion = DPOLoss(
104
57
  beta=self.train_config.dpo_config.loss_beta,
@@ -105,7 +105,7 @@ def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
105
105
  # return state_dict_on_rank_0 if TrainerTools().parallel.is_main_process else None
106
106
 
107
107
 
108
- def get_ds_model_params(model: nn.Module):
108
+ def get_ds_model_params(model: nn.Module, only_rank0=False):
109
109
  """
110
110
  从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
111
111
  兼容 ZeRO Stages 0, 1, 2, 3。
@@ -117,7 +117,7 @@ def get_ds_model_params(model: nn.Module):
117
117
 
118
118
  # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
119
119
  # 我们需要将其广播给所有进程。
120
- if TrainerTools().parallel.world_size > 1:
120
+ if not only_rank0 and TrainerTools().parallel.world_size > 1:
121
121
  # 准备一个列表,rank 0 有数据,其他 rank 是占位符
122
122
  object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
123
123
  # 执行广播,这个操作是阻塞的,会同步所有进程
@@ -66,7 +66,7 @@ def _get_fsdp_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
66
66
  return None
67
67
 
68
68
 
69
- def get_fsdp_model_params(model: nn.Module):
69
+ def get_fsdp_model_params(model: nn.Module, only_rank0=False):
70
70
  """
71
71
  从一个 FSDP 包装的模型中高效地提取完整的 FP32 state_dict。
72
72
  这个函数会聚合所有分片的参数,并确保所有 rank 都收到一个完整的副本。
@@ -76,7 +76,7 @@ def get_fsdp_model_params(model: nn.Module):
76
76
 
77
77
  # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
78
78
  # 我们需要将其广播给所有进程。
79
- if TrainerTools().parallel.world_size > 1:
79
+ if not only_rank0 and TrainerTools().parallel.world_size > 1:
80
80
  # 准备一个列表,rank 0 有数据,其他 rank 是占位符
81
81
  object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
82
82
  # 执行广播,这个操作是阻塞的,会同步所有进程
llm_trainer/trainer.py CHANGED
@@ -178,7 +178,7 @@ class Trainer:
178
178
 
179
179
  def _init_eval_model(self) -> Optional[nn.Module]:
180
180
  if TrainerTools().parallel.is_main_process:
181
- return self._new_model(self.train_config).to('cpu')
181
+ return self._new_model(self.train_config).to(device='cpu', dtype=TrainerTools().dtype)
182
182
 
183
183
  return None
184
184
 
@@ -337,6 +337,34 @@ class Trainer:
337
337
 
338
338
  return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
339
339
 
340
+ def _init_reference_args(self) -> dict:
341
+ parallel_kwargs, _, _, _ = self._convert_train_args()
342
+
343
+ if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
344
+ # reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
345
+ # if model is not None:
346
+ # hidden_size = (
347
+ # max(model.config.hidden_sizes)
348
+ # if getattr(model.config, "hidden_sizes", None)
349
+ # else getattr(model.config, "hidden_size", None)
350
+ # )
351
+ # if hidden_size is not None and stage == 3:
352
+ # # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
353
+ # # @ step 0: expected module 1, but got module 0`
354
+ # # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
355
+ # config_kwargs.update(
356
+ # {
357
+ # "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
358
+ # "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
359
+ # "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
360
+ # }
361
+ # )
362
+
363
+ if parallel_kwargs['zero_optimization']['stage'] != 3:
364
+ parallel_kwargs['zero_optimization']['stage'] = 0
365
+
366
+ return parallel_kwargs
367
+
340
368
  def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
341
369
  file_path = self.train_config.file_dataset[file_idx]
342
370
  max_position_embeddings = self.train_config.model_config.max_position_embeddings
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.4.12
3
+ Version: 0.4.14
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,11 +1,11 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=yZcExxneN2yzvWxRiK-pstMWs35LV7GiOfqcLq-S6vc,5745
2
+ llm_trainer/checkpoint.py,sha256=ItDzuXVikk-0gWSw-IS7SrODEdlJEb5nZs15dBFkPdk,5793
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
- llm_trainer/dpo_trainer.py,sha256=34E2b-t0GZYutaw6bESgARe9C12PUMWcY4aGZ34eAZU,13576
6
- llm_trainer/ds_checkpoint.py,sha256=x_tjgJR47P8gVwV4qAnTUCGwx7eVq2Epw0vOVV7fkYo,4925
5
+ llm_trainer/dpo_trainer.py,sha256=djBhvI_ixTV1nLNg84tgCpfV--pu6IRiOhO28V-aANQ,11425
6
+ llm_trainer/ds_checkpoint.py,sha256=fprJlbSgtyKmmpytyMOZBs3pcjZA13SeWao0llnLpNQ,4962
7
7
  llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
8
- llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
8
+ llm_trainer/fsdp_checkpoint.py,sha256=dAHIGHfuvTA6OC0jV9Ls-oD4ZR9CPGa31mjtoh-2dZE,3229
9
9
  llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
10
10
  llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
11
11
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
@@ -20,16 +20,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
20
20
  llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
21
21
  llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
22
22
  llm_trainer/train_configs.py,sha256=HKzH3nfMT1-SW4Htwa0KqYtMd6FAJcthR5IEo6di8us,8168
23
- llm_trainer/trainer.py,sha256=pUtJVRosn54j1hn76CFAptJcAsrDo59H6p8NMkg2zt4,25521
23
+ llm_trainer/trainer.py,sha256=j5fDqMzvU6SYwxHsv9wX0UVX4JXS-8eP1AkHgVxKf9U,27119
24
24
  llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
25
- project_llm_trainer-0.4.12.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
- project_llm_trainer-0.4.12.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
- project_llm_trainer-0.4.12.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
- project_llm_trainer-0.4.12.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
- project_llm_trainer-0.4.12.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
- project_llm_trainer-0.4.12.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
- project_llm_trainer-0.4.12.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
- project_llm_trainer-0.4.12.dist-info/METADATA,sha256=W-HeRGlXi3bFsKIVE1FyQAh4Lcvo0SOXMNu-9YnACKQ,196
33
- project_llm_trainer-0.4.12.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
- project_llm_trainer-0.4.12.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
- project_llm_trainer-0.4.12.dist-info/RECORD,,
25
+ project_llm_trainer-0.4.14.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
+ project_llm_trainer-0.4.14.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
+ project_llm_trainer-0.4.14.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
+ project_llm_trainer-0.4.14.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
+ project_llm_trainer-0.4.14.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
+ project_llm_trainer-0.4.14.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
+ project_llm_trainer-0.4.14.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
+ project_llm_trainer-0.4.14.dist-info/METADATA,sha256=VMEWVv8pBqFUAhIAiH4_S4ECUHln31gchHLhTtUAM1o,196
33
+ project_llm_trainer-0.4.14.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ project_llm_trainer-0.4.14.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
+ project_llm_trainer-0.4.14.dist-info/RECORD,,