project-llm-trainer 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +6 -0
- llm_trainer/checkpoint.py +161 -0
- llm_trainer/dataset.py +140 -0
- llm_trainer/dcp.py +93 -0
- llm_trainer/dpo_trainer.py +300 -0
- llm_trainer/ds_checkpoint.py +61 -0
- llm_trainer/eval.py +86 -0
- llm_trainer/generate_utils.py +424 -0
- llm_trainer/grpo_trainer.py +393 -0
- llm_trainer/log.py +16 -0
- llm_trainer/loss.py +171 -0
- llm_trainer/parallel.py +146 -0
- llm_trainer/parallel_ddp.py +39 -0
- llm_trainer/parallel_ds.py +45 -0
- llm_trainer/parallel_fsdp.py +115 -0
- llm_trainer/parallel_none.py +28 -0
- llm_trainer/scheduler.py +138 -0
- llm_trainer/sft_trainer.py +39 -0
- llm_trainer/tokenizer.py +166 -0
- llm_trainer/tools.py +102 -0
- llm_trainer/train_configs.py +445 -0
- llm_trainer/trainer.py +569 -0
- llm_trainer/utils.py +262 -0
- project_llm_trainer-0.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.3.data/scripts/ddp_train +12 -0
- project_llm_trainer-0.3.data/scripts/ds_train +12 -0
- project_llm_trainer-0.3.data/scripts/plot_loss +39 -0
- project_llm_trainer-0.3.data/scripts/plot_lr +41 -0
- project_llm_trainer-0.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.3.data/scripts/smart_train +28 -0
- project_llm_trainer-0.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.3.dist-info/RECORD +34 -0
- project_llm_trainer-0.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,445 @@
|
|
|
1
|
+
from typing import Optional, Union, Set, Type, Callable, List, Mapping, Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from llm_model import ModelConfig, VLMConfig
|
|
6
|
+
from .tools import FileDataset
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DsOffloadConfig:
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
*,
|
|
13
|
+
device: str = 'cpu',
|
|
14
|
+
pin_memory: bool = True
|
|
15
|
+
):
|
|
16
|
+
self.device = device
|
|
17
|
+
self.pin_memory = pin_memory
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DsActivationCheckpointingConfig:
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
*,
|
|
24
|
+
partition_activations: bool = True,
|
|
25
|
+
cpu_checkpointing: bool = True,
|
|
26
|
+
contiguous_memory_optimization: bool = True,
|
|
27
|
+
number_checkpoints: Optional[int] = None,
|
|
28
|
+
synchronize_checkpoint_boundary: bool = True,
|
|
29
|
+
profile: bool = True
|
|
30
|
+
):
|
|
31
|
+
self.partition_activations =partition_activations
|
|
32
|
+
self.cpu_checkpointing = cpu_checkpointing
|
|
33
|
+
self.contiguous_memory_optimization = contiguous_memory_optimization
|
|
34
|
+
self.number_checkpoints = number_checkpoints
|
|
35
|
+
self.synchronize_checkpoint_boundary = synchronize_checkpoint_boundary
|
|
36
|
+
self.profile = profile
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DsZeROConfig:
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
*,
|
|
43
|
+
stage: int,
|
|
44
|
+
allgather_partitions: Optional[bool] = True,
|
|
45
|
+
allgather_bucket_size: Optional[int] = 5e8,
|
|
46
|
+
overlap_comm: Optional[bool] = True,
|
|
47
|
+
reduce_scatter: Optional[bool] = True,
|
|
48
|
+
reduce_bucket_size: Optional[Union[str, int]] = 5e8,
|
|
49
|
+
contiguous_gradients: Optional[bool] = True
|
|
50
|
+
):
|
|
51
|
+
self.stage = stage
|
|
52
|
+
self.allgather_partitions = allgather_partitions
|
|
53
|
+
self.allgather_bucket_size = allgather_bucket_size
|
|
54
|
+
self.overlap_comm = overlap_comm
|
|
55
|
+
self.reduce_scatter = reduce_scatter
|
|
56
|
+
self.reduce_bucket_size = reduce_bucket_size
|
|
57
|
+
self.contiguous_gradients = contiguous_gradients
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DsZero1Config(DsZeROConfig):
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
*,
|
|
64
|
+
allgather_partitions: Optional[bool] = True,
|
|
65
|
+
allgather_bucket_size: Optional[int] = 5e8,
|
|
66
|
+
overlap_comm: Optional[bool] = True,
|
|
67
|
+
reduce_scatter: Optional[bool] = True,
|
|
68
|
+
reduce_bucket_size: Optional[Union[str, int]] = 5e8,
|
|
69
|
+
contiguous_gradients: Optional[bool] = True
|
|
70
|
+
):
|
|
71
|
+
super().__init__(
|
|
72
|
+
stage=1,
|
|
73
|
+
allgather_partitions=allgather_partitions,
|
|
74
|
+
allgather_bucket_size=allgather_bucket_size,
|
|
75
|
+
overlap_comm=overlap_comm,
|
|
76
|
+
reduce_scatter=reduce_scatter,
|
|
77
|
+
reduce_bucket_size=reduce_bucket_size,
|
|
78
|
+
contiguous_gradients=contiguous_gradients
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class DsZero2Config(DsZeROConfig):
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
*,
|
|
86
|
+
allgather_partitions: Optional[bool] = True,
|
|
87
|
+
allgather_bucket_size: Optional[int] = 5e8,
|
|
88
|
+
overlap_comm: Optional[bool] = True,
|
|
89
|
+
reduce_scatter: Optional[bool] = True,
|
|
90
|
+
reduce_bucket_size: Optional[Union[str, int]] = 5e8,
|
|
91
|
+
contiguous_gradients: Optional[bool] = True,
|
|
92
|
+
offload_optimizer: Optional[DsOffloadConfig] = None,
|
|
93
|
+
offload_param: Optional[DsOffloadConfig] = None,
|
|
94
|
+
|
|
95
|
+
):
|
|
96
|
+
super().__init__(
|
|
97
|
+
stage=2,
|
|
98
|
+
allgather_partitions=allgather_partitions,
|
|
99
|
+
allgather_bucket_size=allgather_bucket_size,
|
|
100
|
+
overlap_comm=overlap_comm,
|
|
101
|
+
reduce_scatter=reduce_scatter,
|
|
102
|
+
reduce_bucket_size=reduce_bucket_size,
|
|
103
|
+
contiguous_gradients=contiguous_gradients
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
self.offload_optimizer = offload_optimizer
|
|
107
|
+
self.offload_param = offload_param
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class DsZero3Config(DsZeROConfig):
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
*,
|
|
114
|
+
allgather_partitions: Optional[bool] = None,
|
|
115
|
+
allgather_bucket_size: Optional[bool] = None,
|
|
116
|
+
overlap_comm: Optional[bool] = True,
|
|
117
|
+
reduce_scatter: Optional[bool] = None,
|
|
118
|
+
reduce_bucket_size: Optional[Union[str, int]] = 'auto',
|
|
119
|
+
contiguous_gradients: Optional[bool] = True,
|
|
120
|
+
sub_group_size: Optional[int] = 1e9,
|
|
121
|
+
stage3_prefetch_bucket_size: Optional[Union[str, int]] = 'auto',
|
|
122
|
+
stage3_param_persistence_threshold: Optional[Union[str, int]] = 'auto',
|
|
123
|
+
stage3_max_live_parameters: Optional[int] = 1e9,
|
|
124
|
+
stage3_max_reuse_distance: Optional[int] = 1e9,
|
|
125
|
+
stage3_gather_16bit_weights_on_model_save: Optional[bool] = True,
|
|
126
|
+
offload_optimizer: Optional[DsOffloadConfig] = None,
|
|
127
|
+
offload_param: Optional[DsOffloadConfig] = None,
|
|
128
|
+
|
|
129
|
+
):
|
|
130
|
+
super().__init__(
|
|
131
|
+
stage=3,
|
|
132
|
+
allgather_partitions=allgather_partitions,
|
|
133
|
+
allgather_bucket_size=allgather_bucket_size,
|
|
134
|
+
overlap_comm=overlap_comm,
|
|
135
|
+
reduce_scatter=reduce_scatter,
|
|
136
|
+
reduce_bucket_size=reduce_bucket_size,
|
|
137
|
+
contiguous_gradients=contiguous_gradients
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
self.sub_group_size = sub_group_size
|
|
141
|
+
self.stage3_prefetch_bucket_size = stage3_prefetch_bucket_size
|
|
142
|
+
self.stage3_param_persistence_threshold = stage3_param_persistence_threshold
|
|
143
|
+
self.stage3_max_live_parameters = stage3_max_live_parameters
|
|
144
|
+
self.stage3_max_reuse_distance = stage3_max_reuse_distance
|
|
145
|
+
self.stage3_gather_16bit_weights_on_model_save = stage3_gather_16bit_weights_on_model_save
|
|
146
|
+
|
|
147
|
+
self.offload_optimizer = offload_optimizer
|
|
148
|
+
self.offload_param = offload_param
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class DsFp16Config:
|
|
152
|
+
"""
|
|
153
|
+
DeepSpeed fp16配置项
|
|
154
|
+
参数说明:https://deepspeed.org.cn/docs/config-json/
|
|
155
|
+
"""
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
*,
|
|
159
|
+
enabled: Union[str, bool] = 'auto',
|
|
160
|
+
loss_scale: int = 0,
|
|
161
|
+
loss_scale_window: int = 1000,
|
|
162
|
+
initial_scale_power: int = 16,
|
|
163
|
+
hysteresis: int = 2,
|
|
164
|
+
min_loss_scale: int = 1,
|
|
165
|
+
fp16_opt_level: Optional[str] = '02'
|
|
166
|
+
):
|
|
167
|
+
self.enabled = enabled
|
|
168
|
+
self.loss_scale = loss_scale
|
|
169
|
+
self.loss_scale_window = loss_scale_window
|
|
170
|
+
self.initial_scale_power = initial_scale_power
|
|
171
|
+
self.hysteresis = hysteresis
|
|
172
|
+
self.min_loss_scale = min_loss_scale
|
|
173
|
+
self.fp16_opt_level = fp16_opt_level
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class DsBf16Config:
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
*,
|
|
180
|
+
enabled: bool = True
|
|
181
|
+
):
|
|
182
|
+
self.enabled = enabled
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class DsConfig:
|
|
186
|
+
"""
|
|
187
|
+
DeepSpeed训练模式配置
|
|
188
|
+
"""
|
|
189
|
+
def __init__(
|
|
190
|
+
self,
|
|
191
|
+
*,
|
|
192
|
+
zero_config: Optional[DsZeROConfig] = DsZero3Config(),
|
|
193
|
+
fp16_config: Optional[DsFp16Config] = DsFp16Config(),
|
|
194
|
+
bf16_config: Optional[DsBf16Config] = DsBf16Config(),
|
|
195
|
+
gradient_clipping: Optional[float] = 1.0,
|
|
196
|
+
activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
|
|
197
|
+
):
|
|
198
|
+
self.zero_config = zero_config
|
|
199
|
+
self.fp16_config = fp16_config
|
|
200
|
+
self.bf16_config = bf16_config
|
|
201
|
+
self.gradient_clipping = gradient_clipping
|
|
202
|
+
self.activation_checkpointing = activation_checkpointing
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class FsdpConfig:
|
|
206
|
+
"""
|
|
207
|
+
fsdp训练模式配置项
|
|
208
|
+
Args:
|
|
209
|
+
transformer_layer_cls (`Set[Type[nn.Module]]`, *optional*, default is None):
|
|
210
|
+
提供transformer层的类
|
|
211
|
+
wrap_policy_num_params (`int`, *optional*, default is -1):
|
|
212
|
+
size_based_auto_wrap_policy的min_num_params参数,-1不生效该策略
|
|
213
|
+
cpu_offload (`bool`, *optional*, default is False):
|
|
214
|
+
是否使用cpu卸载
|
|
215
|
+
offload_params (`bool`, default is False):
|
|
216
|
+
是否卸载参数,在cpu_offload为True时生效
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
*,
|
|
222
|
+
transformer_layer_cls: Optional[Set[Type[nn.Module]]] = None,
|
|
223
|
+
wrap_policy_num_params: int = -1,
|
|
224
|
+
cpu_offload: bool = False,
|
|
225
|
+
offload_params: bool = False,
|
|
226
|
+
):
|
|
227
|
+
self.transformer_layer_cls = transformer_layer_cls
|
|
228
|
+
self.wrap_policy_num_params = wrap_policy_num_params
|
|
229
|
+
self.cpu_offload = cpu_offload
|
|
230
|
+
self.offload_params = offload_params
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class DataLoaderConfig:
|
|
234
|
+
"""
|
|
235
|
+
data loader配置项
|
|
236
|
+
Args:
|
|
237
|
+
data_loader_pin_memory (`bool`, *optional*, default is None):
|
|
238
|
+
data_loader pin_memory config
|
|
239
|
+
data_loader_num_workers (`int`, *optional*, default is 0):
|
|
240
|
+
data_loader num_workers config
|
|
241
|
+
data_loader_shuffle (`bool`, *optional*, default is False):
|
|
242
|
+
是否需要shuffle数据
|
|
243
|
+
data_loader_drop_last (`bool`, default is False):
|
|
244
|
+
最后一个batch不满足batch_size时,是否丢弃
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(
|
|
248
|
+
self,
|
|
249
|
+
*,
|
|
250
|
+
data_loader_pin_memory: bool = False,
|
|
251
|
+
data_loader_num_workers: int = 0,
|
|
252
|
+
data_loader_shuffle: bool = False,
|
|
253
|
+
data_loader_drop_last: bool = True,
|
|
254
|
+
):
|
|
255
|
+
self.data_loader_pin_memory = data_loader_pin_memory
|
|
256
|
+
self.data_loader_num_workers = data_loader_num_workers
|
|
257
|
+
self.data_loader_shuffle = data_loader_shuffle
|
|
258
|
+
self.data_loader_drop_last = data_loader_drop_last
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class LrConfig:
|
|
262
|
+
def __init__(
|
|
263
|
+
self,
|
|
264
|
+
*,
|
|
265
|
+
enable_lr_scheduler: bool = False,
|
|
266
|
+
initial_lr: Optional[float] = None,
|
|
267
|
+
weight_decay: float = 0.1,
|
|
268
|
+
max_lr: Optional[float] = None,
|
|
269
|
+
min_lr: Optional[float] = None,
|
|
270
|
+
period: Optional[int] = None,
|
|
271
|
+
period_mul: Optional[int] = None,
|
|
272
|
+
warmup_iters: Optional[int] = None
|
|
273
|
+
):
|
|
274
|
+
self.enable_lr_scheduler = enable_lr_scheduler
|
|
275
|
+
self.initial_lr = initial_lr
|
|
276
|
+
self.weight_decay = weight_decay
|
|
277
|
+
self.max_lr = max_lr
|
|
278
|
+
self.min_lr = min_lr
|
|
279
|
+
self.period = period
|
|
280
|
+
self.period_mul = period_mul
|
|
281
|
+
self.warmup_iters = warmup_iters
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class LossConfig:
|
|
285
|
+
def __init__(
|
|
286
|
+
self,
|
|
287
|
+
*,
|
|
288
|
+
critical_tokens: Optional[List[int]] = None,
|
|
289
|
+
critical_alpha: float = 1.0,
|
|
290
|
+
aux_loss_coef: Optional[float] = 1.0
|
|
291
|
+
):
|
|
292
|
+
super().__init__()
|
|
293
|
+
self.critical_tokens = critical_tokens
|
|
294
|
+
self.critical_alpha = critical_alpha
|
|
295
|
+
self.aux_loss_coef = aux_loss_coef
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class DPOConfig:
|
|
299
|
+
def __init__(
|
|
300
|
+
self,
|
|
301
|
+
loss_beta: float,
|
|
302
|
+
loss_label_smoothing: float = 0.0,
|
|
303
|
+
loss_ipo: bool = False,
|
|
304
|
+
nll_loss_coef: Optional[float] = None
|
|
305
|
+
):
|
|
306
|
+
super().__init__()
|
|
307
|
+
self.loss_beta = loss_beta
|
|
308
|
+
self.loss_label_smoothing = loss_label_smoothing
|
|
309
|
+
self.loss_ipo = loss_ipo
|
|
310
|
+
self.nll_loss_coef = nll_loss_coef
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class GRPOConfig:
|
|
314
|
+
def __init__(
|
|
315
|
+
self,
|
|
316
|
+
grpo_steps: int = 1,
|
|
317
|
+
clip_eps: float = 0.2,
|
|
318
|
+
kl_weight: float = 0.01,
|
|
319
|
+
group_size: int = 12,
|
|
320
|
+
gen_max_new_tokens: Optional[int] = None,
|
|
321
|
+
gen_temperature: Optional[float] = None,
|
|
322
|
+
gen_k: Optional[int] = None,
|
|
323
|
+
gen_p: Optional[float] = None,
|
|
324
|
+
gen_suppress_tokens: Optional[list[int]] = None,
|
|
325
|
+
):
|
|
326
|
+
self.grpo_steps = grpo_steps
|
|
327
|
+
self.clip_eps = clip_eps
|
|
328
|
+
self.kl_weight = kl_weight
|
|
329
|
+
self.group_size = group_size
|
|
330
|
+
self.gen_max_new_tokens = gen_max_new_tokens
|
|
331
|
+
self.gen_temperature = gen_temperature
|
|
332
|
+
self.gen_k = gen_k
|
|
333
|
+
self.gen_p = gen_p
|
|
334
|
+
self.gen_suppress_tokens = gen_suppress_tokens
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class KDConfig:
|
|
338
|
+
"""
|
|
339
|
+
知识蒸馏模式配置项
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
teacher_logits_provider (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
|
343
|
+
知识蒸馏教师模型logits的提供者
|
|
344
|
+
kd_coef (`float`, *optional*, default is 0.4):
|
|
345
|
+
蒸馏loss的占比,loss = kd_coef * kd_loss + (1 - kd_coef) * lm_loss
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
def __init__(
|
|
349
|
+
self,
|
|
350
|
+
*,
|
|
351
|
+
teacher_logits_provider: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
352
|
+
kd_coef: float = 0.4
|
|
353
|
+
):
|
|
354
|
+
self.teacher_logits_provider = teacher_logits_provider
|
|
355
|
+
self.kd_coef = kd_coef
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class EvalConfig:
|
|
359
|
+
def __init__(
|
|
360
|
+
self,
|
|
361
|
+
max_new_tokens: int = 512,
|
|
362
|
+
temperature: float = 1.0,
|
|
363
|
+
top_p: float = 0.95,
|
|
364
|
+
top_k: Optional[float] = None
|
|
365
|
+
):
|
|
366
|
+
self.max_new_tokens = max_new_tokens
|
|
367
|
+
self.temperature = temperature
|
|
368
|
+
self.top_p = top_p
|
|
369
|
+
self.top_k = top_k
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class TrainConfig:
|
|
373
|
+
"""
|
|
374
|
+
训练参数配置项
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
n_epochs (`int`):
|
|
378
|
+
训练epochs
|
|
379
|
+
batch_size (`int`):
|
|
380
|
+
每个batch的大小
|
|
381
|
+
model_config (`ModelConfig`):
|
|
382
|
+
模型的配置
|
|
383
|
+
file_dataset (`FileDataset`):
|
|
384
|
+
训练文件dataset
|
|
385
|
+
mask_prompt (`bool`)
|
|
386
|
+
指定是否mask prompt部分的token
|
|
387
|
+
gradient_accumulation_steps (`int`, *Optional*, default is 0):
|
|
388
|
+
梯度累积步数,为0时不使用梯度累积
|
|
389
|
+
grpo训练时不生效该配置!
|
|
390
|
+
eval_batch_interval (`int`, default is 100):
|
|
391
|
+
每隔多少个batch进行模型eval
|
|
392
|
+
lr_config (`LrConfig`):
|
|
393
|
+
lr配置项
|
|
394
|
+
fsdp_config: (`FsdpConfig`):
|
|
395
|
+
fsdp训练模式配置项
|
|
396
|
+
data_loader_config: (`DataLoaderConfig`):
|
|
397
|
+
data loader配置项
|
|
398
|
+
kd_config: (`KDConfig`, *Optional*, default is None):
|
|
399
|
+
知识蒸馏配置项,为None时不使用知识蒸馏
|
|
400
|
+
pixel_values_provider: (`Callable[[list[str]], torch.Tensor]`, *Optional*, default is None):
|
|
401
|
+
训练vlm时根据image_tag提供pixel_values信息
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
def __init__(
|
|
405
|
+
self,
|
|
406
|
+
n_epochs: int,
|
|
407
|
+
batch_size: int,
|
|
408
|
+
*,
|
|
409
|
+
model_config: Union[ModelConfig, VLMConfig],
|
|
410
|
+
file_dataset: FileDataset,
|
|
411
|
+
mask_prompt: bool = True,
|
|
412
|
+
gradient_accumulation_steps: int = 0,
|
|
413
|
+
eval_batch_interval: int = 100,
|
|
414
|
+
loss_config: LossConfig = LossConfig(),
|
|
415
|
+
dpo_config: Optional[DPOConfig] = None,
|
|
416
|
+
grpo_config: Optional[GRPOConfig] = None,
|
|
417
|
+
lr_config: LrConfig = LrConfig(),
|
|
418
|
+
ds_config: DsConfig = DsConfig(),
|
|
419
|
+
fsdp_config: FsdpConfig = FsdpConfig(),
|
|
420
|
+
data_loader_config: DataLoaderConfig = DataLoaderConfig(),
|
|
421
|
+
kd_config: Optional[KDConfig] = None,
|
|
422
|
+
pixel_values_provider: Optional[Callable[[list[int]], torch.Tensor]] = None,
|
|
423
|
+
init_state_dict: Optional[Mapping[str, Any]] = None,
|
|
424
|
+
eval_config: EvalConfig = EvalConfig()
|
|
425
|
+
):
|
|
426
|
+
self.n_epochs = n_epochs
|
|
427
|
+
self.batch_size = batch_size
|
|
428
|
+
self.model_config = model_config
|
|
429
|
+
self.file_dataset = file_dataset
|
|
430
|
+
self.mask_prompt = mask_prompt
|
|
431
|
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
432
|
+
self.eval_batch_interval = eval_batch_interval
|
|
433
|
+
self.loss_config = loss_config
|
|
434
|
+
self.dpo_config = dpo_config
|
|
435
|
+
self.grpo_config = grpo_config
|
|
436
|
+
self.lr_config = lr_config
|
|
437
|
+
self.ds_config = ds_config
|
|
438
|
+
self.fsdp_config = fsdp_config
|
|
439
|
+
self.data_loader_config = data_loader_config
|
|
440
|
+
self.kd_config = kd_config
|
|
441
|
+
self.pixel_values_provider = pixel_values_provider
|
|
442
|
+
self.init_state_dict = init_state_dict
|
|
443
|
+
self.eval_config = eval_config
|
|
444
|
+
|
|
445
|
+
|