project-llm-trainer 0.4.12__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -6,7 +6,6 @@ import torch.distributed as dist
6
6
  import torch.nn.functional as F
7
7
 
8
8
  from .parallel_ds import DsParallel
9
- from .parallel_fsdp import FsdpParallel
10
9
  from .trainer import Trainer
11
10
  from .train_configs import TrainConfig
12
11
  from .dataset import DPODataset
@@ -53,52 +52,6 @@ class DPOTrainer(Trainer):
53
52
 
54
53
  return reference_model
55
54
 
56
- def _init_reference_args(self):
57
- if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
58
- parallel_kwargs = {
59
- 'gradient_accumulation_steps': 1,
60
- 'train_micro_batch_size_per_gpu': 1
61
- }
62
-
63
- if self.train_config.ds_config.zero_config:
64
- zero_optimization = {'stage': 0}
65
- parallel_kwargs['zero_optimization'] = zero_optimization
66
-
67
-
68
- if (self.train_config.ds_config.bf16_config is not None
69
- and self.train_config.ds_config.bf16_config.enabled):
70
- bf16_config = self.train_config.ds_config.bf16_config
71
- bf16 = {
72
- 'enabled': bf16_config.enabled
73
- }
74
- parallel_kwargs['bf16'] = bf16
75
- elif self.train_config.ds_config.fp16_config:
76
- fb16_config = self.train_config.ds_config.fp16_config
77
- fp16 = {
78
- 'enabled': fb16_config.enabled,
79
- 'loss_scale': fb16_config.loss_scale,
80
- 'loss_scale_window': fb16_config.loss_scale_window,
81
- 'initial_scale_power': fb16_config.initial_scale_power,
82
- 'hysteresis': fb16_config.hysteresis,
83
- 'min_loss_scale': fb16_config.min_loss_scale
84
- }
85
-
86
- if fb16_config.fp16_opt_level is not None:
87
- fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
88
-
89
- parallel_kwargs['fp16'] = fp16
90
- elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
91
- parallel_kwargs = {
92
- 'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
93
- 'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
94
- 'cpu_offload': self.train_config.fsdp_config.cpu_offload,
95
- 'offload_params': self.train_config.fsdp_config.offload_params
96
- }
97
- else:
98
- parallel_kwargs = None
99
-
100
- return parallel_kwargs
101
-
102
55
  def _init_loss(self):
103
56
  criterion = DPOLoss(
104
57
  beta=self.train_config.dpo_config.loss_beta,
llm_trainer/trainer.py CHANGED
@@ -178,7 +178,7 @@ class Trainer:
178
178
 
179
179
  def _init_eval_model(self) -> Optional[nn.Module]:
180
180
  if TrainerTools().parallel.is_main_process:
181
- return self._new_model(self.train_config).to('cpu')
181
+ return self._new_model(self.train_config).to(device='cpu', dtype=TrainerTools().dtype)
182
182
 
183
183
  return None
184
184
 
@@ -337,6 +337,34 @@ class Trainer:
337
337
 
338
338
  return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
339
339
 
340
+ def _init_reference_args(self) -> dict:
341
+ parallel_kwargs, _, _, _ = self._convert_train_args()
342
+
343
+ if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
344
+ # reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
345
+ # if model is not None:
346
+ # hidden_size = (
347
+ # max(model.config.hidden_sizes)
348
+ # if getattr(model.config, "hidden_sizes", None)
349
+ # else getattr(model.config, "hidden_size", None)
350
+ # )
351
+ # if hidden_size is not None and stage == 3:
352
+ # # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
353
+ # # @ step 0: expected module 1, but got module 0`
354
+ # # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
355
+ # config_kwargs.update(
356
+ # {
357
+ # "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
358
+ # "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
359
+ # "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
360
+ # }
361
+ # )
362
+
363
+ if parallel_kwargs['zero_optimization']['stage'] != 3:
364
+ parallel_kwargs['zero_optimization']['stage'] = 0
365
+
366
+ return parallel_kwargs
367
+
340
368
  def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
341
369
  file_path = self.train_config.file_dataset[file_idx]
342
370
  max_position_embeddings = self.train_config.model_config.max_position_embeddings
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.4.12
3
+ Version: 0.4.13
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -2,7 +2,7 @@ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=yZcExxneN2yzvWxRiK-pstMWs35LV7GiOfqcLq-S6vc,5745
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
- llm_trainer/dpo_trainer.py,sha256=34E2b-t0GZYutaw6bESgARe9C12PUMWcY4aGZ34eAZU,13576
5
+ llm_trainer/dpo_trainer.py,sha256=djBhvI_ixTV1nLNg84tgCpfV--pu6IRiOhO28V-aANQ,11425
6
6
  llm_trainer/ds_checkpoint.py,sha256=x_tjgJR47P8gVwV4qAnTUCGwx7eVq2Epw0vOVV7fkYo,4925
7
7
  llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
8
8
  llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
@@ -20,16 +20,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
20
20
  llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
21
21
  llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
22
22
  llm_trainer/train_configs.py,sha256=HKzH3nfMT1-SW4Htwa0KqYtMd6FAJcthR5IEo6di8us,8168
23
- llm_trainer/trainer.py,sha256=pUtJVRosn54j1hn76CFAptJcAsrDo59H6p8NMkg2zt4,25521
23
+ llm_trainer/trainer.py,sha256=j5fDqMzvU6SYwxHsv9wX0UVX4JXS-8eP1AkHgVxKf9U,27119
24
24
  llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
25
- project_llm_trainer-0.4.12.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
- project_llm_trainer-0.4.12.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
- project_llm_trainer-0.4.12.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
- project_llm_trainer-0.4.12.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
- project_llm_trainer-0.4.12.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
- project_llm_trainer-0.4.12.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
- project_llm_trainer-0.4.12.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
- project_llm_trainer-0.4.12.dist-info/METADATA,sha256=W-HeRGlXi3bFsKIVE1FyQAh4Lcvo0SOXMNu-9YnACKQ,196
33
- project_llm_trainer-0.4.12.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
- project_llm_trainer-0.4.12.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
- project_llm_trainer-0.4.12.dist-info/RECORD,,
25
+ project_llm_trainer-0.4.13.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
+ project_llm_trainer-0.4.13.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
+ project_llm_trainer-0.4.13.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
+ project_llm_trainer-0.4.13.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
+ project_llm_trainer-0.4.13.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
+ project_llm_trainer-0.4.13.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
+ project_llm_trainer-0.4.13.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
+ project_llm_trainer-0.4.13.dist-info/METADATA,sha256=hiW-7qgWuPizKVz4cU8mEHoqiuT6ZqNlCBb7nwVfFQ4,196
33
+ project_llm_trainer-0.4.13.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ project_llm_trainer-0.4.13.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
+ project_llm_trainer-0.4.13.dist-info/RECORD,,