project-llm-trainer 0.4.10__py3-none-any.whl → 0.4.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dpo_trainer.py +7 -3
- llm_trainer/train_configs.py +131 -332
- {project_llm_trainer-0.4.10.dist-info → project_llm_trainer-0.4.11.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.4.10.dist-info → project_llm_trainer-0.4.11.dist-info}/RECORD +13 -13
- {project_llm_trainer-0.4.10.data → project_llm_trainer-0.4.11.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.10.data → project_llm_trainer-0.4.11.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.10.data → project_llm_trainer-0.4.11.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.10.data → project_llm_trainer-0.4.11.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.10.data → project_llm_trainer-0.4.11.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.10.data → project_llm_trainer-0.4.11.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.10.data → project_llm_trainer-0.4.11.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.10.dist-info → project_llm_trainer-0.4.11.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.10.dist-info → project_llm_trainer-0.4.11.dist-info}/top_level.txt +0 -0
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -250,13 +250,17 @@ class DPOTrainer(Trainer):
|
|
|
250
250
|
if gradient_accumulation_steps > 1:
|
|
251
251
|
loss = loss / gradient_accumulation_steps
|
|
252
252
|
|
|
253
|
-
loss_accumulation += loss.detach()
|
|
253
|
+
loss_accumulation += loss.detach().item()
|
|
254
254
|
self._backward_loss(loss)
|
|
255
255
|
|
|
256
256
|
if need_update_grad:
|
|
257
|
+
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
258
|
+
|
|
257
259
|
# todo check all_reduce??
|
|
258
260
|
if TrainerTools().parallel.parallel_train:
|
|
259
|
-
dist.all_reduce(
|
|
261
|
+
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
262
|
+
|
|
263
|
+
final_log_loss = loss_tensor.item()
|
|
260
264
|
|
|
261
265
|
# ds模式已经集成gradient_clipping
|
|
262
266
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
@@ -270,7 +274,7 @@ class DPOTrainer(Trainer):
|
|
|
270
274
|
epoch_tag=f'epoch: {epoch}',
|
|
271
275
|
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
272
276
|
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
273
|
-
loss=
|
|
277
|
+
loss=final_log_loss
|
|
274
278
|
)
|
|
275
279
|
# reset to default
|
|
276
280
|
loss_accumulation = 0.0
|
llm_trainer/train_configs.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from typing import Optional, Union, Set, Type, Callable, List, Mapping, Any
|
|
2
|
+
from dataclasses import dataclass, field
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import nn
|
|
@@ -6,202 +7,84 @@ from llm_model import ModelConfig, VLMConfig
|
|
|
6
7
|
from .tools import FileDataset
|
|
7
8
|
|
|
8
9
|
|
|
10
|
+
@dataclass(kw_only=True)
|
|
9
11
|
class DsOffloadConfig:
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
*,
|
|
13
|
-
device: str = 'cpu',
|
|
14
|
-
pin_memory: bool = True
|
|
15
|
-
):
|
|
16
|
-
self.device = device
|
|
17
|
-
self.pin_memory = pin_memory
|
|
12
|
+
device: str = 'cpu'
|
|
13
|
+
pin_memory: bool = True
|
|
18
14
|
|
|
19
15
|
|
|
16
|
+
@dataclass(kw_only=True)
|
|
20
17
|
class DsActivationCheckpointingConfig:
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
number_checkpoints: Optional[int] = None,
|
|
28
|
-
synchronize_checkpoint_boundary: bool = False,
|
|
29
|
-
profile: bool = False
|
|
30
|
-
):
|
|
31
|
-
self.partition_activations =partition_activations
|
|
32
|
-
self.cpu_checkpointing = cpu_checkpointing
|
|
33
|
-
self.contiguous_memory_optimization = contiguous_memory_optimization
|
|
34
|
-
self.number_checkpoints = number_checkpoints
|
|
35
|
-
self.synchronize_checkpoint_boundary = synchronize_checkpoint_boundary
|
|
36
|
-
self.profile = profile
|
|
18
|
+
partition_activations: bool = True
|
|
19
|
+
cpu_checkpointing: bool = False
|
|
20
|
+
contiguous_memory_optimization: bool = True
|
|
21
|
+
number_checkpoints: Optional[int] = None
|
|
22
|
+
synchronize_checkpoint_boundary: bool = False
|
|
23
|
+
profile: bool = False
|
|
37
24
|
|
|
38
25
|
|
|
26
|
+
@dataclass(kw_only=True)
|
|
39
27
|
class DsZeROConfig:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
reduce_scatter: Optional[bool] = True,
|
|
48
|
-
reduce_bucket_size: Optional[Union[str, int]] = 5e8,
|
|
49
|
-
contiguous_gradients: Optional[bool] = True
|
|
50
|
-
):
|
|
51
|
-
self.stage = stage
|
|
52
|
-
self.allgather_partitions = allgather_partitions
|
|
53
|
-
self.allgather_bucket_size = allgather_bucket_size
|
|
54
|
-
self.overlap_comm = overlap_comm
|
|
55
|
-
self.reduce_scatter = reduce_scatter
|
|
56
|
-
self.reduce_bucket_size = reduce_bucket_size
|
|
57
|
-
self.contiguous_gradients = contiguous_gradients
|
|
28
|
+
stage: int
|
|
29
|
+
allgather_partitions: Optional[bool] = True
|
|
30
|
+
allgather_bucket_size: Optional[int] = 5e8
|
|
31
|
+
overlap_comm: Optional[bool] = True
|
|
32
|
+
reduce_scatter: Optional[bool] = True
|
|
33
|
+
reduce_bucket_size: Optional[Union[str, int]] = 5e8
|
|
34
|
+
contiguous_gradients: Optional[bool] = True
|
|
58
35
|
|
|
59
36
|
|
|
37
|
+
@dataclass(kw_only=True)
|
|
60
38
|
class DsZero1Config(DsZeROConfig):
|
|
61
|
-
|
|
62
|
-
self,
|
|
63
|
-
*,
|
|
64
|
-
allgather_partitions: Optional[bool] = True,
|
|
65
|
-
allgather_bucket_size: Optional[int] = 5e8,
|
|
66
|
-
overlap_comm: Optional[bool] = True,
|
|
67
|
-
reduce_scatter: Optional[bool] = True,
|
|
68
|
-
reduce_bucket_size: Optional[Union[str, int]] = 5e8,
|
|
69
|
-
contiguous_gradients: Optional[bool] = True
|
|
70
|
-
):
|
|
71
|
-
super().__init__(
|
|
72
|
-
stage=1,
|
|
73
|
-
allgather_partitions=allgather_partitions,
|
|
74
|
-
allgather_bucket_size=allgather_bucket_size,
|
|
75
|
-
overlap_comm=overlap_comm,
|
|
76
|
-
reduce_scatter=reduce_scatter,
|
|
77
|
-
reduce_bucket_size=reduce_bucket_size,
|
|
78
|
-
contiguous_gradients=contiguous_gradients
|
|
79
|
-
)
|
|
39
|
+
stage: int = field(default=1, init=False)
|
|
80
40
|
|
|
81
41
|
|
|
42
|
+
@dataclass(kw_only=True)
|
|
82
43
|
class DsZero2Config(DsZeROConfig):
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
allgather_partitions: Optional[bool] = True,
|
|
87
|
-
allgather_bucket_size: Optional[int] = 5e8,
|
|
88
|
-
overlap_comm: Optional[bool] = True,
|
|
89
|
-
reduce_scatter: Optional[bool] = True,
|
|
90
|
-
reduce_bucket_size: Optional[Union[str, int]] = 5e8,
|
|
91
|
-
contiguous_gradients: Optional[bool] = True,
|
|
92
|
-
offload_optimizer: Optional[DsOffloadConfig] = None,
|
|
93
|
-
offload_param: Optional[DsOffloadConfig] = None,
|
|
94
|
-
|
|
95
|
-
):
|
|
96
|
-
super().__init__(
|
|
97
|
-
stage=2,
|
|
98
|
-
allgather_partitions=allgather_partitions,
|
|
99
|
-
allgather_bucket_size=allgather_bucket_size,
|
|
100
|
-
overlap_comm=overlap_comm,
|
|
101
|
-
reduce_scatter=reduce_scatter,
|
|
102
|
-
reduce_bucket_size=reduce_bucket_size,
|
|
103
|
-
contiguous_gradients=contiguous_gradients
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
self.offload_optimizer = offload_optimizer
|
|
107
|
-
self.offload_param = offload_param
|
|
44
|
+
stage: int = field(default=2, init=False)
|
|
45
|
+
offload_optimizer: Optional[DsOffloadConfig] = None
|
|
46
|
+
offload_param: Optional[DsOffloadConfig] = None
|
|
108
47
|
|
|
109
48
|
|
|
49
|
+
@dataclass(kw_only=True)
|
|
110
50
|
class DsZero3Config(DsZeROConfig):
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
stage3_max_live_parameters: Optional[int] = 1e9,
|
|
124
|
-
stage3_max_reuse_distance: Optional[int] = 1e9,
|
|
125
|
-
stage3_gather_16bit_weights_on_model_save: Optional[bool] = True,
|
|
126
|
-
offload_optimizer: Optional[DsOffloadConfig] = None,
|
|
127
|
-
offload_param: Optional[DsOffloadConfig] = None,
|
|
128
|
-
|
|
129
|
-
):
|
|
130
|
-
super().__init__(
|
|
131
|
-
stage=3,
|
|
132
|
-
allgather_partitions=allgather_partitions,
|
|
133
|
-
allgather_bucket_size=allgather_bucket_size,
|
|
134
|
-
overlap_comm=overlap_comm,
|
|
135
|
-
reduce_scatter=reduce_scatter,
|
|
136
|
-
reduce_bucket_size=reduce_bucket_size,
|
|
137
|
-
contiguous_gradients=contiguous_gradients
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
self.sub_group_size = sub_group_size
|
|
141
|
-
self.stage3_prefetch_bucket_size = stage3_prefetch_bucket_size
|
|
142
|
-
self.stage3_param_persistence_threshold = stage3_param_persistence_threshold
|
|
143
|
-
self.stage3_max_live_parameters = stage3_max_live_parameters
|
|
144
|
-
self.stage3_max_reuse_distance = stage3_max_reuse_distance
|
|
145
|
-
self.stage3_gather_16bit_weights_on_model_save = stage3_gather_16bit_weights_on_model_save
|
|
146
|
-
|
|
147
|
-
self.offload_optimizer = offload_optimizer
|
|
148
|
-
self.offload_param = offload_param
|
|
149
|
-
|
|
150
|
-
|
|
51
|
+
stage: int = field(default=3, init=False)
|
|
52
|
+
sub_group_size: Optional[int] = 1e9
|
|
53
|
+
stage3_prefetch_bucket_size: Optional[Union[str, int]] = 'auto'
|
|
54
|
+
stage3_param_persistence_threshold: Optional[Union[str, int]] = 'auto'
|
|
55
|
+
stage3_max_live_parameters: Optional[int] = 1e9
|
|
56
|
+
stage3_max_reuse_distance: Optional[int] = 1e9
|
|
57
|
+
stage3_gather_16bit_weights_on_model_save: Optional[bool] = True
|
|
58
|
+
offload_optimizer: Optional[DsOffloadConfig] = None
|
|
59
|
+
offload_param: Optional[DsOffloadConfig] = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass(kw_only=True)
|
|
151
63
|
class DsFp16Config:
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
enabled: Union[str, bool] = 'auto',
|
|
160
|
-
loss_scale: int = 0,
|
|
161
|
-
loss_scale_window: int = 1000,
|
|
162
|
-
initial_scale_power: int = 16,
|
|
163
|
-
hysteresis: int = 2,
|
|
164
|
-
min_loss_scale: int = 1,
|
|
165
|
-
fp16_opt_level: Optional[str] = '02'
|
|
166
|
-
):
|
|
167
|
-
self.enabled = enabled
|
|
168
|
-
self.loss_scale = loss_scale
|
|
169
|
-
self.loss_scale_window = loss_scale_window
|
|
170
|
-
self.initial_scale_power = initial_scale_power
|
|
171
|
-
self.hysteresis = hysteresis
|
|
172
|
-
self.min_loss_scale = min_loss_scale
|
|
173
|
-
self.fp16_opt_level = fp16_opt_level
|
|
64
|
+
enabled: Union[str, bool] = 'auto'
|
|
65
|
+
loss_scale: int = 0
|
|
66
|
+
loss_scale_window: int = 1000
|
|
67
|
+
initial_scale_power: int = 16
|
|
68
|
+
hysteresis: int = 2
|
|
69
|
+
min_loss_scale: int = 1
|
|
70
|
+
fp16_opt_level: Optional[str] = '02'
|
|
174
71
|
|
|
175
72
|
|
|
73
|
+
@dataclass(kw_only=True)
|
|
176
74
|
class DsBf16Config:
|
|
177
|
-
|
|
178
|
-
self,
|
|
179
|
-
*,
|
|
180
|
-
enabled: bool = True
|
|
181
|
-
):
|
|
182
|
-
self.enabled = enabled
|
|
75
|
+
enabled: bool = True
|
|
183
76
|
|
|
184
77
|
|
|
78
|
+
@dataclass(kw_only=True)
|
|
185
79
|
class DsConfig:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
*,
|
|
192
|
-
zero_config: Optional[DsZeROConfig] = DsZero3Config(),
|
|
193
|
-
fp16_config: Optional[DsFp16Config] = DsFp16Config(),
|
|
194
|
-
bf16_config: Optional[DsBf16Config] = DsBf16Config(),
|
|
195
|
-
gradient_clipping: Optional[float] = 1.0,
|
|
196
|
-
activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
|
|
197
|
-
):
|
|
198
|
-
self.zero_config = zero_config
|
|
199
|
-
self.fp16_config = fp16_config
|
|
200
|
-
self.bf16_config = bf16_config
|
|
201
|
-
self.gradient_clipping = gradient_clipping
|
|
202
|
-
self.activation_checkpointing = activation_checkpointing
|
|
80
|
+
zero_config: Optional[DsZeROConfig] = DsZero3Config()
|
|
81
|
+
fp16_config: Optional[DsFp16Config] = DsFp16Config()
|
|
82
|
+
bf16_config: Optional[DsBf16Config] = DsBf16Config()
|
|
83
|
+
gradient_clipping: Optional[float] = 1.0
|
|
84
|
+
activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
|
|
203
85
|
|
|
204
86
|
|
|
87
|
+
@dataclass(kw_only=True)
|
|
205
88
|
class FsdpConfig:
|
|
206
89
|
"""
|
|
207
90
|
fsdp训练模式配置项
|
|
@@ -214,22 +97,14 @@ class FsdpConfig:
|
|
|
214
97
|
是否使用cpu卸载
|
|
215
98
|
offload_params (`bool`, default is False):
|
|
216
99
|
是否卸载参数,在cpu_offload为True时生效
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
transformer_layer_cls: Optional[Set[Type[nn.Module]]] = None,
|
|
223
|
-
wrap_policy_num_params: int = -1,
|
|
224
|
-
cpu_offload: bool = False,
|
|
225
|
-
offload_params: bool = False,
|
|
226
|
-
):
|
|
227
|
-
self.transformer_layer_cls = transformer_layer_cls
|
|
228
|
-
self.wrap_policy_num_params = wrap_policy_num_params
|
|
229
|
-
self.cpu_offload = cpu_offload
|
|
230
|
-
self.offload_params = offload_params
|
|
100
|
+
"""
|
|
101
|
+
transformer_layer_cls: Optional[Set[Type[nn.Module]]] = None
|
|
102
|
+
wrap_policy_num_params: int = -1
|
|
103
|
+
cpu_offload: bool = False
|
|
104
|
+
offload_params: bool = False
|
|
231
105
|
|
|
232
106
|
|
|
107
|
+
@dataclass(kw_only=True)
|
|
233
108
|
class DataLoaderConfig:
|
|
234
109
|
"""
|
|
235
110
|
data loader配置项
|
|
@@ -242,98 +117,54 @@ class DataLoaderConfig:
|
|
|
242
117
|
是否需要shuffle数据
|
|
243
118
|
data_loader_drop_last (`bool`, default is False):
|
|
244
119
|
最后一个batch不满足batch_size时,是否丢弃
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
data_loader_pin_memory: bool = False,
|
|
251
|
-
data_loader_num_workers: int = 0,
|
|
252
|
-
data_loader_shuffle: bool = False,
|
|
253
|
-
data_loader_drop_last: bool = True,
|
|
254
|
-
):
|
|
255
|
-
self.data_loader_pin_memory = data_loader_pin_memory
|
|
256
|
-
self.data_loader_num_workers = data_loader_num_workers
|
|
257
|
-
self.data_loader_shuffle = data_loader_shuffle
|
|
258
|
-
self.data_loader_drop_last = data_loader_drop_last
|
|
120
|
+
"""
|
|
121
|
+
data_loader_pin_memory: bool = False
|
|
122
|
+
data_loader_num_workers: int = 0
|
|
123
|
+
data_loader_shuffle: bool = False
|
|
124
|
+
data_loader_drop_last: bool = True
|
|
259
125
|
|
|
260
126
|
|
|
127
|
+
@dataclass(kw_only=True)
|
|
261
128
|
class LrConfig:
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
period: Optional[int] = None,
|
|
271
|
-
period_mul: Optional[int] = None,
|
|
272
|
-
warmup_iters: Optional[int] = None
|
|
273
|
-
):
|
|
274
|
-
self.enable_lr_scheduler = enable_lr_scheduler
|
|
275
|
-
self.initial_lr = initial_lr
|
|
276
|
-
self.weight_decay = weight_decay
|
|
277
|
-
self.max_lr = max_lr
|
|
278
|
-
self.min_lr = min_lr
|
|
279
|
-
self.period = period
|
|
280
|
-
self.period_mul = period_mul
|
|
281
|
-
self.warmup_iters = warmup_iters
|
|
129
|
+
enable_lr_scheduler: bool = False
|
|
130
|
+
initial_lr: Optional[float] = None
|
|
131
|
+
weight_decay: float = 0.1
|
|
132
|
+
max_lr: Optional[float] = None
|
|
133
|
+
min_lr: Optional[float] = None
|
|
134
|
+
period: Optional[int] = None
|
|
135
|
+
period_mul: Optional[int] = None
|
|
136
|
+
warmup_iters: Optional[int] = None
|
|
282
137
|
|
|
283
138
|
|
|
139
|
+
@dataclass(kw_only=True)
|
|
284
140
|
class LossConfig:
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
critical_tokens: Optional[List[int]] = None,
|
|
289
|
-
critical_alpha: float = 1.0,
|
|
290
|
-
aux_loss_coef: Optional[float] = 1.0
|
|
291
|
-
):
|
|
292
|
-
super().__init__()
|
|
293
|
-
self.critical_tokens = critical_tokens
|
|
294
|
-
self.critical_alpha = critical_alpha
|
|
295
|
-
self.aux_loss_coef = aux_loss_coef
|
|
141
|
+
critical_tokens: Optional[List[int]] = None
|
|
142
|
+
critical_alpha: float = 1.0
|
|
143
|
+
aux_loss_coef: Optional[float] = 1.0
|
|
296
144
|
|
|
297
145
|
|
|
146
|
+
@dataclass(kw_only=True)
|
|
298
147
|
class DPOConfig:
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
loss_ipo: bool = False,
|
|
304
|
-
nll_loss_coef: Optional[float] = None
|
|
305
|
-
):
|
|
306
|
-
super().__init__()
|
|
307
|
-
self.loss_beta = loss_beta
|
|
308
|
-
self.loss_label_smoothing = loss_label_smoothing
|
|
309
|
-
self.loss_ipo = loss_ipo
|
|
310
|
-
self.nll_loss_coef = nll_loss_coef
|
|
148
|
+
loss_beta: float
|
|
149
|
+
loss_label_smoothing: float = 0.0
|
|
150
|
+
loss_ipo: bool = False
|
|
151
|
+
nll_loss_coef: Optional[float] = None
|
|
311
152
|
|
|
312
153
|
|
|
154
|
+
@dataclass(kw_only=True)
|
|
313
155
|
class GRPOConfig:
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
self.grpo_steps = grpo_steps
|
|
327
|
-
self.clip_eps = clip_eps
|
|
328
|
-
self.kl_weight = kl_weight
|
|
329
|
-
self.group_size = group_size
|
|
330
|
-
self.gen_max_new_tokens = gen_max_new_tokens
|
|
331
|
-
self.gen_temperature = gen_temperature
|
|
332
|
-
self.gen_k = gen_k
|
|
333
|
-
self.gen_p = gen_p
|
|
334
|
-
self.gen_suppress_tokens = gen_suppress_tokens
|
|
335
|
-
|
|
336
|
-
|
|
156
|
+
grpo_steps: int = 1
|
|
157
|
+
clip_eps: float = 0.2
|
|
158
|
+
kl_weight: float = 0.01
|
|
159
|
+
group_size: int = 12
|
|
160
|
+
gen_max_new_tokens: Optional[int] = None
|
|
161
|
+
gen_temperature: Optional[float] = None
|
|
162
|
+
gen_k: Optional[int] = None
|
|
163
|
+
gen_p: Optional[float] = None
|
|
164
|
+
gen_suppress_tokens: Optional[list[int]] = None
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dataclass(kw_only=True)
|
|
337
168
|
class KDConfig:
|
|
338
169
|
"""
|
|
339
170
|
知识蒸馏模式配置项
|
|
@@ -343,32 +174,20 @@ class KDConfig:
|
|
|
343
174
|
知识蒸馏教师模型logits的提供者
|
|
344
175
|
kd_coef (`float`, *optional*, default is 0.4):
|
|
345
176
|
蒸馏loss的占比,loss = kd_coef * kd_loss + (1 - kd_coef) * lm_loss
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
self,
|
|
350
|
-
*,
|
|
351
|
-
teacher_logits_provider: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
352
|
-
kd_coef: float = 0.4
|
|
353
|
-
):
|
|
354
|
-
self.teacher_logits_provider = teacher_logits_provider
|
|
355
|
-
self.kd_coef = kd_coef
|
|
177
|
+
"""
|
|
178
|
+
teacher_logits_provider: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
|
179
|
+
kd_coef: float = 0.4
|
|
356
180
|
|
|
357
181
|
|
|
182
|
+
@dataclass(kw_only=True)
|
|
358
183
|
class EvalConfig:
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
top_p: float = 0.95,
|
|
364
|
-
top_k: Optional[float] = None
|
|
365
|
-
):
|
|
366
|
-
self.max_new_tokens = max_new_tokens
|
|
367
|
-
self.temperature = temperature
|
|
368
|
-
self.top_p = top_p
|
|
369
|
-
self.top_k = top_k
|
|
184
|
+
max_new_tokens: int = 512
|
|
185
|
+
temperature: float = 1.0
|
|
186
|
+
top_p: float = 0.95
|
|
187
|
+
top_k: Optional[float] = None
|
|
370
188
|
|
|
371
189
|
|
|
190
|
+
@dataclass(kw_only=True)
|
|
372
191
|
class TrainConfig:
|
|
373
192
|
"""
|
|
374
193
|
训练参数配置项
|
|
@@ -399,51 +218,31 @@ class TrainConfig:
|
|
|
399
218
|
知识蒸馏配置项,为None时不使用知识蒸馏
|
|
400
219
|
pixel_values_provider: (`Callable[[list[str]], torch.Tensor]`, *Optional*, default is None):
|
|
401
220
|
训练vlm时根据image_tag提供pixel_values信息
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
init_state_dict: Optional[Mapping[str, Any]] = None,
|
|
425
|
-
eval_config: EvalConfig = EvalConfig(),
|
|
426
|
-
freeze_llm_model: bool = False
|
|
427
|
-
):
|
|
428
|
-
self.n_epochs = n_epochs
|
|
429
|
-
self.batch_size = batch_size
|
|
430
|
-
self.model_config = model_config
|
|
431
|
-
self.file_dataset = file_dataset
|
|
432
|
-
self.image_tags_file_dataset = image_tags_file_dataset
|
|
433
|
-
self.mask_prompt = mask_prompt
|
|
434
|
-
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
435
|
-
self.eval_batch_interval = eval_batch_interval
|
|
436
|
-
self.loss_config = loss_config
|
|
437
|
-
self.dpo_config = dpo_config
|
|
438
|
-
self.grpo_config = grpo_config
|
|
439
|
-
self.lr_config = lr_config
|
|
440
|
-
self.ds_config = ds_config
|
|
441
|
-
self.fsdp_config = fsdp_config
|
|
442
|
-
self.data_loader_config = data_loader_config
|
|
443
|
-
self.kd_config = kd_config
|
|
444
|
-
self.pixel_values_provider = pixel_values_provider
|
|
445
|
-
self.init_state_dict = init_state_dict
|
|
446
|
-
self.eval_config = eval_config
|
|
447
|
-
self.freeze_llm_model = freeze_llm_model
|
|
221
|
+
"""
|
|
222
|
+
n_epochs: int
|
|
223
|
+
batch_size: int
|
|
224
|
+
model_config: Union[ModelConfig, VLMConfig]
|
|
225
|
+
|
|
226
|
+
file_dataset: FileDataset
|
|
227
|
+
data_loader_config: DataLoaderConfig = DataLoaderConfig()
|
|
228
|
+
image_tags_file_dataset: Optional[FileDataset] = None
|
|
229
|
+
|
|
230
|
+
loss_config: LossConfig = LossConfig()
|
|
231
|
+
lr_config: LrConfig = LrConfig()
|
|
232
|
+
|
|
233
|
+
ds_config: DsConfig = DsConfig()
|
|
234
|
+
fsdp_config: FsdpConfig = FsdpConfig()
|
|
235
|
+
|
|
236
|
+
kd_config: Optional[KDConfig] = None
|
|
237
|
+
dpo_config: Optional[DPOConfig] = None
|
|
238
|
+
grpo_config: Optional[GRPOConfig] = None
|
|
239
|
+
|
|
240
|
+
mask_prompt: bool = True
|
|
241
|
+
gradient_accumulation_steps: int = 0
|
|
242
|
+
eval_batch_interval: int = 100
|
|
448
243
|
|
|
244
|
+
eval_config: EvalConfig = EvalConfig()
|
|
245
|
+
pixel_values_provider: Optional[Callable[[list[str]], torch.Tensor]] = None
|
|
449
246
|
|
|
247
|
+
init_state_dict: Optional[Mapping[str, Any]] = None
|
|
248
|
+
freeze_llm_model: bool = False
|
|
@@ -2,7 +2,7 @@ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
|
2
2
|
llm_trainer/checkpoint.py,sha256=yZcExxneN2yzvWxRiK-pstMWs35LV7GiOfqcLq-S6vc,5745
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
4
|
llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
|
|
5
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
5
|
+
llm_trainer/dpo_trainer.py,sha256=34E2b-t0GZYutaw6bESgARe9C12PUMWcY4aGZ34eAZU,13576
|
|
6
6
|
llm_trainer/ds_checkpoint.py,sha256=x_tjgJR47P8gVwV4qAnTUCGwx7eVq2Epw0vOVV7fkYo,4925
|
|
7
7
|
llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
|
|
8
8
|
llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
|
|
@@ -19,17 +19,17 @@ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
|
19
19
|
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
20
20
|
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
21
21
|
llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
|
|
22
|
-
llm_trainer/train_configs.py,sha256=
|
|
22
|
+
llm_trainer/train_configs.py,sha256=4sM96SOgwcn6jBGtbG5-qDZbJjiHVB6l7FWqdq7hbj0,7979
|
|
23
23
|
llm_trainer/trainer.py,sha256=pUtJVRosn54j1hn76CFAptJcAsrDo59H6p8NMkg2zt4,25521
|
|
24
24
|
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
25
|
-
project_llm_trainer-0.4.
|
|
26
|
-
project_llm_trainer-0.4.
|
|
27
|
-
project_llm_trainer-0.4.
|
|
28
|
-
project_llm_trainer-0.4.
|
|
29
|
-
project_llm_trainer-0.4.
|
|
30
|
-
project_llm_trainer-0.4.
|
|
31
|
-
project_llm_trainer-0.4.
|
|
32
|
-
project_llm_trainer-0.4.
|
|
33
|
-
project_llm_trainer-0.4.
|
|
34
|
-
project_llm_trainer-0.4.
|
|
35
|
-
project_llm_trainer-0.4.
|
|
25
|
+
project_llm_trainer-0.4.11.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
26
|
+
project_llm_trainer-0.4.11.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
27
|
+
project_llm_trainer-0.4.11.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
28
|
+
project_llm_trainer-0.4.11.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
29
|
+
project_llm_trainer-0.4.11.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
30
|
+
project_llm_trainer-0.4.11.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
31
|
+
project_llm_trainer-0.4.11.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
32
|
+
project_llm_trainer-0.4.11.dist-info/METADATA,sha256=JEZo2-np0t_K-J6yapyAXsArpvYTmrSNGDsdy32kWas,196
|
|
33
|
+
project_llm_trainer-0.4.11.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
34
|
+
project_llm_trainer-0.4.11.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
35
|
+
project_llm_trainer-0.4.11.dist-info/RECORD,,
|
{project_llm_trainer-0.4.10.data → project_llm_trainer-0.4.11.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|