project-llm-trainer 0.4.15__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +0 -50
- llm_trainer/dpo_trainer.py +6 -3
- llm_trainer/eval.py +3 -30
- llm_trainer/generate_utils.py +9 -74
- llm_trainer/grpo_trainer.py +27 -28
- llm_trainer/loss.py +1 -1
- llm_trainer/partition_utils.py +146 -0
- llm_trainer/tokenizer.py +10 -10
- llm_trainer/tools.py +0 -2
- llm_trainer/train_configs.py +5 -25
- llm_trainer/trainer.py +28 -67
- llm_trainer/utils.py +0 -1
- {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/METADATA +1 -1
- project_llm_trainer-0.5.1.dist-info/RECORD +33 -0
- llm_trainer/dcp.py +0 -93
- llm_trainer/ds_model_params.py +0 -72
- llm_trainer/fsdp_checkpoint.py +0 -52
- llm_trainer/fsdp_model_params.py +0 -39
- llm_trainer/model_params.py +0 -28
- llm_trainer/parallel_fsdp.py +0 -121
- project_llm_trainer-0.4.15.dist-info/RECORD +0 -38
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -6,35 +6,11 @@ from torch.optim import Optimizer
|
|
|
6
6
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
7
7
|
|
|
8
8
|
from .parallel_ds import DsParallel
|
|
9
|
-
from .parallel_fsdp import FsdpParallel
|
|
10
|
-
from .parallel_ddp import DdpParallel
|
|
11
9
|
from .scheduler import LRScheduler
|
|
12
10
|
from .tools import TrainerTools
|
|
13
11
|
|
|
14
|
-
try:
|
|
15
|
-
from .dcp import save_dcp, load_dcp, convert_dcp_to_pth
|
|
16
|
-
except:
|
|
17
|
-
os.environ['ENABLE_DCP'] = "0"
|
|
18
|
-
|
|
19
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
20
|
-
|
|
21
|
-
# https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
|
|
22
|
-
|
|
23
12
|
DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
|
|
24
13
|
|
|
25
|
-
|
|
26
|
-
def _can_use_dcp(model: nn.Module) -> bool:
|
|
27
|
-
if os.environ.get('ENABLE_DCP', '1') != '1':
|
|
28
|
-
return False
|
|
29
|
-
|
|
30
|
-
# 如果是fsdp或者ddp,才能使用dcp保存
|
|
31
|
-
if (isinstance(TrainerTools().parallel, FsdpParallel)
|
|
32
|
-
or isinstance(TrainerTools().parallel, DdpParallel)):
|
|
33
|
-
return True
|
|
34
|
-
|
|
35
|
-
return False
|
|
36
|
-
|
|
37
|
-
|
|
38
14
|
def save_checkpoint(
|
|
39
15
|
model: nn.Module,
|
|
40
16
|
optimizer: Optional[Optimizer] = None,
|
|
@@ -43,11 +19,6 @@ def save_checkpoint(
|
|
|
43
19
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
44
20
|
from .ds_checkpoint import save_ds_checkpoint
|
|
45
21
|
save_ds_checkpoint(model, suffix)
|
|
46
|
-
elif _can_use_dcp(model):
|
|
47
|
-
save_dcp(model, optimizer, suffix)
|
|
48
|
-
elif isinstance(model, FSDP):
|
|
49
|
-
from .fsdp_checkpoint import save_fsdp_checkpoint
|
|
50
|
-
save_fsdp_checkpoint(model, optimizer, suffix)
|
|
51
22
|
else:
|
|
52
23
|
if TrainerTools().parallel.is_main_process:
|
|
53
24
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
@@ -73,11 +44,6 @@ def load_checkpoint(
|
|
|
73
44
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
74
45
|
from .ds_checkpoint import load_ds_checkpoint
|
|
75
46
|
load_ds_checkpoint(model, load_module_only=load_module_only, suffix=suffix)
|
|
76
|
-
elif _can_use_dcp(model):
|
|
77
|
-
load_dcp(model, optimizer, suffix)
|
|
78
|
-
elif isinstance(model, FSDP):
|
|
79
|
-
from .fsdp_checkpoint import load_fsdp_checkpoint
|
|
80
|
-
load_fsdp_checkpoint(model, optimizer, device, suffix)
|
|
81
47
|
else:
|
|
82
48
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
83
49
|
if suffix:
|
|
@@ -99,22 +65,6 @@ def load_checkpoint_for_eval(
|
|
|
99
65
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
100
66
|
from .ds_checkpoint import load_ds_checkpoint_for_eval
|
|
101
67
|
load_ds_checkpoint_for_eval(model)
|
|
102
|
-
elif _can_use_dcp(model):
|
|
103
|
-
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
104
|
-
|
|
105
|
-
# load_dcp方式在cpu上会报错,所以改为先将ckpt转换为pth,然后再加载pth
|
|
106
|
-
# load_dcp(model, optimizer)
|
|
107
|
-
pth_name = os.environ.get('EVAL_CHECKPOINT_NAME', checkpoint_name)
|
|
108
|
-
if suffix:
|
|
109
|
-
pth_name = f'{pth_name}_{suffix}'
|
|
110
|
-
|
|
111
|
-
convert_dcp_to_pth(pth_name)
|
|
112
|
-
|
|
113
|
-
if os.path.exists(pth_name):
|
|
114
|
-
ckpt = torch.load(pth_name, map_location=device, weights_only=True)
|
|
115
|
-
model.load_state_dict(ckpt['app']['model_state_dict'])
|
|
116
|
-
# 使用完删除
|
|
117
|
-
os.remove(pth_name)
|
|
118
68
|
else:
|
|
119
69
|
load_checkpoint(model, None, device, suffix=suffix)
|
|
120
70
|
|
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -12,7 +12,7 @@ from .dataset import DPODataset
|
|
|
12
12
|
from .loss import DPOLoss
|
|
13
13
|
from .tools import TrainerTools
|
|
14
14
|
from .utils import get_dpo_collate_fn
|
|
15
|
-
from .
|
|
15
|
+
from .partition_utils import sync_model_params
|
|
16
16
|
|
|
17
17
|
from .checkpoint import (
|
|
18
18
|
save_checkpoint,
|
|
@@ -38,7 +38,6 @@ class DPOTrainer(Trainer):
|
|
|
38
38
|
|
|
39
39
|
def _init_reference_model(self):
|
|
40
40
|
reference_model = self._new_model(self.train_config)
|
|
41
|
-
copy_model_params(_from=self.train_model, _to=reference_model)
|
|
42
41
|
|
|
43
42
|
reference_model, _ = TrainerTools().parallel.process(
|
|
44
43
|
model=reference_model,
|
|
@@ -51,6 +50,11 @@ class DPOTrainer(Trainer):
|
|
|
51
50
|
for param in reference_model.parameters():
|
|
52
51
|
param.requires_grad = False
|
|
53
52
|
|
|
53
|
+
sync_model_params(
|
|
54
|
+
_from=self.train_model,
|
|
55
|
+
_to=reference_model
|
|
56
|
+
)
|
|
57
|
+
|
|
54
58
|
return reference_model
|
|
55
59
|
|
|
56
60
|
def _init_loss(self):
|
|
@@ -210,7 +214,6 @@ class DPOTrainer(Trainer):
|
|
|
210
214
|
if need_update_grad:
|
|
211
215
|
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
212
216
|
|
|
213
|
-
# todo check all_reduce??
|
|
214
217
|
if TrainerTools().parallel.parallel_train:
|
|
215
218
|
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
216
219
|
|
llm_trainer/eval.py
CHANGED
|
@@ -5,16 +5,14 @@ from .log import get_log_dir
|
|
|
5
5
|
from .tools import TrainerTools
|
|
6
6
|
from .train_configs import EvalConfig
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
def _eval_task(
|
|
8
|
+
def submit_gen_task(
|
|
10
9
|
eval_model: torch.nn.Module,
|
|
11
10
|
eval_config: EvalConfig,
|
|
12
11
|
tag,
|
|
13
12
|
prompt,
|
|
14
13
|
pixel_values,
|
|
15
14
|
max_position_embeddings,
|
|
16
|
-
tokens_per_image
|
|
17
|
-
device
|
|
15
|
+
tokens_per_image
|
|
18
16
|
):
|
|
19
17
|
log_dir = get_log_dir()
|
|
20
18
|
|
|
@@ -28,33 +26,8 @@ def _eval_task(
|
|
|
28
26
|
p=eval_config.top_p,
|
|
29
27
|
pixel_values=pixel_values,
|
|
30
28
|
tokens_per_image=tokens_per_image,
|
|
31
|
-
device=device
|
|
29
|
+
device=TrainerTools().parallel.device
|
|
32
30
|
)
|
|
33
31
|
|
|
34
32
|
with open(f'{log_dir}gen.txt', 'a') as f:
|
|
35
33
|
f.write(f"{tag}, gen->{gen_result}\n")
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def submit_gen_task(
|
|
39
|
-
eval_model: torch.nn.Module,
|
|
40
|
-
eval_config: EvalConfig,
|
|
41
|
-
tag,
|
|
42
|
-
prompt,
|
|
43
|
-
pixel_values,
|
|
44
|
-
max_position_embeddings,
|
|
45
|
-
tokens_per_image
|
|
46
|
-
):
|
|
47
|
-
eval_model.to(TrainerTools().parallel.device)
|
|
48
|
-
_eval_task(
|
|
49
|
-
eval_model=eval_model,
|
|
50
|
-
eval_config=eval_config,
|
|
51
|
-
tag=tag,
|
|
52
|
-
prompt=prompt,
|
|
53
|
-
pixel_values=pixel_values,
|
|
54
|
-
max_position_embeddings=max_position_embeddings,
|
|
55
|
-
tokens_per_image=tokens_per_image,
|
|
56
|
-
device=TrainerTools().parallel.device
|
|
57
|
-
)
|
|
58
|
-
eval_model.to('cpu')
|
|
59
|
-
|
|
60
|
-
# threading.Thread(target=_eval_task, args=args).start()
|
llm_trainer/generate_utils.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Union, Optional, List
|
|
2
2
|
from contextlib import nullcontext
|
|
3
3
|
import torch
|
|
4
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
5
4
|
from llm_model import VlmModel, KVCache
|
|
6
5
|
from .tools import TrainerTools
|
|
7
6
|
from .utils import batch_repeat_image_tok
|
|
@@ -108,8 +107,7 @@ def _generate(
|
|
|
108
107
|
pixel_values: Optional[torch.Tensor] = None,
|
|
109
108
|
tokens_per_image: int = -1,
|
|
110
109
|
suppress_tokens: Optional[List[int]] = None,
|
|
111
|
-
device: Union[str, torch.device, int]
|
|
112
|
-
reasoning_budget: Optional[int] = None
|
|
110
|
+
device: Union[str, torch.device, int]
|
|
113
111
|
):
|
|
114
112
|
"""
|
|
115
113
|
:param model:
|
|
@@ -131,8 +129,7 @@ def _generate(
|
|
|
131
129
|
device_type=device,
|
|
132
130
|
dtype=TrainerTools().dtype,
|
|
133
131
|
enabled=True,
|
|
134
|
-
|
|
135
|
-
cache_enabled=False if isinstance(model, FSDP) else None
|
|
132
|
+
cache_enabled=None
|
|
136
133
|
) if TrainerTools().use_amp else nullcontext()
|
|
137
134
|
|
|
138
135
|
if isinstance(model, VlmModel):
|
|
@@ -144,28 +141,6 @@ def _generate(
|
|
|
144
141
|
kv_cache: Optional[KVCache] = None
|
|
145
142
|
generate_tokens = tokens.clone()
|
|
146
143
|
|
|
147
|
-
reasoning_start = TrainerTools().tokenizer.reasoning_start
|
|
148
|
-
reasoning_end = TrainerTools().tokenizer.reasoning_end
|
|
149
|
-
|
|
150
|
-
# --- 状态初始化 ---
|
|
151
|
-
in_reasoning_block = False
|
|
152
|
-
reasoning_step_count = 0
|
|
153
|
-
# “冷静期”标志位。当强制结束思考后,在下一步抑制<reasoning>的生成。
|
|
154
|
-
suppress_reasoning_start_next = False
|
|
155
|
-
|
|
156
|
-
if reasoning_budget is not None:
|
|
157
|
-
prompt_tokens = tokens[0]
|
|
158
|
-
start_indices = (prompt_tokens == reasoning_start).nonzero(as_tuple=True)[0]
|
|
159
|
-
end_indices = (prompt_tokens == reasoning_end).nonzero(as_tuple=True)[0]
|
|
160
|
-
|
|
161
|
-
last_start_idx = start_indices[-1].item() if len(start_indices) > 0 else -1
|
|
162
|
-
last_end_idx = end_indices[-1].item() if len(end_indices) > 0 else -1
|
|
163
|
-
|
|
164
|
-
if last_start_idx > last_end_idx:
|
|
165
|
-
in_reasoning_block = True
|
|
166
|
-
reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
|
|
167
|
-
|
|
168
|
-
model.eval()
|
|
169
144
|
with torch.inference_mode():
|
|
170
145
|
for _ in range(max_new_tokens):
|
|
171
146
|
# 是否需要截取??
|
|
@@ -185,23 +160,6 @@ def _generate(
|
|
|
185
160
|
# (batch, vocab_size)
|
|
186
161
|
logits = logits[:, -1, :]
|
|
187
162
|
|
|
188
|
-
# --- 推理预算逻辑 ---
|
|
189
|
-
force_end_reasoning_token = False
|
|
190
|
-
if reasoning_budget is not None:
|
|
191
|
-
# 检查是否需要在此步抑制 <reasoning>
|
|
192
|
-
should_suppress_this_step = suppress_reasoning_start_next
|
|
193
|
-
suppress_reasoning_start_next = False # 立即重置标志位
|
|
194
|
-
|
|
195
|
-
# 修改: 检查是否超出预算
|
|
196
|
-
if in_reasoning_block and reasoning_step_count >= reasoning_budget:
|
|
197
|
-
force_end_reasoning_token = True
|
|
198
|
-
# 设置标志位,在下一步抑制 <reasoning>
|
|
199
|
-
suppress_reasoning_start_next = True
|
|
200
|
-
|
|
201
|
-
# 如果上一轮设置了抑制标志,则在此轮执行抑制
|
|
202
|
-
if should_suppress_this_step:
|
|
203
|
-
logits[:, reasoning_start] = -float("inf")
|
|
204
|
-
|
|
205
163
|
# 抑制特殊token输出
|
|
206
164
|
if suppress_tokens and len(suppress_tokens) != 0:
|
|
207
165
|
logits = _suppress_warper(logits, suppress_tokens)
|
|
@@ -217,10 +175,6 @@ def _generate(
|
|
|
217
175
|
if p and 0 < p <= 1:
|
|
218
176
|
logits = _top_p_warper(logits, p)
|
|
219
177
|
|
|
220
|
-
if force_end_reasoning_token:
|
|
221
|
-
logits[:] = -float("inf")
|
|
222
|
-
logits[:, reasoning_end] = 0.0
|
|
223
|
-
|
|
224
178
|
if multinomial:
|
|
225
179
|
prob = logits.softmax(dim=-1)
|
|
226
180
|
# 返回下标
|
|
@@ -229,18 +183,6 @@ def _generate(
|
|
|
229
183
|
# 返回下标
|
|
230
184
|
next_token = logits.argmax(dim=-1, keepdim=True)
|
|
231
185
|
|
|
232
|
-
if reasoning_budget is not None:
|
|
233
|
-
current_token_id = next_token.item()
|
|
234
|
-
if not in_reasoning_block and current_token_id == reasoning_start:
|
|
235
|
-
in_reasoning_block = True
|
|
236
|
-
reasoning_step_count = 0
|
|
237
|
-
elif in_reasoning_block:
|
|
238
|
-
if current_token_id == reasoning_end:
|
|
239
|
-
in_reasoning_block = False
|
|
240
|
-
reasoning_step_count = 0
|
|
241
|
-
else:
|
|
242
|
-
reasoning_step_count += 1
|
|
243
|
-
|
|
244
186
|
# token, is_full_result
|
|
245
187
|
yield next_token, False
|
|
246
188
|
|
|
@@ -269,8 +211,7 @@ def _streaming_generate(
|
|
|
269
211
|
pixel_values: Optional[torch.Tensor] = None,
|
|
270
212
|
tokens_per_image: int = -1,
|
|
271
213
|
suppress_tokens: Optional[List[int]] = None,
|
|
272
|
-
device: Union[str, torch.device, int] = None
|
|
273
|
-
reasoning_budget: Optional[int] = None
|
|
214
|
+
device: Union[str, torch.device, int] = None
|
|
274
215
|
):
|
|
275
216
|
device = TrainerTools().parallel.device if not device else device
|
|
276
217
|
encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
|
|
@@ -286,8 +227,7 @@ def _streaming_generate(
|
|
|
286
227
|
pixel_values=pixel_values,
|
|
287
228
|
tokens_per_image=tokens_per_image,
|
|
288
229
|
suppress_tokens=suppress_tokens,
|
|
289
|
-
device=device
|
|
290
|
-
reasoning_budget=reasoning_budget
|
|
230
|
+
device=device
|
|
291
231
|
)
|
|
292
232
|
|
|
293
233
|
for (token, is_full_result) in generate_text_iterator:
|
|
@@ -306,8 +246,7 @@ def streaming_generate(
|
|
|
306
246
|
pixel_values: Optional[torch.Tensor] = None,
|
|
307
247
|
tokens_per_image: int = -1,
|
|
308
248
|
suppress_tokens: Optional[List[int]] = None,
|
|
309
|
-
device: Union[str, torch.device, int] = None
|
|
310
|
-
reasoning_budget: Optional[int] = None
|
|
249
|
+
device: Union[str, torch.device, int] = None
|
|
311
250
|
):
|
|
312
251
|
text_iterator = _streaming_generate(
|
|
313
252
|
model=model,
|
|
@@ -320,8 +259,7 @@ def streaming_generate(
|
|
|
320
259
|
pixel_values=pixel_values,
|
|
321
260
|
tokens_per_image=tokens_per_image,
|
|
322
261
|
suppress_tokens=suppress_tokens,
|
|
323
|
-
device=device
|
|
324
|
-
reasoning_budget=reasoning_budget
|
|
262
|
+
device=device
|
|
325
263
|
)
|
|
326
264
|
|
|
327
265
|
for (token, is_full_result) in text_iterator:
|
|
@@ -341,8 +279,7 @@ def generate(
|
|
|
341
279
|
pixel_values: Optional[torch.Tensor] = None,
|
|
342
280
|
tokens_per_image: int = -1,
|
|
343
281
|
suppress_tokens: Optional[List[int]] = None,
|
|
344
|
-
device: Union[str, torch.device, int] = None
|
|
345
|
-
reasoning_budget: Optional[int] = None
|
|
282
|
+
device: Union[str, torch.device, int] = None
|
|
346
283
|
):
|
|
347
284
|
text_iterator = _streaming_generate(
|
|
348
285
|
model=model,
|
|
@@ -355,8 +292,7 @@ def generate(
|
|
|
355
292
|
suppress_tokens=suppress_tokens,
|
|
356
293
|
pixel_values=pixel_values,
|
|
357
294
|
tokens_per_image=tokens_per_image,
|
|
358
|
-
device=device
|
|
359
|
-
reasoning_budget=reasoning_budget
|
|
295
|
+
device=device
|
|
360
296
|
)
|
|
361
297
|
|
|
362
298
|
for (token, is_full_result) in text_iterator:
|
|
@@ -386,7 +322,7 @@ def batch_generate(
|
|
|
386
322
|
device_type=device,
|
|
387
323
|
dtype=TrainerTools().dtype,
|
|
388
324
|
enabled=True,
|
|
389
|
-
cache_enabled=
|
|
325
|
+
cache_enabled=None
|
|
390
326
|
) if TrainerTools().use_amp else nullcontext()
|
|
391
327
|
|
|
392
328
|
if isinstance(model, VlmModel):
|
|
@@ -403,7 +339,6 @@ def batch_generate(
|
|
|
403
339
|
end_token = TrainerTools().tokenizer.end
|
|
404
340
|
done = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
405
341
|
|
|
406
|
-
model.eval()
|
|
407
342
|
with torch.inference_mode():
|
|
408
343
|
for _ in range(max_new_tokens):
|
|
409
344
|
# 只处理未完成的样本
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import time
|
|
2
|
-
import copy
|
|
3
2
|
from typing import Tuple, List, Union, Callable, Optional
|
|
4
3
|
import torch
|
|
5
4
|
from torch.utils.data import Dataset
|
|
@@ -15,7 +14,11 @@ from .loss import GRPOLoss
|
|
|
15
14
|
from .tools import TrainerTools
|
|
16
15
|
from .generate_utils import batch_generate
|
|
17
16
|
from .log import log
|
|
18
|
-
|
|
17
|
+
|
|
18
|
+
from .partition_utils import (
|
|
19
|
+
sync_model_params,
|
|
20
|
+
unwrap_model_for_generation
|
|
21
|
+
)
|
|
19
22
|
|
|
20
23
|
from .checkpoint import (
|
|
21
24
|
save_checkpoint,
|
|
@@ -39,7 +42,6 @@ class GRPOTrainer(Trainer):
|
|
|
39
42
|
|
|
40
43
|
self.reward_func = reward_func
|
|
41
44
|
self.reference_model = self._init_reference_model()
|
|
42
|
-
self.generate_model = self._init_generate_model()
|
|
43
45
|
|
|
44
46
|
# 默认使用torch提供的pad_sequence
|
|
45
47
|
# 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
|
|
@@ -47,17 +49,20 @@ class GRPOTrainer(Trainer):
|
|
|
47
49
|
|
|
48
50
|
def _init_reference_model(self):
|
|
49
51
|
reference_model = self._new_model(self.train_config)
|
|
50
|
-
reference_model.to('cpu')
|
|
51
|
-
reference_model.eval()
|
|
52
52
|
|
|
53
|
+
reference_model, _ = TrainerTools().parallel.process(
|
|
54
|
+
model=reference_model,
|
|
55
|
+
optimizer=None,
|
|
56
|
+
kwargs=self._init_reference_args(),
|
|
57
|
+
save_instance=False
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
reference_model.eval()
|
|
53
61
|
for param in reference_model.parameters():
|
|
54
62
|
param.requires_grad = False
|
|
55
63
|
|
|
56
64
|
return reference_model
|
|
57
65
|
|
|
58
|
-
def _init_generate_model(self):
|
|
59
|
-
return copy.deepcopy(self.reference_model)
|
|
60
|
-
|
|
61
66
|
def _init_loss(self):
|
|
62
67
|
criterion = GRPOLoss(
|
|
63
68
|
clip_eps=self.train_config.grpo_config.clip_eps,
|
|
@@ -163,7 +168,7 @@ class GRPOTrainer(Trainer):
|
|
|
163
168
|
# [batch*group_size, 1]
|
|
164
169
|
return advantages.unsqueeze(1) # Add dimension for token-wise operations
|
|
165
170
|
|
|
166
|
-
def _generate_completions(self, prompts, group_size: int):
|
|
171
|
+
def _generate_completions(self, model, prompts, group_size: int):
|
|
167
172
|
pad_token_id = TrainerTools().tokenizer.pad
|
|
168
173
|
device = TrainerTools().parallel.device
|
|
169
174
|
|
|
@@ -181,7 +186,7 @@ class GRPOTrainer(Trainer):
|
|
|
181
186
|
|
|
182
187
|
# [batch*group_size, max_prompt_len+max_gen_len]
|
|
183
188
|
outputs: torch.Tensor = batch_generate(
|
|
184
|
-
model=
|
|
189
|
+
model=model,
|
|
185
190
|
tokens=prompt_ids,
|
|
186
191
|
pad_token_id=pad_token_id,
|
|
187
192
|
attention_mask=prompt_masks,
|
|
@@ -201,7 +206,7 @@ class GRPOTrainer(Trainer):
|
|
|
201
206
|
|
|
202
207
|
return prompt_ids, prompt_masks, completion_ids, completion_masks
|
|
203
208
|
|
|
204
|
-
def _generate_rollout_data(self, batch_data: List[dict]):
|
|
209
|
+
def _generate_rollout_data(self, generate_model, batch_data: List[dict]):
|
|
205
210
|
prompts = [item["prompt"] for item in batch_data]
|
|
206
211
|
answers = [item["answer"] for item in batch_data]
|
|
207
212
|
group_size = self.train_config.grpo_config.group_size
|
|
@@ -210,13 +215,13 @@ class GRPOTrainer(Trainer):
|
|
|
210
215
|
# 修复问题:Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal
|
|
211
216
|
with torch.no_grad():
|
|
212
217
|
# with torch.inference_mode():
|
|
213
|
-
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(prompts, group_size)
|
|
218
|
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(generate_model, prompts, group_size)
|
|
214
219
|
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
215
220
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
216
221
|
logits_to_keep = completion_ids.shape[1]
|
|
217
222
|
|
|
218
223
|
# Compute old_log_probs from the current model, with gradients disabled.
|
|
219
|
-
old_log_probs, _ = self._compute_log_probabilities(
|
|
224
|
+
old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
|
|
220
225
|
|
|
221
226
|
# Compute ref_log_probs from the reference model, which remains static.
|
|
222
227
|
ref_log_probs, _ = self._compute_log_probabilities(self.reference_model, input_ids, attention_mask, logits_to_keep)
|
|
@@ -275,12 +280,15 @@ class GRPOTrainer(Trainer):
|
|
|
275
280
|
def train(self):
|
|
276
281
|
global_steps = 0
|
|
277
282
|
skipping_train = False
|
|
278
|
-
device = TrainerTools().parallel.device
|
|
279
283
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
280
284
|
|
|
281
285
|
for epoch in range(self.train_config.n_epochs):
|
|
282
|
-
|
|
283
|
-
|
|
286
|
+
sync_model_params(
|
|
287
|
+
_from=self.train_model,
|
|
288
|
+
_to=self.reference_model,
|
|
289
|
+
mixup_alpha=self.train_config.grpo_config.mixup_alpha
|
|
290
|
+
)
|
|
291
|
+
|
|
284
292
|
file_count = len(self.train_config.file_dataset)
|
|
285
293
|
|
|
286
294
|
for file_idx in range(file_count):
|
|
@@ -307,22 +315,13 @@ class GRPOTrainer(Trainer):
|
|
|
307
315
|
skipping_train = False
|
|
308
316
|
|
|
309
317
|
# start generate
|
|
310
|
-
# 使用单独的模型生成数据, 原因是在deepspeed并行训练时,使用train_model生成数据会卡死
|
|
311
|
-
self.generate_model.to(device)
|
|
312
|
-
self.reference_model.to(device)
|
|
313
|
-
|
|
314
318
|
if TrainerTools().parallel.is_main_process:
|
|
315
319
|
log(f'start generate for batch {batch}/{batch_count_per_file}')
|
|
316
320
|
|
|
317
321
|
# 生成数据
|
|
318
|
-
with
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
rollout_data = self._generate_rollout_data(batch_data)
|
|
322
|
-
|
|
323
|
-
# 卸载到cpu上,等待下次使用时再to gpu
|
|
324
|
-
self.generate_model.to('cpu')
|
|
325
|
-
self.reference_model.to('cpu')
|
|
322
|
+
with unwrap_model_for_generation(self.train_model) as generate_model:
|
|
323
|
+
rollout_data = self._generate_rollout_data(generate_model, batch_data)
|
|
324
|
+
|
|
326
325
|
torch.cuda.empty_cache()
|
|
327
326
|
# end generate
|
|
328
327
|
|
llm_trainer/loss.py
CHANGED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
import itertools
|
|
4
|
+
from packaging import version
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
7
|
+
|
|
8
|
+
from .tools import TrainerTools
|
|
9
|
+
from .parallel_ds import DsParallel
|
|
10
|
+
from .parallel_ddp import DdpParallel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@contextmanager
|
|
14
|
+
def unwrap_model_for_generation(model: nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
Context manager to unwrap distributed or accelerated models for generation tasks.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model:
|
|
20
|
+
Model to be unwrapped.
|
|
21
|
+
Yields:
|
|
22
|
+
Unwrapped model.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
```python
|
|
26
|
+
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
|
|
27
|
+
generated_outputs = unwrapped_model.generate(input_ids)
|
|
28
|
+
```
|
|
29
|
+
"""
|
|
30
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
31
|
+
import deepspeed
|
|
32
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
33
|
+
|
|
34
|
+
if model.zero_optimization_stage() == 3:
|
|
35
|
+
with deepspeed.zero.GatheredParameters(model.parameters()):
|
|
36
|
+
_remove_hooks(model)
|
|
37
|
+
yield unwrap_model(model)
|
|
38
|
+
_add_hooks(model)
|
|
39
|
+
else:
|
|
40
|
+
yield unwrap_model(model)
|
|
41
|
+
elif isinstance(TrainerTools().parallel, DdpParallel):
|
|
42
|
+
yield unwrap_model(model)
|
|
43
|
+
else:
|
|
44
|
+
yield model
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def sync_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
48
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
49
|
+
_sync_ds_model_params(_from, _to, mixup_alpha)
|
|
50
|
+
elif isinstance(TrainerTools().parallel, DdpParallel):
|
|
51
|
+
_sync_ddp_model_params(_from, _to, mixup_alpha)
|
|
52
|
+
else:
|
|
53
|
+
_copy_params(_from, _to, mixup_alpha)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def unwrap_model(model) -> nn.Module:
|
|
57
|
+
try:
|
|
58
|
+
import deepspeed
|
|
59
|
+
if isinstance(model, deepspeed.DeepSpeedEngine):
|
|
60
|
+
return model.module
|
|
61
|
+
except: ...
|
|
62
|
+
|
|
63
|
+
if isinstance(model, DDP):
|
|
64
|
+
return model.module
|
|
65
|
+
|
|
66
|
+
return model
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _copy_params(model, target_model, mixup_alpha):
|
|
70
|
+
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
|
|
71
|
+
target_param.data.mul_(1.0 - mixup_alpha).add_(copy_param.data, alpha=mixup_alpha)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _sync_ds_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
75
|
+
import deepspeed
|
|
76
|
+
assert isinstance(_from, deepspeed.DeepSpeedEngine)
|
|
77
|
+
|
|
78
|
+
origin_from = unwrap_model(_from)
|
|
79
|
+
|
|
80
|
+
if _from.zero_optimization_stage() == 3:
|
|
81
|
+
with deepspeed.zero.GatheredParameters(list(origin_from.parameters()) + list(_to.parameters()), modifier_rank=0):
|
|
82
|
+
if TrainerTools().parallel.is_main_process:
|
|
83
|
+
_copy_params(origin_from, _to, mixup_alpha)
|
|
84
|
+
else:
|
|
85
|
+
_copy_params(origin_from, _to, mixup_alpha)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _sync_ddp_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
89
|
+
assert isinstance(_from, DDP)
|
|
90
|
+
|
|
91
|
+
origin_from = unwrap_model(_from)
|
|
92
|
+
_copy_params(origin_from, _to, mixup_alpha)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _add_hooks(model: nn.Module) -> None:
|
|
96
|
+
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
|
97
|
+
import deepspeed
|
|
98
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
99
|
+
|
|
100
|
+
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
|
|
101
|
+
return
|
|
102
|
+
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
|
103
|
+
optimizer_offload = model.optimizer.parameter_offload
|
|
104
|
+
elif model.optimizer is not None:
|
|
105
|
+
optimizer_offload = model.optimizer
|
|
106
|
+
else:
|
|
107
|
+
raise RuntimeError("The model optimizer is None, which is not yet supported.")
|
|
108
|
+
if version.parse(deepspeed.__version__) >= version.parse("0.16.4"):
|
|
109
|
+
# Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
|
|
110
|
+
optimizer_offload._register_deepspeed_module(optimizer_offload.module)
|
|
111
|
+
else:
|
|
112
|
+
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _remove_hooks(model: nn.Module) -> None:
|
|
116
|
+
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
|
117
|
+
import deepspeed
|
|
118
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
119
|
+
|
|
120
|
+
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
|
|
121
|
+
return
|
|
122
|
+
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
|
123
|
+
optimizer_offload = model.optimizer.parameter_offload
|
|
124
|
+
elif model.optimizer is not None:
|
|
125
|
+
optimizer_offload = model.optimizer
|
|
126
|
+
else:
|
|
127
|
+
raise RuntimeError("The model optimizer is None, which is not yet supported.")
|
|
128
|
+
|
|
129
|
+
for param in _iter_params(optimizer_offload.module, recurse=True):
|
|
130
|
+
param.ds_active_sub_modules.clear()
|
|
131
|
+
|
|
132
|
+
for hook in optimizer_offload.forward_hooks:
|
|
133
|
+
hook.remove()
|
|
134
|
+
for hook in optimizer_offload.backward_hooks:
|
|
135
|
+
hook.remove()
|
|
136
|
+
|
|
137
|
+
optimizer_offload.forward_hooks = []
|
|
138
|
+
optimizer_offload.backward_hooks = []
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _iter_params(module, recurse=False):
|
|
142
|
+
return [param for _, param in _get_all_parameters(module, recurse)]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _get_all_parameters(sub_module, recurse=False):
|
|
146
|
+
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
|