project-llm-trainer 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -6,7 +6,6 @@ import torch.distributed as dist
6
6
  import torch.nn.functional as F
7
7
 
8
8
  from .parallel_ds import DsParallel
9
- from .parallel_fsdp import FsdpParallel
10
9
  from .trainer import Trainer
11
10
  from .train_configs import TrainConfig
12
11
  from .dataset import DPODataset
@@ -53,52 +52,6 @@ class DPOTrainer(Trainer):
53
52
 
54
53
  return reference_model
55
54
 
56
- def _init_reference_args(self):
57
- if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
58
- parallel_kwargs = {
59
- 'gradient_accumulation_steps': 1,
60
- 'train_micro_batch_size_per_gpu': 1
61
- }
62
-
63
- if self.train_config.ds_config.zero_config:
64
- zero_optimization = {'stage': 0}
65
- parallel_kwargs['zero_optimization'] = zero_optimization
66
-
67
-
68
- if (self.train_config.ds_config.bf16_config is not None
69
- and self.train_config.ds_config.bf16_config.enabled):
70
- bf16_config = self.train_config.ds_config.bf16_config
71
- bf16 = {
72
- 'enabled': bf16_config.enabled
73
- }
74
- parallel_kwargs['bf16'] = bf16
75
- elif self.train_config.ds_config.fp16_config:
76
- fb16_config = self.train_config.ds_config.fp16_config
77
- fp16 = {
78
- 'enabled': fb16_config.enabled,
79
- 'loss_scale': fb16_config.loss_scale,
80
- 'loss_scale_window': fb16_config.loss_scale_window,
81
- 'initial_scale_power': fb16_config.initial_scale_power,
82
- 'hysteresis': fb16_config.hysteresis,
83
- 'min_loss_scale': fb16_config.min_loss_scale
84
- }
85
-
86
- if fb16_config.fp16_opt_level is not None:
87
- fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
88
-
89
- parallel_kwargs['fp16'] = fp16
90
- elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
91
- parallel_kwargs = {
92
- 'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
93
- 'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
94
- 'cpu_offload': self.train_config.fsdp_config.cpu_offload,
95
- 'offload_params': self.train_config.fsdp_config.offload_params
96
- }
97
- else:
98
- parallel_kwargs = None
99
-
100
- return parallel_kwargs
101
-
102
55
  def _init_loss(self):
