project-llm-trainer 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,445 @@
1
+ from typing import Optional, Union, Set, Type, Callable, List, Mapping, Any
2
+
3
+ import torch
4
+ from torch import nn
5
+ from llm_model import ModelConfig, VLMConfig
6
+ from .tools import FileDataset
7
+
8
+
9
+ class DsOffloadConfig:
10
+ def __init__(
11
+ self,
12
+ *,
13
+ device: str = 'cpu',
14
+ pin_memory: bool = True
15
+ ):
16
+ self.device = device
17
+ self.pin_memory = pin_memory
18
+
19
+
20
+ class DsActivationCheckpointingConfig:
21
+ def __init__(
22
+ self,
23
+ *,
24
+ partition_activations: bool = True,
25
+ cpu_checkpointing: bool = True,
26
+ contiguous_memory_optimization: bool = True,
27
+ number_checkpoints: Optional[int] = None,
28
+ synchronize_checkpoint_boundary: bool = True,
29
+ profile: bool = True
30
+ ):
31
+ self.partition_activations =partition_activations
32
+ self.cpu_checkpointing = cpu_checkpointing
33
+ self.contiguous_memory_optimization = contiguous_memory_optimization
34
+ self.number_checkpoints = number_checkpoints
35
+ self.synchronize_checkpoint_boundary = synchronize_checkpoint_boundary
36
+ self.profile = profile
37
+
38
+
39
+ class DsZeROConfig:
40
+ def __init__(
41
+ self,
42
+ *,
43
+ stage: int,
44
+ allgather_partitions: Optional[bool] = True,
45
+ allgather_bucket_size: Optional[int] = 5e8,
46
+ overlap_comm: Optional[bool] = True,
47
+ reduce_scatter: Optional[bool] = True,
48
+ reduce_bucket_size: Optional[Union[str, int]] = 5e8,
49
+ contiguous_gradients: Optional[bool] = True
50
+ ):
51
+ self.stage = stage
52
+ self.allgather_partitions = allgather_partitions
53
+ self.allgather_bucket_size = allgather_bucket_size
54
+ self.overlap_comm = overlap_comm
55
+ self.reduce_scatter = reduce_scatter
56
+ self.reduce_bucket_size = reduce_bucket_size
57
+ self.contiguous_gradients = contiguous_gradients
58
+
59
+
60
+ class DsZero1Config(DsZeROConfig):
61
+ def __init__(
62
+ self,
63
+ *,
64
+ allgather_partitions: Optional[bool] = True,
65
+ allgather_bucket_size: Optional[int] = 5e8,
66
+ overlap_comm: Optional[bool] = True,
67
+ reduce_scatter: Optional[bool] = True,
68
+ reduce_bucket_size: Optional[Union[str, int]] = 5e8,
69
+ contiguous_gradients: Optional[bool] = True
70
+ ):
71
+ super().__init__(
72
+ stage=1,
73
+ allgather_partitions=allgather_partitions,
74
+ allgather_bucket_size=allgather_bucket_size,
75
+ overlap_comm=overlap_comm,
76
+ reduce_scatter=reduce_scatter,
77
+ reduce_bucket_size=reduce_bucket_size,
78
+ contiguous_gradients=contiguous_gradients
79
+ )
80
+
81
+
82
+ class DsZero2Config(DsZeROConfig):
83
+ def __init__(
84
+ self,
85
+ *,
86
+ allgather_partitions: Optional[bool] = True,
87
+ allgather_bucket_size: Optional[int] = 5e8,
88
+ overlap_comm: Optional[bool] = True,
89
+ reduce_scatter: Optional[bool] = True,
90
+ reduce_bucket_size: Optional[Union[str, int]] = 5e8,
91
+ contiguous_gradients: Optional[bool] = True,
92
+ offload_optimizer: Optional[DsOffloadConfig] = None,
93
+ offload_param: Optional[DsOffloadConfig] = None,
94
+
95
+ ):
96
+ super().__init__(
97
+ stage=2,
98
+ allgather_partitions=allgather_partitions,
99
+ allgather_bucket_size=allgather_bucket_size,
100
+ overlap_comm=overlap_comm,
101
+ reduce_scatter=reduce_scatter,
102
+ reduce_bucket_size=reduce_bucket_size,
103
+ contiguous_gradients=contiguous_gradients
104
+ )
105
+
106
+ self.offload_optimizer = offload_optimizer
107
+ self.offload_param = offload_param
108
+
109
+
110
+ class DsZero3Config(DsZeROConfig):
111
+ def __init__(
112
+ self,
113
+ *,
114
+ allgather_partitions: Optional[bool] = None,
115
+ allgather_bucket_size: Optional[bool] = None,
116
+ overlap_comm: Optional[bool] = True,
117
+ reduce_scatter: Optional[bool] = None,
118
+ reduce_bucket_size: Optional[Union[str, int]] = 'auto',
119
+ contiguous_gradients: Optional[bool] = True,
120
+ sub_group_size: Optional[int] = 1e9,
121
+ stage3_prefetch_bucket_size: Optional[Union[str, int]] = 'auto',
122
+ stage3_param_persistence_threshold: Optional[Union[str, int]] = 'auto',
123
+ stage3_max_live_parameters: Optional[int] = 1e9,
124
+ stage3_max_reuse_distance: Optional[int] = 1e9,
125
+ stage3_gather_16bit_weights_on_model_save: Optional[bool] = True,
126
+ offload_optimizer: Optional[DsOffloadConfig] = None,
127
+ offload_param: Optional[DsOffloadConfig] = None,
128
+
129
+ ):
130
+ super().__init__(
131
+ stage=3,
132
+ allgather_partitions=allgather_partitions,
133
+ allgather_bucket_size=allgather_bucket_size,
134
+ overlap_comm=overlap_comm,
135
+ reduce_scatter=reduce_scatter,
136
+ reduce_bucket_size=reduce_bucket_size,
137
+ contiguous_gradients=contiguous_gradients
138
+ )
139
+
140
+ self.sub_group_size = sub_group_size
141
+ self.stage3_prefetch_bucket_size = stage3_prefetch_bucket_size
142
+ self.stage3_param_persistence_threshold = stage3_param_persistence_threshold
143
+ self.stage3_max_live_parameters = stage3_max_live_parameters
144
+ self.stage3_max_reuse_distance = stage3_max_reuse_distance
145
+ self.stage3_gather_16bit_weights_on_model_save = stage3_gather_16bit_weights_on_model_save
146
+
147
+ self.offload_optimizer = offload_optimizer
148
+ self.offload_param = offload_param
149
+
150
+
151
+ class DsFp16Config:
152
+ """
153
+ DeepSpeed fp16配置项
154
+ 参数说明:https://deepspeed.org.cn/docs/config-json/
155
+ """
156
+ def __init__(
157
+ self,
158
+ *,
159
+ enabled: Union[str, bool] = 'auto',
160
+ loss_scale: int = 0,
161
+ loss_scale_window: int = 1000,
162
+ initial_scale_power: int = 16,
163
+ hysteresis: int = 2,
164
+ min_loss_scale: int = 1,
165
+ fp16_opt_level: Optional[str] = '02'
166
+ ):
167
+ self.enabled = enabled
168
+ self.loss_scale = loss_scale
169
+ self.loss_scale_window = loss_scale_window
170
+ self.initial_scale_power = initial_scale_power
171
+ self.hysteresis = hysteresis
172
+ self.min_loss_scale = min_loss_scale
173
+ self.fp16_opt_level = fp16_opt_level
174
+
175
+
176
+ class DsBf16Config:
177
+ def __init__(
178
+ self,
179
+ *,
180
+ enabled: bool = True
181
+ ):
182
+ self.enabled = enabled
183
+
184
+
185
+ class DsConfig:
186
+ """
187
+ DeepSpeed训练模式配置
188
+ """
189
+ def __init__(
190
+ self,
191
+ *,
192
+ zero_config: Optional[DsZeROConfig] = DsZero3Config(),
193
+ fp16_config: Optional[DsFp16Config] = DsFp16Config(),
194
+ bf16_config: Optional[DsBf16Config] = DsBf16Config(),
195
+ gradient_clipping: Optional[float] = 1.0,
196
+ activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
197
+ ):
198
+ self.zero_config = zero_config
199
+ self.fp16_config = fp16_config
200
+ self.bf16_config = bf16_config
201
+ self.gradient_clipping = gradient_clipping
202
+ self.activation_checkpointing = activation_checkpointing
203
+
204
+
205
+ class FsdpConfig:
206
+ """
207
+ fsdp训练模式配置项
208
+ Args:
209
+ transformer_layer_cls (`Set[Type[nn.Module]]`, *optional*, default is None):
210
+ 提供transformer层的类
211
+ wrap_policy_num_params (`int`, *optional*, default is -1):
212
+ size_based_auto_wrap_policy的min_num_params参数,-1不生效该策略
213
+ cpu_offload (`bool`, *optional*, default is False):
214
+ 是否使用cpu卸载
215
+ offload_params (`bool`, default is False):
216
+ 是否卸载参数,在cpu_offload为True时生效
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ *,
222
+ transformer_layer_cls: Optional[Set[Type[nn.Module]]] = None,
223
+ wrap_policy_num_params: int = -1,
224
+ cpu_offload: bool = False,
225
+ offload_params: bool = False,
226
+ ):
227
+ self.transformer_layer_cls = transformer_layer_cls
228
+ self.wrap_policy_num_params = wrap_policy_num_params
229
+ self.cpu_offload = cpu_offload
230
+ self.offload_params = offload_params
231
+
232
+
233
+ class DataLoaderConfig:
234
+ """
235
+ data loader配置项
236
+ Args:
237
+ data_loader_pin_memory (`bool`, *optional*, default is None):
238
+ data_loader pin_memory config
239
+ data_loader_num_workers (`int`, *optional*, default is 0):
240
+ data_loader num_workers config
241
+ data_loader_shuffle (`bool`, *optional*, default is False):
242
+ 是否需要shuffle数据
243
+ data_loader_drop_last (`bool`, default is False):
244
+ 最后一个batch不满足batch_size时,是否丢弃
245
+ """
246
+
247
+ def __init__(
248
+ self,
249
+ *,
250
+ data_loader_pin_memory: bool = False,
251
+ data_loader_num_workers: int = 0,
252
+ data_loader_shuffle: bool = False,
253
+ data_loader_drop_last: bool = True,
254
+ ):
255
+ self.data_loader_pin_memory = data_loader_pin_memory
256
+ self.data_loader_num_workers = data_loader_num_workers
257
+ self.data_loader_shuffle = data_loader_shuffle
258
+ self.data_loader_drop_last = data_loader_drop_last
259
+
260
+
261
+ class LrConfig:
262
+ def __init__(
263
+ self,
264
+ *,
265
+ enable_lr_scheduler: bool = False,
266
+ initial_lr: Optional[float] = None,
267
+ weight_decay: float = 0.1,
268
+ max_lr: Optional[float] = None,
269
+ min_lr: Optional[float] = None,
270
+ period: Optional[int] = None,
271
+ period_mul: Optional[int] = None,
272
+ warmup_iters: Optional[int] = None
273
+ ):
274
+ self.enable_lr_scheduler = enable_lr_scheduler
275
+ self.initial_lr = initial_lr
276
+ self.weight_decay = weight_decay
277
+ self.max_lr = max_lr
278
+ self.min_lr = min_lr
279
+ self.period = period
280
+ self.period_mul = period_mul
281
+ self.warmup_iters = warmup_iters
282
+
283
+
284
+ class LossConfig:
285
+ def __init__(
286
+ self,
287
+ *,
288
+ critical_tokens: Optional[List[int]] = None,
289
+ critical_alpha: float = 1.0,
290
+ aux_loss_coef: Optional[float] = 1.0
291
+ ):
292
+ super().__init__()
293
+ self.critical_tokens = critical_tokens
294
+ self.critical_alpha = critical_alpha
295
+ self.aux_loss_coef = aux_loss_coef
296
+
297
+
298
+ class DPOConfig:
299
+ def __init__(
300
+ self,
301
+ loss_beta: float,
302
+ loss_label_smoothing: float = 0.0,
303
+ loss_ipo: bool = False,
304
+ nll_loss_coef: Optional[float] = None
305
+ ):
306
+ super().__init__()
307
+ self.loss_beta = loss_beta
308
+ self.loss_label_smoothing = loss_label_smoothing
309
+ self.loss_ipo = loss_ipo
310
+ self.nll_loss_coef = nll_loss_coef
311
+
312
+
313
+ class GRPOConfig:
314
+ def __init__(
315
+ self,
316
+ grpo_steps: int = 1,
317
+ clip_eps: float = 0.2,
318
+ kl_weight: float = 0.01,
319
+ group_size: int = 12,
320
+ gen_max_new_tokens: Optional[int] = None,
321
+ gen_temperature: Optional[float] = None,
322
+ gen_k: Optional[int] = None,
323
+ gen_p: Optional[float] = None,
324
+ gen_suppress_tokens: Optional[list[int]] = None,
325
+ ):
326
+ self.grpo_steps = grpo_steps
327
+ self.clip_eps = clip_eps
328
+ self.kl_weight = kl_weight
329
+ self.group_size = group_size
330
+ self.gen_max_new_tokens = gen_max_new_tokens
331
+ self.gen_temperature = gen_temperature
332
+ self.gen_k = gen_k
333
+ self.gen_p = gen_p
334
+ self.gen_suppress_tokens = gen_suppress_tokens
335
+
336
+
337
+ class KDConfig:
338
+ """
339
+ 知识蒸馏模式配置项
340
+
341
+ Args:
342
+ teacher_logits_provider (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
343
+ 知识蒸馏教师模型logits的提供者
344
+ kd_coef (`float`, *optional*, default is 0.4):
345
+ 蒸馏loss的占比,loss = kd_coef * kd_loss + (1 - kd_coef) * lm_loss
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ *,
351
+ teacher_logits_provider: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
352
+ kd_coef: float = 0.4
353
+ ):
354
+ self.teacher_logits_provider = teacher_logits_provider
355
+ self.kd_coef = kd_coef
356
+
357
+
358
+ class EvalConfig:
359
+ def __init__(
360
+ self,
361
+ max_new_tokens: int = 512,
362
+ temperature: float = 1.0,
363
+ top_p: float = 0.95,
364
+ top_k: Optional[float] = None
365
+ ):
366
+ self.max_new_tokens = max_new_tokens
367
+ self.temperature = temperature
368
+ self.top_p = top_p
369
+ self.top_k = top_k
370
+
371
+
372
+ class TrainConfig:
373
+ """
374
+ 训练参数配置项
375
+
376
+ Args:
377
+ n_epochs (`int`):
378
+ 训练epochs
379
+ batch_size (`int`):
380
+ 每个batch的大小
381
+ model_config (`ModelConfig`):
382
+ 模型的配置
383
+ file_dataset (`FileDataset`):
384
+ 训练文件dataset
385
+ mask_prompt (`bool`)
386
+ 指定是否mask prompt部分的token
387
+ gradient_accumulation_steps (`int`, *Optional*, default is 0):
388
+ 梯度累积步数,为0时不使用梯度累积
389
+ grpo训练时不生效该配置!
390
+ eval_batch_interval (`int`, default is 100):
391
+ 每隔多少个batch进行模型eval
392
+ lr_config (`LrConfig`):
393
+ lr配置项
394
+ fsdp_config: (`FsdpConfig`):
395
+ fsdp训练模式配置项
396
+ data_loader_config: (`DataLoaderConfig`):
397
+ data loader配置项
398
+ kd_config: (`KDConfig`, *Optional*, default is None):
399
+ 知识蒸馏配置项,为None时不使用知识蒸馏
400
+ pixel_values_provider: (`Callable[[list[str]], torch.Tensor]`, *Optional*, default is None):
401
+ 训练vlm时根据image_tag提供pixel_values信息
402
+ """
403
+
404
+ def __init__(
405
+ self,
406
+ n_epochs: int,
407
+ batch_size: int,
408
+ *,
409
+ model_config: Union[ModelConfig, VLMConfig],
410
+ file_dataset: FileDataset,
411
+ mask_prompt: bool = True,
412
+ gradient_accumulation_steps: int = 0,
413
+ eval_batch_interval: int = 100,
414
+ loss_config: LossConfig = LossConfig(),
415
+ dpo_config: Optional[DPOConfig] = None,
416
+ grpo_config: Optional[GRPOConfig] = None,
417
+ lr_config: LrConfig = LrConfig(),
418
+ ds_config: DsConfig = DsConfig(),
419
+ fsdp_config: FsdpConfig = FsdpConfig(),
420
+ data_loader_config: DataLoaderConfig = DataLoaderConfig(),
421
+ kd_config: Optional[KDConfig] = None,
422
+ pixel_values_provider: Optional[Callable[[list[int]], torch.Tensor]] = None,
423
+ init_state_dict: Optional[Mapping[str, Any]] = None,
424
+ eval_config: EvalConfig = EvalConfig()
425
+ ):
426
+ self.n_epochs = n_epochs
427
+ self.batch_size = batch_size
428
+ self.model_config = model_config
429
+ self.file_dataset = file_dataset
430
+ self.mask_prompt = mask_prompt
431
+ self.gradient_accumulation_steps = gradient_accumulation_steps
432
+ self.eval_batch_interval = eval_batch_interval
433
+ self.loss_config = loss_config
434
+ self.dpo_config = dpo_config
435
+ self.grpo_config = grpo_config
436
+ self.lr_config = lr_config
437
+ self.ds_config = ds_config
438
+ self.fsdp_config = fsdp_config
439
+ self.data_loader_config = data_loader_config
440
+ self.kd_config = kd_config
441
+ self.pixel_values_provider = pixel_values_provider
442
+ self.init_state_dict = init_state_dict
443
+ self.eval_config = eval_config
444
+
445
+