project-llm-trainer 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,327 @@
1
+ from typing import Optional, Union, Callable, List, Mapping, Any, Tuple
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ from llm_model import ModelConfig, VLMConfig
6
+ from .tools import FileDataset
7
+
8
+
9
+ @dataclass(kw_only=True)
10
+ class DsOffloadConfig:
11
+ device: str = 'cpu'
12
+ pin_memory: bool = True
13
+
14
+
15
+ @dataclass(kw_only=True)
16
+ class DsActivationCheckpointingConfig:
17
+ partition_activations: bool = True
18
+ cpu_checkpointing: bool = False
19
+ contiguous_memory_optimization: bool = True
20
+ number_checkpoints: Optional[int] = None
21
+ synchronize_checkpoint_boundary: bool = False
22
+ profile: bool = False
23
+
24
+
25
+ @dataclass(kw_only=True)
26
+ class DsZeROConfig:
27
+ stage: int
28
+ allgather_partitions: Optional[bool] = True
29
+ allgather_bucket_size: Optional[int] = 5e8
30
+ overlap_comm: Optional[bool] = True
31
+ reduce_scatter: Optional[bool] = True
32
+ reduce_bucket_size: Optional[Union[str, int]] = 5e8
33
+ contiguous_gradients: Optional[bool] = True
34
+
35
+ @dataclass(kw_only=True)
36
+ class DsZero0Config(DsZeROConfig):
37
+ stage: int = field(default=0, init=False)
38
+
39
+
40
+ @dataclass(kw_only=True)
41
+ class DsZero1Config(DsZeROConfig):
42
+ stage: int = field(default=1, init=False)
43
+
44
+
45
+ @dataclass(kw_only=True)
46
+ class DsZero2Config(DsZeROConfig):
47
+ stage: int = field(default=2, init=False)
48
+ offload_optimizer: Optional[DsOffloadConfig] = None
49
+ offload_param: Optional[DsOffloadConfig] = None
50
+
51
+
52
+ @dataclass(kw_only=True)
53
+ class DsZero3Config(DsZeROConfig):
54
+ stage: int = field(default=3, init=False)
55
+ sub_group_size: Optional[int] = 1e9
56
+ stage3_prefetch_bucket_size: Optional[Union[str, int]] = 'auto'
57
+ stage3_param_persistence_threshold: Optional[Union[str, int]] = 'auto'
58
+ stage3_max_live_parameters: Optional[int] = 1e9
59
+ stage3_max_reuse_distance: Optional[int] = 1e9
60
+ stage3_gather_16bit_weights_on_model_save: Optional[bool] = True
61
+ offload_optimizer: Optional[DsOffloadConfig] = None
62
+ offload_param: Optional[DsOffloadConfig] = None
63
+
64
+
65
+ @dataclass(kw_only=True)
66
+ class DsFp16Config:
67
+ enabled: Union[str, bool] = 'auto'
68
+ loss_scale: int = 0
69
+ loss_scale_window: int = 1000
70
+ initial_scale_power: int = 16
71
+ hysteresis: int = 2
72
+ min_loss_scale: int = 1
73
+ fp16_opt_level: Optional[str] = 'O2'
74
+
75
+
76
+ @dataclass(kw_only=True)
77
+ class DsBf16Config:
78
+ enabled: bool = True
79
+
80
+
81
+ @dataclass(kw_only=True)
82
+ class DsConfig:
83
+ zero_config: Optional[DsZeROConfig] = field(default_factory=DsZero3Config)
84
+ fp16_config: Optional[DsFp16Config] = field(default_factory=DsFp16Config)
85
+ bf16_config: Optional[DsBf16Config] = field(default_factory=DsBf16Config)
86
+ gradient_clipping: Optional[float] = 1.0
87
+ activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
88
+
89
+
90
+ @dataclass(kw_only=True)
91
+ class DataLoaderConfig:
92
+ """
93
+ data loader配置项
94
+ Args:
95
+ data_loader_pin_memory (`bool`, *optional*, default is None):
96
+ data_loader pin_memory config
97
+ data_loader_num_workers (`int`, *optional*, default is 0):
98
+ data_loader num_workers config
99
+ data_loader_shuffle (`bool`, *optional*, default is False):
100
+ 是否需要shuffle数据
101
+ data_loader_drop_last (`bool`, default is False):
102
+ 最后一个batch不满足batch_size时,是否丢弃
103
+ """
104
+ data_loader_pin_memory: bool = False
105
+ data_loader_num_workers: int = 0
106
+ data_loader_shuffle: bool = False
107
+ data_loader_drop_last: bool = True
108
+
109
+
110
+ @dataclass(kw_only=True)
111
+ class OptimConfig:
112
+ optim_type: str = 'adam' # or 'lion'
113
+ enable_lr_scheduler: bool = False
114
+ initial_lr: float
115
+ weight_decay: Optional[float] = None
116
+ betas: Optional[Tuple[float, float]] = None
117
+ warmup_iters: Optional[int] = None
118
+ max_lr: Optional[float] = None
119
+ min_lr: Optional[float] = None
120
+ cosine_annealing_period: Optional[int] = None
121
+ cosine_annealing_period_mul: int = 0
122
+
123
+
124
+ @dataclass(kw_only=True)
125
+ class LossConfig:
126
+ critical_tokens: Optional[List[int]] = None
127
+ critical_alpha: float = 1.0
128
+ aux_loss_coef: Optional[float] = 0.001
129
+
130
+
131
+ @dataclass(kw_only=True)
132
+ class KDConfig:
133
+ """
134
+ 知识蒸馏模式配置项
135
+
136
+ Args:
137
+ teacher_logits_provider (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
138
+ 知识蒸馏教师模型logits的提供者
139
+ kd_coef (`float`, *optional*, default is 0.4):
140
+ 蒸馏loss的占比,loss = kd_coef * kd_loss + (1 - kd_coef) * lm_loss
141
+ """
142
+ teacher_logits_provider: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
143
+ kd_coef: float = 0.4
144
+
145
+
146
+ @dataclass(kw_only=True)
147
+ class EvalConfig:
148
+ """
149
+ 训练参数配置项
150
+
151
+ Args:
152
+ eval_batch_interval (`int`, default is 100):
153
+ 每隔多少个batch进行模型eval
154
+ """
155
+ max_seq_len: int
156
+
157
+ eval_batch_interval: int = 100
158
+ temperature: float = 1.0
159
+ top_p: float = 0.95
160
+ top_k: Optional[float] = None
161
+
162
+
163
+ @dataclass(kw_only=True)
164
+ class PretrainConfig:
165
+ """
166
+ 训练参数配置项
167
+
168
+ Args:
169
+ gradient_accumulation_steps (`int`, *Optional*, default is 1):
170
+ 梯度累积步数,为0时不使用梯度累积
171
+ 目前仅适用于pretrain\sft\dpo,不适用于ppo\grpo\gspo
172
+ kd_config: (`KDConfig`, *Optional*, default is None):
173
+ 知识蒸馏配置项,为None时不使用知识蒸馏
174
+ """
175
+ gradient_accumulation_steps: int = 1
176
+ kd_config: Optional[KDConfig] = None
177
+
178
+
179
+ @dataclass(kw_only=True)
180
+ class SFTConfig:
181
+ """
182
+ 训练参数配置项
183
+
184
+ Args:
185
+ mask_prompt (`bool`)
186
+ 指定是否mask prompt部分的token
187
+ gradient_accumulation_steps (`int`, *Optional*, default is 1):
188
+ 梯度累积步数,为0时不使用梯度累积
189
+ 目前仅适用于pretrain\sft\dpo,不适用于ppo\grpo\gspo
190
+ kd_config: (`KDConfig`, *Optional*, default is None):
191
+ 知识蒸馏配置项,为None时不使用知识蒸馏
192
+ pixel_values_provider: (`Callable[[list[str]], torch.Tensor]`, *Optional*, default is None):
193
+ 训练vlm时根据image_tag提供pixel_values信息
194
+ freeze_llm_model:
195
+ 是否冻结llm部分model参数,用于训练vlm
196
+ """
197
+ mask_prompt: bool = True
198
+ gradient_accumulation_steps: int = 1
199
+ kd_config: Optional[KDConfig] = None
200
+ image_tags_file_dataset: Optional[FileDataset] = None
201
+ pixel_values_provider: Optional[Callable[[list[str]], torch.Tensor]] = None
202
+ freeze_llm_model: bool = False
203
+
204
+
205
+ @dataclass(kw_only=True)
206
+ class DPOConfig:
207
+ """
208
+ 训练参数配置项
209
+
210
+ Args:
211
+ mask_prompt (`bool`)
212
+ 指定是否mask prompt部分的token
213
+ gradient_accumulation_steps (`int`, *Optional*, default is 1):
214
+ 梯度累积步数,为0时不使用梯度累积
215
+ 目前仅适用于pretrain\sft\dpo,不适用于ppo\grpo\gspo
216
+ """
217
+ ref_model_checkpoint: Mapping[str, Any]
218
+ mask_prompt: bool = True
219
+ gradient_accumulation_steps: int = 1
220
+ loss_beta: float
221
+ loss_label_smoothing: float = 0.0
222
+ loss_ipo: bool = False
223
+ nll_loss_coef: Optional[float] = None
224
+
225
+
226
+ @dataclass(kw_only=True)
227
+ class PPOConfig:
228
+ ppo_epochs: int
229
+ ppo_batch_size: int
230
+ ref_model_checkpoint: Mapping[str, Any]
231
+ value_model_checkpoint: Optional[Mapping[str, Any]] = None
232
+ value_optim_config: Optional['OptimConfig'] = None
233
+ gradient_accumulation_steps: int = 1
234
+ gamma: float = 1.0
235
+ lam: float = 0.95
236
+ clip_eps: float = 0.1
237
+ vf_coef: float = 0.5
238
+ kl_beta: float = 0.02
239
+ kl_estimator: str = 'k1' # or k3
240
+ missing_eos_penalty: Optional[float] = None
241
+ normalize_rewards: bool = False
242
+ normalize_method: str = 'RunningMeanStd' # RunningMeanStd or BatchStd
243
+ whiten_rewards: bool = False
244
+ gen_max_seq_len: int
245
+ gen_temperature: Optional[float] = None
246
+ gen_k: Optional[int] = None
247
+ gen_p: Optional[float] = None
248
+ gen_suppress_tokens: Optional[list[int]] = None
249
+
250
+
251
+ @dataclass(kw_only=True)
252
+ class GRPOConfig:
253
+ grpo_steps: int = 1
254
+ group_size: int = 12
255
+ mixup_alpha: float = 1.0
256
+ loss_beta: float = 0.0 # or 0.04 for grpo
257
+ loss_clip_eps: float = 3e-4
258
+ loss_clip_eps_high: Optional[float] = 4e-4
259
+ loss_delta: Optional[float] = None
260
+ loss_importance_sampling_level: str = 'seq' # token or seq
261
+ loss_type: str = 'grpo' # grpo or bnpo or dr_grpo
262
+ gen_max_seq_len: int
263
+ gen_temperature: Optional[float] = None
264
+ gen_k: Optional[int] = None
265
+ gen_p: Optional[float] = None
266
+ gen_suppress_tokens: Optional[list[int]] = None
267
+
268
+
269
+ @dataclass(kw_only=True)
270
+ class TrainConfig:
271
+ """
272
+ 训练参数配置项
273
+
274
+ Args:
275
+ n_epochs (`int`):
276
+ 训练epochs
277
+ batch_size (`int`):
278
+ 每个batch的大小
279
+ model_config (`ModelConfig`):
280
+ 模型的配置
281
+ init_state_dict:
282
+ 初始化检查点
283
+ file_dataset (`FileDataset`):
284
+ 训练文件dataset
285
+ dataset_block_size (`int`, default is None)
286
+ 训练序列最大长度,为None时取model的max_position_embedding
287
+ data_loader_config: (`DataLoaderConfig`):
288
+ data loader配置项
289
+ loss_config:
290
+ 配置loss
291
+ ds_config:
292
+ 配置deepspeed
293
+ eval_config:
294
+ 配置eval
295
+ optim_config (`OptimConfig`):
296
+ optim配置项
297
+ pretrain_config:
298
+ 预训练配置项,仅适用于使用Trainer
299
+ sft_config:
300
+ sft配置项,仅适用于使用SFTTrainer
301
+ dpo_config:
302
+ dpo配置项,仅适用于使用DPOTrainer
303
+ ppo_config:
304
+ ppo配置项,仅适用于使用PPOTrainer
305
+ grpo_config:
306
+ grpo配置项,仅适用于使用GRPOTrainer
307
+ """
308
+ n_epochs: int
309
+ batch_size: int
310
+ model_config: Union[ModelConfig, VLMConfig]
311
+ init_state_dict: Optional[Mapping[str, Any]] = None
312
+
313
+ file_dataset: FileDataset
314
+ dataset_block_size: int
315
+ data_loader_config: DataLoaderConfig = field(default_factory=DataLoaderConfig)
316
+
317
+ loss_config: LossConfig = field(default_factory=LossConfig)
318
+ optim_config: OptimConfig = field(default_factory=OptimConfig)
319
+ ds_config: DsConfig = field(default_factory=DsConfig)
320
+
321
+ eval_config: EvalConfig = field(default_factory=EvalConfig)
322
+
323
+ pretrain_config: Optional[PretrainConfig] = None
324
+ sft_config: Optional[SFTConfig] = None
325
+ dpo_config: Optional[DPOConfig] = None
326
+ ppo_config: Optional[PPOConfig] = None
327
+ grpo_config: Optional[GRPOConfig] = None
llm_trainer/trainer.py ADDED
@@ -0,0 +1,34 @@
1
+ from typing import List, Tuple
2
+
3
+ from torch.utils.data import Dataset
4
+
5
+ from .base_trainer import BaseTrainer
6
+ from .train_configs import TrainConfig
7
+ from .utils import pretrain_collate_fn
8
+ from .dataset import PretrainDataset
9
+
10
+
11
+ class Trainer(BaseTrainer):
12
+ def __init__(
13
+ self,
14
+ *,
15
+ train_config: TrainConfig,
16
+ eval_prompts: List[str],
17
+ ):
18
+ super().__init__(
19
+ train_config=train_config,
20
+ eval_prompts=eval_prompts,
21
+ kd_config=train_config.pretrain_config.kd_config,
22
+ gradient_accumulation_steps=train_config.pretrain_config.gradient_accumulation_steps
23
+ )
24
+
25
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
26
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
27
+ data_loader_kwargs.update({"collate_fn": pretrain_collate_fn})
28
+
29
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
30
+
31
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
32
+ file_path = self.train_config.file_dataset[file_idx]
33
+ block_size = self.train_config.dataset_block_size
34
+ return PretrainDataset(file_path, block_size, block_size), file_path