103
56
  criterion = DPOLoss(
104
57
  beta=self.train_config.dpo_config.loss_beta,
@@ -67,7 +67,7 @@ class DsFp16Config:
67
67
  initial_scale_power: int = 16
68
68
  hysteresis: int = 2
69
69
  min_loss_scale: int = 1
70
- fp16_opt_level: Optional[str] = '02'
70
+ fp16_opt_level: Optional[str] = 'O2'
71
71
 
72
72
 
73
73
  @dataclass(kw_only=True)
@@ -77,9 +77,9 @@ class DsBf16Config:
77
77
 
78
78
  @dataclass(kw_only=True)
79
79
  class DsConfig:
80
- zero_config: Optional[DsZeROConfig] = DsZero3Config()
81
- fp16_config: Optional[DsFp16Config] = DsFp16Config()
82
- bf16_config: Optional[DsBf16Config] = DsBf16Config()
80
+ zero_config: Optional[DsZeROConfig] = field(default_factory=DsZero3Config)
81
+ fp16_config: Optional[DsFp16Config] = field(default_factory=DsFp16Config)
82
+ bf16_config: Optional[DsBf16Config] = field(default_factory=DsBf16Config)
83
83
  gradient_clipping: Optional[float] = 1.0
84
84
  activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
85
85
 
@@ -224,14 +224,14 @@ class TrainConfig:
224
224
  model_config: Union[ModelConfig, VLMConfig]
225
225
 
226
226
  file_dataset: FileDataset
227
- data_loader_config: DataLoaderConfig = DataLoaderConfig()
227
+ data_loader_config: DataLoaderConfig = field(default_factory=DataLoaderConfig)
228
228
  image_tags_file_dataset: Optional[FileDataset] = None
229
229
 
230
- loss_config: LossConfig = LossConfig()
231
- lr_config: LrConfig = LrConfig()
230
+ loss_config: LossConfig = field(default_factory=LossConfig)
231
+ lr_config: LrConfig = field(default_factory=LrConfig)
232
232
 
233
- ds_config: DsConfig = DsConfig()
234
- fsdp_config: FsdpConfig = FsdpConfig()
233
+ ds_config: DsConfig = field(default_factory=DsConfig)
234
+ fsdp_config: FsdpConfig = field(default_factory=FsdpConfig)
235
235
 
236
236
  kd_config: Optional[KDConfig] = None
237
237
  dpo_config: Optional[DPOConfig] = None
@@ -241,7 +241,7 @@ class TrainConfig:
241
241
  gradient_accumulation_steps: int = 0
242
242
  eval_batch_interval: int = 100
243
243
 
244
- eval_config: EvalConfig = EvalConfig()
244
+ eval_config: EvalConfig = field(default_factory=EvalConfig)
245
245
  pixel_values_provider: Optional[Callable[[list[str]], torch.Tensor]] = None
246
246
 
247
247
  init_state_dict: Optional[Mapping[str, Any]] = None
llm_trainer/trainer.py CHANGED
@@ -178,7 +178,7 @@ class Trainer:
178
178
 
179
179
  def _init_eval_model(self) -> Optional[nn.Module]:
180
180
  if TrainerTools().parallel.is_main_process:
181
- return self._new_model(self.train_config).to('cpu')
181
+ return self._new_model(self.train_config).to(device='cpu', dtype=TrainerTools().dtype)
182
182
 
183
183
  return None
184
184
 
@@ -337,6 +337,34 @@ class Trainer:
337
337
 
338
338
  return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
339
339
 
340
+ def _init_reference_args(self) -> dict:
341
+ parallel_kwargs, _, _, _ = self._convert_train_args()
342
+
343
+ if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
344
+ # reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
345
+ # if model is not None:
346
+ # hidden_size = (
347
+ # max(model.config.hidden_sizes)
348
+ # if getattr(model.config, "hidden_sizes", None)
349
+ # else getattr(model.config, "hidden_size", None)
350
+ # )
351
+ # if hidden_size is not None and stage == 3:
352
+ # # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
353
+ # # @ step 0: expected module 1, but got module 0`
354
+ # # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
355
+ # config_kwargs.update(
356
+ # {
357
+ # "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
358
+ # "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
359
+ # "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
360
+ # }
361
+ # )
362
+
363
+ if parallel_kwargs['zero_optimization']['stage'] != 3:
364
+ parallel_kwargs['zero_optimization']['stage'] = 0
365
+
366
+ return parallel_kwargs
367
+
340
368
  def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
341
369
  file_path = self.train_config.file_dataset[file_idx]
342
370
  max_position_embeddings = self.train_config.model_config.max_position_embeddings
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.4.11
3
+ Version: 0.4.13
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -2,7 +2,7 @@ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=yZcExxneN2yzvWxRiK-pstMWs35LV7GiOfqcLq-S6vc,5745
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
- llm_trainer/dpo_trainer.py,sha256=34E2b-t0GZYutaw6bESgARe9C12PUMWcY4aGZ34eAZU,13576
5
+ llm_trainer/dpo_trainer.py,sha256=djBhvI_ixTV1nLNg84tgCpfV--pu6IRiOhO28V-aANQ,11425
6
6
  llm_trainer/ds_checkpoint.py,sha256=x_tjgJR47P8gVwV4qAnTUCGwx7eVq2Epw0vOVV7fkYo,4925
7
7
  llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
8
8
  llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
@@ -19,17 +19,17 @@ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
19
19
  llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
20
20
  llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
21
21
  llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
22
- llm_trainer/train_configs.py,sha256=4sM96SOgwcn6jBGtbG5-qDZbJjiHVB6l7FWqdq7hbj0,7979
23
- llm_trainer/trainer.py,sha256=pUtJVRosn54j1hn76CFAptJcAsrDo59H6p8NMkg2zt4,25521
22
+ llm_trainer/train_configs.py,sha256=HKzH3nfMT1-SW4Htwa0KqYtMd6FAJcthR5IEo6di8us,8168
23
+ llm_trainer/trainer.py,sha256=j5fDqMzvU6SYwxHsv9wX0UVX4JXS-8eP1AkHgVxKf9U,27119
24
24
  llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
25
- project_llm_trainer-0.4.11.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
- project_llm_trainer-0.4.11.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
- project_llm_trainer-0.4.11.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
- project_llm_trainer-0.4.11.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
- project_llm_trainer-0.4.11.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
- project_llm_trainer-0.4.11.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
- project_llm_trainer-0.4.11.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
- project_llm_trainer-0.4.11.dist-info/METADATA,sha256=JEZo2-np0t_K-J6yapyAXsArpvYTmrSNGDsdy32kWas,196
33
- project_llm_trainer-0.4.11.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
- project_llm_trainer-0.4.11.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
- project_llm_trainer-0.4.11.dist-info/RECORD,,
25
+ project_llm_trainer-0.4.13.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
+ project_llm_trainer-0.4.13.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
+ project_llm_trainer-0.4.13.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
+ project_llm_trainer-0.4.13.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
+ project_llm_trainer-0.4.13.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
+ project_llm_trainer-0.4.13.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
+ project_llm_trainer-0.4.13.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
+ project_llm_trainer-0.4.13.dist-info/METADATA,sha256=hiW-7qgWuPizKVz4cU8mEHoqiuT6ZqNlCBb7nwVfFQ4,196
33
+ project_llm_trainer-0.4.13.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ project_llm_trainer-0.4.13.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
+ project_llm_trainer-0.4.13.dist-info/RECORD,,