project-llm-trainer 0.4.14__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +0 -73
- llm_trainer/dpo_trainer.py +7 -3
- llm_trainer/ds_checkpoint.py +0 -66
- llm_trainer/eval.py +3 -30
- llm_trainer/generate_utils.py +2 -6
- llm_trainer/grpo_trainer.py +27 -28
- llm_trainer/loss.py +1 -1
- llm_trainer/partition_utils.py +146 -0
- llm_trainer/tools.py +0 -2
- llm_trainer/train_configs.py +5 -25
- llm_trainer/trainer.py +30 -69
- llm_trainer/utils.py +0 -1
- {project_llm_trainer-0.4.14.dist-info → project_llm_trainer-0.5.0.dist-info}/METADATA +1 -1
- project_llm_trainer-0.5.0.dist-info/RECORD +33 -0
- llm_trainer/dcp.py +0 -93
- llm_trainer/fsdp_checkpoint.py +0 -87
- llm_trainer/parallel_fsdp.py +0 -121
- project_llm_trainer-0.4.14.dist-info/RECORD +0 -35
- {project_llm_trainer-0.4.14.data → project_llm_trainer-0.5.0.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.14.data → project_llm_trainer-0.5.0.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.14.data → project_llm_trainer-0.5.0.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.14.data → project_llm_trainer-0.5.0.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.14.data → project_llm_trainer-0.5.0.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.14.data → project_llm_trainer-0.5.0.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.14.data → project_llm_trainer-0.5.0.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.14.dist-info → project_llm_trainer-0.5.0.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.14.dist-info → project_llm_trainer-0.5.0.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -6,35 +6,11 @@ from torch.optim import Optimizer
|
|
|
6
6
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
7
7
|
|
|
8
8
|
from .parallel_ds import DsParallel
|
|
9
|
-
from .parallel_fsdp import FsdpParallel
|
|
10
|
-
from .parallel_ddp import DdpParallel
|
|
11
9
|
from .scheduler import LRScheduler
|
|
12
10
|
from .tools import TrainerTools
|
|
13
11
|
|
|
14
|
-
try:
|
|
15
|
-
from .dcp import save_dcp, load_dcp, convert_dcp_to_pth
|
|
16
|
-
except:
|
|
17
|
-
os.environ['ENABLE_DCP'] = "0"
|
|
18
|
-
|
|
19
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
20
|
-
|
|
21
|
-
# https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
|
|
22
|
-
|
|
23
12
|
DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
|
|
24
13
|
|
|
25
|
-
|
|
26
|
-
def _can_use_dcp(model: nn.Module) -> bool:
|
|
27
|
-
if os.environ.get('ENABLE_DCP', '1') != '1':
|
|
28
|
-
return False
|
|
29
|
-
|
|
30
|
-
# 如果是fsdp或者ddp,才能使用dcp保存
|
|
31
|
-
if (isinstance(TrainerTools().parallel, FsdpParallel)
|
|
32
|
-
or isinstance(TrainerTools().parallel, DdpParallel)):
|
|
33
|
-
return True
|
|
34
|
-
|
|
35
|
-
return False
|
|
36
|
-
|
|
37
|
-
|
|
38
14
|
def save_checkpoint(
|
|
39
15
|
model: nn.Module,
|
|
40
16
|
optimizer: Optional[Optimizer] = None,
|
|
@@ -43,11 +19,6 @@ def save_checkpoint(
|
|
|
43
19
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
44
20
|
from .ds_checkpoint import save_ds_checkpoint
|
|
45
21
|
save_ds_checkpoint(model, suffix)
|
|
46
|
-
elif _can_use_dcp(model):
|
|
47
|
-
save_dcp(model, optimizer, suffix)
|
|
48
|
-
elif isinstance(model, FSDP):
|
|
49
|
-
from .fsdp_checkpoint import save_fsdp_checkpoint
|
|
50
|
-
save_fsdp_checkpoint(model, optimizer, suffix)
|
|
51
22
|
else:
|
|
52
23
|
if TrainerTools().parallel.is_main_process:
|
|
53
24
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
@@ -73,11 +44,6 @@ def load_checkpoint(
|
|
|
73
44
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
74
45
|
from .ds_checkpoint import load_ds_checkpoint
|
|
75
46
|
load_ds_checkpoint(model, load_module_only=load_module_only, suffix=suffix)
|
|
76
|
-
elif _can_use_dcp(model):
|
|
77
|
-
load_dcp(model, optimizer, suffix)
|
|
78
|
-
elif isinstance(model, FSDP):
|
|
79
|
-
from .fsdp_checkpoint import load_fsdp_checkpoint
|
|
80
|
-
load_fsdp_checkpoint(model, optimizer, device, suffix)
|
|
81
47
|
else:
|
|
82
48
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
83
49
|
if suffix:
|
|
@@ -99,49 +65,10 @@ def load_checkpoint_for_eval(
|
|
|
99
65
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
100
66
|
from .ds_checkpoint import load_ds_checkpoint_for_eval
|
|
101
67
|
load_ds_checkpoint_for_eval(model)
|
|
102
|
-
elif _can_use_dcp(model):
|
|
103
|
-
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
104
|
-
|
|
105
|
-
# load_dcp方式在cpu上会报错,所以改为先将ckpt转换为pth,然后再加载pth
|
|
106
|
-
# load_dcp(model, optimizer)
|
|
107
|
-
pth_name = os.environ.get('EVAL_CHECKPOINT_NAME', checkpoint_name)
|
|
108
|
-
if suffix:
|
|
109
|
-
pth_name = f'{pth_name}_{suffix}'
|
|
110
|
-
|
|
111
|
-
convert_dcp_to_pth(pth_name)
|
|
112
|
-
|
|
113
|
-
if os.path.exists(pth_name):
|
|
114
|
-
ckpt = torch.load(pth_name, map_location=device, weights_only=True)
|
|
115
|
-
model.load_state_dict(ckpt['app']['model_state_dict'])
|
|
116
|
-
# 使用完删除
|
|
117
|
-
os.remove(pth_name)
|
|
118
68
|
else:
|
|
119
69
|
load_checkpoint(model, None, device, suffix=suffix)
|
|
120
70
|
|
|
121
71
|
|
|
122
|
-
def copy_model_params(
|
|
123
|
-
_from: nn.Module,
|
|
124
|
-
_to: Optional[nn.Module]
|
|
125
|
-
):
|
|
126
|
-
"""
|
|
127
|
-
必须在所有rank上调用,非rank0, _to可以设置为None
|
|
128
|
-
"""
|
|
129
|
-
|
|
130
|
-
if isinstance(TrainerTools().parallel, DsParallel):
|
|
131
|
-
from .ds_checkpoint import get_ds_model_params
|
|
132
|
-
state_dict = get_ds_model_params(_from, only_rank0=_to is None)
|
|
133
|
-
elif isinstance(TrainerTools().parallel, FsdpParallel):
|
|
134
|
-
from .fsdp_checkpoint import get_fsdp_model_params
|
|
135
|
-
state_dict = get_fsdp_model_params(_from, only_rank0=_to is None)
|
|
136
|
-
elif isinstance(_from, DDP):
|
|
137
|
-
state_dict = _from.module.state_dict()
|
|
138
|
-
else:
|
|
139
|
-
state_dict = _from.state_dict()
|
|
140
|
-
|
|
141
|
-
if _to and state_dict:
|
|
142
|
-
_to.load_state_dict(state_dict)
|
|
143
|
-
|
|
144
|
-
|
|
145
72
|
def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
|
|
146
73
|
# 暂时只保存主进程的
|
|
147
74
|
if TrainerTools().parallel.is_main_process:
|
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -12,13 +12,14 @@ from .dataset import DPODataset
|
|
|
12
12
|
from .loss import DPOLoss
|
|
13
13
|
from .tools import TrainerTools
|
|
14
14
|
from .utils import get_dpo_collate_fn
|
|
15
|
+
from .partition_utils import sync_model_params
|
|
15
16
|
|
|
16
17
|
from .checkpoint import (
|
|
17
18
|
save_checkpoint,
|
|
18
|
-
copy_model_params,
|
|
19
19
|
save_steps,
|
|
20
20
|
)
|
|
21
21
|
|
|
22
|
+
|
|
22
23
|
class DPOTrainer(Trainer):
|
|
23
24
|
def __init__(
|
|
24
25
|
self,
|
|
@@ -37,7 +38,6 @@ class DPOTrainer(Trainer):
|
|
|
37
38
|
|
|
38
39
|
def _init_reference_model(self):
|
|
39
40
|
reference_model = self._new_model(self.train_config)
|
|
40
|
-
copy_model_params(_from=self.train_model, _to=reference_model)
|
|
41
41
|
|
|
42
42
|
reference_model, _ = TrainerTools().parallel.process(
|
|
43
43
|
model=reference_model,
|
|
@@ -50,6 +50,11 @@ class DPOTrainer(Trainer):
|
|
|
50
50
|
for param in reference_model.parameters():
|
|
51
51
|
param.requires_grad = False
|
|
52
52
|
|
|
53
|
+
sync_model_params(
|
|
54
|
+
_from=self.train_model,
|
|
55
|
+
_to=reference_model
|
|
56
|
+
)
|
|
57
|
+
|
|
53
58
|
return reference_model
|
|
54
59
|
|
|
55
60
|
def _init_loss(self):
|
|
@@ -209,7 +214,6 @@ class DPOTrainer(Trainer):
|
|
|
209
214
|
if need_update_grad:
|
|
210
215
|
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
211
216
|
|
|
212
|
-
# todo check all_reduce??
|
|
213
217
|
if TrainerTools().parallel.parallel_train:
|
|
214
218
|
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
215
219
|
|
llm_trainer/ds_checkpoint.py
CHANGED
|
@@ -2,11 +2,7 @@ import os
|
|
|
2
2
|
from typing import Optional
|
|
3
3
|
from glob import glob
|
|
4
4
|
import shutil
|
|
5
|
-
import torch
|
|
6
5
|
from torch import nn
|
|
7
|
-
import torch.distributed as dist
|
|
8
|
-
|
|
9
|
-
from .tools import TrainerTools
|
|
10
6
|
|
|
11
7
|
try:
|
|
12
8
|
import deepspeed
|
|
@@ -65,65 +61,3 @@ def load_ds_checkpoint_for_eval(model: nn.Module):
|
|
|
65
61
|
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
66
62
|
state_dict = get_fp32_state_dict_from_zero_checkpoint(ckpt_dir)
|
|
67
63
|
model.load_state_dict(state_dict)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
|
|
71
|
-
"""
|
|
72
|
-
需要在所有rank上调用,然后只有rank0有值
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
if model.zero_optimization_stage() != 3:
|
|
76
|
-
if TrainerTools().parallel.is_main_process:
|
|
77
|
-
return {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
|
|
78
|
-
return None
|
|
79
|
-
|
|
80
|
-
# --- ZeRO-3 ---
|
|
81
|
-
# 只调用一次 GatheredParameters,传入所有参数
|
|
82
|
-
with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
|
|
83
|
-
if TrainerTools().parallel.is_main_process:
|
|
84
|
-
# 在这个 'with' 代码块内,rank 0 上的 model.module 拥有完整的参数
|
|
85
|
-
# 所以我们可以像操作普通模型一样直接调用 state_dict()
|
|
86
|
-
full_state_dict = model.module.state_dict()
|
|
87
|
-
|
|
88
|
-
# 将其克隆到 CPU 并返回
|
|
89
|
-
return {k: v.cpu().clone() for k, v in full_state_dict.items()}
|
|
90
|
-
|
|
91
|
-
# 其他 rank 执行到这里时,上下文结束,直接返回 None
|
|
92
|
-
return None
|
|
93
|
-
|
|
94
|
-
# # ZeRO-3
|
|
95
|
-
# state_dict_on_rank_0 = {}
|
|
96
|
-
# for param_name, param in model.module.named_parameters():
|
|
97
|
-
# if hasattr(param, 'ds_id'):
|
|
98
|
-
# with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
|
|
99
|
-
# if TrainerTools().parallel.is_main_process:
|
|
100
|
-
# state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
|
|
101
|
-
# else:
|
|
102
|
-
# if TrainerTools().parallel.is_main_process:
|
|
103
|
-
# state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
|
|
104
|
-
#
|
|
105
|
-
# return state_dict_on_rank_0 if TrainerTools().parallel.is_main_process else None
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def get_ds_model_params(model: nn.Module, only_rank0=False):
|
|
109
|
-
"""
|
|
110
|
-
从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
|
|
111
|
-
兼容 ZeRO Stages 0, 1, 2, 3。
|
|
112
|
-
包含了对 ZeRO-3 中分片参数的正确处理。
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
assert isinstance(model, DeepSpeedEngine)
|
|
116
|
-
state_dict = _get_ds_full_state_dict_on_rank0(model)
|
|
117
|
-
|
|
118
|
-
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
119
|
-
# 我们需要将其广播给所有进程。
|
|
120
|
-
if not only_rank0 and TrainerTools().parallel.world_size > 1:
|
|
121
|
-
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
122
|
-
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
123
|
-
# 执行广播,这个操作是阻塞的,会同步所有进程
|
|
124
|
-
dist.broadcast_object_list(object_list, src=0)
|
|
125
|
-
# 所有进程从列表中获取广播后的 state_dict 副本
|
|
126
|
-
state_dict = object_list[0]
|
|
127
|
-
|
|
128
|
-
return state_dict
|
|
129
|
-
|
llm_trainer/eval.py
CHANGED
|
@@ -5,16 +5,14 @@ from .log import get_log_dir
|
|
|
5
5
|
from .tools import TrainerTools
|
|
6
6
|
from .train_configs import EvalConfig
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
def _eval_task(
|
|
8
|
+
def submit_gen_task(
|
|
10
9
|
eval_model: torch.nn.Module,
|
|
11
10
|
eval_config: EvalConfig,
|
|
12
11
|
tag,
|
|
13
12
|
prompt,
|
|
14
13
|
pixel_values,
|
|
15
14
|
max_position_embeddings,
|
|
16
|
-
tokens_per_image
|
|
17
|
-
device
|
|
15
|
+
tokens_per_image
|
|
18
16
|
):
|
|
19
17
|
log_dir = get_log_dir()
|
|
20
18
|
|
|
@@ -28,33 +26,8 @@ def _eval_task(
|
|
|
28
26
|
p=eval_config.top_p,
|
|
29
27
|
pixel_values=pixel_values,
|
|
30
28
|
tokens_per_image=tokens_per_image,
|
|
31
|
-
device=device
|
|
29
|
+
device=TrainerTools().parallel.device
|
|
32
30
|
)
|
|
33
31
|
|
|
34
32
|
with open(f'{log_dir}gen.txt', 'a') as f:
|
|
35
33
|
f.write(f"{tag}, gen->{gen_result}\n")
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def submit_gen_task(
|
|
39
|
-
eval_model: torch.nn.Module,
|
|
40
|
-
eval_config: EvalConfig,
|
|
41
|
-
tag,
|
|
42
|
-
prompt,
|
|
43
|
-
pixel_values,
|
|
44
|
-
max_position_embeddings,
|
|
45
|
-
tokens_per_image
|
|
46
|
-
):
|
|
47
|
-
eval_model.to(TrainerTools().parallel.device)
|
|
48
|
-
_eval_task(
|
|
49
|
-
eval_model=eval_model,
|
|
50
|
-
eval_config=eval_config,
|
|
51
|
-
tag=tag,
|
|
52
|
-
prompt=prompt,
|
|
53
|
-
pixel_values=pixel_values,
|
|
54
|
-
max_position_embeddings=max_position_embeddings,
|
|
55
|
-
tokens_per_image=tokens_per_image,
|
|
56
|
-
device=TrainerTools().parallel.device
|
|
57
|
-
)
|
|
58
|
-
eval_model.to('cpu')
|
|
59
|
-
|
|
60
|
-
# threading.Thread(target=_eval_task, args=args).start()
|
llm_trainer/generate_utils.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Union, Optional, List
|
|
2
2
|
from contextlib import nullcontext
|
|
3
3
|
import torch
|
|
4
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
5
4
|
from llm_model import VlmModel, KVCache
|
|
6
5
|
from .tools import TrainerTools
|
|
7
6
|
from .utils import batch_repeat_image_tok
|
|
@@ -131,8 +130,7 @@ def _generate(
|
|
|
131
130
|
device_type=device,
|
|
132
131
|
dtype=TrainerTools().dtype,
|
|
133
132
|
enabled=True,
|
|
134
|
-
|
|
135
|
-
cache_enabled=False if isinstance(model, FSDP) else None
|
|
133
|
+
cache_enabled=None
|
|
136
134
|
) if TrainerTools().use_amp else nullcontext()
|
|
137
135
|
|
|
138
136
|
if isinstance(model, VlmModel):
|
|
@@ -165,7 +163,6 @@ def _generate(
|
|
|
165
163
|
in_reasoning_block = True
|
|
166
164
|
reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
|
|
167
165
|
|
|
168
|
-
model.eval()
|
|
169
166
|
with torch.inference_mode():
|
|
170
167
|
for _ in range(max_new_tokens):
|
|
171
168
|
# 是否需要截取??
|
|
@@ -386,7 +383,7 @@ def batch_generate(
|
|
|
386
383
|
device_type=device,
|
|
387
384
|
dtype=TrainerTools().dtype,
|
|
388
385
|
enabled=True,
|
|
389
|
-
cache_enabled=
|
|
386
|
+
cache_enabled=None
|
|
390
387
|
) if TrainerTools().use_amp else nullcontext()
|
|
391
388
|
|
|
392
389
|
if isinstance(model, VlmModel):
|
|
@@ -403,7 +400,6 @@ def batch_generate(
|
|
|
403
400
|
end_token = TrainerTools().tokenizer.end
|
|
404
401
|
done = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
405
402
|
|
|
406
|
-
model.eval()
|
|
407
403
|
with torch.inference_mode():
|
|
408
404
|
for _ in range(max_new_tokens):
|
|
409
405
|
# 只处理未完成的样本
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import time
|
|
2
|
-
import copy
|
|
3
2
|
from typing import Tuple, List, Union, Callable, Optional
|
|
4
3
|
import torch
|
|
5
4
|
from torch.utils.data import Dataset
|
|
@@ -16,9 +15,13 @@ from .tools import TrainerTools
|
|
|
16
15
|
from .generate_utils import batch_generate
|
|
17
16
|
from .log import log
|
|
18
17
|
|
|
18
|
+
from .partition_utils import (
|
|
19
|
+
sync_model_params,
|
|
20
|
+
unwrap_model_for_generation
|
|
21
|
+
)
|
|
22
|
+
|
|
19
23
|
from .checkpoint import (
|
|
20
24
|
save_checkpoint,
|
|
21
|
-
copy_model_params,
|
|
22
25
|
save_steps,
|
|
23
26
|
)
|
|
24
27
|
|
|
@@ -39,7 +42,6 @@ class GRPOTrainer(Trainer):
|
|
|
39
42
|
|
|
40
43
|
self.reward_func = reward_func
|
|
41
44
|
self.reference_model = self._init_reference_model()
|
|
42
|
-
self.generate_model = self._init_generate_model()
|
|
43
45
|
|
|
44
46
|
# 默认使用torch提供的pad_sequence
|
|
45
47
|
# 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
|
|
@@ -47,17 +49,20 @@ class GRPOTrainer(Trainer):
|
|
|
47
49
|
|
|
48
50
|
def _init_reference_model(self):
|
|
49
51
|
reference_model = self._new_model(self.train_config)
|
|
50
|
-
reference_model.to('cpu')
|
|
51
|
-
reference_model.eval()
|
|
52
52
|
|
|
53
|
+
reference_model, _ = TrainerTools().parallel.process(
|
|
54
|
+
model=reference_model,
|
|
55
|
+
optimizer=None,
|
|
56
|
+
kwargs=self._init_reference_args(),
|
|
57
|
+
save_instance=False
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
reference_model.eval()
|
|
53
61
|
for param in reference_model.parameters():
|
|
54
62
|
param.requires_grad = False
|
|
55
63
|
|
|
56
64
|
return reference_model
|
|
57
65
|
|
|
58
|
-
def _init_generate_model(self):
|
|
59
|
-
return copy.deepcopy(self.reference_model)
|
|
60
|
-
|
|
61
66
|
def _init_loss(self):
|
|
62
67
|
criterion = GRPOLoss(
|
|
63
68
|
clip_eps=self.train_config.grpo_config.clip_eps,
|
|
@@ -163,7 +168,7 @@ class GRPOTrainer(Trainer):
|
|
|
163
168
|
# [batch*group_size, 1]
|
|
164
169
|
return advantages.unsqueeze(1) # Add dimension for token-wise operations
|
|
165
170
|
|
|
166
|
-
def _generate_completions(self, prompts, group_size: int):
|
|
171
|
+
def _generate_completions(self, model, prompts, group_size: int):
|
|
167
172
|
pad_token_id = TrainerTools().tokenizer.pad
|
|
168
173
|
device = TrainerTools().parallel.device
|
|
169
174
|
|
|
@@ -181,7 +186,7 @@ class GRPOTrainer(Trainer):
|
|
|
181
186
|
|
|
182
187
|
# [batch*group_size, max_prompt_len+max_gen_len]
|
|
183
188
|
outputs: torch.Tensor = batch_generate(
|
|
184
|
-
model=
|
|
189
|
+
model=model,
|
|
185
190
|
tokens=prompt_ids,
|
|
186
191
|
pad_token_id=pad_token_id,
|
|
187
192
|
attention_mask=prompt_masks,
|
|
@@ -201,7 +206,7 @@ class GRPOTrainer(Trainer):
|
|
|
201
206
|
|
|
202
207
|
return prompt_ids, prompt_masks, completion_ids, completion_masks
|
|
203
208
|
|
|
204
|
-
def _generate_rollout_data(self, batch_data: List[dict]):
|
|
209
|
+
def _generate_rollout_data(self, generate_model, batch_data: List[dict]):
|
|
205
210
|
prompts = [item["prompt"] for item in batch_data]
|
|
206
211
|
answers = [item["answer"] for item in batch_data]
|
|
207
212
|
group_size = self.train_config.grpo_config.group_size
|
|
@@ -210,13 +215,13 @@ class GRPOTrainer(Trainer):
|
|
|
210
215
|
# 修复问题:Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal
|
|
211
216
|
with torch.no_grad():
|
|
212
217
|
# with torch.inference_mode():
|
|
213
|
-
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(prompts, group_size)
|
|
218
|
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(generate_model, prompts, group_size)
|
|
214
219
|
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
215
220
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
216
221
|
logits_to_keep = completion_ids.shape[1]
|
|
217
222
|
|
|
218
223
|
# Compute old_log_probs from the current model, with gradients disabled.
|
|
219
|
-
old_log_probs, _ = self._compute_log_probabilities(
|
|
224
|
+
old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
|
|
220
225
|
|
|
221
226
|
# Compute ref_log_probs from the reference model, which remains static.
|
|
222
227
|
ref_log_probs, _ = self._compute_log_probabilities(self.reference_model, input_ids, attention_mask, logits_to_keep)
|
|
@@ -275,12 +280,15 @@ class GRPOTrainer(Trainer):
|
|
|
275
280
|
def train(self):
|
|
276
281
|
global_steps = 0
|
|
277
282
|
skipping_train = False
|
|
278
|
-
device = TrainerTools().parallel.device
|
|
279
283
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
280
284
|
|
|
281
285
|
for epoch in range(self.train_config.n_epochs):
|
|
282
|
-
|
|
283
|
-
|
|
286
|
+
sync_model_params(
|
|
287
|
+
_from=self.train_model,
|
|
288
|
+
_to=self.reference_model,
|
|
289
|
+
mixup_alpha=self.train_config.grpo_config.mixup_alpha
|
|
290
|
+
)
|
|
291
|
+
|
|
284
292
|
file_count = len(self.train_config.file_dataset)
|
|
285
293
|
|
|
286
294
|
for file_idx in range(file_count):
|
|
@@ -307,22 +315,13 @@ class GRPOTrainer(Trainer):
|
|
|
307
315
|
skipping_train = False
|
|
308
316
|
|
|
309
317
|
# start generate
|
|
310
|
-
# 使用单独的模型生成数据, 原因是在deepspeed并行训练时,使用train_model生成数据会卡死
|
|
311
|
-
self.generate_model.to(device)
|
|
312
|
-
self.reference_model.to(device)
|
|
313
|
-
|
|
314
318
|
if TrainerTools().parallel.is_main_process:
|
|
315
319
|
log(f'start generate for batch {batch}/{batch_count_per_file}')
|
|
316
320
|
|
|
317
321
|
# 生成数据
|
|
318
|
-
with
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
rollout_data = self._generate_rollout_data(batch_data)
|
|
322
|
-
|
|
323
|
-
# 卸载到cpu上,等待下次使用时再to gpu
|
|
324
|
-
self.generate_model.to('cpu')
|
|
325
|
-
self.reference_model.to('cpu')
|
|
322
|
+
with unwrap_model_for_generation(self.train_model) as generate_model:
|
|
323
|
+
rollout_data = self._generate_rollout_data(generate_model, batch_data)
|
|
324
|
+
|
|
326
325
|
torch.cuda.empty_cache()
|
|
327
326
|
# end generate
|
|
328
327
|
|
llm_trainer/loss.py
CHANGED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
import itertools
|
|
4
|
+
from packaging import version
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
7
|
+
|
|
8
|
+
from .tools import TrainerTools
|
|
9
|
+
from .parallel_ds import DsParallel
|
|
10
|
+
from .parallel_ddp import DdpParallel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@contextmanager
|
|
14
|
+
def unwrap_model_for_generation(model: nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
Context manager to unwrap distributed or accelerated models for generation tasks.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model:
|
|
20
|
+
Model to be unwrapped.
|
|
21
|
+
Yields:
|
|
22
|
+
Unwrapped model.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
```python
|
|
26
|
+
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
|
|
27
|
+
generated_outputs = unwrapped_model.generate(input_ids)
|
|
28
|
+
```
|
|
29
|
+
"""
|
|
30
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
31
|
+
import deepspeed
|
|
32
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
33
|
+
|
|
34
|
+
if model.zero_optimization_stage() == 3:
|
|
35
|
+
with deepspeed.zero.GatheredParameters(model.parameters()):
|
|
36
|
+
_remove_hooks(model)
|
|
37
|
+
yield unwrap_model(model)
|
|
38
|
+
_add_hooks(model)
|
|
39
|
+
else:
|
|
40
|
+
yield unwrap_model(model)
|
|
41
|
+
elif isinstance(TrainerTools().parallel, DdpParallel):
|
|
42
|
+
yield unwrap_model(model)
|
|
43
|
+
else:
|
|
44
|
+
yield model
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def sync_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
48
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
49
|
+
_sync_ds_model_params(_from, _to, mixup_alpha)
|
|
50
|
+
elif isinstance(TrainerTools().parallel, DdpParallel):
|
|
51
|
+
_sync_ddp_model_params(_from, _to, mixup_alpha)
|
|
52
|
+
else:
|
|
53
|
+
_copy_params(_from, _to, mixup_alpha)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def unwrap_model(model) -> nn.Module:
|
|
57
|
+
try:
|
|
58
|
+
import deepspeed
|
|
59
|
+
if isinstance(model, deepspeed.DeepSpeedEngine):
|
|
60
|
+
return model.module
|
|
61
|
+
except: ...
|
|
62
|
+
|
|
63
|
+
if isinstance(model, DDP):
|
|
64
|
+
return model.module
|
|
65
|
+
|
|
66
|
+
return model
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _copy_params(model, target_model, mixup_alpha):
|
|
70
|
+
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
|
|
71
|
+
target_param.data.mul_(1.0 - mixup_alpha).add_(copy_param.data, alpha=mixup_alpha)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _sync_ds_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
75
|
+
import deepspeed
|
|
76
|
+
assert isinstance(_from, deepspeed.DeepSpeedEngine)
|
|
77
|
+
|
|
78
|
+
origin_from = unwrap_model(_from)
|
|
79
|
+
|
|
80
|
+
if _from.zero_optimization_stage() == 3:
|
|
81
|
+
with deepspeed.zero.GatheredParameters(list(origin_from.parameters()) + list(_to.parameters()), modifier_rank=0):
|
|
82
|
+
if TrainerTools().parallel.is_main_process:
|
|
83
|
+
_copy_params(origin_from, _to, mixup_alpha)
|
|
84
|
+
else:
|
|
85
|
+
_copy_params(origin_from, _to, mixup_alpha)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _sync_ddp_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
89
|
+
assert isinstance(_from, DDP)
|
|
90
|
+
|
|
91
|
+
origin_from = unwrap_model(_from)
|
|
92
|
+
_copy_params(origin_from, _to, mixup_alpha)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _add_hooks(model: nn.Module) -> None:
|
|
96
|
+
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
|
97
|
+
import deepspeed
|
|
98
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
99
|
+
|
|
100
|
+
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
|
|
101
|
+
return
|
|
102
|
+
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
|
103
|
+
optimizer_offload = model.optimizer.parameter_offload
|
|
104
|
+
elif model.optimizer is not None:
|
|
105
|
+
optimizer_offload = model.optimizer
|
|
106
|
+
else:
|
|
107
|
+
raise RuntimeError("The model optimizer is None, which is not yet supported.")
|
|
108
|
+
if version.parse(deepspeed.__version__) >= version.parse("0.16.4"):
|
|
109
|
+
# Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
|
|
110
|
+
optimizer_offload._register_deepspeed_module(optimizer_offload.module)
|
|
111
|
+
else:
|
|
112
|
+
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _remove_hooks(model: nn.Module) -> None:
|
|
116
|
+
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
|
117
|
+
import deepspeed
|
|
118
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
119
|
+
|
|
120
|
+
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
|
|
121
|
+
return
|
|
122
|
+
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
|
123
|
+
optimizer_offload = model.optimizer.parameter_offload
|
|
124
|
+
elif model.optimizer is not None:
|
|
125
|
+
optimizer_offload = model.optimizer
|
|
126
|
+
else:
|
|
127
|
+
raise RuntimeError("The model optimizer is None, which is not yet supported.")
|
|
128
|
+
|
|
129
|
+
for param in _iter_params(optimizer_offload.module, recurse=True):
|
|
130
|
+
param.ds_active_sub_modules.clear()
|
|
131
|
+
|
|
132
|
+
for hook in optimizer_offload.forward_hooks:
|
|
133
|
+
hook.remove()
|
|
134
|
+
for hook in optimizer_offload.backward_hooks:
|
|
135
|
+
hook.remove()
|
|
136
|
+
|
|
137
|
+
optimizer_offload.forward_hooks = []
|
|
138
|
+
optimizer_offload.backward_hooks = []
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _iter_params(module, recurse=False):
|
|
142
|
+
return [param for _, param in _get_all_parameters(module, recurse)]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _get_all_parameters(sub_module, recurse=False):
|
|
146
|
+
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
|
llm_trainer/tools.py
CHANGED
|
@@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
import torch
|
|
4
4
|
from .tokenizer import Tokenizer
|
|
5
5
|
from .parallel_ds import DsParallel
|
|
6
|
-
from .parallel_fsdp import FsdpParallel
|
|
7
6
|
from .parallel_ddp import DdpParallel
|
|
8
7
|
from .parallel_none import NoneParallel
|
|
9
8
|
from .log import log
|
|
@@ -11,7 +10,6 @@ from .log import log
|
|
|
11
10
|
|
|
12
11
|
parallel_types = {
|
|
13
12
|
'ds': DsParallel,
|
|
14
|
-
'fsdp': FsdpParallel,
|
|
15
13
|
'ddp': DdpParallel,
|
|
16
14
|
'none': NoneParallel
|
|
17
15
|
}
|
llm_trainer/train_configs.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
from typing import Optional, Union,
|
|
1
|
+
from typing import Optional, Union, Callable, List, Mapping, Any
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
-
from torch import nn
|
|
6
5
|
from llm_model import ModelConfig, VLMConfig
|
|
7
6
|
from .tools import FileDataset
|
|
8
7
|
|
|
@@ -33,6 +32,9 @@ class DsZeROConfig:
|
|
|
33
32
|
reduce_bucket_size: Optional[Union[str, int]] = 5e8
|
|
34
33
|
contiguous_gradients: Optional[bool] = True
|
|
35
34
|
|
|
35
|
+
@dataclass(kw_only=True)
|
|
36
|
+
class DsZero0Config(DsZeROConfig):
|
|
37
|
+
stage: int = field(default=0, init=False)
|
|
36
38
|
|
|
37
39
|
@dataclass(kw_only=True)
|
|
38
40
|
class DsZero1Config(DsZeROConfig):
|
|
@@ -84,26 +86,6 @@ class DsConfig:
|
|
|
84
86
|
activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
|
|
85
87
|
|
|
86
88
|
|
|
87
|
-
@dataclass(kw_only=True)
|
|
88
|
-
class FsdpConfig:
|
|
89
|
-
"""
|
|
90
|
-
fsdp训练模式配置项
|
|
91
|
-
Args:
|
|
92
|
-
transformer_layer_cls (`Set[Type[nn.Module]]`, *optional*, default is None):
|
|
93
|
-
提供transformer层的类
|
|
94
|
-
wrap_policy_num_params (`int`, *optional*, default is -1):
|
|
95
|
-
size_based_auto_wrap_policy的min_num_params参数,-1不生效该策略
|
|
96
|
-
cpu_offload (`bool`, *optional*, default is False):
|
|
97
|
-
是否使用cpu卸载
|
|
98
|
-
offload_params (`bool`, default is False):
|
|
99
|
-
是否卸载参数,在cpu_offload为True时生效
|
|
100
|
-
"""
|
|
101
|
-
transformer_layer_cls: Optional[Set[Type[nn.Module]]] = None
|
|
102
|
-
wrap_policy_num_params: int = -1
|
|
103
|
-
cpu_offload: bool = False
|
|
104
|
-
offload_params: bool = False
|
|
105
|
-
|
|
106
|
-
|
|
107
89
|
@dataclass(kw_only=True)
|
|
108
90
|
class DataLoaderConfig:
|
|
109
91
|
"""
|
|
@@ -157,6 +139,7 @@ class GRPOConfig:
|
|
|
157
139
|
clip_eps: float = 0.2
|
|
158
140
|
kl_weight: float = 0.01
|
|
159
141
|
group_size: int = 12
|
|
142
|
+
mixup_alpha: float = 1.0
|
|
160
143
|
gen_max_new_tokens: Optional[int] = None
|
|
161
144
|
gen_temperature: Optional[float] = None
|
|
162
145
|
gen_k: Optional[int] = None
|
|
@@ -210,8 +193,6 @@ class TrainConfig:
|
|
|
210
193
|
每隔多少个batch进行模型eval
|
|
211
194
|
lr_config (`LrConfig`):
|
|
212
195
|
lr配置项
|
|
213
|
-
fsdp_config: (`FsdpConfig`):
|
|
214
|
-
fsdp训练模式配置项
|
|
215
196
|
data_loader_config: (`DataLoaderConfig`):
|
|
216
197
|
data loader配置项
|
|
217
198
|
kd_config: (`KDConfig`, *Optional*, default is None):
|
|
@@ -231,7 +212,6 @@ class TrainConfig:
|
|
|
231
212
|
lr_config: LrConfig = field(default_factory=LrConfig)
|
|
232
213
|
|
|
233
214
|
ds_config: DsConfig = field(default_factory=DsConfig)
|
|
234
|
-
fsdp_config: FsdpConfig = field(default_factory=FsdpConfig)
|
|
235
215
|
|
|
236
216
|
kd_config: Optional[KDConfig] = None
|
|
237
217
|
dpo_config: Optional[DPOConfig] = None
|
llm_trainer/trainer.py
CHANGED
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from contextlib import nullcontext
|
|
3
|
-
from typing import Optional, Tuple, List, Dict, Any
|
|
3
|
+
from typing import Optional, Tuple, List, Dict, Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from torch import nn
|
|
7
6
|
import torch.distributed as dist
|
|
8
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
9
7
|
from torch.utils.data import Dataset
|
|
10
8
|
from llm_model import LlmModel, VlmModel
|
|
11
9
|
|
|
12
10
|
from .parallel_ds import DsParallel
|
|
13
|
-
from .parallel_fsdp import FsdpParallel
|
|
14
11
|
from .tools import TrainerTools
|
|
15
12
|
from .loss import LMLoss, KDLoss
|
|
16
13
|
from .dataset import TextDataset
|
|
14
|
+
from .eval import submit_gen_task
|
|
15
|
+
from .partition_utils import unwrap_model_for_generation
|
|
17
16
|
|
|
18
17
|
from .train_configs import (
|
|
19
18
|
TrainConfig,
|
|
@@ -31,10 +30,10 @@ from .scheduler import (
|
|
|
31
30
|
from .checkpoint import (
|
|
32
31
|
load_checkpoint,
|
|
33
32
|
save_checkpoint,
|
|
34
|
-
copy_model_params,
|
|
35
33
|
load_steps,
|
|
36
34
|
save_steps,
|
|
37
35
|
)
|
|
36
|
+
|
|
38
37
|
from .utils import (
|
|
39
38
|
set_seed,
|
|
40
39
|
pretrain_collate_fn,
|
|
@@ -45,8 +44,6 @@ from .log import(
|
|
|
45
44
|
get_log_dir
|
|
46
45
|
)
|
|
47
46
|
|
|
48
|
-
from .eval import submit_gen_task
|
|
49
|
-
|
|
50
47
|
class Trainer:
|
|
51
48
|
def __init__(
|
|
52
49
|
self,
|
|
@@ -78,7 +75,6 @@ class Trainer:
|
|
|
78
75
|
|
|
79
76
|
self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr, parallel_kwargs, use_ds_optim)
|
|
80
77
|
self.lr_scheduler = self._init_lr_scheduler(initial_lr)
|
|
81
|
-
self.eval_model: Optional[nn.Module] = self._init_eval_model()
|
|
82
78
|
|
|
83
79
|
self.criterion, self.kd_loss = self._init_loss()
|
|
84
80
|
|
|
@@ -86,9 +82,7 @@ class Trainer:
|
|
|
86
82
|
device_type=TrainerTools().parallel.device_type,
|
|
87
83
|
dtype=TrainerTools().dtype,
|
|
88
84
|
enabled=True,
|
|
89
|
-
|
|
90
|
-
# https://www.zhihu.com/question/642793891
|
|
91
|
-
cache_enabled=False if isinstance(self.train_model, FSDP) else None
|
|
85
|
+
cache_enabled=None
|
|
92
86
|
) if TrainerTools().use_amp else nullcontext()
|
|
93
87
|
|
|
94
88
|
load_checkpoint(
|
|
@@ -176,12 +170,6 @@ class Trainer:
|
|
|
176
170
|
|
|
177
171
|
return model, optim
|
|
178
172
|
|
|
179
|
-
def _init_eval_model(self) -> Optional[nn.Module]:
|
|
180
|
-
if TrainerTools().parallel.is_main_process:
|
|
181
|
-
return self._new_model(self.train_config).to(device='cpu', dtype=TrainerTools().dtype)
|
|
182
|
-
|
|
183
|
-
return None
|
|
184
|
-
|
|
185
173
|
def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
|
|
186
174
|
if self.train_config.lr_config.enable_lr_scheduler:
|
|
187
175
|
min_lr = self.train_config.lr_config.min_lr
|
|
@@ -313,13 +301,6 @@ class Trainer:
|
|
|
313
301
|
activation_checkpointing['number_checkpoints'] = activation_checkpointing_config.number_checkpoints
|
|
314
302
|
|
|
315
303
|
parallel_kwargs['activation_checkpointing'] = activation_checkpointing
|
|
316
|
-
elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
|
|
317
|
-
parallel_kwargs = {
|
|
318
|
-
'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
|
|
319
|
-
'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
|
|
320
|
-
'cpu_offload': self.train_config.fsdp_config.cpu_offload,
|
|
321
|
-
'offload_params': self.train_config.fsdp_config.offload_params
|
|
322
|
-
}
|
|
323
304
|
|
|
324
305
|
dataloader_args = self.train_config.data_loader_config
|
|
325
306
|
data_loader_kwargs = {
|
|
@@ -441,54 +422,35 @@ class Trainer:
|
|
|
441
422
|
|
|
442
423
|
raise e
|
|
443
424
|
|
|
444
|
-
def
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
425
|
+
def _eval(self, tag: str):
|
|
426
|
+
with unwrap_model_for_generation(self.train_model) as generate_model:
|
|
427
|
+
if TrainerTools().parallel.is_main_process:
|
|
428
|
+
generate_model.eval()
|
|
429
|
+
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
430
|
+
|
|
431
|
+
if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
|
|
432
|
+
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
433
|
+
else:
|
|
434
|
+
eval_pixel_values = None
|
|
435
|
+
|
|
436
|
+
submit_gen_task(
|
|
437
|
+
generate_model,
|
|
438
|
+
self.train_config.eval_config,
|
|
439
|
+
tag=tag,
|
|
440
|
+
prompt=eval_prompt,
|
|
441
|
+
pixel_values=eval_pixel_values,
|
|
442
|
+
max_position_embeddings=self.train_config.model_config.max_position_embeddings,
|
|
443
|
+
tokens_per_image=self.tokens_per_image
|
|
444
|
+
)
|
|
445
|
+
generate_model.train()
|
|
449
446
|
|
|
450
|
-
if TrainerTools().parallel.is_main_process:
|
|
451
|
-
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
452
|
-
if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
|
|
453
|
-
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
454
|
-
else:
|
|
455
|
-
eval_pixel_values = None
|
|
456
|
-
|
|
457
|
-
submit_gen_task(
|
|
458
|
-
self.eval_model,
|
|
459
|
-
self.train_config.eval_config,
|
|
460
|
-
tag=f'sign:batch/{tag}',
|
|
461
|
-
prompt=eval_prompt,
|
|
462
|
-
pixel_values=eval_pixel_values,
|
|
463
|
-
max_position_embeddings=self.train_config.model_config.max_position_embeddings,
|
|
464
|
-
tokens_per_image=self.tokens_per_image
|
|
465
|
-
)
|
|
466
447
|
TrainerTools().parallel.wait()
|
|
467
448
|
|
|
468
|
-
def
|
|
469
|
-
|
|
470
|
-
tag: str
|
|
471
|
-
):
|
|
472
|
-
copy_model_params(_from=self.train_model, _to=self.eval_model)
|
|
473
|
-
|
|
474
|
-
if TrainerTools().parallel.is_main_process:
|
|
475
|
-
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
476
|
-
if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
|
|
477
|
-
eval_pixel_values = self.pixel_values_provider([eval_image_tag])
|
|
478
|
-
else:
|
|
479
|
-
eval_pixel_values = None
|
|
480
|
-
|
|
481
|
-
submit_gen_task(
|
|
482
|
-
self.eval_model,
|
|
483
|
-
self.train_config.eval_config,
|
|
484
|
-
tag=f'sign:epoch/{tag}',
|
|
485
|
-
prompt=eval_prompt,
|
|
486
|
-
pixel_values=eval_pixel_values,
|
|
487
|
-
max_position_embeddings=self.train_config.model_config.max_position_embeddings,
|
|
488
|
-
tokens_per_image=self.tokens_per_image
|
|
489
|
-
)
|
|
449
|
+
def _on_batch_end(self, tag: str):
|
|
450
|
+
self._eval(f'sign:batch/{tag}')
|
|
490
451
|
|
|
491
|
-
|
|
452
|
+
def _on_epoch_end(self, tag: str):
|
|
453
|
+
self._eval(f'sign:epoch/{tag}')
|
|
492
454
|
|
|
493
455
|
def _on_file_start(
|
|
494
456
|
self,
|
|
@@ -574,7 +536,6 @@ class Trainer:
|
|
|
574
536
|
if need_update_grad:
|
|
575
537
|
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
576
538
|
|
|
577
|
-
# todo check all_reduce??
|
|
578
539
|
if TrainerTools().parallel.parallel_train:
|
|
579
540
|
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
580
541
|
|
llm_trainer/utils.py
CHANGED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=xTmmQSJ_jQDVSTT3km1p_8eRrc7yE_dEsi92z9OX5ec,3251
|
|
3
|
+
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=wMREatLt0I8Ajdm_sI2U8Zj-IN1L6txP9s_tH1oI3-s,11431
|
|
5
|
+
llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
|
|
6
|
+
llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
|
|
7
|
+
llm_trainer/generate_utils.py,sha256=2MoEGEpoTzx7khO3dPcC2akFLyjtbFFpdJtuB_QQ3OY,17708
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=qiC3KwxYPSB9UKqyk4eSRvORP3b6GM-2ozqI8u3QvI0,15568
|
|
9
|
+
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
|
+
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
|
+
llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
|
|
12
|
+
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
|
+
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
|
+
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
|
+
llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
|
|
16
|
+
llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
17
|
+
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
|
+
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
19
|
+
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
|
+
llm_trainer/train_configs.py,sha256=m57W71SI5VCCU9aJ_nJkB-3AJrSGiNXmV28rdpuYmLg,7332
|
|
21
|
+
llm_trainer/trainer.py,sha256=zTJVyY1cAjJdTkyXCOy2ZPVP18SOMLdWhD54Mz2JRe4,25314
|
|
22
|
+
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
+
project_llm_trainer-0.5.0.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.0.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.0.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.0.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.0.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.0.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.0.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.0.dist-info/METADATA,sha256=YDj-N4VL8O_AqNanwfU6Yt38J97p3RgtUSzmwl0Y-GM,195
|
|
31
|
+
project_llm_trainer-0.5.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.0.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.0.dist-info/RECORD,,
|
llm_trainer/dcp.py
DELETED
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import Optional, Dict, Any
|
|
3
|
-
from torch import nn
|
|
4
|
-
from torch.optim import Optimizer
|
|
5
|
-
import torch.distributed.checkpoint as dcp
|
|
6
|
-
from torch.distributed.checkpoint.stateful import Stateful
|
|
7
|
-
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
|
|
8
|
-
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
|
|
9
|
-
|
|
10
|
-
DEFAULT_CHECKPOINT_DIR = "checkpoint"
|
|
11
|
-
|
|
12
|
-
class AppState(Stateful):
|
|
13
|
-
def __init__(self, model: nn.Module, optimizer: Optimizer):
|
|
14
|
-
self.model = model
|
|
15
|
-
self.optimizer = optimizer
|
|
16
|
-
|
|
17
|
-
def state_dict(self) -> Dict[str, Any]:
|
|
18
|
-
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
|
|
19
|
-
return {
|
|
20
|
-
'model_state_dict': model_state_dict,
|
|
21
|
-
'optim_state_dict': optimizer_state_dict
|
|
22
|
-
}
|
|
23
|
-
|
|
24
|
-
def load_state_dict(self, state_dict: Dict[str, Any]):
|
|
25
|
-
set_state_dict(
|
|
26
|
-
model=self.model,
|
|
27
|
-
optimizers=self.optimizer,
|
|
28
|
-
model_state_dict=state_dict['model_state_dict'],
|
|
29
|
-
optim_state_dict=state_dict['optim_state_dict']
|
|
30
|
-
)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def save_dcp(
|
|
34
|
-
model: nn.Module,
|
|
35
|
-
optimizer: Optimizer,
|
|
36
|
-
suffix: Optional[str] = None
|
|
37
|
-
):
|
|
38
|
-
checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
|
|
39
|
-
if suffix:
|
|
40
|
-
checkpoint_id = f"{checkpoint_id}_{suffix}"
|
|
41
|
-
|
|
42
|
-
state_dict = {'app': AppState(model, optimizer)}
|
|
43
|
-
|
|
44
|
-
# fs_storage_writer = dcp.FileSystemWriter(checkpoint_id, overwrite=True)
|
|
45
|
-
# dcp.save(state_dict=state_dict, storage_writer=fs_storage_writer)
|
|
46
|
-
dcp.save(state_dict=state_dict, checkpoint_id=checkpoint_id)
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def load_dcp(
|
|
50
|
-
model: nn.Module,
|
|
51
|
-
optimizer: Optional[Optimizer] = None,
|
|
52
|
-
suffix: Optional[str] = None
|
|
53
|
-
):
|
|
54
|
-
checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
|
|
55
|
-
if suffix:
|
|
56
|
-
checkpoint_id = f"{checkpoint_id}_{suffix}"
|
|
57
|
-
|
|
58
|
-
if os.path.exists(checkpoint_id):
|
|
59
|
-
state_dict = {'app': AppState(model, optimizer)}
|
|
60
|
-
# AppState帮助加载到state_dict中, 然后加载到model中
|
|
61
|
-
dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
|
|
62
|
-
|
|
63
|
-
# if isinstance(model, FSDP):
|
|
64
|
-
# state_dict = {'app': AppState(model, optimizer)}
|
|
65
|
-
# # AppState帮助加载到state_dict中, 然后加载到model中
|
|
66
|
-
# dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
|
|
67
|
-
# else:
|
|
68
|
-
# state_dict = {"model_state_dict": model.state_dict()}
|
|
69
|
-
#
|
|
70
|
-
# if optimizer:
|
|
71
|
-
# state_dict.update({'optim_state_dict': optimizer.state_dict()})
|
|
72
|
-
#
|
|
73
|
-
# # since no progress group is initialized, DCP will disable any collectives.
|
|
74
|
-
# # 加载到state_dict中,然后通过model.load_state_dict加载到model中
|
|
75
|
-
# dcp.load(
|
|
76
|
-
# state_dict=state_dict,
|
|
77
|
-
# checkpoint_id=checkpoint_id,
|
|
78
|
-
# )
|
|
79
|
-
#
|
|
80
|
-
# model.load_state_dict(state_dict["model_state_dict"])
|
|
81
|
-
# if optimizer:
|
|
82
|
-
# optimizer.load_state_dict(state_dict["optim_state_dict"])
|
|
83
|
-
|
|
84
|
-
def convert_dcp_to_pth(pth_path: str):
|
|
85
|
-
dcp_path = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
|
|
86
|
-
if os.path.exists(dcp_path):
|
|
87
|
-
# convert dcp model to torch.save (assumes checkpoint was generated as above)
|
|
88
|
-
dcp_to_torch_save(dcp_path, pth_path)
|
|
89
|
-
|
|
90
|
-
def convert_pth_to_dcp(pth_path: str):
|
|
91
|
-
if os.path.exists(pth_path):
|
|
92
|
-
# converts the torch.save model back to DCP
|
|
93
|
-
torch_save_to_dcp(pth_path, os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR))
|
llm_trainer/fsdp_checkpoint.py
DELETED
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import Optional, Union, Tuple
|
|
3
|
-
import torch
|
|
4
|
-
from torch import nn
|
|
5
|
-
from torch.optim import Optimizer
|
|
6
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
7
|
-
import torch.distributed as dist
|
|
8
|
-
|
|
9
|
-
from .tools import TrainerTools
|
|
10
|
-
|
|
11
|
-
DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
|
|
12
|
-
|
|
13
|
-
def save_fsdp_checkpoint(
|
|
14
|
-
model: nn.Module,
|
|
15
|
-
optimizer: Optional[Optimizer] = None,
|
|
16
|
-
suffix: Optional[str] = None
|
|
17
|
-
):
|
|
18
|
-
# 未经过测试 参考:https://doc.hfai.high-flyer.cn/haiscale/haiscale_fsdp.html
|
|
19
|
-
# 是否使用rank0_only=True?
|
|
20
|
-
with FSDP.summon_full_params(
|
|
21
|
-
module=model,
|
|
22
|
-
rank0_only=True,
|
|
23
|
-
writeback=False,
|
|
24
|
-
offload_to_cpu=True
|
|
25
|
-
):
|
|
26
|
-
if TrainerTools().parallel.is_main_process:
|
|
27
|
-
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
28
|
-
if suffix:
|
|
29
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
30
|
-
|
|
31
|
-
ckpt = {'model_state_dict': model.state_dict()}
|
|
32
|
-
if optimizer:
|
|
33
|
-
ckpt.update({'optim_state_dict': optimizer.state_dict()})
|
|
34
|
-
|
|
35
|
-
torch.save(ckpt, checkpoint_name)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def load_fsdp_checkpoint(
|
|
39
|
-
model: nn.Module,
|
|
40
|
-
optimizer: Optional[Optimizer] = None,
|
|
41
|
-
device: Optional[Union[torch.device, str]] = None,
|
|
42
|
-
suffix: Optional[str] = None
|
|
43
|
-
):
|
|
44
|
-
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
45
|
-
if suffix:
|
|
46
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
47
|
-
|
|
48
|
-
with FSDP.summon_full_params(module=model):
|
|
49
|
-
state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
|
|
50
|
-
model.load_state_dict(state_dict['model_state_dict'])
|
|
51
|
-
|
|
52
|
-
if optimizer:
|
|
53
|
-
optimizer.load_state_dict(state_dict['optim_state_dict'])
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def _get_fsdp_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
|
|
57
|
-
"""
|
|
58
|
-
可以在任意rank上调用,然后只有rank0有值
|
|
59
|
-
"""
|
|
60
|
-
|
|
61
|
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
62
|
-
with FSDP.summon_full_params(model, writeback=False, offload_to_cpu=True):
|
|
63
|
-
if TrainerTools().parallel.is_main_process:
|
|
64
|
-
return {k: v.clone() for k, v in model.state_dict().items()}
|
|
65
|
-
|
|
66
|
-
return None
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def get_fsdp_model_params(model: nn.Module, only_rank0=False):
|
|
70
|
-
"""
|
|
71
|
-
从一个 FSDP 包装的模型中高效地提取完整的 FP32 state_dict。
|
|
72
|
-
这个函数会聚合所有分片的参数,并确保所有 rank 都收到一个完整的副本。
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
state_dict = _get_fsdp_full_state_dict_on_rank0(model)
|
|
76
|
-
|
|
77
|
-
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
78
|
-
# 我们需要将其广播给所有进程。
|
|
79
|
-
if not only_rank0 and TrainerTools().parallel.world_size > 1:
|
|
80
|
-
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
81
|
-
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
82
|
-
# 执行广播,这个操作是阻塞的,会同步所有进程
|
|
83
|
-
dist.broadcast_object_list(object_list, src=0)
|
|
84
|
-
# 所有进程从列表中获取广播后的 state_dict 副本
|
|
85
|
-
state_dict = object_list[0]
|
|
86
|
-
|
|
87
|
-
return state_dict
|
llm_trainer/parallel_fsdp.py
DELETED
|
@@ -1,121 +0,0 @@
|
|
|
1
|
-
from typing import Optional, Tuple
|
|
2
|
-
import functools
|
|
3
|
-
import torch
|
|
4
|
-
from torch import nn
|
|
5
|
-
from torch.distributed.fsdp import (
|
|
6
|
-
FullyShardedDataParallel as FSDP,
|
|
7
|
-
MixedPrecision,
|
|
8
|
-
ShardingStrategy,
|
|
9
|
-
BackwardPrefetch,
|
|
10
|
-
CPUOffload,
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
from torch.distributed.fsdp.wrap import (
|
|
14
|
-
size_based_auto_wrap_policy,
|
|
15
|
-
transformer_auto_wrap_policy,
|
|
16
|
-
always_wrap_policy,
|
|
17
|
-
enable_wrap,
|
|
18
|
-
wrap,
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
from .parallel import Parallel
|
|
22
|
-
|
|
23
|
-
class FsdpParallel(Parallel):
|
|
24
|
-
def __init__(self):
|
|
25
|
-
super().__init__()
|
|
26
|
-
|
|
27
|
-
def process(
|
|
28
|
-
self,
|
|
29
|
-
model: nn.Module,
|
|
30
|
-
optimizer: torch.optim.Optimizer,
|
|
31
|
-
kwargs: Optional[dict] = None,
|
|
32
|
-
save_instance: bool = True
|
|
33
|
-
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
|
34
|
-
"""
|
|
35
|
-
:param model:
|
|
36
|
-
:param optimizer:
|
|
37
|
-
:param kwargs:
|
|
38
|
-
"wrap_policy_num_params" int size_based_auto_wrap_policy的最小参数量
|
|
39
|
-
"cpu_offload" bool 是否使用cpu卸载
|
|
40
|
-
"offload_params" bool 是否卸载参数,在cpu_offload为True时生效
|
|
41
|
-
:param save_instance
|
|
42
|
-
:return:
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
model.to(self.device)
|
|
46
|
-
|
|
47
|
-
if self._use_compile:
|
|
48
|
-
model = torch.compile(model)
|
|
49
|
-
|
|
50
|
-
if self._use_parallel:
|
|
51
|
-
if 'transformer_layer_cls' in kwargs:
|
|
52
|
-
auto_wrap_policy = functools.partial(
|
|
53
|
-
transformer_auto_wrap_policy,
|
|
54
|
-
transformer_layer_cls=kwargs['transformer_layer_cls']
|
|
55
|
-
)
|
|
56
|
-
elif 'wrap_policy_num_params' in kwargs:
|
|
57
|
-
auto_wrap_policy = functools.partial(
|
|
58
|
-
size_based_auto_wrap_policy,
|
|
59
|
-
min_num_params=kwargs['wrap_policy_num_params']
|
|
60
|
-
)
|
|
61
|
-
else:
|
|
62
|
-
auto_wrap_policy = None
|
|
63
|
-
|
|
64
|
-
if 'cpu_offload' in kwargs:
|
|
65
|
-
offload_params = False
|
|
66
|
-
if 'offload_params' in kwargs:
|
|
67
|
-
offload_params = kwargs['offload_params']
|
|
68
|
-
|
|
69
|
-
# 选择配置 cpu_offload,以便在计算中不使用包装参数时将这些参数卸载到 CPU。
|
|
70
|
-
# 这可以进一步提高内存效率,但代价是主机和设备之间的数据传输开销。
|
|
71
|
-
cpu_offload = CPUOffload(offload_params=offload_params)
|
|
72
|
-
else:
|
|
73
|
-
cpu_offload = None
|
|
74
|
-
|
|
75
|
-
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
|
76
|
-
mixed_precision = MixedPrecision(
|
|
77
|
-
param_dtype=torch.bfloat16,
|
|
78
|
-
# Gradient communication precision.
|
|
79
|
-
reduce_dtype=torch.bfloat16,
|
|
80
|
-
# Buffer precision.
|
|
81
|
-
buffer_dtype=torch.bfloat16,
|
|
82
|
-
)
|
|
83
|
-
else:
|
|
84
|
-
mixed_precision = None
|
|
85
|
-
|
|
86
|
-
raw_model = model
|
|
87
|
-
|
|
88
|
-
# device_mesh = init_device_mesh("cuda", (self.world_size,))
|
|
89
|
-
# model = FSDP(
|
|
90
|
-
# model,
|
|
91
|
-
# auto_wrap_policy=auto_wrap_policy,
|
|
92
|
-
# mixed_precision=mixed_precision,
|
|
93
|
-
# cpu_offload=cpu_offload,
|
|
94
|
-
# device_id=torch.cuda.current_device(),
|
|
95
|
-
# device_mesh=device_mesh
|
|
96
|
-
# )
|
|
97
|
-
|
|
98
|
-
model = FSDP(
|
|
99
|
-
model,
|
|
100
|
-
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
101
|
-
auto_wrap_policy=auto_wrap_policy,
|
|
102
|
-
mixed_precision=mixed_precision,
|
|
103
|
-
cpu_offload=cpu_offload,
|
|
104
|
-
device_id=torch.cuda.current_device(),
|
|
105
|
-
process_group=None,
|
|
106
|
-
# use_orig_params=True,
|
|
107
|
-
# backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # bit faster async comms, bit higher memory
|
|
108
|
-
# limit_all_gathers=False,
|
|
109
|
-
# forward_prefetch=True,
|
|
110
|
-
)
|
|
111
|
-
else:
|
|
112
|
-
model = model
|
|
113
|
-
raw_model = model
|
|
114
|
-
|
|
115
|
-
if save_instance:
|
|
116
|
-
self.raw_model = raw_model
|
|
117
|
-
self.model = model
|
|
118
|
-
|
|
119
|
-
return model, optimizer
|
|
120
|
-
|
|
121
|
-
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=ItDzuXVikk-0gWSw-IS7SrODEdlJEb5nZs15dBFkPdk,5793
|
|
3
|
-
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
|
|
5
|
-
llm_trainer/dpo_trainer.py,sha256=djBhvI_ixTV1nLNg84tgCpfV--pu6IRiOhO28V-aANQ,11425
|
|
6
|
-
llm_trainer/ds_checkpoint.py,sha256=fprJlbSgtyKmmpytyMOZBs3pcjZA13SeWao0llnLpNQ,4962
|
|
7
|
-
llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
|
|
8
|
-
llm_trainer/fsdp_checkpoint.py,sha256=dAHIGHfuvTA6OC0jV9Ls-oD4ZR9CPGa31mjtoh-2dZE,3229
|
|
9
|
-
llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
|
|
10
|
-
llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
|
|
11
|
-
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
12
|
-
llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
|
|
13
|
-
llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
|
|
14
|
-
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
15
|
-
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
16
|
-
llm_trainer/parallel_fsdp.py,sha256=cQOdY8ou6m8OsR06PpFVn6GiyZlK9nefkcGyszUOIJk,4055
|
|
17
|
-
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
18
|
-
llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
19
|
-
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
20
|
-
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
21
|
-
llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
|
|
22
|
-
llm_trainer/train_configs.py,sha256=HKzH3nfMT1-SW4Htwa0KqYtMd6FAJcthR5IEo6di8us,8168
|
|
23
|
-
llm_trainer/trainer.py,sha256=j5fDqMzvU6SYwxHsv9wX0UVX4JXS-8eP1AkHgVxKf9U,27119
|
|
24
|
-
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
25
|
-
project_llm_trainer-0.4.14.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
26
|
-
project_llm_trainer-0.4.14.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
27
|
-
project_llm_trainer-0.4.14.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
28
|
-
project_llm_trainer-0.4.14.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
29
|
-
project_llm_trainer-0.4.14.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
30
|
-
project_llm_trainer-0.4.14.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
31
|
-
project_llm_trainer-0.4.14.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
32
|
-
project_llm_trainer-0.4.14.dist-info/METADATA,sha256=VMEWVv8pBqFUAhIAiH4_S4ECUHln31gchHLhTtUAM1o,196
|
|
33
|
-
project_llm_trainer-0.4.14.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
34
|
-
project_llm_trainer-0.4.14.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
35
|
-
project_llm_trainer-0.4.14.dist-info/RECORD,,
|
{project_llm_trainer-0.4.14.data → project_llm_trainer-0.5.0.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|