project-llm-trainer 0.5.16__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dpo_trainer.py +17 -16
- llm_trainer/grpo_trainer.py +17 -13
- llm_trainer/loss.py +46 -27
- llm_trainer/parallel.py +1 -1
- llm_trainer/partition_utils.py +85 -4
- llm_trainer/train_configs.py +5 -2
- llm_trainer/trainer.py +11 -5
- {project_llm_trainer-0.5.16.dist-info → project_llm_trainer-0.6.0.dist-info}/METADATA +1 -1
- project_llm_trainer-0.6.0.dist-info/RECORD +33 -0
- project_llm_trainer-0.5.16.dist-info/RECORD +0 -33
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.6.0.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.6.0.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.6.0.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.6.0.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.6.0.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.6.0.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.6.0.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.16.dist-info → project_llm_trainer-0.6.0.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.16.dist-info → project_llm_trainer-0.6.0.dist-info}/top_level.txt +0 -0
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -35,28 +35,28 @@ class DPOTrainer(Trainer):
|
|
|
35
35
|
eval_image_tags=eval_image_tags
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
-
self.
|
|
38
|
+
self.ref_model = self._init_ref_model()
|
|
39
39
|
|
|
40
|
-
def
|
|
41
|
-
|
|
40
|
+
def _init_ref_model(self):
|
|
41
|
+
ref_model = self._new_model(self.train_config)
|
|
42
42
|
|
|
43
|
-
|
|
44
|
-
model=
|
|
43
|
+
ref_model, _ = TrainerTools().parallel.process(
|
|
44
|
+
model=ref_model,
|
|
45
45
|
optimizer=None,
|
|
46
|
-
kwargs=self.
|
|
46
|
+
kwargs=self._init_ref_model_args(),
|
|
47
47
|
save_instance=False
|
|
48
48
|
)
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
for param in
|
|
50
|
+
ref_model.eval()
|
|
51
|
+
for param in ref_model.parameters():
|
|
52
52
|
param.requires_grad = False
|
|
53
53
|
|
|
54
54
|
sync_model_params(
|
|
55
55
|
_from=self.train_model,
|
|
56
|
-
_to=
|
|
56
|
+
_to=ref_model
|
|
57
57
|
)
|
|
58
58
|
|
|
59
|
-
return
|
|
59
|
+
return ref_model
|
|
60
60
|
|
|
61
61
|
def _init_loss(self):
|
|
62
62
|
criterion = DPOLoss(
|
|
@@ -203,17 +203,18 @@ class DPOTrainer(Trainer):
|
|
|
203
203
|
|
|
204
204
|
with self.ctx:
|
|
205
205
|
policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
|
|
206
|
-
with torch.inference_mode():
|
|
207
|
-
ref_outputs = self.reference_model(concat_inputs, attention_mask=concat_mask)
|
|
208
|
-
|
|
209
206
|
policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
|
|
210
|
-
|
|
207
|
+
aux_loss = policy_outputs.get('aux_loss')
|
|
208
|
+
|
|
209
|
+
with torch.no_grad():
|
|
210
|
+
ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_mask)
|
|
211
|
+
ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
|
|
211
212
|
|
|
212
213
|
# calc loss
|
|
213
214
|
loss = self.criterion(policy_probs, ref_probs)
|
|
214
215
|
|
|
215
|
-
if aux_loss_coef and
|
|
216
|
-
loss += aux_loss_coef *
|
|
216
|
+
if aux_loss_coef and aux_loss:
|
|
217
|
+
loss += aux_loss_coef * aux_loss
|
|
217
218
|
|
|
218
219
|
if gradient_accumulation_steps > 1:
|
|
219
220
|
loss = loss / gradient_accumulation_steps
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -42,32 +42,36 @@ class GRPOTrainer(Trainer):
|
|
|
42
42
|
)
|
|
43
43
|
|
|
44
44
|
self.reward_func = reward_func
|
|
45
|
-
self.
|
|
45
|
+
self.ref_model = self._init_ref_model()
|
|
46
46
|
|
|
47
47
|
# 默认使用torch提供的pad_sequence
|
|
48
48
|
# 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
|
|
49
49
|
self._use_origin_pad_sequence = True
|
|
50
50
|
|
|
51
|
-
def
|
|
52
|
-
|
|
51
|
+
def _init_ref_model(self):
|
|
52
|
+
ref_model = self._new_model(self.train_config)
|
|
53
53
|
|
|
54
|
-
|
|
55
|
-
model=
|
|
54
|
+
ref_model, _ = TrainerTools().parallel.process(
|
|
55
|
+
model=ref_model,
|
|
56
56
|
optimizer=None,
|
|
57
|
-
kwargs=self.
|
|
57
|
+
kwargs=self._init_ref_model_args(),
|
|
58
58
|
save_instance=False
|
|
59
59
|
)
|
|
60
60
|
|
|
61
|
-
|
|
62
|
-
for param in
|
|
61
|
+
ref_model.eval()
|
|
62
|
+
for param in ref_model.parameters():
|
|
63
63
|
param.requires_grad = False
|
|
64
64
|
|
|
65
|
-
return
|
|
65
|
+
return ref_model
|
|
66
66
|
|
|
67
67
|
def _init_loss(self):
|
|
68
68
|
criterion = GRPOLoss(
|
|
69
|
-
|
|
70
|
-
|
|
69
|
+
beta=self.train_config.grpo_config.loss_beta,
|
|
70
|
+
clip_eps=self.train_config.grpo_config.loss_clip_eps,
|
|
71
|
+
delta=self.train_config.grpo_config.loss_delta,
|
|
72
|
+
importance_sampling_level=self.train_config.grpo_config.loss_importance_sampling_level,
|
|
73
|
+
loss_type=self.train_config.grpo_config.loss_type,
|
|
74
|
+
gen_max_new_tokens=self.train_config.grpo_config.gen_max_new_tokens
|
|
71
75
|
)
|
|
72
76
|
|
|
73
77
|
return criterion, None
|
|
@@ -225,7 +229,7 @@ class GRPOTrainer(Trainer):
|
|
|
225
229
|
old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
|
|
226
230
|
|
|
227
231
|
# Compute ref_log_probs from the reference model, which remains static.
|
|
228
|
-
ref_log_probs, _ = self._compute_log_probabilities(self.
|
|
232
|
+
ref_log_probs, _ = self._compute_log_probabilities(self.ref_model, input_ids, attention_mask, logits_to_keep)
|
|
229
233
|
|
|
230
234
|
repeated_prompts = [p for p in prompts for _ in range(group_size)]
|
|
231
235
|
repeated_answers = [a for a in answers for _ in range(group_size)]
|
|
@@ -290,7 +294,7 @@ class GRPOTrainer(Trainer):
|
|
|
290
294
|
for epoch in range(self.train_config.n_epochs):
|
|
291
295
|
sync_model_params(
|
|
292
296
|
_from=self.train_model,
|
|
293
|
-
_to=self.
|
|
297
|
+
_to=self.ref_model,
|
|
294
298
|
mixup_alpha=self.train_config.grpo_config.mixup_alpha
|
|
295
299
|
)
|
|
296
300
|
|
llm_trainer/loss.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import List, Optional
|
|
|
2
2
|
import torch
|
|
3
3
|
from torch import nn
|
|
4
4
|
import torch.nn.functional as F
|
|
5
|
+
from .tools import TrainerTools
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class LMLoss(nn.Module):
|
|
@@ -115,6 +116,7 @@ class DPOLoss(nn.Module):
|
|
|
115
116
|
)
|
|
116
117
|
|
|
117
118
|
loss = losses.mean()
|
|
119
|
+
|
|
118
120
|
# chosen_rewards = self.beta * (policy_chosen_probs - ref_chosen_probs).detach()
|
|
119
121
|
# rejected_rewards = self.beta * (policy_reject_probs - ref_reject_probs).detach()
|
|
120
122
|
|
|
@@ -124,12 +126,21 @@ class DPOLoss(nn.Module):
|
|
|
124
126
|
class GRPOLoss(nn.Module):
|
|
125
127
|
def __init__(
|
|
126
128
|
self,
|
|
129
|
+
beta: float,
|
|
127
130
|
clip_eps: float,
|
|
128
|
-
|
|
131
|
+
delta: Optional[float] = None,
|
|
132
|
+
importance_sampling_level: str = 'token',
|
|
133
|
+
loss_type: str = 'grpo',
|
|
134
|
+
gen_max_new_tokens: Optional[float] = None
|
|
129
135
|
):
|
|
130
136
|
super().__init__()
|
|
137
|
+
|
|
138
|
+
self.beta = beta
|
|
131
139
|
self.clip_eps = clip_eps
|
|
132
|
-
self.
|
|
140
|
+
self.delta = delta
|
|
141
|
+
self.importance_sampling_level = importance_sampling_level
|
|
142
|
+
self.loss_type = loss_type
|
|
143
|
+
self.gen_max_new_tokens = gen_max_new_tokens
|
|
133
144
|
|
|
134
145
|
def forward(
|
|
135
146
|
self,
|
|
@@ -139,33 +150,41 @@ class GRPOLoss(nn.Module):
|
|
|
139
150
|
completion_mask: torch.Tensor,
|
|
140
151
|
advantages: torch.Tensor
|
|
141
152
|
) -> torch.Tensor:
|
|
142
|
-
# Compute policy ratio
|
|
143
|
-
ratio = torch.exp(log_probs - old_log_probs)
|
|
144
153
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
154
|
+
if self.beta != 0.0:
|
|
155
|
+
per_token_kl = torch.exp(ref_log_probs - log_probs) - (ref_log_probs - log_probs) - 1
|
|
156
|
+
else:
|
|
157
|
+
per_token_kl = None
|
|
158
|
+
|
|
159
|
+
log_ratio = log_probs - old_log_probs
|
|
160
|
+
if self.importance_sampling_level == "seq":
|
|
161
|
+
# GSPO
|
|
162
|
+
log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
|
|
163
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
|
164
|
+
else:
|
|
165
|
+
# GRPO
|
|
166
|
+
log_importance_weights = log_ratio
|
|
149
167
|
|
|
150
|
-
|
|
151
|
-
|
|
168
|
+
coef_1 = torch.exp(log_importance_weights)
|
|
169
|
+
coef_2 = torch.clamp(coef_1, 1 - self.clip_eps, 1 + self.clip_eps)
|
|
152
170
|
|
|
153
|
-
#
|
|
154
|
-
|
|
155
|
-
|
|
171
|
+
# Two-sided clipping
|
|
172
|
+
if self.delta is not None:
|
|
173
|
+
coef_1 = torch.clamp(coef_1, max=self.delta)
|
|
156
174
|
|
|
157
|
-
|
|
175
|
+
per_token_loss1 = coef_1 * advantages
|
|
176
|
+
per_token_loss2 = coef_2 * advantages
|
|
177
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
|
178
|
+
|
|
179
|
+
if self.beta != 0.0:
|
|
180
|
+
per_token_loss = per_token_loss + self.beta * per_token_kl
|
|
158
181
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
# loss = -torch.min(surr1, surr2) + self.kl_weight * kl
|
|
169
|
-
#
|
|
170
|
-
# loss = self._masked_mean(loss, mask, dim=-1).mean()
|
|
171
|
-
# return loss, kl.mean()
|
|
182
|
+
if self.loss_type == "bnpo":
|
|
183
|
+
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
|
|
184
|
+
elif self.loss_type == "dr_grpo":
|
|
185
|
+
assert self.gen_max_new_tokens is not None
|
|
186
|
+
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.gen_max_new_tokens)
|
|
187
|
+
else:
|
|
188
|
+
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
|
|
189
|
+
|
|
190
|
+
return loss
|
llm_trainer/parallel.py
CHANGED
llm_trainer/partition_utils.py
CHANGED
|
@@ -4,6 +4,7 @@ import itertools
|
|
|
4
4
|
from packaging import version
|
|
5
5
|
from torch import nn
|
|
6
6
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
7
|
+
import torch.distributed as dist
|
|
7
8
|
|
|
8
9
|
from .tools import TrainerTools
|
|
9
10
|
from .parallel_ds import DsParallel
|
|
@@ -45,12 +46,40 @@ def unwrap_model_for_generation(model: nn.Module):
|
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
def sync_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
49
|
+
"""
|
|
50
|
+
必须在所有rank上调用,非rank0, _to 可以设置为None.
|
|
51
|
+
当前函数不适用于_to是一个zero3模型
|
|
52
|
+
"""
|
|
48
53
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
49
|
-
|
|
50
|
-
elif isinstance(
|
|
51
|
-
|
|
54
|
+
state_dict = _get_ds_model_params(_from, only_rank0=_to is None)
|
|
55
|
+
elif isinstance(_from, DDP):
|
|
56
|
+
state_dict = _from.module.state_dict()
|
|
57
|
+
else:
|
|
58
|
+
state_dict = _from.state_dict()
|
|
59
|
+
|
|
60
|
+
if not _to or not state_dict:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
unwrap_to_model = unwrap_model(_to)
|
|
64
|
+
if mixup_alpha == 1.0:
|
|
65
|
+
# 直接覆盖
|
|
66
|
+
unwrap_to_model.load_state_dict(state_dict)
|
|
52
67
|
else:
|
|
53
|
-
|
|
68
|
+
# 混合参数
|
|
69
|
+
for param_name, target_param in unwrap_to_model.named_parameters():
|
|
70
|
+
if param_name in state_dict:
|
|
71
|
+
from_param_tensor = state_dict[param_name]
|
|
72
|
+
target_param.data.mul_(1.0 - mixup_alpha).add_(
|
|
73
|
+
from_param_tensor.data.to(target_param.device),
|
|
74
|
+
alpha=mixup_alpha
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# if isinstance(TrainerTools().parallel, DsParallel):
|
|
78
|
+
# _sync_ds_model_params(_from, _to, mixup_alpha)
|
|
79
|
+
# elif isinstance(TrainerTools().parallel, DdpParallel):
|
|
80
|
+
# _sync_ddp_model_params(_from, _to, mixup_alpha)
|
|
81
|
+
# else:
|
|
82
|
+
# _copy_params(_from, _to, mixup_alpha)
|
|
54
83
|
|
|
55
84
|
|
|
56
85
|
def unwrap_model(model) -> nn.Module:
|
|
@@ -66,6 +95,57 @@ def unwrap_model(model) -> nn.Module:
|
|
|
66
95
|
return model
|
|
67
96
|
|
|
68
97
|
|
|
98
|
+
def _get_ds_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
|
|
99
|
+
"""
|
|
100
|
+
需要在所有rank上调用,然后只有rank0有值
|
|
101
|
+
"""
|
|
102
|
+
import deepspeed
|
|
103
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
104
|
+
|
|
105
|
+
if model.zero_optimization_stage() != 3:
|
|
106
|
+
if TrainerTools().parallel.is_main_process:
|
|
107
|
+
return {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
# --- ZeRO-3 ---
|
|
111
|
+
# 只调用一次 GatheredParameters,传入所有参数
|
|
112
|
+
with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
|
|
113
|
+
if TrainerTools().parallel.is_main_process:
|
|
114
|
+
# 在这个 'with' 代码块内,rank 0 上的 model.module 拥有完整的参数
|
|
115
|
+
# 所以我们可以像操作普通模型一样直接调用 state_dict()
|
|
116
|
+
full_state_dict = model.module.state_dict()
|
|
117
|
+
|
|
118
|
+
# 将其克隆到 CPU 并返回
|
|
119
|
+
return {k: v.cpu().clone() for k, v in full_state_dict.items()}
|
|
120
|
+
|
|
121
|
+
# 其他 rank 执行到这里时,上下文结束,直接返回 None
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _get_ds_model_params(model: nn.Module, only_rank0=False):
|
|
126
|
+
"""
|
|
127
|
+
从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
|
|
128
|
+
兼容 ZeRO Stages 0, 1, 2, 3。
|
|
129
|
+
包含了对 ZeRO-3 中分片参数的正确处理。
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
import deepspeed
|
|
133
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
134
|
+
state_dict = _get_ds_full_state_dict_on_rank0(model)
|
|
135
|
+
|
|
136
|
+
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
137
|
+
# 我们需要将其广播给所有进程。
|
|
138
|
+
if not only_rank0 and TrainerTools().parallel.world_size > 1:
|
|
139
|
+
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
140
|
+
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
141
|
+
# 执行广播,这个操作是阻塞的,会同步所有进程
|
|
142
|
+
dist.broadcast_object_list(object_list, src=0)
|
|
143
|
+
# 所有进程从列表中获取广播后的 state_dict 副本
|
|
144
|
+
state_dict = object_list[0]
|
|
145
|
+
|
|
146
|
+
return state_dict
|
|
147
|
+
|
|
148
|
+
|
|
69
149
|
def _copy_params(model, target_model, mixup_alpha):
|
|
70
150
|
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
|
|
71
151
|
target_param.data.mul_(1.0 - mixup_alpha).add_(copy_param.data, alpha=mixup_alpha)
|
|
@@ -79,6 +159,7 @@ def _sync_ds_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alph
|
|
|
79
159
|
|
|
80
160
|
if _from.zero_optimization_stage() == 3:
|
|
81
161
|
with deepspeed.zero.GatheredParameters(list(origin_from.parameters()) + list(_to.parameters()), modifier_rank=0):
|
|
162
|
+
# why only rank 0?
|
|
82
163
|
if TrainerTools().parallel.is_main_process:
|
|
83
164
|
_copy_params(origin_from, _to, mixup_alpha)
|
|
84
165
|
else:
|
llm_trainer/train_configs.py
CHANGED
|
@@ -136,10 +136,13 @@ class DPOConfig:
|
|
|
136
136
|
@dataclass(kw_only=True)
|
|
137
137
|
class GRPOConfig:
|
|
138
138
|
grpo_steps: int = 1
|
|
139
|
-
clip_eps: float = 0.1
|
|
140
|
-
kl_weight: float = 0.04
|
|
141
139
|
group_size: int = 12
|
|
142
140
|
mixup_alpha: float = 1.0
|
|
141
|
+
loss_beta: float = 0.04
|
|
142
|
+
loss_clip_eps: float = 0.1
|
|
143
|
+
loss_delta: Optional[float] = None
|
|
144
|
+
loss_importance_sampling_level: str = 'token' # token or seq
|
|
145
|
+
loss_type: str = 'grpo' # grpo or bnpo or dr_grpo
|
|
143
146
|
gen_max_new_tokens: Optional[int] = None
|
|
144
147
|
gen_temperature: Optional[float] = None
|
|
145
148
|
gen_k: Optional[int] = None
|
llm_trainer/trainer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import time
|
|
2
1
|
from contextlib import nullcontext
|
|
3
2
|
from typing import Optional, Tuple, List, Dict, Any
|
|
3
|
+
import copy
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.distributed as dist
|
|
@@ -65,6 +65,7 @@ class Trainer:
|
|
|
65
65
|
assert len(self.eval_prompts) == len(self.eval_image_tags)
|
|
66
66
|
|
|
67
67
|
parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = self._convert_train_args()
|
|
68
|
+
self.parallel_kwargs = parallel_kwargs
|
|
68
69
|
self.data_loader_kwargs: dict[str, Any] = data_loader_kwargs
|
|
69
70
|
self.sampler_kwargs: dict[str, Any] = sampler_kwargs
|
|
70
71
|
|
|
@@ -323,8 +324,8 @@ class Trainer:
|
|
|
323
324
|
|
|
324
325
|
return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
|
|
325
326
|
|
|
326
|
-
def
|
|
327
|
-
parallel_kwargs
|
|
327
|
+
def _init_ref_model_args(self) -> dict:
|
|
328
|
+
parallel_kwargs = copy.deepcopy(self.parallel_kwargs)
|
|
328
329
|
|
|
329
330
|
if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
|
|
330
331
|
# reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
|
|
@@ -346,8 +347,13 @@ class Trainer:
|
|
|
346
347
|
# }
|
|
347
348
|
# )
|
|
348
349
|
|
|
349
|
-
|
|
350
|
-
|
|
350
|
+
parallel_kwargs.pop('activation_checkpointing', None)
|
|
351
|
+
parallel_kwargs.pop('gradient_clipping', None)
|
|
352
|
+
|
|
353
|
+
# ref_model暂时先使用stage 0, 解决训练卡住问题
|
|
354
|
+
parallel_kwargs["zero_optimization"] = {"stage": 0}
|
|
355
|
+
# if parallel_kwargs.get("zero_optimization", {}).get("stage", 0) != 3:
|
|
356
|
+
# parallel_kwargs["zero_optimization"] = {"stage": 0}
|
|
351
357
|
|
|
352
358
|
return parallel_kwargs
|
|
353
359
|
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
|
|
3
|
+
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=mETXpU1ZSasg1UM72wnh9NaoTuXBibuNuodfuW7u8Iw,12269
|
|
5
|
+
llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
|
|
6
|
+
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
|
+
llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=-wbozslll_bcGUMqrbS0a73jhosyjc3oC3PHLSev6lw,16344
|
|
9
|
+
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
|
+
llm_trainer/loss.py,sha256=eYvOlCoguKnLvdGuqvQpGUoLVSADQ5coaU3DWYbJEdM,6811
|
|
11
|
+
llm_trainer/parallel.py,sha256=G9X0FddIJwd9j-5XOknB4AlBe4G2W6fUCaQH6ycC2Fo,4490
|
|
12
|
+
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
|
+
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
|
+
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
|
+
llm_trainer/partition_utils.py,sha256=eEYNhfEIF4hGzZ3OLa6sEBIECz261drptEz_n7fZYtk,8396
|
|
16
|
+
llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
17
|
+
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
|
+
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
|
+
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
|
+
llm_trainer/train_configs.py,sha256=U4hwXWKI6svDqiDOu6RPTitCzpxEYyjZUN6gwh_co8c,7510
|
|
21
|
+
llm_trainer/trainer.py,sha256=Q821nlLDKRZVpaRoiZ7DiJplpAJRRLtvR_33FbClGA0,26729
|
|
22
|
+
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
+
project_llm_trainer-0.6.0.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.6.0.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.6.0.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.6.0.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.6.0.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.6.0.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.6.0.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.6.0.dist-info/METADATA,sha256=_F0QQHrdQNGXG8eDGRDsgEvdX6fYWXSDg5Ad089CXHk,195
|
|
31
|
+
project_llm_trainer-0.6.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.6.0.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.6.0.dist-info/RECORD,,
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
|
|
3
|
-
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=--ItH-rkkq24Da3M_Kf0VxpQ3t-k0fpZrzFGqkYsjks,12304
|
|
5
|
-
llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
|
|
6
|
-
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
|
-
llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=g_ivzQop2SkvhlKAEWb0zUnIvNuHTfsOoIG6y29oTCw,16106
|
|
9
|
-
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
|
-
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
|
-
llm_trainer/parallel.py,sha256=j1L4n-JmDkDZblURrNKpEAWEqqGIAXAN9PT_fSS_OnE,4492
|
|
12
|
-
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
|
-
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
|
-
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
|
-
llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
|
|
16
|
-
llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
17
|
-
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
|
-
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
|
-
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
|
-
llm_trainer/train_configs.py,sha256=992wy0YhBG2WvxwdLEPL4_-JUl4NkwMPT-jj_BIHo6A,7347
|
|
21
|
-
llm_trainer/trainer.py,sha256=YqWhD9jXbrUdm3KEjEHLyg_qHiXCy5R7PK-arCXxJ6M,26399
|
|
22
|
-
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.16.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
-
project_llm_trainer-0.5.16.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
-
project_llm_trainer-0.5.16.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
-
project_llm_trainer-0.5.16.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
-
project_llm_trainer-0.5.16.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
-
project_llm_trainer-0.5.16.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
-
project_llm_trainer-0.5.16.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
-
project_llm_trainer-0.5.16.dist-info/METADATA,sha256=h0TMNrZMUU875tVasbuqt69EuPPMbo_nv6tHQLKeNbQ,196
|
|
31
|
-
project_llm_trainer-0.5.16.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
-
project_llm_trainer-0.5.16.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
-
project_llm_trainer-0.5.16.dist-info/RECORD,,
|
{project_llm_trainer-0.5.16.data → project_llm_trainer-0.6.0.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|